package database
import (
"context"
"errors"
"testing"
"time"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
"sourcecraft.dev/bigbes/lethe/internal/config"
)
// openTestDB opens a fresh :memory: SQLite database with the same pragmas the
// real Database service applies, then runs Migrate. Each call returns a
// distinct in-memory database; we use the modernc.org/sqlite "?_pragma="
// query parameters to guarantee FK enforcement and WAL/busy timeout settings
// match production.
func openTestDB(t *testing.T) *sqlx.DB {
t.Helper()
dsn := buildDSN(":memory:", 5*time.Second)
db, err := sqlx.Connect("sqlite", dsn)
if err != nil {
t.Fatalf("connect: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := Migrate(db); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func insertSession(t *testing.T, db *sqlx.DB, owner, tool, host, sessionID string) {
t.Helper()
_, err := db.Exec(`
INSERT INTO sessions
(owner, tool, host, session_id, started_at, ended_at, working_dir, source_file, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`, owner, tool, host, sessionID, 1700000000, 1700000100, nil, "/tmp/x.jsonl", nil)
if err != nil {
t.Fatalf("insert session: %v", err)
}
}
func insertTurn(t *testing.T, db *sqlx.DB, owner, tool, host, sessionID, turnID, content string, toolCalls *string) {
t.Helper()
_, err := db.Exec(`
INSERT INTO turns
(owner, tool, host, session_id, turn_id, seq, role, timestamp, content, tool_calls)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, owner, tool, host, sessionID, turnID, 1, "user", 1700000050, content, toolCalls)
if err != nil {
t.Fatalf("insert turn: %v", err)
}
}
func TestMigrateIsIdempotent(t *testing.T) {
dsn := buildDSN(":memory:", 5*time.Second)
db, err := sqlx.Connect("sqlite", dsn)
if err != nil {
t.Fatalf("connect: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := Migrate(db); err != nil {
t.Fatalf("first migrate: %v", err)
}
// Second invocation must succeed (golang-migrate returns ErrNoChange,
// which Migrate translates into nil).
if err := Migrate(db); err != nil {
t.Fatalf("second migrate: %v", err)
}
}
func TestTurnInsertPopulatesTurnsFTSWithOwner(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
insertTurn(t, db, "alice", "cc", "phoebe", "s1", "t1", "hello world from alice", nil)
var n int
if err := db.Get(&n, `SELECT COUNT(*) FROM turns_fts WHERE owner = 'alice' AND turns_fts MATCH 'hello'`); err != nil {
t.Fatalf("query fts: %v", err)
}
if n != 1 {
t.Fatalf("expected 1 fts row for owner=alice matching 'hello', got %d", n)
}
}
func TestTurnUpdateUpdatesTurnsFTS(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
insertTurn(t, db, "alice", "cc", "phoebe", "s1", "t1", "original phrase", nil)
if _, err := db.Exec(`
UPDATE turns SET content = ? WHERE owner = ? AND tool = ? AND host = ? AND session_id = ? AND turn_id = ?
`, "replaced phrase", "alice", "cc", "phoebe", "s1", "t1"); err != nil {
t.Fatalf("update: %v", err)
}
var oldHits, newHits int
if err := db.Get(&oldHits, `SELECT COUNT(*) FROM turns_fts WHERE turns_fts MATCH 'original'`); err != nil {
t.Fatalf("query old: %v", err)
}
if err := db.Get(&newHits, `SELECT COUNT(*) FROM turns_fts WHERE turns_fts MATCH 'replaced'`); err != nil {
t.Fatalf("query new: %v", err)
}
if oldHits != 0 {
t.Fatalf("expected 0 hits for 'original' after update, got %d", oldHits)
}
if newHits != 1 {
t.Fatalf("expected 1 hit for 'replaced' after update, got %d", newHits)
}
}
func TestTurnDeleteRemovesFromTurnsFTS(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
insertTurn(t, db, "alice", "cc", "phoebe", "s1", "t1", "doomed content", nil)
if _, err := db.Exec(`DELETE FROM turns WHERE turn_id = ?`, "t1"); err != nil {
t.Fatalf("delete: %v", err)
}
var n int
if err := db.Get(&n, `SELECT COUNT(*) FROM turns_fts`); err != nil {
t.Fatalf("query fts: %v", err)
}
if n != 0 {
t.Fatalf("expected empty turns_fts after delete, got %d rows", n)
}
}
func TestToolOutputsFTSInsertUpdateDeleteWhenToolCallsPresent(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
tc := `{"name":"shell","args":"ls /tmp"}`
insertTurn(t, db, "alice", "cc", "phoebe", "s1", "t1", "running tool", &tc)
var hits int
if err := db.Get(&hits, `SELECT COUNT(*) FROM tool_outputs_fts WHERE owner = 'alice' AND tool_outputs_fts MATCH 'shell'`); err != nil {
t.Fatalf("query insert: %v", err)
}
if hits != 1 {
t.Fatalf("expected 1 tool_outputs_fts row, got %d", hits)
}
tc2 := `{"name":"editor","args":"open"}`
if _, err := db.Exec(`UPDATE turns SET tool_calls = ? WHERE turn_id = ?`, tc2, "t1"); err != nil {
t.Fatalf("update tool_calls: %v", err)
}
if err := db.Get(&hits, `SELECT COUNT(*) FROM tool_outputs_fts WHERE tool_outputs_fts MATCH 'editor'`); err != nil {
t.Fatalf("query update: %v", err)
}
if hits != 1 {
t.Fatalf("expected 1 hit after update, got %d", hits)
}
if err := db.Get(&hits, `SELECT COUNT(*) FROM tool_outputs_fts WHERE tool_outputs_fts MATCH 'shell'`); err != nil {
t.Fatalf("query update old: %v", err)
}
if hits != 0 {
t.Fatalf("expected 0 hits for old tool_calls after update, got %d", hits)
}
if _, err := db.Exec(`DELETE FROM turns WHERE turn_id = ?`, "t1"); err != nil {
t.Fatalf("delete: %v", err)
}
if err := db.Get(&hits, `SELECT COUNT(*) FROM tool_outputs_fts`); err != nil {
t.Fatalf("query delete: %v", err)
}
if hits != 0 {
t.Fatalf("expected empty tool_outputs_fts after delete, got %d", hits)
}
}
func TestToolOutputsFTSSkipsNullToolCalls(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
insertTurn(t, db, "alice", "cc", "phoebe", "s1", "t1", "no tool call here", nil)
var n int
if err := db.Get(&n, `SELECT COUNT(*) FROM tool_outputs_fts`); err != nil {
t.Fatalf("count: %v", err)
}
if n != 0 {
t.Fatalf("expected tool_outputs_fts empty when tool_calls is NULL, got %d", n)
}
}
func TestForeignKeyRejectsOrphanTurn(t *testing.T) {
db := openTestDB(t)
// No sessions row inserted.
_, err := db.Exec(`
INSERT INTO turns
(owner, tool, host, session_id, turn_id, seq, role, timestamp, content)
VALUES ('alice', 'cc', 'phoebe', 'ghost', 't1', 1, 'user', 1700000050, 'no parent')
`)
if err == nil {
t.Fatalf("expected FK violation, got nil")
}
}
func TestTwoOwnersSameSessionTriple(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
insertSession(t, db, "bob", "cc", "phoebe", "s1")
var n int
if err := db.Get(&n, `SELECT COUNT(*) FROM sessions WHERE tool = 'cc' AND host = 'phoebe' AND session_id = 's1'`); err != nil {
t.Fatalf("count: %v", err)
}
if n != 2 {
t.Fatalf("expected 2 sessions across owners, got %d", n)
}
}
func TestFTSQueryFiltersByOwner(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
insertSession(t, db, "bob", "cc", "phoebe", "s2")
insertTurn(t, db, "alice", "cc", "phoebe", "s1", "t1", "the quick brown fox", nil)
insertTurn(t, db, "bob", "cc", "phoebe", "s2", "t1", "the quick brown fox", nil)
var alice, bob int
if err := db.Get(&alice, `SELECT COUNT(*) FROM turns_fts WHERE owner = 'alice' AND turns_fts MATCH 'quick'`); err != nil {
t.Fatalf("alice: %v", err)
}
if err := db.Get(&bob, `SELECT COUNT(*) FROM turns_fts WHERE owner = 'bob' AND turns_fts MATCH 'quick'`); err != nil {
t.Fatalf("bob: %v", err)
}
if alice != 1 || bob != 1 {
t.Fatalf("expected 1 hit per owner, got alice=%d bob=%d", alice, bob)
}
// And cross-check overall row count is exactly 2.
var total int
if err := db.Get(&total, `SELECT COUNT(*) FROM turns_fts WHERE turns_fts MATCH 'quick'`); err != nil {
t.Fatalf("total: %v", err)
}
if total != 2 {
t.Fatalf("expected 2 total fts hits, got %d", total)
}
}
func TestUpsertFiresUpdateTriggerAndKeepsFTSCoherent(t *testing.T) {
// Phase 7's ingest path uses INSERT ... ON CONFLICT DO UPDATE, which
// fires the UPDATE trigger (not INSERT) when the conflict branch is
// taken. Pin that contract here so a future SQLite/FTS5 regression
// trips a test instead of corrupting the index in production.
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
// First insert.
if _, err := db.Exec(`
INSERT INTO turns
(owner, tool, host, session_id, turn_id, seq, role, timestamp, content, tool_calls)
VALUES ('alice','cc','phoebe','s1','t1',1,'user',1700000050,'first version', '{"name":"first"}')
`); err != nil {
t.Fatalf("first insert: %v", err)
}
// Upsert with new content + new tool_calls. We exercise the same shape
// the ingest service will use: ON CONFLICT on the full composite PK
// with DO UPDATE SET on the mutating columns.
if _, err := db.Exec(`
INSERT INTO turns
(owner, tool, host, session_id, turn_id, seq, role, timestamp, content, tool_calls)
VALUES ('alice','cc','phoebe','s1','t1',1,'user',1700000060,'second version', '{"name":"second"}')
ON CONFLICT (owner, tool, host, session_id, turn_id) DO UPDATE SET
content = excluded.content,
tool_calls = excluded.tool_calls,
timestamp = excluded.timestamp
`); err != nil {
t.Fatalf("upsert: %v", err)
}
// turns_fts: only "second" should match; "first" should not.
var n int
if err := db.Get(&n, `SELECT COUNT(*) FROM turns_fts WHERE turns_fts MATCH 'first'`); err != nil {
t.Fatalf("query first: %v", err)
}
if n != 0 {
t.Fatalf("expected old 'first' content gone after upsert, got %d hits", n)
}
if err := db.Get(&n, `SELECT COUNT(*) FROM turns_fts WHERE turns_fts MATCH 'second'`); err != nil {
t.Fatalf("query second: %v", err)
}
if n != 1 {
t.Fatalf("expected 1 hit for new 'second' content after upsert, got %d", n)
}
// tool_outputs_fts: same expectation on the JSON column.
if err := db.Get(&n, `SELECT COUNT(*) FROM tool_outputs_fts WHERE tool_outputs_fts MATCH 'first'`); err != nil {
t.Fatalf("query tc first: %v", err)
}
if n != 0 {
t.Fatalf("expected old tool_calls gone after upsert, got %d", n)
}
if err := db.Get(&n, `SELECT COUNT(*) FROM tool_outputs_fts WHERE tool_outputs_fts MATCH 'second'`); err != nil {
t.Fatalf("query tc second: %v", err)
}
if n != 1 {
t.Fatalf("expected 1 hit for new tool_calls after upsert, got %d", n)
}
}
func TestDatabaseInitDestroyOnMemoryDSN(t *testing.T) {
// End-to-end: the steward Init/Destroy contract drives the whole stack
// (DSN build + connect + migrate + close).
d := &Database{
Cfg: config.DatabaseConfig{
Path: ":memory:",
BusyTimeout: 5 * time.Second,
},
}
ctx := context.Background()
if err := d.Init(ctx); err != nil {
t.Fatalf("Init: %v", err)
}
if d.DB == nil {
t.Fatalf("expected DB populated after Init")
}
// Confirm migrations ran end-to-end through Init.
var n int
if err := d.DB.Get(&n, `SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = 'sessions'`); err != nil {
t.Fatalf("query schema: %v", err)
}
if n != 1 {
t.Fatalf("expected sessions table to exist, got %d", n)
}
if err := d.Destroy(ctx); err != nil {
t.Fatalf("Destroy: %v", err)
}
// Idempotent on second Destroy.
if err := d.Destroy(ctx); err != nil {
t.Fatalf("second Destroy must be a no-op, got %v", err)
}
}
func TestInTxCommitAndRollback(t *testing.T) {
db := openTestDB(t)
insertSession(t, db, "alice", "cc", "phoebe", "s1")
// Commit path: insert a turn inside InTx, expect it visible after.
if err := InTx(context.Background(), db, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`
INSERT INTO turns
(owner, tool, host, session_id, turn_id, seq, role, timestamp, content)
VALUES ('alice', 'cc', 'phoebe', 's1', 't1', 1, 'user', 1700000050, 'committed')
`)
return err
}); err != nil {
t.Fatalf("InTx commit: %v", err)
}
var n int
if err := db.Get(&n, `SELECT COUNT(*) FROM turns WHERE turn_id = 't1'`); err != nil {
t.Fatalf("count: %v", err)
}
if n != 1 {
t.Fatalf("commit path: expected 1 turn, got %d", n)
}
// Rollback path: error inside fn rolls back; InTx returns the error.
sentinel := errors.New("rollback me")
err := InTx(context.Background(), db, func(tx *sqlx.Tx) error {
if _, err := tx.Exec(`
INSERT INTO turns
(owner, tool, host, session_id, turn_id, seq, role, timestamp, content)
VALUES ('alice', 'cc', 'phoebe', 's1', 't2', 2, 'user', 1700000060, 'rolled back')
`); err != nil {
return err
}
return sentinel
})
if !errors.Is(err, sentinel) {
t.Fatalf("expected sentinel error from InTx, got %v", err)
}
if err := db.Get(&n, `SELECT COUNT(*) FROM turns WHERE turn_id = 't2'`); err != nil {
t.Fatalf("count after rollback: %v", err)
}
if n != 0 {
t.Fatalf("rollback path: expected 0 rows for t2, got %d", n)
}
}