~bigbes/shroud

ref: 5afb3bad1be8eb82352dde56e6836a0a8ea4ef7f shroud/internal/awgserver/muxbind.go -rw-r--r-- 3.9 KiB
5afb3bad — Eugene Blikh feat: add optional shadowsocks and outline smart dialer config 2 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package awgserver

import (
	"fmt"
	"net"
	"net/netip"
	"sync"

	"github.com/amnezia-vpn/amneziawg-go/conn"
)

// quicPacket holds a non-AWG packet to be delivered to the QUIC server.
type quicPacket struct {
	data []byte
	addr *net.UDPAddr
}

// MuxBind implements conn.Bind. It owns a single *net.UDPConn on the listen port.
// Its ReceiveFunc checks each incoming packet against AWG signatures:
//   - AWG packets are delivered to device.Device via the ReceiveFunc return.
//   - Non-AWG packets are sent to a channel that feeds FilteredConn → QUIC server.
type MuxBind struct {
	mu   sync.Mutex
	conn *net.UDPConn
	cfg  *Config

	quicCh chan quicPacket
	closed chan struct{}
}

// NewMuxBind creates a MuxBind. Call Open() to start listening.
func NewMuxBind(cfg *Config, quicCh chan quicPacket) *MuxBind {
	return &MuxBind{
		cfg:    cfg,
		quicCh: quicCh,
		closed: make(chan struct{}),
	}
}

// FilteredConn returns a net.PacketConn that receives only non-AWG packets.
func (b *MuxBind) FilteredConn() *FilteredConn {
	b.mu.Lock()
	defer b.mu.Unlock()
	return newFilteredConn(b.quicCh, b.conn)
}

// Open creates the shared UDP socket and returns a ReceiveFunc that
// demultiplexes AWG vs non-AWG traffic.
func (b *MuxBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
	b.mu.Lock()
	defer b.mu.Unlock()

	if b.conn != nil {
		return nil, 0, fmt.Errorf("bind already open")
	}

	laddr := &net.UDPAddr{Port: int(port)}
	c, err := net.ListenUDP("udp", laddr)
	if err != nil {
		return nil, 0, fmt.Errorf("listen UDP :%d: %w", port, err)
	}
	b.conn = c

	actualPort := uint16(c.LocalAddr().(*net.UDPAddr).Port)

	recv := func(packets [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
		for {
			select {
			case <-b.closed:
				return 0, net.ErrClosed
			default:
			}

			buf := packets[0]
			n, addr, err := b.conn.ReadFromUDP(buf)
			if err != nil {
				return 0, err
			}

			pkt := buf[:n]

			if isAWGPacket(pkt, b.cfg) {
				sizes[0] = n
				eps[0] = &MuxEndpoint{addr: addr}
				return 1, nil
			}

			// Non-AWG packet → send to QUIC server via channel.
			pktCopy := make([]byte, n)
			copy(pktCopy, pkt)
			select {
			case b.quicCh <- quicPacket{data: pktCopy, addr: addr}:
			case <-b.closed:
				return 0, net.ErrClosed
			}
		}
	}

	return []conn.ReceiveFunc{recv}, actualPort, nil
}

// Close shuts down the bind.
func (b *MuxBind) Close() error {
	b.mu.Lock()
	defer b.mu.Unlock()

	select {
	case <-b.closed:
	default:
		close(b.closed)
	}

	if b.conn != nil {
		err := b.conn.Close()
		b.conn = nil
		return err
	}
	return nil
}

// SetMark sets the socket mark. No-op on platforms that don't support it.
func (b *MuxBind) SetMark(mark uint32) error {
	return nil
}

// Send writes AWG response packets through the shared socket.
func (b *MuxBind) Send(bufs [][]byte, ep conn.Endpoint) error {
	b.mu.Lock()
	c := b.conn
	b.mu.Unlock()

	if c == nil {
		return net.ErrClosed
	}

	mep := ep.(*MuxEndpoint)
	for _, buf := range bufs {
		if _, err := c.WriteToUDP(buf, mep.addr); err != nil {
			return err
		}
	}
	return nil
}

// ParseEndpoint creates a MuxEndpoint from a string like "1.2.3.4:51820".
func (b *MuxBind) ParseEndpoint(s string) (conn.Endpoint, error) {
	addr, err := net.ResolveUDPAddr("udp", s)
	if err != nil {
		return nil, err
	}
	return &MuxEndpoint{addr: addr}, nil
}

// BatchSize returns 1 (no batching).
func (b *MuxBind) BatchSize() int {
	return 1
}

// MuxEndpoint implements conn.Endpoint for the multiplexed socket.
type MuxEndpoint struct {
	addr *net.UDPAddr
}

func (e *MuxEndpoint) ClearSrc()           {}
func (e *MuxEndpoint) SrcToString() string { return "" }
func (e *MuxEndpoint) DstToString() string { return e.addr.String() }

func (e *MuxEndpoint) DstToBytes() []byte {
	ap := e.addr.AddrPort()
	b, _ := ap.MarshalBinary()
	return b
}

func (e *MuxEndpoint) DstIP() netip.Addr {
	return e.addr.AddrPort().Addr()
}

func (e *MuxEndpoint) SrcIP() netip.Addr {
	return netip.Addr{}
}