~bigbes/huntsman

ref: 783841b91eafd678cb3895cfcc8dfd89f290ece7 huntsman/internal/server/middleware/requestid_test.go -rw-r--r-- 1.7 KiB
783841b9 — Eugene Blikh Initial commit: multi-provider search router 6 days ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package middleware

import (
	"context"
	"net/http"
	"net/http/httptest"
	"testing"
)

func TestRequestIDGeneratesWhenAbsent(t *testing.T) {
	var seenInHandler string
	h := RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
		seenInHandler = GetRequestID(r.Context())
	}))

	req := httptest.NewRequestWithContext(t.Context(), "GET", "/", http.NoBody)
	rr := httptest.NewRecorder()
	h.ServeHTTP(rr, req)

	got := rr.Header().Get("X-Request-ID")
	if got == "" {
		t.Fatal("X-Request-ID header should be set")
	}
	if got != seenInHandler {
		t.Errorf("response header %q != context value %q", got, seenInHandler)
	}
	// Generated UUID should be 36 chars (8-4-4-4-12).
	if len(got) != 36 {
		t.Errorf("generated id length = %d, want 36", len(got))
	}
}

func TestRequestIDPreservesIncoming(t *testing.T) {
	const incoming = "trace-abc-123"
	var seen string
	h := RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
		seen = GetRequestID(r.Context())
	}))

	req := httptest.NewRequestWithContext(t.Context(), "GET", "/", http.NoBody)
	req.Header.Set("X-Request-ID", incoming)
	rr := httptest.NewRecorder()
	h.ServeHTTP(rr, req)

	if seen != incoming {
		t.Errorf("context id = %q, want %q", seen, incoming)
	}
	if got := rr.Header().Get("X-Request-ID"); got != incoming {
		t.Errorf("response id = %q, want %q", got, incoming)
	}
}

func TestGetRequestIDEmptyContext(t *testing.T) {
	if got := GetRequestID(context.Background()); got != "" {
		t.Errorf("expected empty string, got %q", got)
	}
}

func TestGetRequestIDWrongType(t *testing.T) {
	ctx := context.WithValue(context.Background(), RequestIDKey, 42) // not a string
	if got := GetRequestID(ctx); got != "" {
		t.Errorf("expected empty string, got %q", got)
	}
}