package awgserver
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
// quicPacket holds a non-AWG packet to be delivered to the QUIC server.
type quicPacket struct {
data []byte
addr *net.UDPAddr
}
// MuxBind implements conn.Bind. It owns a single *net.UDPConn on the listen port.
// Its ReceiveFunc checks each incoming packet against AWG signatures:
// - AWG packets are delivered to device.Device via the ReceiveFunc return.
// - Non-AWG packets are sent to a channel that feeds FilteredConn → QUIC server.
type MuxBind struct {
mu sync.Mutex
conn *net.UDPConn
cfg *Config
quicCh chan quicPacket
closed chan struct{}
}
// NewMuxBind creates a MuxBind. Call Open() to start listening.
func NewMuxBind(cfg *Config, quicCh chan quicPacket) *MuxBind {
return &MuxBind{
cfg: cfg,
quicCh: quicCh,
closed: make(chan struct{}),
}
}
// FilteredConn returns a net.PacketConn that receives only non-AWG packets.
func (b *MuxBind) FilteredConn() *FilteredConn {
b.mu.Lock()
defer b.mu.Unlock()
return newFilteredConn(b.quicCh, b.conn)
}
// Open creates the shared UDP socket and returns a ReceiveFunc that
// demultiplexes AWG vs non-AWG traffic.
func (b *MuxBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.conn != nil {
return nil, 0, fmt.Errorf("bind already open")
}
laddr := &net.UDPAddr{Port: int(port)}
c, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, 0, fmt.Errorf("listen UDP :%d: %w", port, err)
}
b.conn = c
actualPort := uint16(c.LocalAddr().(*net.UDPAddr).Port)
recv := func(packets [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
for {
select {
case <-b.closed:
return 0, net.ErrClosed
default:
}
buf := packets[0]
n, addr, err := b.conn.ReadFromUDP(buf)
if err != nil {
return 0, err
}
pkt := buf[:n]
if isAWGPacket(pkt, b.cfg) {
sizes[0] = n
eps[0] = &MuxEndpoint{addr: addr}
return 1, nil
}
// Non-AWG packet → send to QUIC server via channel.
pktCopy := make([]byte, n)
copy(pktCopy, pkt)
select {
case b.quicCh <- quicPacket{data: pktCopy, addr: addr}:
case <-b.closed:
return 0, net.ErrClosed
}
}
}
return []conn.ReceiveFunc{recv}, actualPort, nil
}
// Close shuts down the bind.
func (b *MuxBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
select {
case <-b.closed:
default:
close(b.closed)
}
if b.conn != nil {
err := b.conn.Close()
b.conn = nil
return err
}
return nil
}
// SetMark sets the socket mark. No-op on platforms that don't support it.
func (b *MuxBind) SetMark(mark uint32) error {
return nil
}
// Send writes AWG response packets through the shared socket.
func (b *MuxBind) Send(bufs [][]byte, ep conn.Endpoint) error {
b.mu.Lock()
c := b.conn
b.mu.Unlock()
if c == nil {
return net.ErrClosed
}
mep := ep.(*MuxEndpoint)
for _, buf := range bufs {
if _, err := c.WriteToUDP(buf, mep.addr); err != nil {
return err
}
}
return nil
}
// ParseEndpoint creates a MuxEndpoint from a string like "1.2.3.4:51820".
func (b *MuxBind) ParseEndpoint(s string) (conn.Endpoint, error) {
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
return &MuxEndpoint{addr: addr}, nil
}
// BatchSize returns 1 (no batching).
func (b *MuxBind) BatchSize() int {
return 1
}
// MuxEndpoint implements conn.Endpoint for the multiplexed socket.
type MuxEndpoint struct {
addr *net.UDPAddr
}
func (e *MuxEndpoint) ClearSrc() {}
func (e *MuxEndpoint) SrcToString() string { return "" }
func (e *MuxEndpoint) DstToString() string { return e.addr.String() }
func (e *MuxEndpoint) DstToBytes() []byte {
ap := e.addr.AddrPort()
b, _ := ap.MarshalBinary()
return b
}
func (e *MuxEndpoint) DstIP() netip.Addr {
return e.addr.AddrPort().Addr()
}
func (e *MuxEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}