package middleware import ( "context" "net/http" "net/http/httptest" "testing" ) func TestRequestIDGeneratesWhenAbsent(t *testing.T) { var seenInHandler string h := RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { seenInHandler = GetRequestID(r.Context()) })) req := httptest.NewRequestWithContext(t.Context(), "GET", "/", http.NoBody) rr := httptest.NewRecorder() h.ServeHTTP(rr, req) got := rr.Header().Get("X-Request-ID") if got == "" { t.Fatal("X-Request-ID header should be set") } if got != seenInHandler { t.Errorf("response header %q != context value %q", got, seenInHandler) } // Generated UUID should be 36 chars (8-4-4-4-12). if len(got) != 36 { t.Errorf("generated id length = %d, want 36", len(got)) } } func TestRequestIDPreservesIncoming(t *testing.T) { const incoming = "trace-abc-123" var seen string h := RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { seen = GetRequestID(r.Context()) })) req := httptest.NewRequestWithContext(t.Context(), "GET", "/", http.NoBody) req.Header.Set("X-Request-ID", incoming) rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if seen != incoming { t.Errorf("context id = %q, want %q", seen, incoming) } if got := rr.Header().Get("X-Request-ID"); got != incoming { t.Errorf("response id = %q, want %q", got, incoming) } } func TestGetRequestIDEmptyContext(t *testing.T) { if got := GetRequestID(context.Background()); got != "" { t.Errorf("expected empty string, got %q", got) } } func TestGetRequestIDWrongType(t *testing.T) { ctx := context.WithValue(context.Background(), RequestIDKey, 42) // not a string if got := GetRequestID(ctx); got != "" { t.Errorf("expected empty string, got %q", got) } }