package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestIDGeneratesWhenAbsent(t *testing.T) {
var seenInHandler string
h := RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
seenInHandler = GetRequestID(r.Context())
}))
req := httptest.NewRequestWithContext(t.Context(), "GET", "/", http.NoBody)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
got := rr.Header().Get("X-Request-ID")
if got == "" {
t.Fatal("X-Request-ID header should be set")
}
if got != seenInHandler {
t.Errorf("response header %q != context value %q", got, seenInHandler)
}
// Generated UUID should be 36 chars (8-4-4-4-12).
if len(got) != 36 {
t.Errorf("generated id length = %d, want 36", len(got))
}
}
func TestRequestIDPreservesIncoming(t *testing.T) {
const incoming = "trace-abc-123"
var seen string
h := RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
seen = GetRequestID(r.Context())
}))
req := httptest.NewRequestWithContext(t.Context(), "GET", "/", http.NoBody)
req.Header.Set("X-Request-ID", incoming)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if seen != incoming {
t.Errorf("context id = %q, want %q", seen, incoming)
}
if got := rr.Header().Get("X-Request-ID"); got != incoming {
t.Errorf("response id = %q, want %q", got, incoming)
}
}
func TestGetRequestIDEmptyContext(t *testing.T) {
if got := GetRequestID(context.Background()); got != "" {
t.Errorf("expected empty string, got %q", got)
}
}
func TestGetRequestIDWrongType(t *testing.T) {
ctx := context.WithValue(context.Background(), RequestIDKey, 42) // not a string
if got := GetRequestID(ctx); got != "" {
t.Errorf("expected empty string, got %q", got)
}
}