From 3804bcf465a4eb00d1f03720316d0a176b0b97fb Mon Sep 17 00:00:00 2001 From: Eugene Blikh Date: Thu, 19 Mar 2026 10:29:34 +0300 Subject: [PATCH] feat(vless): add VLESS+REALITY transport support Add VLESS over REALITY as a third VPN transport alongside Shadowsocks and AmneziaWG. - Implement VLESS protocol parser, TCP relay, and REALITY-based TLS server. - Add x25519 keypair and short ID generation with tests. - Add MMDB manager for automatic download and daily update of GeoIP databases. - Integrate VLESS server lifecycle into main serve command and API handlers. - Add CLI subcommands: vless keygen, vless info, vless share. - Extend store with VLESSKeyData and REALITY server keys persistence. - Update config with vless section and MMDB URL/cache/auto-update settings. false --- cmd/shroud/main.go | 244 ++++++++++++++++++++++++++++++-- config.example.yaml | 16 ++- go.mod | 2 + go.sum | 4 + internal/api/handlers.go | 14 ++ internal/api/models.go | 7 +- internal/api/router.go | 38 +++-- internal/config/config.go | 69 ++++++++- internal/mmdb/mmdb.go | 197 ++++++++++++++++++++++++++ internal/store/store.go | 10 +- internal/vless/keygen.go | 37 +++++ internal/vless/keygen_test.go | 68 +++++++++ internal/vless/protocol.go | 124 ++++++++++++++++ internal/vless/protocol_test.go | 173 ++++++++++++++++++++++ internal/vless/relay.go | 114 +++++++++++++++ internal/vless/server.go | 169 ++++++++++++++++++++++ 16 files changed, 1257 insertions(+), 29 deletions(-) create mode 100644 internal/mmdb/mmdb.go create mode 100644 internal/vless/keygen.go create mode 100644 internal/vless/keygen_test.go create mode 100644 internal/vless/protocol.go create mode 100644 internal/vless/protocol_test.go create mode 100644 internal/vless/relay.go create mode 100644 internal/vless/server.go diff --git a/cmd/shroud/main.go b/cmd/shroud/main.go index 5315f5d3668272081b297fa60a4527ca59415c5e..51a46e3be5f69224046c01fd1eff376009ff58e6 100644 --- a/cmd/shroud/main.go +++ b/cmd/shroud/main.go @@ -25,12 +25,16 @@ import ( "golang.org/x/term" "gopkg.in/yaml.v3" + "encoding/base64" + "sourcecraft.dev/bigbes/shroud/internal/api" "sourcecraft.dev/bigbes/shroud/internal/awgserver" "sourcecraft.dev/bigbes/shroud/internal/config" "sourcecraft.dev/bigbes/shroud/internal/metrics" + "sourcecraft.dev/bigbes/shroud/internal/mmdb" "sourcecraft.dev/bigbes/shroud/internal/ssserver" "sourcecraft.dev/bigbes/shroud/internal/store" + "sourcecraft.dev/bigbes/shroud/internal/vless" ) var version = "dev" @@ -49,7 +53,7 @@ func newRootCmd() *cobra.Command { root := &cobra.Command{ Use: "shroud", - Short: "Shadowsocks + AmneziaWG VPN server", + Short: "Shadowsocks + AmneziaWG + VLESS VPN server", Version: version, RunE: func(cmd *cobra.Command, args []string) error { return runServe(configFile, verbose) @@ -66,6 +70,7 @@ func newRootCmd() *cobra.Command { newCompletionCmd(), newKeyCmd(&configFile), newServerCmd(&configFile), + newVLESSCmd(&configFile), ) return root @@ -160,8 +165,12 @@ func newKeyListCmd(configFile *string) *cobra.Command { if k.AWG != nil { awgInfo = fmt.Sprintf(" awg=%s", k.AWG.AllowedIP) } - fmt.Printf("%-6s %-20s port=%-6d cipher=%s limit=%s%s\n", - k.ID, k.Name, k.Port, k.Method, limit, awgInfo) + vlessInfo := "" + if k.VLESS != nil { + vlessInfo = fmt.Sprintf(" vless=%s", k.VLESS.UUID) + } + fmt.Printf("%-6s %-20s port=%-6d cipher=%s limit=%s%s%s\n", + k.ID, k.Name, k.Port, k.Method, limit, awgInfo, vlessInfo) } return nil }, @@ -248,6 +257,13 @@ func newKeyAddCmd(configFile *string) *cobra.Command { } } + // Generate VLESS credentials if VLESS is enabled. + if cfg.VLESS.Enabled { + ak.VLESS = &store.VLESSKeyData{ + UUID: uuid.New().String(), + } + } + if err := st.CreateKey(ak); err != nil { return err } @@ -256,6 +272,9 @@ func newKeyAddCmd(configFile *string) *cobra.Command { if ak.AWG != nil { fmt.Printf(" awg: %s (pubkey=%s)\n", ak.AWG.AllowedIP, ak.AWG.PublicKey) } + if ak.VLESS != nil { + fmt.Printf(" vless: %s\n", ak.VLESS.UUID) + } return nil }, } @@ -490,15 +509,27 @@ func runServe(configFile string, verbose bool) error { logger.Info("Generated API secret.", "secret", apiSecret) } + // Set up MMDB files (download if needed, start auto-update). + mmdbMgr, err := mmdb.NewManager(cfg.Shadowsocks.MMDBConfig(), logger) + if err != nil { + return fmt.Errorf("setting up MMDB manager: %w", err) + } + mmdbPaths, err := mmdbMgr.Resolve() + if err != nil { + return fmt.Errorf("resolving MMDB files: %w", err) + } + mmdbMgr.StartAutoUpdate() + defer mmdbMgr.Stop() + // Set up IP info for geo-metrics. var ip2info ipinfo.IPInfoMap - if cfg.Shadowsocks.IPCountryDB != "" || cfg.Shadowsocks.IPASNDB != "" { - mmdb, err := ipinfo.NewMMDBIPInfoMap(cfg.Shadowsocks.IPCountryDB, cfg.Shadowsocks.IPASNDB) + if mmdbPaths.CountryPath != "" || mmdbPaths.ASNPath != "" { + mmdbMap, err := ipinfo.NewMMDBIPInfoMap(mmdbPaths.CountryPath, mmdbPaths.ASNPath) if err != nil { return fmt.Errorf("loading IP info databases: %w", err) } - defer mmdb.Close() - ip2info = mmdb + defer mmdbMap.Close() + ip2info = mmdbMap } // Set up metrics. @@ -649,6 +680,69 @@ func runServe(configFile string, verbose bool) error { logger.Info("AmneziaWG server started.", "port", cfg.AmneziaWG.ListenPort) } + // Set up VLESS+REALITY server (optional). + var vlessServer *vless.Server + if cfg.VLESS.Enabled { + srv := st.GetServer() + if srv.VLESSPrivateKey == "" { + priv, pub, err := vless.GenerateX25519Keypair() + if err != nil { + return fmt.Errorf("generating REALITY keypair: %w", err) + } + shortID, err := vless.GenerateShortID() + if err != nil { + return fmt.Errorf("generating REALITY short ID: %w", err) + } + if err := st.UpdateServer(func(s *store.ServerState) { + s.VLESSPrivateKey = priv + s.VLESSPublicKey = pub + s.VLESSShortID = shortID + }); err != nil { + return fmt.Errorf("persisting REALITY keys: %w", err) + } + logger.Info("Generated REALITY keypair.") + srv = st.GetServer() + } + + privKey, err := base64.StdEncoding.DecodeString(srv.VLESSPrivateKey) + if err != nil { + return fmt.Errorf("decoding REALITY private key: %w", err) + } + var shortID [8]byte + shortIDBytes, err := hex.DecodeString(srv.VLESSShortID) + if err != nil { + return fmt.Errorf("decoding REALITY short ID: %w", err) + } + copy(shortID[:], shortIDBytes) + + pubKey, err := base64.StdEncoding.DecodeString(srv.VLESSPublicKey) + if err != nil { + return fmt.Errorf("decoding REALITY public key: %w", err) + } + + vlessCfg := vless.Config{ + ListenAddr: cfg.VLESS.ListenAddr, + PrivateKey: privKey, + PublicKey: pubKey, + ShortID: shortID, + ServerNames: cfg.VLESS.ServerNames, + Dest: cfg.VLESS.Dest, + Show: cfg.VLESS.Show, + } + + vlessServer = vless.New(vlessCfg, logger) + if err := vlessServer.Start(); err != nil { + return fmt.Errorf("starting VLESS server: %w", err) + } + + if keys := st.ListKeys(); len(keys) > 0 { + if err := vlessServer.SyncKeys(keys); err != nil { + return fmt.Errorf("syncing VLESS keys: %w", err) + } + } + logger.Info("VLESS+REALITY server started.", "addr", cfg.VLESS.ListenAddr) + } + // Start Prometheus metrics + /healthz endpoint. metricsHandler := promhttp.HandlerFor(registry, promhttp.HandlerOpts{}) metricsMux := http.NewServeMux() @@ -675,7 +769,7 @@ func runServe(configFile string, verbose bool) error { if cfg.AmneziaWG.Enabled { awgCfgPtr = &cfg.AmneziaWG } - handler := api.NewHandler(st, ss, awg, awgCfgPtr, transferTracker, version, logger) + handler := api.NewHandler(st, ss, awg, awgCfgPtr, vlessServer, transferTracker, version, logger) router := api.NewRouter(apiSecret, handler) apiServer := &http.Server{Addr: cfg.API.ListenAddr, Handler: router} go func() { @@ -705,6 +799,9 @@ func runServe(configFile string, verbose bool) error { <-sigCh logger.Info("Shutting down...") + if vlessServer != nil { + vlessServer.Stop() + } if awg != nil { awg.Stop() } @@ -717,6 +814,137 @@ func runServe(configFile string, verbose bool) error { return nil } +// --- vless commands --- + +func newVLESSCmd(configFile *string) *cobra.Command { + cmd := &cobra.Command{ + Use: "vless", + Short: "VLESS+REALITY management commands", + } + cmd.AddCommand( + newVLESSKeygenCmd(), + newVLESSInfoCmd(configFile), + newVLESSShareCmd(configFile), + ) + return cmd +} + +func newVLESSKeygenCmd() *cobra.Command { + return &cobra.Command{ + Use: "keygen", + Short: "Generate a new REALITY x25519 keypair and short ID", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + priv, pub, err := vless.GenerateX25519Keypair() + if err != nil { + return err + } + shortID, err := vless.GenerateShortID() + if err != nil { + return err + } + fmt.Printf("Private Key: %s\n", priv) + fmt.Printf("Public Key: %s\n", pub) + fmt.Printf("Short ID: %s\n", shortID) + return nil + }, + } +} + +func newVLESSInfoCmd(configFile *string) *cobra.Command { + return &cobra.Command{ + Use: "info", + Short: "Show VLESS+REALITY server info", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load(*configFile) + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + if !cfg.VLESS.Enabled { + fmt.Println("VLESS is not enabled.") + return nil + } + st, err := openStore(*configFile) + if err != nil { + return err + } + srv := st.GetServer() + fmt.Printf("Listen Addr: %s\n", cfg.VLESS.ListenAddr) + fmt.Printf("Server Names: %v\n", cfg.VLESS.ServerNames) + fmt.Printf("Dest: %s\n", cfg.VLESS.Dest) + if srv.VLESSPublicKey != "" { + fmt.Printf("Public Key: %s\n", srv.VLESSPublicKey) + fmt.Printf("Short ID: %s\n", srv.VLESSShortID) + } else { + fmt.Println("REALITY keypair not yet generated (start the server first).") + } + return nil + }, + } +} + +func newVLESSShareCmd(configFile *string) *cobra.Command { + return &cobra.Command{ + Use: "share ", + Short: "Generate a VLESS share link for a key", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load(*configFile) + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + if !cfg.VLESS.Enabled { + return fmt.Errorf("VLESS is not enabled") + } + st, err := openStore(*configFile) + if err != nil { + return err + } + srv := st.GetServer() + if srv.VLESSPublicKey == "" { + return fmt.Errorf("REALITY keypair not yet generated (start the server first)") + } + + key, ok := st.GetKey(args[0]) + if !ok { + return fmt.Errorf("key %q not found", args[0]) + } + if key.VLESS == nil { + return fmt.Errorf("key %q has no VLESS credentials", args[0]) + } + + hostname := srv.Hostname + if hostname == "" { + hostname = "localhost" + } + + // Extract port from listen address. + _, portStr, err := net.SplitHostPort(cfg.VLESS.ListenAddr) + if err != nil { + portStr = "443" + } + + sni := "" + if len(cfg.VLESS.ServerNames) > 0 { + sni = cfg.VLESS.ServerNames[0] + } + + // Base64url encode the public key (no padding). + pubKeyB64URL := base64.RawURLEncoding.EncodeToString([]byte(srv.VLESSPublicKey)) + + name := key.Name + if name == "" { + name = key.ID + } + + fmt.Printf("vless://%s@%s:%s?encryption=none&security=reality&sni=%s&fp=chrome&pbk=%s&sid=%s&type=tcp#%s\n", + key.VLESS.UUID, hostname, portStr, sni, pubKeyB64URL, srv.VLESSShortID, name) + return nil + }, + } +} + func pickRandomPort() (int, error) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/config.example.yaml b/config.example.yaml index 65ef279195375135427a52e505d4d85aa66a67f7..2fca16f6f8d240e208f1ffef5c266e08bc365982 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -25,8 +25,12 @@ shadowsocks: default_cipher: chacha20-ietf-poly1305 nat_timeout: 5m replay_history: 10000 - ip_country_db: "" # optional: MaxMind GeoLite2-Country.mmdb for per-country metrics - ip_asn_db: "" # optional: MaxMind GeoLite2-ASN.mmdb for per-ASN metrics + # GeoIP databases: URL or local path. Defaults to GeoLite2 from github.com/P3TERX/GeoLite.mmdb. + # Set to "none" to disable. URLs are downloaded and cached automatically. + ip_country_db: "" # default: https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb + ip_asn_db: "" # default: https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-ASN.mmdb + ip_db_cache_dir: "" # default: /var/lib/shroud/mmdb + ip_db_auto_update: true # auto-update MMDB files daily from URLs amneziawg: enabled: true @@ -52,6 +56,14 @@ amneziawg: h3: "250000-300000" h4: "350000-400000" +vless: + enabled: false + listen_addr: ":443" # REALITY listener address + server_names: # SNIs accepted by REALITY handshake + - www.microsoft.com + dest: "www.microsoft.com:443" # decoy forward target for unauthenticated probes + show: false # debug REALITY handshakes + # ACME (Let's Encrypt) certificate settings. # Used by AmneziaWG HTTP/3 cover server for DPI resistance. acme: diff --git a/go.mod b/go.mod index 3bc08fbd93edc2e4196148a3f63aff40b537ec9a..1f868b1709504e20a19bb6686744fd679de20193 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/opencontainers/selinux v1.12.0 // indirect github.com/oschwald/geoip2-golang v1.11.0 // indirect github.com/oschwald/maxminddb-golang v1.13.1 // indirect + github.com/pires/go-proxyproto v0.7.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/prometheus-community/go-runit v0.1.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect @@ -57,6 +58,7 @@ require ( github.com/spf13/pflag v1.0.10 // indirect github.com/tevino/abool v1.2.0 // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect + github.com/xtls/reality v0.0.0-20240712055506-48f0b2d5ed6d // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect diff --git a/go.sum b/go.sum index 0bcf495508fe3911c55bc113de9525dee82498b6..03ea1e74726fbf5998893760c369e2419dfb8b79 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/oschwald/geoip2-golang v1.11.0 h1:hNENhCn1Uyzhf9PTmquXENiWS6AlxAEnBII github.com/oschwald/geoip2-golang v1.11.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= @@ -129,6 +131,8 @@ github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg= github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc= github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU= +github.com/xtls/reality v0.0.0-20240712055506-48f0b2d5ed6d h1:+B97uD9uHLgAAulhigmys4BVwZZypzK7gPN3WtpgRJg= +github.com/xtls/reality v0.0.0-20240712055506-48f0b2d5ed6d/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 5c6a301b36fbe71518f31e75638983b2eb4c9c3f..3bffdd4cb67297f69053b27dd9221aa3fd769e7e 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -6,6 +6,8 @@ import ( "strings" "text/template" + "github.com/google/uuid" + "sourcecraft.dev/bigbes/shroud/internal/awgserver" "sourcecraft.dev/bigbes/shroud/internal/config" "sourcecraft.dev/bigbes/shroud/internal/store" @@ -142,6 +144,13 @@ func (h *Handler) doCreateKey(w http.ResponseWriter, req CreateAccessKeyRequest) ak.AWG = awgData } + // Generate VLESS credentials if VLESS is enabled. + if h.vlessServer != nil { + ak.VLESS = &store.VLESSKeyData{ + UUID: uuid.New().String(), + } + } + if err := h.store.CreateKey(ak); err != nil { writeError(w, http.StatusConflict, err.Error()) return @@ -307,6 +316,11 @@ func (h *Handler) keyToResponse(k store.AccessKey) AccessKeyResponse { if k.DataLimit != nil { resp.DataLimit = &DataLimitJSON{Bytes: k.DataLimit.Bytes} } + if k.VLESS != nil && h.vlessServer != nil { + resp.VLESS = &VLESSKeyResponse{ + UUID: k.VLESS.UUID, + } + } if k.AWG != nil && h.awgConfig != nil { srv := h.store.GetServer() hostname := h.awgHostname(srv.Hostname) diff --git a/internal/api/models.go b/internal/api/models.go index eed1dd438b54e4de5715b0dbd7f8b9e82ddc6663..fe847ba4ad23b6ead873cde6599e992dad487cb7 100644 --- a/internal/api/models.go +++ b/internal/api/models.go @@ -8,7 +8,12 @@ type AccessKeyResponse struct { Method string `json:"method"` AccessURL string `json:"accessUrl"` DataLimit *DataLimitJSON `json:"dataLimit,omitempty"` - AWG *AWGKeyResponse `json:"awg,omitempty"` + AWG *AWGKeyResponse `json:"awg,omitempty"` + VLESS *VLESSKeyResponse `json:"vless,omitempty"` +} + +type VLESSKeyResponse struct { + UUID string `json:"uuid"` } type AWGKeyResponse struct { diff --git a/internal/api/router.go b/internal/api/router.go index 8e8a361769bb8b17d2a04dc9e0f7ed75be08a761..bb2b0ac99d4de49f4dd09a39a35e6c10743463c2 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -12,27 +12,30 @@ import ( "sourcecraft.dev/bigbes/shroud/internal/metrics" "sourcecraft.dev/bigbes/shroud/internal/ssserver" "sourcecraft.dev/bigbes/shroud/internal/store" + "sourcecraft.dev/bigbes/shroud/internal/vless" ) type Handler struct { - store store.Store - ssServer *ssserver.Server - awgServer *awgserver.Server - awgConfig *config.AmneziaWGConfig - tracker *metrics.TransferTracker - version string - logger *slog.Logger + store store.Store + ssServer *ssserver.Server + awgServer *awgserver.Server + awgConfig *config.AmneziaWGConfig + vlessServer *vless.Server + tracker *metrics.TransferTracker + version string + logger *slog.Logger } -func NewHandler(s store.Store, ss *ssserver.Server, awg *awgserver.Server, awgCfg *config.AmneziaWGConfig, tracker *metrics.TransferTracker, version string, logger *slog.Logger) *Handler { +func NewHandler(s store.Store, ss *ssserver.Server, awg *awgserver.Server, awgCfg *config.AmneziaWGConfig, vl *vless.Server, tracker *metrics.TransferTracker, version string, logger *slog.Logger) *Handler { return &Handler{ - store: s, - ssServer: ss, - awgServer: awg, - awgConfig: awgCfg, - tracker: tracker, - version: version, - logger: logger, + store: s, + ssServer: ss, + awgServer: awg, + awgConfig: awgCfg, + vlessServer: vl, + tracker: tracker, + version: version, + logger: logger, } } @@ -101,6 +104,11 @@ func (h *Handler) syncKeys() error { return err } } + if h.vlessServer != nil { + if err := h.vlessServer.SyncKeys(keys); err != nil { + return err + } + } return nil } diff --git a/internal/config/config.go b/internal/config/config.go index 59686aa125a03681b7562ccca82d34e0b028ab9d..bd473818f545c0d3bd74537d755a5c36e9698cef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,8 @@ import ( "os" "path/filepath" + "sourcecraft.dev/bigbes/shroud/internal/mmdb" + "gopkg.in/yaml.v3" ) @@ -19,6 +21,8 @@ type Config struct { Shadowsocks ShadowsocksConfig `yaml:"shadowsocks"` // AmneziaWG VPN settings. AmneziaWG AmneziaWGConfig `yaml:"amneziawg"` + // VLESS over REALITY transport settings. + VLESS VLESSConfig `yaml:"vless"` // ACME (Let's Encrypt) certificate settings. ACME ACMEConfig `yaml:"acme"` // Path to persistent state file. @@ -100,6 +104,19 @@ type AmneziaWGConfig struct { H4 string `yaml:"h4"` } +type VLESSConfig struct { + // Enable VLESS over REALITY transport. + Enabled bool `yaml:"enabled"` + // Listen address for the REALITY listener (e.g., ":443"). + ListenAddr string `yaml:"listen_addr"` + // Accepted SNIs for REALITY handshake (e.g., ["www.microsoft.com"]). + ServerNames []string `yaml:"server_names"` + // Decoy forward target for unauthenticated probes (e.g., "www.microsoft.com:443"). + Dest string `yaml:"dest"` + // Enable debug logging for REALITY handshakes. + Show bool `yaml:"show"` +} + type ACMEConfig struct { // Directory for cached TLS certificates. CertCache string `yaml:"cert_cache"` @@ -118,10 +135,19 @@ type ShadowsocksConfig struct { NATTimeout string `yaml:"nat_timeout"` // Replay protection history size (0 = disabled). ReplayHistory int `yaml:"replay_history"` - // Path to IP-to-country MaxMind MMDB file. + // URL or local path to IP-to-country MaxMind MMDB file. + // Defaults to GeoLite2-Country.mmdb from github.com/P3TERX/GeoLite.mmdb. + // Set to "none" to disable. IPCountryDB string `yaml:"ip_country_db"` - // Path to IP-to-ASN MaxMind MMDB file. + // URL or local path to IP-to-ASN MaxMind MMDB file. + // Defaults to GeoLite2-ASN.mmdb from github.com/P3TERX/GeoLite.mmdb. + // Set to "none" to disable. IPASNDB string `yaml:"ip_asn_db"` + // Directory to cache downloaded MMDB files. + // Defaults to "/var/lib/shroud/mmdb". + IPDBCacheDir string `yaml:"ip_db_cache_dir"` + // Auto-update MMDB files daily from URLs. Defaults to true. + IPDBAutoUpdate *bool `yaml:"ip_db_auto_update"` } func Load(filename string) (*Config, error) { @@ -161,6 +187,16 @@ func (c *Config) setDefaults() { if c.Shadowsocks.NATTimeout == "" { c.Shadowsocks.NATTimeout = "5m" } + if c.Shadowsocks.IPCountryDB == "" { + c.Shadowsocks.IPCountryDB = mmdb.DefaultCountryURL + } + if c.Shadowsocks.IPASNDB == "" { + c.Shadowsocks.IPASNDB = mmdb.DefaultASNURL + } + if c.Shadowsocks.IPDBAutoUpdate == nil { + t := true + c.Shadowsocks.IPDBAutoUpdate = &t + } if c.Server.Name == "" { c.Server.Name = "Outline Server" } @@ -179,6 +215,12 @@ func (c *Config) setDefaults() { c.AmneziaWG.MTU = 1420 } } + // VLESS defaults. + if c.VLESS.Enabled { + if c.VLESS.ListenAddr == "" { + c.VLESS.ListenAddr = ":443" + } + } // AmneziaWG domain defaults to server hostname. if c.AmneziaWG.Enabled && c.AmneziaWG.Domain == "" { c.AmneziaWG.Domain = c.Server.Hostname @@ -214,6 +256,29 @@ func (c *AmneziaWGConfig) AWGHostname(serverHostname string) string { return "" } +// MMDBConfig returns the MMDB manager configuration derived from shadowsocks settings. +// Sources set to "none" are treated as disabled (empty). +func (c *ShadowsocksConfig) MMDBConfig() mmdb.Config { + country := c.IPCountryDB + if country == "none" { + country = "" + } + asn := c.IPASNDB + if asn == "none" { + asn = "" + } + autoUpdate := true + if c.IPDBAutoUpdate != nil { + autoUpdate = *c.IPDBAutoUpdate + } + return mmdb.Config{ + CountryURL: country, + ASNURL: asn, + CacheDir: c.IPDBCacheDir, + AutoUpdate: autoUpdate, + } +} + // StateFile returns the absolute path to the state file. func (c *Config) StateFile() string { if filepath.IsAbs(c.State) { diff --git a/internal/mmdb/mmdb.go b/internal/mmdb/mmdb.go new file mode 100644 index 0000000000000000000000000000000000000000..58c70493d2fc81aa5d6766ff3216e5e06265274a --- /dev/null +++ b/internal/mmdb/mmdb.go @@ -0,0 +1,197 @@ +// Package mmdb handles downloading, caching, and auto-updating MaxMind MMDB files. +package mmdb + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "sync" + "time" +) + +const ( + DefaultCountryURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb" + DefaultASNURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-ASN.mmdb" + DefaultCacheDir = "/var/lib/shroud/mmdb" + + countryFile = "GeoLite2-Country.mmdb" + asnFile = "GeoLite2-ASN.mmdb" + + updateInterval = 24 * time.Hour +) + +// Config holds MMDB source configuration. +type Config struct { + CountryURL string // URL or local path for country DB. Empty disables country lookups. + ASNURL string // URL or local path for ASN DB. Empty disables ASN lookups. + CacheDir string // Directory to cache downloaded files. + AutoUpdate bool // Whether to auto-update from URLs daily. +} + +// Manager handles MMDB file lifecycle: download, cache, and periodic updates. +type Manager struct { + cfg Config + logger *slog.Logger + + cancel context.CancelFunc + wg sync.WaitGroup +} + +// Result holds resolved local paths to MMDB files. +type Result struct { + CountryPath string + ASNPath string +} + +// NewManager creates a new MMDB manager and ensures files are available. +// It downloads missing files from URLs if needed. +func NewManager(cfg Config, logger *slog.Logger) (*Manager, error) { + if cfg.CacheDir == "" { + cfg.CacheDir = DefaultCacheDir + } + if err := os.MkdirAll(cfg.CacheDir, 0o755); err != nil { + return nil, fmt.Errorf("creating MMDB cache dir %s: %w", cfg.CacheDir, err) + } + m := &Manager{cfg: cfg, logger: logger} + return m, nil +} + +// Resolve ensures MMDB files are available locally and returns their paths. +// For URLs, it downloads the file if not cached. For local paths, it returns them as-is. +func (m *Manager) Resolve() (Result, error) { + var res Result + var err error + + if m.cfg.CountryURL != "" { + res.CountryPath, err = m.resolve(m.cfg.CountryURL, countryFile) + if err != nil { + return res, fmt.Errorf("resolving country DB: %w", err) + } + } + if m.cfg.ASNURL != "" { + res.ASNPath, err = m.resolve(m.cfg.ASNURL, asnFile) + if err != nil { + return res, fmt.Errorf("resolving ASN DB: %w", err) + } + } + return res, nil +} + +// StartAutoUpdate launches a background goroutine that re-downloads MMDB files daily. +// Call Stop() to cancel. +func (m *Manager) StartAutoUpdate() { + if !m.cfg.AutoUpdate || !m.hasURLSources() { + return + } + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + m.wg.Add(1) + go func() { + defer m.wg.Done() + m.logger.Info("MMDB auto-update enabled.", "interval", updateInterval) + ticker := time.NewTicker(updateInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.update() + } + } + }() +} + +// Stop cancels auto-update and waits for the goroutine to exit. +func (m *Manager) Stop() { + if m.cancel != nil { + m.cancel() + m.wg.Wait() + } +} + +// resolve returns a local path for the given source. +// If source is a URL (starts with http:// or https://), it downloads to cacheDir/filename. +// Otherwise it treats source as a local file path. +func (m *Manager) resolve(source, filename string) (string, error) { + if !isURL(source) { + // Local file path — just verify it exists. + if _, err := os.Stat(source); err != nil { + return "", fmt.Errorf("local file %s: %w", source, err) + } + return source, nil + } + dest := filepath.Join(m.cfg.CacheDir, filename) + if _, err := os.Stat(dest); err == nil { + return dest, nil // Already cached. + } + if err := download(source, dest); err != nil { + return "", err + } + m.logger.Info("Downloaded MMDB file.", "url", source, "path", dest) + return dest, nil +} + +func (m *Manager) update() { + if m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL) { + dest := filepath.Join(m.cfg.CacheDir, countryFile) + if err := download(m.cfg.CountryURL, dest); err != nil { + m.logger.Error("Failed to update country DB.", "err", err) + } else { + m.logger.Info("Updated country DB.", "url", m.cfg.CountryURL) + } + } + if m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL) { + dest := filepath.Join(m.cfg.CacheDir, asnFile) + if err := download(m.cfg.ASNURL, dest); err != nil { + m.logger.Error("Failed to update ASN DB.", "err", err) + } else { + m.logger.Info("Updated ASN DB.", "url", m.cfg.ASNURL) + } + } +} + +func (m *Manager) hasURLSources() bool { + return (m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL)) || + (m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL)) +} + +func isURL(s string) bool { + return len(s) > 8 && (s[:7] == "http://" || s[:8] == "https://") +} + +func download(url, dest string) error { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("downloading %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("downloading %s: HTTP %d", url, resp.StatusCode) + } + + // Write to temp file, then rename for atomicity. + tmp := dest + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("creating temp file %s: %w", tmp, err) + } + if _, err := io.Copy(f, resp.Body); err != nil { + f.Close() + os.Remove(tmp) + return fmt.Errorf("writing %s: %w", tmp, err) + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return err + } + if err := os.Rename(tmp, dest); err != nil { + os.Remove(tmp) + return fmt.Errorf("renaming %s to %s: %w", tmp, dest, err) + } + return nil +} diff --git a/internal/store/store.go b/internal/store/store.go index d626a7d1062d1425a505bae08e2e5217760a57ba..e3d74a7be40d0c1fcc153756f479616cb7392582 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -13,7 +13,8 @@ type AccessKey struct { Port int `yaml:"port"` Method string `yaml:"method"` DataLimit *DataLimit `yaml:"data_limit,omitempty"` - AWG *AWGKeyData `yaml:"awg,omitempty"` + AWG *AWGKeyData `yaml:"awg,omitempty"` + VLESS *VLESSKeyData `yaml:"vless,omitempty"` } type AWGKeyData struct { @@ -22,6 +23,10 @@ type AWGKeyData struct { AllowedIP string `yaml:"allowed_ip"` } +type VLESSKeyData struct { + UUID string `yaml:"uuid"` +} + type DataLimit struct { Bytes int64 `yaml:"bytes"` } @@ -39,6 +44,9 @@ type ServerState struct { PortForNewAccessKeys int `yaml:"port_for_new_access_keys"` AWGPrivateKey string `yaml:"awg_private_key,omitempty"` AWGPublicKey string `yaml:"awg_public_key,omitempty"` + VLESSPrivateKey string `yaml:"vless_private_key,omitempty"` + VLESSPublicKey string `yaml:"vless_public_key,omitempty"` + VLESSShortID string `yaml:"vless_short_id,omitempty"` } type Store interface { diff --git a/internal/vless/keygen.go b/internal/vless/keygen.go new file mode 100644 index 0000000000000000000000000000000000000000..3111c8bef04aa1308c26f534a2ba237956592da4 --- /dev/null +++ b/internal/vless/keygen.go @@ -0,0 +1,37 @@ +package vless + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" + + "golang.org/x/crypto/curve25519" +) + +// GenerateX25519Keypair generates a raw x25519 keypair for REALITY. +// Unlike WireGuard, no clamping is applied to the private key. +func GenerateX25519Keypair() (privBase64, pubBase64 string, err error) { + priv := make([]byte, 32) + if _, err := rand.Read(priv); err != nil { + return "", "", fmt.Errorf("generating private key: %w", err) + } + + pub, err := curve25519.X25519(priv, curve25519.Basepoint) + if err != nil { + return "", "", fmt.Errorf("computing public key: %w", err) + } + + return base64.StdEncoding.EncodeToString(priv), + base64.StdEncoding.EncodeToString(pub), + nil +} + +// GenerateShortID generates a random 8-byte REALITY short ID as a hex string. +func GenerateShortID() (string, error) { + var id [8]byte + if _, err := rand.Read(id[:]); err != nil { + return "", fmt.Errorf("generating short ID: %w", err) + } + return hex.EncodeToString(id[:]), nil +} diff --git a/internal/vless/keygen_test.go b/internal/vless/keygen_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1254caafafbd6c14da0e1762048b4923787d6ac7 --- /dev/null +++ b/internal/vless/keygen_test.go @@ -0,0 +1,68 @@ +package vless + +import ( + "encoding/base64" + "encoding/hex" + "testing" +) + +func TestGenerateX25519Keypair(t *testing.T) { + priv, pub, err := GenerateX25519Keypair() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + privBytes, err := base64.StdEncoding.DecodeString(priv) + if err != nil { + t.Fatalf("invalid base64 private key: %v", err) + } + if len(privBytes) != 32 { + t.Errorf("private key length: got %d, want 32", len(privBytes)) + } + + pubBytes, err := base64.StdEncoding.DecodeString(pub) + if err != nil { + t.Fatalf("invalid base64 public key: %v", err) + } + if len(pubBytes) != 32 { + t.Errorf("public key length: got %d, want 32", len(pubBytes)) + } + + // Keys should be different. + if priv == pub { + t.Error("private and public keys should differ") + } +} + +func TestGenerateX25519Keypair_Unique(t *testing.T) { + priv1, _, _ := GenerateX25519Keypair() + priv2, _, _ := GenerateX25519Keypair() + if priv1 == priv2 { + t.Error("two generated keypairs should not be identical") + } +} + +func TestGenerateShortID(t *testing.T) { + id, err := GenerateShortID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(id) != 16 { + t.Errorf("short ID hex length: got %d, want 16", len(id)) + } + b, err := hex.DecodeString(id) + if err != nil { + t.Fatalf("invalid hex short ID: %v", err) + } + if len(b) != 8 { + t.Errorf("short ID byte length: got %d, want 8", len(b)) + } +} + +func TestGenerateShortID_Unique(t *testing.T) { + id1, _ := GenerateShortID() + id2, _ := GenerateShortID() + if id1 == id2 { + t.Error("two generated short IDs should not be identical") + } +} diff --git a/internal/vless/protocol.go b/internal/vless/protocol.go new file mode 100644 index 0000000000000000000000000000000000000000..1e5ffa3b6d12c9b0a63a400edc2f5428ba0e9349 --- /dev/null +++ b/internal/vless/protocol.go @@ -0,0 +1,124 @@ +package vless + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +const ( + Version = 0 + + CmdTCP byte = 0x01 + CmdUDP byte = 0x02 + + AddrIPv4 byte = 0x01 + AddrDomain byte = 0x02 + AddrIPv6 byte = 0x03 +) + +// Request represents a parsed VLESS request header. +type Request struct { + UUID [16]byte + Command byte + Port uint16 + Address string + Addons []byte +} + +// Target returns the address:port string for dialing. +func (r *Request) Target() string { + return net.JoinHostPort(r.Address, fmt.Sprintf("%d", r.Port)) +} + +// ParseRequest reads a VLESS request header from r. +func ParseRequest(r io.Reader) (*Request, error) { + // Version (1 byte). + var ver [1]byte + if _, err := io.ReadFull(r, ver[:]); err != nil { + return nil, fmt.Errorf("reading version: %w", err) + } + if ver[0] != Version { + return nil, fmt.Errorf("unsupported VLESS version: %d", ver[0]) + } + + // UUID (16 bytes). + var req Request + if _, err := io.ReadFull(r, req.UUID[:]); err != nil { + return nil, fmt.Errorf("reading UUID: %w", err) + } + + // Addons length (1 byte) + addons. + var addonsLen [1]byte + if _, err := io.ReadFull(r, addonsLen[:]); err != nil { + return nil, fmt.Errorf("reading addons length: %w", err) + } + if addonsLen[0] > 0 { + req.Addons = make([]byte, addonsLen[0]) + if _, err := io.ReadFull(r, req.Addons); err != nil { + return nil, fmt.Errorf("reading addons: %w", err) + } + } + + // Command (1 byte). + var cmd [1]byte + if _, err := io.ReadFull(r, cmd[:]); err != nil { + return nil, fmt.Errorf("reading command: %w", err) + } + req.Command = cmd[0] + if req.Command != CmdTCP && req.Command != CmdUDP { + return nil, fmt.Errorf("unsupported command: 0x%02x", req.Command) + } + + // Port (2 bytes, big-endian). + var portBuf [2]byte + if _, err := io.ReadFull(r, portBuf[:]); err != nil { + return nil, fmt.Errorf("reading port: %w", err) + } + req.Port = binary.BigEndian.Uint16(portBuf[:]) + + // Address type (1 byte) + address. + var addrType [1]byte + if _, err := io.ReadFull(r, addrType[:]); err != nil { + return nil, fmt.Errorf("reading address type: %w", err) + } + + switch addrType[0] { + case AddrIPv4: + var ip [4]byte + if _, err := io.ReadFull(r, ip[:]); err != nil { + return nil, fmt.Errorf("reading IPv4 address: %w", err) + } + req.Address = net.IP(ip[:]).String() + + case AddrDomain: + var domainLen [1]byte + if _, err := io.ReadFull(r, domainLen[:]); err != nil { + return nil, fmt.Errorf("reading domain length: %w", err) + } + domain := make([]byte, domainLen[0]) + if _, err := io.ReadFull(r, domain); err != nil { + return nil, fmt.Errorf("reading domain: %w", err) + } + req.Address = string(domain) + + case AddrIPv6: + var ip [16]byte + if _, err := io.ReadFull(r, ip[:]); err != nil { + return nil, fmt.Errorf("reading IPv6 address: %w", err) + } + req.Address = net.IP(ip[:]).String() + + default: + return nil, fmt.Errorf("unsupported address type: 0x%02x", addrType[0]) + } + + return &req, nil +} + +// WriteResponse writes the VLESS response header (version 0, no addons). +func WriteResponse(w io.Writer) error { + _, err := w.Write([]byte{Version, 0x00}) + return err +} diff --git a/internal/vless/protocol_test.go b/internal/vless/protocol_test.go new file mode 100644 index 0000000000000000000000000000000000000000..483ad290930824e47e5c2d8babc43cbefd91c883 --- /dev/null +++ b/internal/vless/protocol_test.go @@ -0,0 +1,173 @@ +package vless + +import ( + "bytes" + "encoding/binary" + "net" + "testing" +) + +func buildRequest(t *testing.T, uuid [16]byte, addons []byte, cmd byte, port uint16, addrType byte, addr []byte) []byte { + t.Helper() + var buf bytes.Buffer + + buf.WriteByte(Version) // version + buf.Write(uuid[:]) // UUID + + if addons == nil { + buf.WriteByte(0) // addons length + } else { + buf.WriteByte(byte(len(addons))) + buf.Write(addons) + } + + buf.WriteByte(cmd) // command + + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], port) + buf.Write(portBuf[:]) // port + + buf.WriteByte(addrType) // address type + buf.Write(addr) // address + + return buf.Bytes() +} + +func TestParseRequest_TCPIPv4(t *testing.T) { + uuid := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + ip := net.ParseIP("192.168.1.1").To4() + data := buildRequest(t, uuid, nil, CmdTCP, 8080, AddrIPv4, ip) + + req, err := ParseRequest(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.UUID != uuid { + t.Errorf("UUID mismatch: got %v, want %v", req.UUID, uuid) + } + if req.Command != CmdTCP { + t.Errorf("Command mismatch: got %d, want %d", req.Command, CmdTCP) + } + if req.Port != 8080 { + t.Errorf("Port mismatch: got %d, want 8080", req.Port) + } + if req.Address != "192.168.1.1" { + t.Errorf("Address mismatch: got %q, want %q", req.Address, "192.168.1.1") + } + if req.Target() != "192.168.1.1:8080" { + t.Errorf("Target mismatch: got %q", req.Target()) + } +} + +func TestParseRequest_UDPDomain(t *testing.T) { + uuid := [16]byte{0xaa, 0xbb, 0xcc, 0xdd} + domain := "example.com" + addrBuf := append([]byte{byte(len(domain))}, []byte(domain)...) + data := buildRequest(t, uuid, nil, CmdUDP, 53, AddrDomain, addrBuf) + + req, err := ParseRequest(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Command != CmdUDP { + t.Errorf("Command mismatch: got %d, want %d", req.Command, CmdUDP) + } + if req.Port != 53 { + t.Errorf("Port mismatch: got %d, want 53", req.Port) + } + if req.Address != "example.com" { + t.Errorf("Address mismatch: got %q, want %q", req.Address, "example.com") + } +} + +func TestParseRequest_TCPIPv6(t *testing.T) { + uuid := [16]byte{0xff} + ip := net.ParseIP("2001:db8::1").To16() + data := buildRequest(t, uuid, nil, CmdTCP, 443, AddrIPv6, ip) + + req, err := ParseRequest(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Address != "2001:db8::1" { + t.Errorf("Address mismatch: got %q, want %q", req.Address, "2001:db8::1") + } +} + +func TestParseRequest_WithAddons(t *testing.T) { + uuid := [16]byte{0x01} + addons := []byte{0x0a, 0x0b, 0x0c} + ip := net.ParseIP("10.0.0.1").To4() + data := buildRequest(t, uuid, addons, CmdTCP, 80, AddrIPv4, ip) + + req, err := ParseRequest(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(req.Addons) != 3 { + t.Errorf("Addons length mismatch: got %d, want 3", len(req.Addons)) + } + if req.Address != "10.0.0.1" { + t.Errorf("Address mismatch: got %q", req.Address) + } +} + +func TestParseRequest_BadVersion(t *testing.T) { + data := []byte{0x01} // version 1, invalid + _, err := ParseRequest(bytes.NewReader(data)) + if err == nil { + t.Fatal("expected error for bad version") + } +} + +func TestParseRequest_Truncated(t *testing.T) { + // Only version + partial UUID. + data := []byte{0x00, 0x01, 0x02} + _, err := ParseRequest(bytes.NewReader(data)) + if err == nil { + t.Fatal("expected error for truncated input") + } +} + +func TestParseRequest_BadCommand(t *testing.T) { + uuid := [16]byte{} + ip := net.ParseIP("1.2.3.4").To4() + data := buildRequest(t, uuid, nil, 0x03, 80, AddrIPv4, ip) // cmd=Mux, unsupported + + _, err := ParseRequest(bytes.NewReader(data)) + if err == nil { + t.Fatal("expected error for unsupported command") + } +} + +func TestParseRequest_BadAddrType(t *testing.T) { + uuid := [16]byte{} + // Build manually with bad address type. + var buf bytes.Buffer + buf.WriteByte(Version) + buf.Write(uuid[:]) + buf.WriteByte(0) // no addons + buf.WriteByte(0x01) // TCP + var portBuf [2]byte + binary.BigEndian.PutUint16(portBuf[:], 80) + buf.Write(portBuf[:]) + buf.WriteByte(0xFF) // bad address type + + _, err := ParseRequest(bytes.NewReader(buf.Bytes())) + if err == nil { + t.Fatal("expected error for bad address type") + } +} + +func TestWriteResponse(t *testing.T) { + var buf bytes.Buffer + if err := WriteResponse(&buf); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if buf.Len() != 2 { + t.Fatalf("response length mismatch: got %d, want 2", buf.Len()) + } + if buf.Bytes()[0] != 0x00 || buf.Bytes()[1] != 0x00 { + t.Errorf("response bytes mismatch: got %v, want [0x00, 0x00]", buf.Bytes()) + } +} diff --git a/internal/vless/relay.go b/internal/vless/relay.go new file mode 100644 index 0000000000000000000000000000000000000000..69d2ddc33f88868bc5731f555c75bc7c8e356352 --- /dev/null +++ b/internal/vless/relay.go @@ -0,0 +1,114 @@ +package vless + +import ( + "encoding/binary" + "io" + "net" + "sync" + "time" +) + +const udpIdleTimeout = 5 * time.Minute + +// relayTCP dials the target and relays data bidirectionally. +// The VLESS response header is written before relaying begins. +func relayTCP(clientConn net.Conn, target string) error { + remote, err := net.DialTimeout("tcp", target, 10*time.Second) + if err != nil { + return err + } + defer remote.Close() + + if err := WriteResponse(clientConn); err != nil { + return err + } + + var wg sync.WaitGroup + wg.Add(2) + + // client → remote + go func() { + defer wg.Done() + io.Copy(remote, clientConn) + remote.(*net.TCPConn).CloseWrite() + }() + + // remote → client + go func() { + defer wg.Done() + io.Copy(clientConn, remote) + }() + + wg.Wait() + return nil +} + +// relayUDP handles VLESS UDP-over-TCP relay. +// UDP datagrams are framed as [2B length BE][payload] on the TCP stream. +func relayUDP(clientConn net.Conn, target string) error { + remoteAddr, err := net.ResolveUDPAddr("udp", target) + if err != nil { + return err + } + + udpConn, err := net.DialUDP("udp", nil, remoteAddr) + if err != nil { + return err + } + defer udpConn.Close() + + if err := WriteResponse(clientConn); err != nil { + return err + } + + var wg sync.WaitGroup + wg.Add(2) + + // client TCP → remote UDP: read [2B len][payload] frames, send as UDP datagrams. + go func() { + defer wg.Done() + defer udpConn.Close() + var lenBuf [2]byte + for { + if _, err := io.ReadFull(clientConn, lenBuf[:]); err != nil { + return + } + payloadLen := binary.BigEndian.Uint16(lenBuf[:]) + if payloadLen == 0 { + continue + } + buf := make([]byte, payloadLen) + if _, err := io.ReadFull(clientConn, buf); err != nil { + return + } + udpConn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if _, err := udpConn.Write(buf); err != nil { + return + } + } + }() + + // remote UDP → client TCP: read UDP datagrams, write as [2B len][payload] frames. + go func() { + defer wg.Done() + buf := make([]byte, 64*1024) + var lenBuf [2]byte + for { + udpConn.SetReadDeadline(time.Now().Add(udpIdleTimeout)) + n, err := udpConn.Read(buf) + if err != nil { + return + } + binary.BigEndian.PutUint16(lenBuf[:], uint16(n)) + if _, err := clientConn.Write(lenBuf[:]); err != nil { + return + } + if _, err := clientConn.Write(buf[:n]); err != nil { + return + } + } + }() + + wg.Wait() + return nil +} diff --git a/internal/vless/server.go b/internal/vless/server.go new file mode 100644 index 0000000000000000000000000000000000000000..2bbcebc4bb3b8784ba5f2fcfbc5683eb18bac8b7 --- /dev/null +++ b/internal/vless/server.go @@ -0,0 +1,169 @@ +package vless + +import ( + "context" + "fmt" + "log/slog" + "net" + "sync" + "time" + + "github.com/google/uuid" + "github.com/xtls/reality" + + "sourcecraft.dev/bigbes/shroud/internal/store" +) + +// Config holds the VLESS+REALITY server configuration. +type Config struct { + ListenAddr string + PrivateKey []byte // 32-byte raw x25519 private key + PublicKey []byte // 32-byte raw x25519 public key (for share links) + ShortID [8]byte + ServerNames []string + Dest string + Show bool +} + +// Server is a VLESS proxy server with REALITY obfuscation. +type Server struct { + cfg Config + listener net.Listener + logger *slog.Logger + + mu sync.RWMutex + uuids map[[16]byte]string // UUID → access key ID + cancel context.CancelFunc +} + +// New creates a new VLESS+REALITY server. +func New(cfg Config, logger *slog.Logger) *Server { + return &Server{ + cfg: cfg, + logger: logger.With("component", "vless"), + uuids: make(map[[16]byte]string), + } +} + +// Start initializes the REALITY listener and begins accepting connections. +func (s *Server) Start() error { + shortIds := map[[8]byte]bool{s.cfg.ShortID: true} + + serverNames := make(map[string]bool, len(s.cfg.ServerNames)) + for _, name := range s.cfg.ServerNames { + serverNames[name] = true + } + + realityCfg := &reality.Config{ + PrivateKey: s.cfg.PrivateKey, + ShortIds: shortIds, + ServerNames: serverNames, + Dest: s.cfg.Dest, + Show: s.cfg.Show, + } + + ln, err := reality.Listen("tcp", s.cfg.ListenAddr, realityCfg) + if err != nil { + return fmt.Errorf("reality listen: %w", err) + } + s.listener = ln + + ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + + go s.acceptLoop(ctx) + + return nil +} + +// SyncKeys rebuilds the UUID authentication map from the given access keys. +func (s *Server) SyncKeys(keys []store.AccessKey) error { + newUUIDs := make(map[[16]byte]string) + for _, k := range keys { + if k.VLESS == nil { + continue + } + parsed, err := uuid.Parse(k.VLESS.UUID) + if err != nil { + s.logger.Warn("Skipping VLESS key with invalid UUID.", "keyID", k.ID, "err", err) + continue + } + newUUIDs[parsed] = k.ID + } + + s.mu.Lock() + s.uuids = newUUIDs + s.mu.Unlock() + + s.logger.Info("VLESS keys synced.", "count", len(newUUIDs)) + return nil +} + +// Stop shuts down the server. +func (s *Server) Stop() { + if s.cancel != nil { + s.cancel() + } + if s.listener != nil { + s.listener.Close() + } + s.logger.Info("VLESS server stopped.") +} + +func (s *Server) acceptLoop(ctx context.Context) { + for { + conn, err := s.listener.Accept() + if err != nil { + select { + case <-ctx.Done(): + return + default: + s.logger.Error("VLESS accept error.", "err", err) + continue + } + } + go s.handleConn(conn) + } +} + +func (s *Server) handleConn(conn net.Conn) { + defer conn.Close() + + // Deadline for header parsing. + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + + req, err := ParseRequest(conn) + if err != nil { + s.logger.Debug("Failed to parse VLESS request.", "err", err) + return + } + + // Clear read deadline for relay. + conn.SetReadDeadline(time.Time{}) + + // Authenticate UUID. + s.mu.RLock() + keyID, ok := s.uuids[req.UUID] + s.mu.RUnlock() + + if !ok { + s.logger.Debug("Unknown VLESS UUID, closing connection.") + return + } + + target := req.Target() + + switch req.Command { + case CmdTCP: + s.logger.Debug("VLESS TCP relay.", "keyID", keyID, "target", target) + if err := relayTCP(conn, target); err != nil { + s.logger.Debug("VLESS TCP relay ended.", "keyID", keyID, "err", err) + } + + case CmdUDP: + s.logger.Debug("VLESS UDP relay.", "keyID", keyID, "target", target) + if err := relayUDP(conn, target); err != nil { + s.logger.Debug("VLESS UDP relay ended.", "keyID", keyID, "err", err) + } + } +}