package ssserver
import (
"container/list"
"context"
"fmt"
"log/slog"
"net"
"sync"
"time"
"golang.getoutline.org/sdk/transport"
"golang.getoutline.org/sdk/transport/shadowsocks"
onet "golang.getoutline.org/tunnel-server/net"
"golang.getoutline.org/tunnel-server/service"
"go.bigb.es/shroud/internal/metrics"
"go.bigb.es/shroud/internal/store"
)
const tcpReadTimeout = 59 * time.Second
type Server struct {
lnManager service.ListenerManager
natTimeout time.Duration
replayCache service.ReplayCache
serverMetrics *metrics.ServerMetrics
serviceMetrics *metrics.TransferTracker
logger *slog.Logger
mu sync.Mutex
stopFunc func() error
}
func New(
natTimeout time.Duration,
replayHistory int,
serverMetrics *metrics.ServerMetrics,
serviceMetrics *metrics.TransferTracker,
logger *slog.Logger,
) *Server {
return &Server{
lnManager: service.NewListenerManager(),
natTimeout: natTimeout,
replayCache: service.NewReplayCache(replayHistory),
serverMetrics: serverMetrics,
serviceMetrics: serviceMetrics,
logger: logger,
}
}
// SyncKeys rebuilds the Shadowsocks service with the given access keys.
// It starts new listeners before closing old ones for zero-downtime reload.
func (s *Server) SyncKeys(keys []store.AccessKey) error {
s.mu.Lock()
defer s.mu.Unlock()
// Group keys by port.
portCiphers := make(map[int]*list.List)
for _, key := range keys {
cl, ok := portCiphers[key.Port]
if !ok {
cl = list.New()
portCiphers[key.Port] = cl
}
cryptoKey, err := shadowsocks.NewEncryptionKey(key.Method, key.Password)
if err != nil {
return fmt.Errorf("creating encryption key for %s: %w", key.ID, err)
}
entry := service.MakeCipherEntry(key.ID, cryptoKey, key.Password)
cl.PushBack(&entry)
}
// Start new listeners.
startErrCh := make(chan error, 1)
stopErrCh := make(chan error, 1)
stopCh := make(chan struct{})
go func() {
lnSet := &listenerSet{
manager: s.lnManager,
closeFuncs: make(map[string]func() error),
}
defer func() {
stopErrCh <- lnSet.Close()
}()
startErrCh <- func() error {
for portNum, cipherEntries := range portCiphers {
addr := fmt.Sprintf(":%d", portNum)
ciphers := service.NewCipherList()
ciphers.Update(cipherEntries)
streamHandler, associationHandler := service.NewShadowsocksHandlers(
service.WithCiphers(ciphers),
service.WithMetrics(s.serviceMetrics),
service.WithReplayCache(&s.replayCache),
service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0)),
service.WithPacketListener(service.MakeTargetUDPListener(onet.RequirePublicIP, s.natTimeout, 0)),
service.WithLogger(s.logger),
)
ln, err := lnSet.ListenStream(addr)
if err != nil {
return fmt.Errorf("listen TCP %s: %w", addr, err)
}
s.logger.Info("TCP service started.", "address", ln.Addr().String())
go service.StreamServe(ln.AcceptStream, func(ctx context.Context, conn transport.StreamConn) {
streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn))
})
pc, err := lnSet.ListenPacket(addr)
if err != nil {
return fmt.Errorf("listen UDP %s: %w", addr, err)
}
s.logger.Info("UDP service started.", "address", pc.LocalAddr().String())
go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) {
associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn))
}, s.serverMetrics)
}
s.serverMetrics.SetNumAccessKeys(len(keys), len(portCiphers))
return nil
}()
<-stopCh
}()
if err := <-startErrCh; err != nil {
return err
}
// Stop old listeners.
if s.stopFunc != nil {
if err := s.stopFunc(); err != nil {
s.logger.Warn("Failed to stop old listeners.", "err", err)
}
}
s.stopFunc = func() error {
stopCh <- struct{}{}
return <-stopErrCh
}
return nil
}
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.stopFunc != nil {
return s.stopFunc()
}
return nil
}
type listenerSet struct {
manager service.ListenerManager
closeFuncs map[string]func() error
mu sync.Mutex
}
func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) {
ls.mu.Lock()
defer ls.mu.Unlock()
key := "stream/" + addr
if _, exists := ls.closeFuncs[key]; exists {
return nil, fmt.Errorf("stream listener for %s already exists", addr)
}
ln, err := ls.manager.ListenStream(addr)
if err != nil {
return nil, err
}
ls.closeFuncs[key] = ln.Close
return ln, nil
}
func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) {
ls.mu.Lock()
defer ls.mu.Unlock()
key := "packet/" + addr
if _, exists := ls.closeFuncs[key]; exists {
return nil, fmt.Errorf("packet listener for %s already exists", addr)
}
pc, err := ls.manager.ListenPacket(addr)
if err != nil {
return nil, err
}
ls.closeFuncs[key] = pc.Close
return pc, nil
}
func (ls *listenerSet) Close() error {
ls.mu.Lock()
defer ls.mu.Unlock()
for addr, closeFunc := range ls.closeFuncs {
if err := closeFunc(); err != nil {
return fmt.Errorf("listener %s failed to close: %w", addr, err)
}
}
ls.closeFuncs = nil
return nil
}