package middleware
import (
"bytes"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestRecoveryCatchesPanic(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
h := Recovery(logger)(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
panic("boom")
}))
rr := httptest.NewRecorder()
h.ServeHTTP(rr, httptest.NewRequestWithContext(t.Context(), "GET", "/explode", http.NoBody))
if rr.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want 500", rr.Code)
}
if !strings.Contains(rr.Body.String(), "internal server error") {
t.Errorf("body = %q", rr.Body.String())
}
logged := buf.String()
if !strings.Contains(logged, "panic recovered") {
t.Errorf("expected log to mention 'panic recovered', got %q", logged)
}
if !strings.Contains(logged, "/explode") {
t.Errorf("expected log to include path, got %q", logged)
}
}
func TestRecoveryPassesThroughWhenNoPanic(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
h := Recovery(logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("hi"))
}))
rr := httptest.NewRecorder()
h.ServeHTTP(rr, httptest.NewRequestWithContext(t.Context(), "GET", "/", http.NoBody))
if rr.Code != http.StatusTeapot {
t.Errorf("status = %d, want 418", rr.Code)
}
if rr.Body.String() != "hi" {
t.Errorf("body = %q", rr.Body.String())
}
if buf.Len() != 0 {
t.Errorf("nothing should have been logged, got %q", buf.String())
}
}