package auth_test
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"testing"
"time"
"sourcecraft.dev/bigbes/lethe/internal/config"
"sourcecraft.dev/bigbes/lethe/internal/server/auth"
)
// oidcTestServer hosts a minimal OIDC discovery + JWKS endpoint backed by
// an in-memory RSA key. It is not a full OP — just enough surface that the
// coreos/go-oidc verifier can fetch the JWKS and validate RS256 signatures
// produced by signToken.
type oidcTestServer struct {
srv *httptest.Server
key *rsa.PrivateKey
keyID string
issuer string
}
// newOIDCTestServer builds an httptest.Server serving discovery and JWKS
// keyed off a fresh 2048-bit RSA key. The caller must call Close (via
// t.Cleanup) when done.
func newOIDCTestServer(t *testing.T) *oidcTestServer {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("rsa.GenerateKey: %v", err)
}
o := &oidcTestServer{key: key, keyID: "test-key-1"}
mux := http.NewServeMux()
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
body := map[string]any{
"issuer": o.issuer,
"jwks_uri": o.issuer + "/jwks",
"authorization_endpoint": o.issuer + "/auth",
"token_endpoint": o.issuer + "/token",
"id_token_signing_alg_values_supported": []string{"RS256"},
"response_types_supported": []string{"id_token"},
"subject_types_supported": []string{"public"},
}
_ = json.NewEncoder(w).Encode(body)
})
mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
jwk := map[string]any{
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"kid": o.keyID,
"n": base64URLUint(key.N),
"e": base64URLUintFromInt(key.E),
}
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{jwk}})
})
o.srv = httptest.NewServer(mux)
o.issuer = o.srv.URL
t.Cleanup(o.srv.Close)
return o
}
// signToken produces a signed RS256 JWT with the supplied claims merged with
// issuer/aud/iat/exp defaults. Claims supplied by the caller win.
func (o *oidcTestServer) signToken(t *testing.T, claims map[string]any) string {
t.Helper()
header := map[string]any{
"alg": "RS256",
"typ": "JWT",
"kid": o.keyID,
}
// Merge defaults under what the caller provided so tests can override
// e.g. exp or aud explicitly.
merged := map[string]any{
"iss": o.issuer,
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
}
for k, v := range claims {
merged[k] = v
}
headerBytes, err := json.Marshal(header)
if err != nil {
t.Fatalf("marshal header: %v", err)
}
payloadBytes, err := json.Marshal(merged)
if err != nil {
t.Fatalf("marshal payload: %v", err)
}
signingInput := base64.RawURLEncoding.EncodeToString(headerBytes) +
"." + base64.RawURLEncoding.EncodeToString(payloadBytes)
digest := sha256.Sum256([]byte(signingInput))
sig, err := rsa.SignPKCS1v15(rand.Reader, o.key, crypto.SHA256, digest[:])
if err != nil {
t.Fatalf("rsa.SignPKCS1v15: %v", err)
}
return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig)
}
// newVerifier constructs a real OIDCVerifier pointed at this test server,
// audience-bound to aud.
func (o *oidcTestServer) newVerifier(t *testing.T, aud string) *auth.OIDCVerifier {
t.Helper()
v := &auth.OIDCVerifier{
Cfg: config.OIDCConfig{
Enabled: true,
Issuer: o.issuer,
Audience: aud,
UsernameClaim: "preferred_username",
},
}
if err := v.Init(context.Background()); err != nil {
t.Fatalf("OIDCVerifier.Init: %v", err)
}
return v
}
// base64URLUint encodes a big.Int as base64url with no padding, per JWA
// (RFC 7518) §6.3.1.1.
func base64URLUint(n *big.Int) string {
return base64.RawURLEncoding.EncodeToString(n.Bytes())
}
// base64URLUintFromInt encodes the public exponent (typically 65537) as the
// minimal big-endian byte representation, per JWA §6.3.1.2.
func base64URLUintFromInt(e int) string {
if e <= 0 {
panic(fmt.Sprintf("invalid RSA exponent: %d", e))
}
var buf [8]byte
binary.BigEndian.PutUint64(buf[:], uint64(e))
// Strip leading zero bytes so the encoding is minimal.
i := 0
for i < len(buf)-1 && buf[i] == 0 {
i++
}
return base64.RawURLEncoding.EncodeToString(buf[i:])
}