package awgserver import ( "fmt" "net" "net/netip" "sync" "github.com/amnezia-vpn/amneziawg-go/conn" ) // quicPacket holds a non-AWG packet to be delivered to the QUIC server. type quicPacket struct { data []byte addr *net.UDPAddr } // MuxBind implements conn.Bind. It owns a single *net.UDPConn on the listen port. // Its ReceiveFunc checks each incoming packet against AWG signatures: // - AWG packets are delivered to device.Device via the ReceiveFunc return. // - Non-AWG packets are sent to a channel that feeds FilteredConn → QUIC server. type MuxBind struct { mu sync.Mutex conn *net.UDPConn cfg *Config quicCh chan quicPacket closed chan struct{} } // NewMuxBind creates a MuxBind. Call Open() to start listening. func NewMuxBind(cfg *Config, quicCh chan quicPacket) *MuxBind { return &MuxBind{ cfg: cfg, quicCh: quicCh, closed: make(chan struct{}), } } // FilteredConn returns a net.PacketConn that receives only non-AWG packets. func (b *MuxBind) FilteredConn() *FilteredConn { b.mu.Lock() defer b.mu.Unlock() return newFilteredConn(b.quicCh, b.conn) } // Open creates the shared UDP socket and returns a ReceiveFunc that // demultiplexes AWG vs non-AWG traffic. func (b *MuxBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { b.mu.Lock() defer b.mu.Unlock() if b.conn != nil { return nil, 0, fmt.Errorf("bind already open") } laddr := &net.UDPAddr{Port: int(port)} c, err := net.ListenUDP("udp", laddr) if err != nil { return nil, 0, fmt.Errorf("listen UDP :%d: %w", port, err) } b.conn = c actualPort := uint16(c.LocalAddr().(*net.UDPAddr).Port) recv := func(packets [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { for { select { case <-b.closed: return 0, net.ErrClosed default: } buf := packets[0] n, addr, err := b.conn.ReadFromUDP(buf) if err != nil { return 0, err } pkt := buf[:n] if isAWGPacket(pkt, b.cfg) { sizes[0] = n eps[0] = &MuxEndpoint{addr: addr} return 1, nil } // Non-AWG packet → send to QUIC server via channel. pktCopy := make([]byte, n) copy(pktCopy, pkt) select { case b.quicCh <- quicPacket{data: pktCopy, addr: addr}: case <-b.closed: return 0, net.ErrClosed } } } return []conn.ReceiveFunc{recv}, actualPort, nil } // Close shuts down the bind. func (b *MuxBind) Close() error { b.mu.Lock() defer b.mu.Unlock() select { case <-b.closed: default: close(b.closed) } if b.conn != nil { err := b.conn.Close() b.conn = nil return err } return nil } // SetMark sets the socket mark. No-op on platforms that don't support it. func (b *MuxBind) SetMark(mark uint32) error { return nil } // Send writes AWG response packets through the shared socket. func (b *MuxBind) Send(bufs [][]byte, ep conn.Endpoint) error { b.mu.Lock() c := b.conn b.mu.Unlock() if c == nil { return net.ErrClosed } mep := ep.(*MuxEndpoint) for _, buf := range bufs { if _, err := c.WriteToUDP(buf, mep.addr); err != nil { return err } } return nil } // ParseEndpoint creates a MuxEndpoint from a string like "1.2.3.4:51820". func (b *MuxBind) ParseEndpoint(s string) (conn.Endpoint, error) { addr, err := net.ResolveUDPAddr("udp", s) if err != nil { return nil, err } return &MuxEndpoint{addr: addr}, nil } // BatchSize returns 1 (no batching). func (b *MuxBind) BatchSize() int { return 1 } // MuxEndpoint implements conn.Endpoint for the multiplexed socket. type MuxEndpoint struct { addr *net.UDPAddr } func (e *MuxEndpoint) ClearSrc() {} func (e *MuxEndpoint) SrcToString() string { return "" } func (e *MuxEndpoint) DstToString() string { return e.addr.String() } func (e *MuxEndpoint) DstToBytes() []byte { ap := e.addr.AddrPort() b, _ := ap.MarshalBinary() return b } func (e *MuxEndpoint) DstIP() netip.Addr { return e.addr.AddrPort().Addr() } func (e *MuxEndpoint) SrcIP() netip.Addr { return netip.Addr{} }