package ssserver import ( "container/list" "context" "fmt" "log/slog" "net" "sync" "time" "golang.getoutline.org/sdk/transport" "golang.getoutline.org/sdk/transport/shadowsocks" onet "golang.getoutline.org/tunnel-server/net" "golang.getoutline.org/tunnel-server/service" "sourcecraft.dev/bigbes/shroud/internal/metrics" "sourcecraft.dev/bigbes/shroud/internal/store" ) const tcpReadTimeout = 59 * time.Second type Server struct { lnManager service.ListenerManager natTimeout time.Duration replayCache service.ReplayCache serverMetrics *metrics.ServerMetrics serviceMetrics *metrics.TransferTracker logger *slog.Logger mu sync.Mutex stopFunc func() error } func New( natTimeout time.Duration, replayHistory int, serverMetrics *metrics.ServerMetrics, serviceMetrics *metrics.TransferTracker, logger *slog.Logger, ) *Server { return &Server{ lnManager: service.NewListenerManager(), natTimeout: natTimeout, replayCache: service.NewReplayCache(replayHistory), serverMetrics: serverMetrics, serviceMetrics: serviceMetrics, logger: logger, } } // SyncKeys rebuilds the Shadowsocks service with the given access keys. // It starts new listeners before closing old ones for zero-downtime reload. func (s *Server) SyncKeys(keys []store.AccessKey) error { s.mu.Lock() defer s.mu.Unlock() // Group keys by port. portCiphers := make(map[int]*list.List) for _, key := range keys { cl, ok := portCiphers[key.Port] if !ok { cl = list.New() portCiphers[key.Port] = cl } cryptoKey, err := shadowsocks.NewEncryptionKey(key.Method, key.Password) if err != nil { return fmt.Errorf("creating encryption key for %s: %w", key.ID, err) } entry := service.MakeCipherEntry(key.ID, cryptoKey, key.Password) cl.PushBack(&entry) } // Start new listeners. startErrCh := make(chan error, 1) stopErrCh := make(chan error, 1) stopCh := make(chan struct{}) go func() { lnSet := &listenerSet{ manager: s.lnManager, closeFuncs: make(map[string]func() error), } defer func() { stopErrCh <- lnSet.Close() }() startErrCh <- func() error { for portNum, cipherEntries := range portCiphers { addr := fmt.Sprintf(":%d", portNum) ciphers := service.NewCipherList() ciphers.Update(cipherEntries) streamHandler, associationHandler := service.NewShadowsocksHandlers( service.WithCiphers(ciphers), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0)), service.WithPacketListener(service.MakeTargetUDPListener(onet.RequirePublicIP, s.natTimeout, 0)), service.WithLogger(s.logger), ) ln, err := lnSet.ListenStream(addr) if err != nil { return fmt.Errorf("listen TCP %s: %w", addr, err) } s.logger.Info("TCP service started.", "address", ln.Addr().String()) go service.StreamServe(ln.AcceptStream, func(ctx context.Context, conn transport.StreamConn) { streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn)) }) pc, err := lnSet.ListenPacket(addr) if err != nil { return fmt.Errorf("listen UDP %s: %w", addr, err) } s.logger.Info("UDP service started.", "address", pc.LocalAddr().String()) go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) { associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) }, s.serverMetrics) } s.serverMetrics.SetNumAccessKeys(len(keys), len(portCiphers)) return nil }() <-stopCh }() if err := <-startErrCh; err != nil { return err } // Stop old listeners. if s.stopFunc != nil { if err := s.stopFunc(); err != nil { s.logger.Warn("Failed to stop old listeners.", "err", err) } } s.stopFunc = func() error { stopCh <- struct{}{} return <-stopErrCh } return nil } func (s *Server) Stop() error { s.mu.Lock() defer s.mu.Unlock() if s.stopFunc != nil { return s.stopFunc() } return nil } type listenerSet struct { manager service.ListenerManager closeFuncs map[string]func() error mu sync.Mutex } func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) { ls.mu.Lock() defer ls.mu.Unlock() key := "stream/" + addr if _, exists := ls.closeFuncs[key]; exists { return nil, fmt.Errorf("stream listener for %s already exists", addr) } ln, err := ls.manager.ListenStream(addr) if err != nil { return nil, err } ls.closeFuncs[key] = ln.Close return ln, nil } func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { ls.mu.Lock() defer ls.mu.Unlock() key := "packet/" + addr if _, exists := ls.closeFuncs[key]; exists { return nil, fmt.Errorf("packet listener for %s already exists", addr) } pc, err := ls.manager.ListenPacket(addr) if err != nil { return nil, err } ls.closeFuncs[key] = pc.Close return pc, nil } func (ls *listenerSet) Close() error { ls.mu.Lock() defer ls.mu.Unlock() for addr, closeFunc := range ls.closeFuncs { if err := closeFunc(); err != nil { return fmt.Errorf("listener %s failed to close: %w", addr, err) } } ls.closeFuncs = nil return nil }