package awgserver
import (
"fmt"
"log/slog"
"strings"
"sync"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/quic-go/quic-go/http3"
"sourcecraft.dev/bigbes/shroud/internal/store"
)
// Config holds all settings for the AmneziaWG server.
type Config struct {
ListenPort int
TUNName string
Address string // server CIDR, e.g. "10.14.0.1/24"
MTU int
PrivateKey string // server interface key, base64
Domain string
CertCache string
ACMEHTTPPort int
// Obfuscation parameters.
Jc, Jmin, Jmax int
S1, S2, S3, S4 int
H1, H2, H3, H4 HeaderRange
}
// Server wraps an AmneziaWG device with MuxBind + HTTP/3 multiplexing.
type Server struct {
cfg Config
dev *device.Device
tunDev tun.Device
muxBind *MuxBind
h3srv *http3.Server
logger *slog.Logger
mu sync.Mutex
}
// New creates a new AWG server (not yet started).
func New(cfg Config, logger *slog.Logger) *Server {
return &Server{
cfg: cfg,
logger: logger,
}
}
// Start creates the TUN device, AWG device with MuxBind, and HTTP/3 server.
func (s *Server) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
// 1. Create TUN device.
tunDev, err := tun.CreateTUN(s.cfg.TUNName, s.cfg.MTU)
if err != nil {
return fmt.Errorf("create TUN %q: %w", s.cfg.TUNName, err)
}
s.tunDev = tunDev
s.logger.Info("TUN device created.", "name", s.cfg.TUNName, "mtu", s.cfg.MTU)
// 2. Create MuxBind.
quicCh := make(chan quicPacket, 1024)
s.muxBind = NewMuxBind(&s.cfg, quicCh)
// 3. Create AWG device.
logLevel := device.LogLevelError
devLogger := device.NewLogger(logLevel, "[awg] ")
s.dev = device.NewDevice(tunDev, s.muxBind, devLogger)
// 4. Configure via UAPI (interface only, peers added by SyncKeys).
ifaceConf, err := s.buildInterfaceConfig()
if err != nil {
s.dev.Close()
tunDev.Close()
return fmt.Errorf("build UAPI interface config: %w", err)
}
if err := s.dev.IpcSet(ifaceConf); err != nil {
s.dev.Close()
tunDev.Close()
return fmt.Errorf("IpcSet interface config: %w", err)
}
s.logger.Info("AWG device configured.")
// 5. Bring up — triggers MuxBind.Open(), starts listening.
s.dev.Up()
s.logger.Info("AWG device is up.", "port", s.cfg.ListenPort)
// 6. Set up UAPI socket for awg show/set.
go s.setupUAPI()
// 7. Get FilteredConn and start HTTP/3 server.
fc := s.muxBind.FilteredConn()
if s.cfg.Domain != "" {
h3srv, err := startH3Server(fc, &s.cfg, s.logger)
if err != nil {
s.logger.Warn("HTTP/3 server failed to start (non-fatal).", "err", err)
} else {
s.h3srv = h3srv
}
}
// 8. Configure TUN address.
if s.cfg.Address != "" {
if err := setTUNAddress(s.cfg.TUNName, s.cfg.Address, s.logger); err != nil {
s.logger.Warn("Failed to configure TUN address (non-fatal).", "err", err)
}
}
return nil
}
// SyncKeys rebuilds the AWG peer list from access keys that have AWG data.
func (s *Server) SyncKeys(keys []store.AccessKey) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.dev == nil {
return fmt.Errorf("AWG device not started")
}
conf := s.buildPeerConfig(keys)
if err := s.dev.IpcSet(conf); err != nil {
return fmt.Errorf("IpcSet peers: %w", err)
}
count := 0
for _, k := range keys {
if k.AWG != nil {
count++
}
}
s.logger.Info("AWG peers synced.", "count", count)
return nil
}
// Stop shuts down the AWG device, TUN, and HTTP/3 server.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.h3srv != nil {
s.h3srv.Close()
}
if s.dev != nil {
s.dev.Close()
}
if s.tunDev != nil {
s.tunDev.Close()
}
s.logger.Info("AWG server stopped.")
return nil
}
func (s *Server) buildInterfaceConfig() (string, error) {
privHex, err := base64ToHex(s.cfg.PrivateKey)
if err != nil {
return "", fmt.Errorf("PrivateKey: %w", err)
}
var b strings.Builder
fmt.Fprintf(&b, "private_key=%s\n", privHex)
fmt.Fprintf(&b, "listen_port=%d\n", s.cfg.ListenPort)
if s.cfg.Jc > 0 {
fmt.Fprintf(&b, "jc=%d\n", s.cfg.Jc)
fmt.Fprintf(&b, "jmin=%d\n", s.cfg.Jmin)
fmt.Fprintf(&b, "jmax=%d\n", s.cfg.Jmax)
}
fmt.Fprintf(&b, "s1=%d\n", s.cfg.S1)
fmt.Fprintf(&b, "s2=%d\n", s.cfg.S2)
fmt.Fprintf(&b, "s3=%d\n", s.cfg.S3)
fmt.Fprintf(&b, "s4=%d\n", s.cfg.S4)
fmt.Fprintf(&b, "h1=%s\n", s.cfg.H1)
fmt.Fprintf(&b, "h2=%s\n", s.cfg.H2)
fmt.Fprintf(&b, "h3=%s\n", s.cfg.H3)
fmt.Fprintf(&b, "h4=%s\n", s.cfg.H4)
return b.String(), nil
}
func (s *Server) buildPeerConfig(keys []store.AccessKey) string {
var b strings.Builder
b.WriteString("replace_peers=true\n")
for _, k := range keys {
if k.AWG == nil {
continue
}
pubHex, err := base64ToHex(k.AWG.PublicKey)
if err != nil {
s.logger.Warn("Skipping AWG peer: bad public key.", "keyID", k.ID, "err", err)
continue
}
fmt.Fprintf(&b, "public_key=%s\n", pubHex)
fmt.Fprintf(&b, "allowed_ip=%s\n", k.AWG.AllowedIP)
}
return b.String()
}
func (s *Server) setupUAPI() {
fileUAPI, err := ipc.UAPIOpen(s.cfg.TUNName)
if err != nil {
s.logger.Debug("UAPI open (non-fatal).", "err", err)
return
}
uapiListener, err := ipc.UAPIListen(s.cfg.TUNName, fileUAPI)
if err != nil {
s.logger.Debug("UAPI listen (non-fatal).", "err", err)
return
}
s.logger.Info("UAPI listening.", "interface", s.cfg.TUNName)
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go s.dev.IpcHandle(conn)
}
}