M cmd/shroud/main.go => cmd/shroud/main.go +236 -8
@@ 25,12 25,16 @@ import (
"golang.org/x/term"
"gopkg.in/yaml.v3"
+ "encoding/base64"
+
"sourcecraft.dev/bigbes/shroud/internal/api"
"sourcecraft.dev/bigbes/shroud/internal/awgserver"
"sourcecraft.dev/bigbes/shroud/internal/config"
"sourcecraft.dev/bigbes/shroud/internal/metrics"
+ "sourcecraft.dev/bigbes/shroud/internal/mmdb"
"sourcecraft.dev/bigbes/shroud/internal/ssserver"
"sourcecraft.dev/bigbes/shroud/internal/store"
+ "sourcecraft.dev/bigbes/shroud/internal/vless"
)
var version = "dev"
@@ 49,7 53,7 @@ func newRootCmd() *cobra.Command {
root := &cobra.Command{
Use: "shroud",
- Short: "Shadowsocks + AmneziaWG VPN server",
+ Short: "Shadowsocks + AmneziaWG + VLESS VPN server",
Version: version,
RunE: func(cmd *cobra.Command, args []string) error {
return runServe(configFile, verbose)
@@ 66,6 70,7 @@ func newRootCmd() *cobra.Command {
newCompletionCmd(),
newKeyCmd(&configFile),
newServerCmd(&configFile),
+ newVLESSCmd(&configFile),
)
return root
@@ 160,8 165,12 @@ func newKeyListCmd(configFile *string) *cobra.Command {
if k.AWG != nil {
awgInfo = fmt.Sprintf(" awg=%s", k.AWG.AllowedIP)
}
- fmt.Printf("%-6s %-20s port=%-6d cipher=%s limit=%s%s\n",
- k.ID, k.Name, k.Port, k.Method, limit, awgInfo)
+ vlessInfo := ""
+ if k.VLESS != nil {
+ vlessInfo = fmt.Sprintf(" vless=%s", k.VLESS.UUID)
+ }
+ fmt.Printf("%-6s %-20s port=%-6d cipher=%s limit=%s%s%s\n",
+ k.ID, k.Name, k.Port, k.Method, limit, awgInfo, vlessInfo)
}
return nil
},
@@ 248,6 257,13 @@ func newKeyAddCmd(configFile *string) *cobra.Command {
}
}
+ // Generate VLESS credentials if VLESS is enabled.
+ if cfg.VLESS.Enabled {
+ ak.VLESS = &store.VLESSKeyData{
+ UUID: uuid.New().String(),
+ }
+ }
+
if err := st.CreateKey(ak); err != nil {
return err
}
@@ 256,6 272,9 @@ func newKeyAddCmd(configFile *string) *cobra.Command {
if ak.AWG != nil {
fmt.Printf(" awg: %s (pubkey=%s)\n", ak.AWG.AllowedIP, ak.AWG.PublicKey)
}
+ if ak.VLESS != nil {
+ fmt.Printf(" vless: %s\n", ak.VLESS.UUID)
+ }
return nil
},
}
@@ 490,15 509,27 @@ func runServe(configFile string, verbose bool) error {
logger.Info("Generated API secret.", "secret", apiSecret)
}
+ // Set up MMDB files (download if needed, start auto-update).
+ mmdbMgr, err := mmdb.NewManager(cfg.Shadowsocks.MMDBConfig(), logger)
+ if err != nil {
+ return fmt.Errorf("setting up MMDB manager: %w", err)
+ }
+ mmdbPaths, err := mmdbMgr.Resolve()
+ if err != nil {
+ return fmt.Errorf("resolving MMDB files: %w", err)
+ }
+ mmdbMgr.StartAutoUpdate()
+ defer mmdbMgr.Stop()
+
// Set up IP info for geo-metrics.
var ip2info ipinfo.IPInfoMap
- if cfg.Shadowsocks.IPCountryDB != "" || cfg.Shadowsocks.IPASNDB != "" {
- mmdb, err := ipinfo.NewMMDBIPInfoMap(cfg.Shadowsocks.IPCountryDB, cfg.Shadowsocks.IPASNDB)
+ if mmdbPaths.CountryPath != "" || mmdbPaths.ASNPath != "" {
+ mmdbMap, err := ipinfo.NewMMDBIPInfoMap(mmdbPaths.CountryPath, mmdbPaths.ASNPath)
if err != nil {
return fmt.Errorf("loading IP info databases: %w", err)
}
- defer mmdb.Close()
- ip2info = mmdb
+ defer mmdbMap.Close()
+ ip2info = mmdbMap
}
// Set up metrics.
@@ 649,6 680,69 @@ func runServe(configFile string, verbose bool) error {
logger.Info("AmneziaWG server started.", "port", cfg.AmneziaWG.ListenPort)
}
+ // Set up VLESS+REALITY server (optional).
+ var vlessServer *vless.Server
+ if cfg.VLESS.Enabled {
+ srv := st.GetServer()
+ if srv.VLESSPrivateKey == "" {
+ priv, pub, err := vless.GenerateX25519Keypair()
+ if err != nil {
+ return fmt.Errorf("generating REALITY keypair: %w", err)
+ }
+ shortID, err := vless.GenerateShortID()
+ if err != nil {
+ return fmt.Errorf("generating REALITY short ID: %w", err)
+ }
+ if err := st.UpdateServer(func(s *store.ServerState) {
+ s.VLESSPrivateKey = priv
+ s.VLESSPublicKey = pub
+ s.VLESSShortID = shortID
+ }); err != nil {
+ return fmt.Errorf("persisting REALITY keys: %w", err)
+ }
+ logger.Info("Generated REALITY keypair.")
+ srv = st.GetServer()
+ }
+
+ privKey, err := base64.StdEncoding.DecodeString(srv.VLESSPrivateKey)
+ if err != nil {
+ return fmt.Errorf("decoding REALITY private key: %w", err)
+ }
+ var shortID [8]byte
+ shortIDBytes, err := hex.DecodeString(srv.VLESSShortID)
+ if err != nil {
+ return fmt.Errorf("decoding REALITY short ID: %w", err)
+ }
+ copy(shortID[:], shortIDBytes)
+
+ pubKey, err := base64.StdEncoding.DecodeString(srv.VLESSPublicKey)
+ if err != nil {
+ return fmt.Errorf("decoding REALITY public key: %w", err)
+ }
+
+ vlessCfg := vless.Config{
+ ListenAddr: cfg.VLESS.ListenAddr,
+ PrivateKey: privKey,
+ PublicKey: pubKey,
+ ShortID: shortID,
+ ServerNames: cfg.VLESS.ServerNames,
+ Dest: cfg.VLESS.Dest,
+ Show: cfg.VLESS.Show,
+ }
+
+ vlessServer = vless.New(vlessCfg, logger)
+ if err := vlessServer.Start(); err != nil {
+ return fmt.Errorf("starting VLESS server: %w", err)
+ }
+
+ if keys := st.ListKeys(); len(keys) > 0 {
+ if err := vlessServer.SyncKeys(keys); err != nil {
+ return fmt.Errorf("syncing VLESS keys: %w", err)
+ }
+ }
+ logger.Info("VLESS+REALITY server started.", "addr", cfg.VLESS.ListenAddr)
+ }
+
// Start Prometheus metrics + /healthz endpoint.
metricsHandler := promhttp.HandlerFor(registry, promhttp.HandlerOpts{})
metricsMux := http.NewServeMux()
@@ 675,7 769,7 @@ func runServe(configFile string, verbose bool) error {
if cfg.AmneziaWG.Enabled {
awgCfgPtr = &cfg.AmneziaWG
}
- handler := api.NewHandler(st, ss, awg, awgCfgPtr, transferTracker, version, logger)
+ handler := api.NewHandler(st, ss, awg, awgCfgPtr, vlessServer, transferTracker, version, logger)
router := api.NewRouter(apiSecret, handler)
apiServer := &http.Server{Addr: cfg.API.ListenAddr, Handler: router}
go func() {
@@ 705,6 799,9 @@ func runServe(configFile string, verbose bool) error {
<-sigCh
logger.Info("Shutting down...")
+ if vlessServer != nil {
+ vlessServer.Stop()
+ }
if awg != nil {
awg.Stop()
}
@@ 717,6 814,137 @@ func runServe(configFile string, verbose bool) error {
return nil
}
+// --- vless commands ---
+
+func newVLESSCmd(configFile *string) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "vless",
+ Short: "VLESS+REALITY management commands",
+ }
+ cmd.AddCommand(
+ newVLESSKeygenCmd(),
+ newVLESSInfoCmd(configFile),
+ newVLESSShareCmd(configFile),
+ )
+ return cmd
+}
+
+func newVLESSKeygenCmd() *cobra.Command {
+ return &cobra.Command{
+ Use: "keygen",
+ Short: "Generate a new REALITY x25519 keypair and short ID",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ priv, pub, err := vless.GenerateX25519Keypair()
+ if err != nil {
+ return err
+ }
+ shortID, err := vless.GenerateShortID()
+ if err != nil {
+ return err
+ }
+ fmt.Printf("Private Key: %s\n", priv)
+ fmt.Printf("Public Key: %s\n", pub)
+ fmt.Printf("Short ID: %s\n", shortID)
+ return nil
+ },
+ }
+}
+
+func newVLESSInfoCmd(configFile *string) *cobra.Command {
+ return &cobra.Command{
+ Use: "info",
+ Short: "Show VLESS+REALITY server info",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ cfg, err := config.Load(*configFile)
+ if err != nil {
+ return fmt.Errorf("loading config: %w", err)
+ }
+ if !cfg.VLESS.Enabled {
+ fmt.Println("VLESS is not enabled.")
+ return nil
+ }
+ st, err := openStore(*configFile)
+ if err != nil {
+ return err
+ }
+ srv := st.GetServer()
+ fmt.Printf("Listen Addr: %s\n", cfg.VLESS.ListenAddr)
+ fmt.Printf("Server Names: %v\n", cfg.VLESS.ServerNames)
+ fmt.Printf("Dest: %s\n", cfg.VLESS.Dest)
+ if srv.VLESSPublicKey != "" {
+ fmt.Printf("Public Key: %s\n", srv.VLESSPublicKey)
+ fmt.Printf("Short ID: %s\n", srv.VLESSShortID)
+ } else {
+ fmt.Println("REALITY keypair not yet generated (start the server first).")
+ }
+ return nil
+ },
+ }
+}
+
+func newVLESSShareCmd(configFile *string) *cobra.Command {
+ return &cobra.Command{
+ Use: "share <key-id>",
+ Short: "Generate a VLESS share link for a key",
+ Args: cobra.ExactArgs(1),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ cfg, err := config.Load(*configFile)
+ if err != nil {
+ return fmt.Errorf("loading config: %w", err)
+ }
+ if !cfg.VLESS.Enabled {
+ return fmt.Errorf("VLESS is not enabled")
+ }
+ st, err := openStore(*configFile)
+ if err != nil {
+ return err
+ }
+ srv := st.GetServer()
+ if srv.VLESSPublicKey == "" {
+ return fmt.Errorf("REALITY keypair not yet generated (start the server first)")
+ }
+
+ key, ok := st.GetKey(args[0])
+ if !ok {
+ return fmt.Errorf("key %q not found", args[0])
+ }
+ if key.VLESS == nil {
+ return fmt.Errorf("key %q has no VLESS credentials", args[0])
+ }
+
+ hostname := srv.Hostname
+ if hostname == "" {
+ hostname = "localhost"
+ }
+
+ // Extract port from listen address.
+ _, portStr, err := net.SplitHostPort(cfg.VLESS.ListenAddr)
+ if err != nil {
+ portStr = "443"
+ }
+
+ sni := ""
+ if len(cfg.VLESS.ServerNames) > 0 {
+ sni = cfg.VLESS.ServerNames[0]
+ }
+
+ // Base64url encode the public key (no padding).
+ pubKeyB64URL := base64.RawURLEncoding.EncodeToString([]byte(srv.VLESSPublicKey))
+
+ name := key.Name
+ if name == "" {
+ name = key.ID
+ }
+
+ fmt.Printf("vless://%s@%s:%s?encryption=none&security=reality&sni=%s&fp=chrome&pbk=%s&sid=%s&type=tcp#%s\n",
+ key.VLESS.UUID, hostname, portStr, sni, pubKeyB64URL, srv.VLESSShortID, name)
+ return nil
+ },
+ }
+}
+
func pickRandomPort() (int, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
M config.example.yaml => config.example.yaml +14 -2
@@ 25,8 25,12 @@ shadowsocks:
default_cipher: chacha20-ietf-poly1305
nat_timeout: 5m
replay_history: 10000
- ip_country_db: "" # optional: MaxMind GeoLite2-Country.mmdb for per-country metrics
- ip_asn_db: "" # optional: MaxMind GeoLite2-ASN.mmdb for per-ASN metrics
+ # GeoIP databases: URL or local path. Defaults to GeoLite2 from github.com/P3TERX/GeoLite.mmdb.
+ # Set to "none" to disable. URLs are downloaded and cached automatically.
+ ip_country_db: "" # default: https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb
+ ip_asn_db: "" # default: https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-ASN.mmdb
+ ip_db_cache_dir: "" # default: /var/lib/shroud/mmdb
+ ip_db_auto_update: true # auto-update MMDB files daily from URLs
amneziawg:
enabled: true
@@ 52,6 56,14 @@ amneziawg:
h3: "250000-300000"
h4: "350000-400000"
+vless:
+ enabled: false
+ listen_addr: ":443" # REALITY listener address
+ server_names: # SNIs accepted by REALITY handshake
+ - www.microsoft.com
+ dest: "www.microsoft.com:443" # decoy forward target for unauthenticated probes
+ show: false # debug REALITY handshakes
+
# ACME (Let's Encrypt) certificate settings.
# Used by AmneziaWG HTTP/3 cover server for DPI resistance.
acme:
M go.mod => go.mod +2 -0
@@ 46,6 46,7 @@ require (
github.com/opencontainers/selinux v1.12.0 // indirect
github.com/oschwald/geoip2-golang v1.11.0 // indirect
github.com/oschwald/maxminddb-golang v1.13.1 // indirect
+ github.com/pires/go-proxyproto v0.7.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/prometheus-community/go-runit v0.1.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
@@ 57,6 58,7 @@ require (
github.com/spf13/pflag v1.0.10 // indirect
github.com/tevino/abool v1.2.0 // indirect
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
+ github.com/xtls/reality v0.0.0-20240712055506-48f0b2d5ed6d // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
M go.sum => go.sum +4 -0
@@ 78,6 78,8 @@ github.com/oschwald/geoip2-golang v1.11.0 h1:hNENhCn1Uyzhf9PTmquXENiWS6AlxAEnBII
github.com/oschwald/geoip2-golang v1.11.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo=
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
+github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs=
+github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
@@ 129,6 131,8 @@ github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
+github.com/xtls/reality v0.0.0-20240712055506-48f0b2d5ed6d h1:+B97uD9uHLgAAulhigmys4BVwZZypzK7gPN3WtpgRJg=
+github.com/xtls/reality v0.0.0-20240712055506-48f0b2d5ed6d/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
M internal/api/handlers.go => internal/api/handlers.go +14 -0
@@ 6,6 6,8 @@ import (
"strings"
"text/template"
+ "github.com/google/uuid"
+
"sourcecraft.dev/bigbes/shroud/internal/awgserver"
"sourcecraft.dev/bigbes/shroud/internal/config"
"sourcecraft.dev/bigbes/shroud/internal/store"
@@ 142,6 144,13 @@ func (h *Handler) doCreateKey(w http.ResponseWriter, req CreateAccessKeyRequest)
ak.AWG = awgData
}
+ // Generate VLESS credentials if VLESS is enabled.
+ if h.vlessServer != nil {
+ ak.VLESS = &store.VLESSKeyData{
+ UUID: uuid.New().String(),
+ }
+ }
+
if err := h.store.CreateKey(ak); err != nil {
writeError(w, http.StatusConflict, err.Error())
return
@@ 307,6 316,11 @@ func (h *Handler) keyToResponse(k store.AccessKey) AccessKeyResponse {
if k.DataLimit != nil {
resp.DataLimit = &DataLimitJSON{Bytes: k.DataLimit.Bytes}
}
+ if k.VLESS != nil && h.vlessServer != nil {
+ resp.VLESS = &VLESSKeyResponse{
+ UUID: k.VLESS.UUID,
+ }
+ }
if k.AWG != nil && h.awgConfig != nil {
srv := h.store.GetServer()
hostname := h.awgHostname(srv.Hostname)
M internal/api/models.go => internal/api/models.go +6 -1
@@ 8,7 8,12 @@ type AccessKeyResponse struct {
Method string `json:"method"`
AccessURL string `json:"accessUrl"`
DataLimit *DataLimitJSON `json:"dataLimit,omitempty"`
- AWG *AWGKeyResponse `json:"awg,omitempty"`
+ AWG *AWGKeyResponse `json:"awg,omitempty"`
+ VLESS *VLESSKeyResponse `json:"vless,omitempty"`
+}
+
+type VLESSKeyResponse struct {
+ UUID string `json:"uuid"`
}
type AWGKeyResponse struct {
M internal/api/router.go => internal/api/router.go +23 -15
@@ 12,27 12,30 @@ import (
"sourcecraft.dev/bigbes/shroud/internal/metrics"
"sourcecraft.dev/bigbes/shroud/internal/ssserver"
"sourcecraft.dev/bigbes/shroud/internal/store"
+ "sourcecraft.dev/bigbes/shroud/internal/vless"
)
type Handler struct {
- store store.Store
- ssServer *ssserver.Server
- awgServer *awgserver.Server
- awgConfig *config.AmneziaWGConfig
- tracker *metrics.TransferTracker
- version string
- logger *slog.Logger
+ store store.Store
+ ssServer *ssserver.Server
+ awgServer *awgserver.Server
+ awgConfig *config.AmneziaWGConfig
+ vlessServer *vless.Server
+ tracker *metrics.TransferTracker
+ version string
+ logger *slog.Logger
}
-func NewHandler(s store.Store, ss *ssserver.Server, awg *awgserver.Server, awgCfg *config.AmneziaWGConfig, tracker *metrics.TransferTracker, version string, logger *slog.Logger) *Handler {
+func NewHandler(s store.Store, ss *ssserver.Server, awg *awgserver.Server, awgCfg *config.AmneziaWGConfig, vl *vless.Server, tracker *metrics.TransferTracker, version string, logger *slog.Logger) *Handler {
return &Handler{
- store: s,
- ssServer: ss,
- awgServer: awg,
- awgConfig: awgCfg,
- tracker: tracker,
- version: version,
- logger: logger,
+ store: s,
+ ssServer: ss,
+ awgServer: awg,
+ awgConfig: awgCfg,
+ vlessServer: vl,
+ tracker: tracker,
+ version: version,
+ logger: logger,
}
}
@@ 101,6 104,11 @@ func (h *Handler) syncKeys() error {
return err
}
}
+ if h.vlessServer != nil {
+ if err := h.vlessServer.SyncKeys(keys); err != nil {
+ return err
+ }
+ }
return nil
}
M internal/config/config.go => internal/config/config.go +67 -2
@@ 5,6 5,8 @@ import (
"os"
"path/filepath"
+ "sourcecraft.dev/bigbes/shroud/internal/mmdb"
+
"gopkg.in/yaml.v3"
)
@@ 19,6 21,8 @@ type Config struct {
Shadowsocks ShadowsocksConfig `yaml:"shadowsocks"`
// AmneziaWG VPN settings.
AmneziaWG AmneziaWGConfig `yaml:"amneziawg"`
+ // VLESS over REALITY transport settings.
+ VLESS VLESSConfig `yaml:"vless"`
// ACME (Let's Encrypt) certificate settings.
ACME ACMEConfig `yaml:"acme"`
// Path to persistent state file.
@@ 100,6 104,19 @@ type AmneziaWGConfig struct {
H4 string `yaml:"h4"`
}
+type VLESSConfig struct {
+ // Enable VLESS over REALITY transport.
+ Enabled bool `yaml:"enabled"`
+ // Listen address for the REALITY listener (e.g., ":443").
+ ListenAddr string `yaml:"listen_addr"`
+ // Accepted SNIs for REALITY handshake (e.g., ["www.microsoft.com"]).
+ ServerNames []string `yaml:"server_names"`
+ // Decoy forward target for unauthenticated probes (e.g., "www.microsoft.com:443").
+ Dest string `yaml:"dest"`
+ // Enable debug logging for REALITY handshakes.
+ Show bool `yaml:"show"`
+}
+
type ACMEConfig struct {
// Directory for cached TLS certificates.
CertCache string `yaml:"cert_cache"`
@@ 118,10 135,19 @@ type ShadowsocksConfig struct {
NATTimeout string `yaml:"nat_timeout"`
// Replay protection history size (0 = disabled).
ReplayHistory int `yaml:"replay_history"`
- // Path to IP-to-country MaxMind MMDB file.
+ // URL or local path to IP-to-country MaxMind MMDB file.
+ // Defaults to GeoLite2-Country.mmdb from github.com/P3TERX/GeoLite.mmdb.
+ // Set to "none" to disable.
IPCountryDB string `yaml:"ip_country_db"`
- // Path to IP-to-ASN MaxMind MMDB file.
+ // URL or local path to IP-to-ASN MaxMind MMDB file.
+ // Defaults to GeoLite2-ASN.mmdb from github.com/P3TERX/GeoLite.mmdb.
+ // Set to "none" to disable.
IPASNDB string `yaml:"ip_asn_db"`
+ // Directory to cache downloaded MMDB files.
+ // Defaults to "/var/lib/shroud/mmdb".
+ IPDBCacheDir string `yaml:"ip_db_cache_dir"`
+ // Auto-update MMDB files daily from URLs. Defaults to true.
+ IPDBAutoUpdate *bool `yaml:"ip_db_auto_update"`
}
func Load(filename string) (*Config, error) {
@@ 161,6 187,16 @@ func (c *Config) setDefaults() {
if c.Shadowsocks.NATTimeout == "" {
c.Shadowsocks.NATTimeout = "5m"
}
+ if c.Shadowsocks.IPCountryDB == "" {
+ c.Shadowsocks.IPCountryDB = mmdb.DefaultCountryURL
+ }
+ if c.Shadowsocks.IPASNDB == "" {
+ c.Shadowsocks.IPASNDB = mmdb.DefaultASNURL
+ }
+ if c.Shadowsocks.IPDBAutoUpdate == nil {
+ t := true
+ c.Shadowsocks.IPDBAutoUpdate = &t
+ }
if c.Server.Name == "" {
c.Server.Name = "Outline Server"
}
@@ 179,6 215,12 @@ func (c *Config) setDefaults() {
c.AmneziaWG.MTU = 1420
}
}
+ // VLESS defaults.
+ if c.VLESS.Enabled {
+ if c.VLESS.ListenAddr == "" {
+ c.VLESS.ListenAddr = ":443"
+ }
+ }
// AmneziaWG domain defaults to server hostname.
if c.AmneziaWG.Enabled && c.AmneziaWG.Domain == "" {
c.AmneziaWG.Domain = c.Server.Hostname
@@ 214,6 256,29 @@ func (c *AmneziaWGConfig) AWGHostname(serverHostname string) string {
return ""
}
+// MMDBConfig returns the MMDB manager configuration derived from shadowsocks settings.
+// Sources set to "none" are treated as disabled (empty).
+func (c *ShadowsocksConfig) MMDBConfig() mmdb.Config {
+ country := c.IPCountryDB
+ if country == "none" {
+ country = ""
+ }
+ asn := c.IPASNDB
+ if asn == "none" {
+ asn = ""
+ }
+ autoUpdate := true
+ if c.IPDBAutoUpdate != nil {
+ autoUpdate = *c.IPDBAutoUpdate
+ }
+ return mmdb.Config{
+ CountryURL: country,
+ ASNURL: asn,
+ CacheDir: c.IPDBCacheDir,
+ AutoUpdate: autoUpdate,
+ }
+}
+
// StateFile returns the absolute path to the state file.
func (c *Config) StateFile() string {
if filepath.IsAbs(c.State) {
A internal/mmdb/mmdb.go => internal/mmdb/mmdb.go +197 -0
@@ 0,0 1,197 @@
+// Package mmdb handles downloading, caching, and auto-updating MaxMind MMDB files.
+package mmdb
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+)
+
+const (
+ DefaultCountryURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb"
+ DefaultASNURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-ASN.mmdb"
+ DefaultCacheDir = "/var/lib/shroud/mmdb"
+
+ countryFile = "GeoLite2-Country.mmdb"
+ asnFile = "GeoLite2-ASN.mmdb"
+
+ updateInterval = 24 * time.Hour
+)
+
+// Config holds MMDB source configuration.
+type Config struct {
+ CountryURL string // URL or local path for country DB. Empty disables country lookups.
+ ASNURL string // URL or local path for ASN DB. Empty disables ASN lookups.
+ CacheDir string // Directory to cache downloaded files.
+ AutoUpdate bool // Whether to auto-update from URLs daily.
+}
+
+// Manager handles MMDB file lifecycle: download, cache, and periodic updates.
+type Manager struct {
+ cfg Config
+ logger *slog.Logger
+
+ cancel context.CancelFunc
+ wg sync.WaitGroup
+}
+
+// Result holds resolved local paths to MMDB files.
+type Result struct {
+ CountryPath string
+ ASNPath string
+}
+
+// NewManager creates a new MMDB manager and ensures files are available.
+// It downloads missing files from URLs if needed.
+func NewManager(cfg Config, logger *slog.Logger) (*Manager, error) {
+ if cfg.CacheDir == "" {
+ cfg.CacheDir = DefaultCacheDir
+ }
+ if err := os.MkdirAll(cfg.CacheDir, 0o755); err != nil {
+ return nil, fmt.Errorf("creating MMDB cache dir %s: %w", cfg.CacheDir, err)
+ }
+ m := &Manager{cfg: cfg, logger: logger}
+ return m, nil
+}
+
+// Resolve ensures MMDB files are available locally and returns their paths.
+// For URLs, it downloads the file if not cached. For local paths, it returns them as-is.
+func (m *Manager) Resolve() (Result, error) {
+ var res Result
+ var err error
+
+ if m.cfg.CountryURL != "" {
+ res.CountryPath, err = m.resolve(m.cfg.CountryURL, countryFile)
+ if err != nil {
+ return res, fmt.Errorf("resolving country DB: %w", err)
+ }
+ }
+ if m.cfg.ASNURL != "" {
+ res.ASNPath, err = m.resolve(m.cfg.ASNURL, asnFile)
+ if err != nil {
+ return res, fmt.Errorf("resolving ASN DB: %w", err)
+ }
+ }
+ return res, nil
+}
+
+// StartAutoUpdate launches a background goroutine that re-downloads MMDB files daily.
+// Call Stop() to cancel.
+func (m *Manager) StartAutoUpdate() {
+ if !m.cfg.AutoUpdate || !m.hasURLSources() {
+ return
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ m.cancel = cancel
+ m.wg.Add(1)
+ go func() {
+ defer m.wg.Done()
+ m.logger.Info("MMDB auto-update enabled.", "interval", updateInterval)
+ ticker := time.NewTicker(updateInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ m.update()
+ }
+ }
+ }()
+}
+
+// Stop cancels auto-update and waits for the goroutine to exit.
+func (m *Manager) Stop() {
+ if m.cancel != nil {
+ m.cancel()
+ m.wg.Wait()
+ }
+}
+
+// resolve returns a local path for the given source.
+// If source is a URL (starts with http:// or https://), it downloads to cacheDir/filename.
+// Otherwise it treats source as a local file path.
+func (m *Manager) resolve(source, filename string) (string, error) {
+ if !isURL(source) {
+ // Local file path — just verify it exists.
+ if _, err := os.Stat(source); err != nil {
+ return "", fmt.Errorf("local file %s: %w", source, err)
+ }
+ return source, nil
+ }
+ dest := filepath.Join(m.cfg.CacheDir, filename)
+ if _, err := os.Stat(dest); err == nil {
+ return dest, nil // Already cached.
+ }
+ if err := download(source, dest); err != nil {
+ return "", err
+ }
+ m.logger.Info("Downloaded MMDB file.", "url", source, "path", dest)
+ return dest, nil
+}
+
+func (m *Manager) update() {
+ if m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL) {
+ dest := filepath.Join(m.cfg.CacheDir, countryFile)
+ if err := download(m.cfg.CountryURL, dest); err != nil {
+ m.logger.Error("Failed to update country DB.", "err", err)
+ } else {
+ m.logger.Info("Updated country DB.", "url", m.cfg.CountryURL)
+ }
+ }
+ if m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL) {
+ dest := filepath.Join(m.cfg.CacheDir, asnFile)
+ if err := download(m.cfg.ASNURL, dest); err != nil {
+ m.logger.Error("Failed to update ASN DB.", "err", err)
+ } else {
+ m.logger.Info("Updated ASN DB.", "url", m.cfg.ASNURL)
+ }
+ }
+}
+
+func (m *Manager) hasURLSources() bool {
+ return (m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL)) ||
+ (m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL))
+}
+
+func isURL(s string) bool {
+ return len(s) > 8 && (s[:7] == "http://" || s[:8] == "https://")
+}
+
+func download(url, dest string) error {
+ resp, err := http.Get(url)
+ if err != nil {
+ return fmt.Errorf("downloading %s: %w", url, err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("downloading %s: HTTP %d", url, resp.StatusCode)
+ }
+
+ // Write to temp file, then rename for atomicity.
+ tmp := dest + ".tmp"
+ f, err := os.Create(tmp)
+ if err != nil {
+ return fmt.Errorf("creating temp file %s: %w", tmp, err)
+ }
+ if _, err := io.Copy(f, resp.Body); err != nil {
+ f.Close()
+ os.Remove(tmp)
+ return fmt.Errorf("writing %s: %w", tmp, err)
+ }
+ if err := f.Close(); err != nil {
+ os.Remove(tmp)
+ return err
+ }
+ if err := os.Rename(tmp, dest); err != nil {
+ os.Remove(tmp)
+ return fmt.Errorf("renaming %s to %s: %w", tmp, dest, err)
+ }
+ return nil
+}
M internal/store/store.go => internal/store/store.go +9 -1
@@ 13,7 13,8 @@ type AccessKey struct {
Port int `yaml:"port"`
Method string `yaml:"method"`
DataLimit *DataLimit `yaml:"data_limit,omitempty"`
- AWG *AWGKeyData `yaml:"awg,omitempty"`
+ AWG *AWGKeyData `yaml:"awg,omitempty"`
+ VLESS *VLESSKeyData `yaml:"vless,omitempty"`
}
type AWGKeyData struct {
@@ 22,6 23,10 @@ type AWGKeyData struct {
AllowedIP string `yaml:"allowed_ip"`
}
+type VLESSKeyData struct {
+ UUID string `yaml:"uuid"`
+}
+
type DataLimit struct {
Bytes int64 `yaml:"bytes"`
}
@@ 39,6 44,9 @@ type ServerState struct {
PortForNewAccessKeys int `yaml:"port_for_new_access_keys"`
AWGPrivateKey string `yaml:"awg_private_key,omitempty"`
AWGPublicKey string `yaml:"awg_public_key,omitempty"`
+ VLESSPrivateKey string `yaml:"vless_private_key,omitempty"`
+ VLESSPublicKey string `yaml:"vless_public_key,omitempty"`
+ VLESSShortID string `yaml:"vless_short_id,omitempty"`
}
type Store interface {
A internal/vless/keygen.go => internal/vless/keygen.go +37 -0
@@ 0,0 1,37 @@
+package vless
+
+import (
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+
+ "golang.org/x/crypto/curve25519"
+)
+
+// GenerateX25519Keypair generates a raw x25519 keypair for REALITY.
+// Unlike WireGuard, no clamping is applied to the private key.
+func GenerateX25519Keypair() (privBase64, pubBase64 string, err error) {
+ priv := make([]byte, 32)
+ if _, err := rand.Read(priv); err != nil {
+ return "", "", fmt.Errorf("generating private key: %w", err)
+ }
+
+ pub, err := curve25519.X25519(priv, curve25519.Basepoint)
+ if err != nil {
+ return "", "", fmt.Errorf("computing public key: %w", err)
+ }
+
+ return base64.StdEncoding.EncodeToString(priv),
+ base64.StdEncoding.EncodeToString(pub),
+ nil
+}
+
+// GenerateShortID generates a random 8-byte REALITY short ID as a hex string.
+func GenerateShortID() (string, error) {
+ var id [8]byte
+ if _, err := rand.Read(id[:]); err != nil {
+ return "", fmt.Errorf("generating short ID: %w", err)
+ }
+ return hex.EncodeToString(id[:]), nil
+}
A internal/vless/keygen_test.go => internal/vless/keygen_test.go +68 -0
@@ 0,0 1,68 @@
+package vless
+
+import (
+ "encoding/base64"
+ "encoding/hex"
+ "testing"
+)
+
+func TestGenerateX25519Keypair(t *testing.T) {
+ priv, pub, err := GenerateX25519Keypair()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ privBytes, err := base64.StdEncoding.DecodeString(priv)
+ if err != nil {
+ t.Fatalf("invalid base64 private key: %v", err)
+ }
+ if len(privBytes) != 32 {
+ t.Errorf("private key length: got %d, want 32", len(privBytes))
+ }
+
+ pubBytes, err := base64.StdEncoding.DecodeString(pub)
+ if err != nil {
+ t.Fatalf("invalid base64 public key: %v", err)
+ }
+ if len(pubBytes) != 32 {
+ t.Errorf("public key length: got %d, want 32", len(pubBytes))
+ }
+
+ // Keys should be different.
+ if priv == pub {
+ t.Error("private and public keys should differ")
+ }
+}
+
+func TestGenerateX25519Keypair_Unique(t *testing.T) {
+ priv1, _, _ := GenerateX25519Keypair()
+ priv2, _, _ := GenerateX25519Keypair()
+ if priv1 == priv2 {
+ t.Error("two generated keypairs should not be identical")
+ }
+}
+
+func TestGenerateShortID(t *testing.T) {
+ id, err := GenerateShortID()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(id) != 16 {
+ t.Errorf("short ID hex length: got %d, want 16", len(id))
+ }
+ b, err := hex.DecodeString(id)
+ if err != nil {
+ t.Fatalf("invalid hex short ID: %v", err)
+ }
+ if len(b) != 8 {
+ t.Errorf("short ID byte length: got %d, want 8", len(b))
+ }
+}
+
+func TestGenerateShortID_Unique(t *testing.T) {
+ id1, _ := GenerateShortID()
+ id2, _ := GenerateShortID()
+ if id1 == id2 {
+ t.Error("two generated short IDs should not be identical")
+ }
+}
A internal/vless/protocol.go => internal/vless/protocol.go +124 -0
@@ 0,0 1,124 @@
+package vless
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+)
+
+const (
+ Version = 0
+
+ CmdTCP byte = 0x01
+ CmdUDP byte = 0x02
+
+ AddrIPv4 byte = 0x01
+ AddrDomain byte = 0x02
+ AddrIPv6 byte = 0x03
+)
+
+// Request represents a parsed VLESS request header.
+type Request struct {
+ UUID [16]byte
+ Command byte
+ Port uint16
+ Address string
+ Addons []byte
+}
+
+// Target returns the address:port string for dialing.
+func (r *Request) Target() string {
+ return net.JoinHostPort(r.Address, fmt.Sprintf("%d", r.Port))
+}
+
+// ParseRequest reads a VLESS request header from r.
+func ParseRequest(r io.Reader) (*Request, error) {
+ // Version (1 byte).
+ var ver [1]byte
+ if _, err := io.ReadFull(r, ver[:]); err != nil {
+ return nil, fmt.Errorf("reading version: %w", err)
+ }
+ if ver[0] != Version {
+ return nil, fmt.Errorf("unsupported VLESS version: %d", ver[0])
+ }
+
+ // UUID (16 bytes).
+ var req Request
+ if _, err := io.ReadFull(r, req.UUID[:]); err != nil {
+ return nil, fmt.Errorf("reading UUID: %w", err)
+ }
+
+ // Addons length (1 byte) + addons.
+ var addonsLen [1]byte
+ if _, err := io.ReadFull(r, addonsLen[:]); err != nil {
+ return nil, fmt.Errorf("reading addons length: %w", err)
+ }
+ if addonsLen[0] > 0 {
+ req.Addons = make([]byte, addonsLen[0])
+ if _, err := io.ReadFull(r, req.Addons); err != nil {
+ return nil, fmt.Errorf("reading addons: %w", err)
+ }
+ }
+
+ // Command (1 byte).
+ var cmd [1]byte
+ if _, err := io.ReadFull(r, cmd[:]); err != nil {
+ return nil, fmt.Errorf("reading command: %w", err)
+ }
+ req.Command = cmd[0]
+ if req.Command != CmdTCP && req.Command != CmdUDP {
+ return nil, fmt.Errorf("unsupported command: 0x%02x", req.Command)
+ }
+
+ // Port (2 bytes, big-endian).
+ var portBuf [2]byte
+ if _, err := io.ReadFull(r, portBuf[:]); err != nil {
+ return nil, fmt.Errorf("reading port: %w", err)
+ }
+ req.Port = binary.BigEndian.Uint16(portBuf[:])
+
+ // Address type (1 byte) + address.
+ var addrType [1]byte
+ if _, err := io.ReadFull(r, addrType[:]); err != nil {
+ return nil, fmt.Errorf("reading address type: %w", err)
+ }
+
+ switch addrType[0] {
+ case AddrIPv4:
+ var ip [4]byte
+ if _, err := io.ReadFull(r, ip[:]); err != nil {
+ return nil, fmt.Errorf("reading IPv4 address: %w", err)
+ }
+ req.Address = net.IP(ip[:]).String()
+
+ case AddrDomain:
+ var domainLen [1]byte
+ if _, err := io.ReadFull(r, domainLen[:]); err != nil {
+ return nil, fmt.Errorf("reading domain length: %w", err)
+ }
+ domain := make([]byte, domainLen[0])
+ if _, err := io.ReadFull(r, domain); err != nil {
+ return nil, fmt.Errorf("reading domain: %w", err)
+ }
+ req.Address = string(domain)
+
+ case AddrIPv6:
+ var ip [16]byte
+ if _, err := io.ReadFull(r, ip[:]); err != nil {
+ return nil, fmt.Errorf("reading IPv6 address: %w", err)
+ }
+ req.Address = net.IP(ip[:]).String()
+
+ default:
+ return nil, fmt.Errorf("unsupported address type: 0x%02x", addrType[0])
+ }
+
+ return &req, nil
+}
+
+// WriteResponse writes the VLESS response header (version 0, no addons).
+func WriteResponse(w io.Writer) error {
+ _, err := w.Write([]byte{Version, 0x00})
+ return err
+}
A internal/vless/protocol_test.go => internal/vless/protocol_test.go +173 -0
@@ 0,0 1,173 @@
+package vless
+
+import (
+ "bytes"
+ "encoding/binary"
+ "net"
+ "testing"
+)
+
+func buildRequest(t *testing.T, uuid [16]byte, addons []byte, cmd byte, port uint16, addrType byte, addr []byte) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+
+ buf.WriteByte(Version) // version
+ buf.Write(uuid[:]) // UUID
+
+ if addons == nil {
+ buf.WriteByte(0) // addons length
+ } else {
+ buf.WriteByte(byte(len(addons)))
+ buf.Write(addons)
+ }
+
+ buf.WriteByte(cmd) // command
+
+ var portBuf [2]byte
+ binary.BigEndian.PutUint16(portBuf[:], port)
+ buf.Write(portBuf[:]) // port
+
+ buf.WriteByte(addrType) // address type
+ buf.Write(addr) // address
+
+ return buf.Bytes()
+}
+
+func TestParseRequest_TCPIPv4(t *testing.T) {
+ uuid := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
+ ip := net.ParseIP("192.168.1.1").To4()
+ data := buildRequest(t, uuid, nil, CmdTCP, 8080, AddrIPv4, ip)
+
+ req, err := ParseRequest(bytes.NewReader(data))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if req.UUID != uuid {
+ t.Errorf("UUID mismatch: got %v, want %v", req.UUID, uuid)
+ }
+ if req.Command != CmdTCP {
+ t.Errorf("Command mismatch: got %d, want %d", req.Command, CmdTCP)
+ }
+ if req.Port != 8080 {
+ t.Errorf("Port mismatch: got %d, want 8080", req.Port)
+ }
+ if req.Address != "192.168.1.1" {
+ t.Errorf("Address mismatch: got %q, want %q", req.Address, "192.168.1.1")
+ }
+ if req.Target() != "192.168.1.1:8080" {
+ t.Errorf("Target mismatch: got %q", req.Target())
+ }
+}
+
+func TestParseRequest_UDPDomain(t *testing.T) {
+ uuid := [16]byte{0xaa, 0xbb, 0xcc, 0xdd}
+ domain := "example.com"
+ addrBuf := append([]byte{byte(len(domain))}, []byte(domain)...)
+ data := buildRequest(t, uuid, nil, CmdUDP, 53, AddrDomain, addrBuf)
+
+ req, err := ParseRequest(bytes.NewReader(data))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if req.Command != CmdUDP {
+ t.Errorf("Command mismatch: got %d, want %d", req.Command, CmdUDP)
+ }
+ if req.Port != 53 {
+ t.Errorf("Port mismatch: got %d, want 53", req.Port)
+ }
+ if req.Address != "example.com" {
+ t.Errorf("Address mismatch: got %q, want %q", req.Address, "example.com")
+ }
+}
+
+func TestParseRequest_TCPIPv6(t *testing.T) {
+ uuid := [16]byte{0xff}
+ ip := net.ParseIP("2001:db8::1").To16()
+ data := buildRequest(t, uuid, nil, CmdTCP, 443, AddrIPv6, ip)
+
+ req, err := ParseRequest(bytes.NewReader(data))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if req.Address != "2001:db8::1" {
+ t.Errorf("Address mismatch: got %q, want %q", req.Address, "2001:db8::1")
+ }
+}
+
+func TestParseRequest_WithAddons(t *testing.T) {
+ uuid := [16]byte{0x01}
+ addons := []byte{0x0a, 0x0b, 0x0c}
+ ip := net.ParseIP("10.0.0.1").To4()
+ data := buildRequest(t, uuid, addons, CmdTCP, 80, AddrIPv4, ip)
+
+ req, err := ParseRequest(bytes.NewReader(data))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(req.Addons) != 3 {
+ t.Errorf("Addons length mismatch: got %d, want 3", len(req.Addons))
+ }
+ if req.Address != "10.0.0.1" {
+ t.Errorf("Address mismatch: got %q", req.Address)
+ }
+}
+
+func TestParseRequest_BadVersion(t *testing.T) {
+ data := []byte{0x01} // version 1, invalid
+ _, err := ParseRequest(bytes.NewReader(data))
+ if err == nil {
+ t.Fatal("expected error for bad version")
+ }
+}
+
+func TestParseRequest_Truncated(t *testing.T) {
+ // Only version + partial UUID.
+ data := []byte{0x00, 0x01, 0x02}
+ _, err := ParseRequest(bytes.NewReader(data))
+ if err == nil {
+ t.Fatal("expected error for truncated input")
+ }
+}
+
+func TestParseRequest_BadCommand(t *testing.T) {
+ uuid := [16]byte{}
+ ip := net.ParseIP("1.2.3.4").To4()
+ data := buildRequest(t, uuid, nil, 0x03, 80, AddrIPv4, ip) // cmd=Mux, unsupported
+
+ _, err := ParseRequest(bytes.NewReader(data))
+ if err == nil {
+ t.Fatal("expected error for unsupported command")
+ }
+}
+
+func TestParseRequest_BadAddrType(t *testing.T) {
+ uuid := [16]byte{}
+ // Build manually with bad address type.
+ var buf bytes.Buffer
+ buf.WriteByte(Version)
+ buf.Write(uuid[:])
+ buf.WriteByte(0) // no addons
+ buf.WriteByte(0x01) // TCP
+ var portBuf [2]byte
+ binary.BigEndian.PutUint16(portBuf[:], 80)
+ buf.Write(portBuf[:])
+ buf.WriteByte(0xFF) // bad address type
+
+ _, err := ParseRequest(bytes.NewReader(buf.Bytes()))
+ if err == nil {
+ t.Fatal("expected error for bad address type")
+ }
+}
+
+func TestWriteResponse(t *testing.T) {
+ var buf bytes.Buffer
+ if err := WriteResponse(&buf); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if buf.Len() != 2 {
+ t.Fatalf("response length mismatch: got %d, want 2", buf.Len())
+ }
+ if buf.Bytes()[0] != 0x00 || buf.Bytes()[1] != 0x00 {
+ t.Errorf("response bytes mismatch: got %v, want [0x00, 0x00]", buf.Bytes())
+ }
+}
A internal/vless/relay.go => internal/vless/relay.go +114 -0
@@ 0,0 1,114 @@
+package vless
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "sync"
+ "time"
+)
+
+const udpIdleTimeout = 5 * time.Minute
+
+// relayTCP dials the target and relays data bidirectionally.
+// The VLESS response header is written before relaying begins.
+func relayTCP(clientConn net.Conn, target string) error {
+ remote, err := net.DialTimeout("tcp", target, 10*time.Second)
+ if err != nil {
+ return err
+ }
+ defer remote.Close()
+
+ if err := WriteResponse(clientConn); err != nil {
+ return err
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ // client → remote
+ go func() {
+ defer wg.Done()
+ io.Copy(remote, clientConn)
+ remote.(*net.TCPConn).CloseWrite()
+ }()
+
+ // remote → client
+ go func() {
+ defer wg.Done()
+ io.Copy(clientConn, remote)
+ }()
+
+ wg.Wait()
+ return nil
+}
+
+// relayUDP handles VLESS UDP-over-TCP relay.
+// UDP datagrams are framed as [2B length BE][payload] on the TCP stream.
+func relayUDP(clientConn net.Conn, target string) error {
+ remoteAddr, err := net.ResolveUDPAddr("udp", target)
+ if err != nil {
+ return err
+ }
+
+ udpConn, err := net.DialUDP("udp", nil, remoteAddr)
+ if err != nil {
+ return err
+ }
+ defer udpConn.Close()
+
+ if err := WriteResponse(clientConn); err != nil {
+ return err
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ // client TCP → remote UDP: read [2B len][payload] frames, send as UDP datagrams.
+ go func() {
+ defer wg.Done()
+ defer udpConn.Close()
+ var lenBuf [2]byte
+ for {
+ if _, err := io.ReadFull(clientConn, lenBuf[:]); err != nil {
+ return
+ }
+ payloadLen := binary.BigEndian.Uint16(lenBuf[:])
+ if payloadLen == 0 {
+ continue
+ }
+ buf := make([]byte, payloadLen)
+ if _, err := io.ReadFull(clientConn, buf); err != nil {
+ return
+ }
+ udpConn.SetWriteDeadline(time.Now().Add(10 * time.Second))
+ if _, err := udpConn.Write(buf); err != nil {
+ return
+ }
+ }
+ }()
+
+ // remote UDP → client TCP: read UDP datagrams, write as [2B len][payload] frames.
+ go func() {
+ defer wg.Done()
+ buf := make([]byte, 64*1024)
+ var lenBuf [2]byte
+ for {
+ udpConn.SetReadDeadline(time.Now().Add(udpIdleTimeout))
+ n, err := udpConn.Read(buf)
+ if err != nil {
+ return
+ }
+ binary.BigEndian.PutUint16(lenBuf[:], uint16(n))
+ if _, err := clientConn.Write(lenBuf[:]); err != nil {
+ return
+ }
+ if _, err := clientConn.Write(buf[:n]); err != nil {
+ return
+ }
+ }
+ }()
+
+ wg.Wait()
+ return nil
+}
A internal/vless/server.go => internal/vless/server.go +169 -0
@@ 0,0 1,169 @@
+package vless
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/xtls/reality"
+
+ "sourcecraft.dev/bigbes/shroud/internal/store"
+)
+
+// Config holds the VLESS+REALITY server configuration.
+type Config struct {
+ ListenAddr string
+ PrivateKey []byte // 32-byte raw x25519 private key
+ PublicKey []byte // 32-byte raw x25519 public key (for share links)
+ ShortID [8]byte
+ ServerNames []string
+ Dest string
+ Show bool
+}
+
+// Server is a VLESS proxy server with REALITY obfuscation.
+type Server struct {
+ cfg Config
+ listener net.Listener
+ logger *slog.Logger
+
+ mu sync.RWMutex
+ uuids map[[16]byte]string // UUID → access key ID
+ cancel context.CancelFunc
+}
+
+// New creates a new VLESS+REALITY server.
+func New(cfg Config, logger *slog.Logger) *Server {
+ return &Server{
+ cfg: cfg,
+ logger: logger.With("component", "vless"),
+ uuids: make(map[[16]byte]string),
+ }
+}
+
+// Start initializes the REALITY listener and begins accepting connections.
+func (s *Server) Start() error {
+ shortIds := map[[8]byte]bool{s.cfg.ShortID: true}
+
+ serverNames := make(map[string]bool, len(s.cfg.ServerNames))
+ for _, name := range s.cfg.ServerNames {
+ serverNames[name] = true
+ }
+
+ realityCfg := &reality.Config{
+ PrivateKey: s.cfg.PrivateKey,
+ ShortIds: shortIds,
+ ServerNames: serverNames,
+ Dest: s.cfg.Dest,
+ Show: s.cfg.Show,
+ }
+
+ ln, err := reality.Listen("tcp", s.cfg.ListenAddr, realityCfg)
+ if err != nil {
+ return fmt.Errorf("reality listen: %w", err)
+ }
+ s.listener = ln
+
+ ctx, cancel := context.WithCancel(context.Background())
+ s.cancel = cancel
+
+ go s.acceptLoop(ctx)
+
+ return nil
+}
+
+// SyncKeys rebuilds the UUID authentication map from the given access keys.
+func (s *Server) SyncKeys(keys []store.AccessKey) error {
+ newUUIDs := make(map[[16]byte]string)
+ for _, k := range keys {
+ if k.VLESS == nil {
+ continue
+ }
+ parsed, err := uuid.Parse(k.VLESS.UUID)
+ if err != nil {
+ s.logger.Warn("Skipping VLESS key with invalid UUID.", "keyID", k.ID, "err", err)
+ continue
+ }
+ newUUIDs[parsed] = k.ID
+ }
+
+ s.mu.Lock()
+ s.uuids = newUUIDs
+ s.mu.Unlock()
+
+ s.logger.Info("VLESS keys synced.", "count", len(newUUIDs))
+ return nil
+}
+
+// Stop shuts down the server.
+func (s *Server) Stop() {
+ if s.cancel != nil {
+ s.cancel()
+ }
+ if s.listener != nil {
+ s.listener.Close()
+ }
+ s.logger.Info("VLESS server stopped.")
+}
+
+func (s *Server) acceptLoop(ctx context.Context) {
+ for {
+ conn, err := s.listener.Accept()
+ if err != nil {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ s.logger.Error("VLESS accept error.", "err", err)
+ continue
+ }
+ }
+ go s.handleConn(conn)
+ }
+}
+
+func (s *Server) handleConn(conn net.Conn) {
+ defer conn.Close()
+
+ // Deadline for header parsing.
+ conn.SetReadDeadline(time.Now().Add(10 * time.Second))
+
+ req, err := ParseRequest(conn)
+ if err != nil {
+ s.logger.Debug("Failed to parse VLESS request.", "err", err)
+ return
+ }
+
+ // Clear read deadline for relay.
+ conn.SetReadDeadline(time.Time{})
+
+ // Authenticate UUID.
+ s.mu.RLock()
+ keyID, ok := s.uuids[req.UUID]
+ s.mu.RUnlock()
+
+ if !ok {
+ s.logger.Debug("Unknown VLESS UUID, closing connection.")
+ return
+ }
+
+ target := req.Target()
+
+ switch req.Command {
+ case CmdTCP:
+ s.logger.Debug("VLESS TCP relay.", "keyID", keyID, "target", target)
+ if err := relayTCP(conn, target); err != nil {
+ s.logger.Debug("VLESS TCP relay ended.", "keyID", keyID, "err", err)
+ }
+
+ case CmdUDP:
+ s.logger.Debug("VLESS UDP relay.", "keyID", keyID, "target", target)
+ if err := relayUDP(conn, target); err != nil {
+ s.logger.Debug("VLESS UDP relay ended.", "keyID", keyID, "err", err)
+ }
+ }
+}