// Package mmdb handles downloading, caching, and auto-updating MaxMind MMDB files.
package mmdb
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"sync"
"time"
)
const (
DefaultCountryURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb"
DefaultASNURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-ASN.mmdb"
DefaultCacheDir = "/var/lib/shroud/mmdb"
countryFile = "GeoLite2-Country.mmdb"
asnFile = "GeoLite2-ASN.mmdb"
updateInterval = 24 * time.Hour
)
// Config holds MMDB source configuration.
type Config struct {
CountryURL string // URL or local path for country DB. Empty disables country lookups.
ASNURL string // URL or local path for ASN DB. Empty disables ASN lookups.
CacheDir string // Directory to cache downloaded files.
AutoUpdate bool // Whether to auto-update from URLs daily.
}
// Manager handles MMDB file lifecycle: download, cache, and periodic updates.
type Manager struct {
cfg Config
logger *slog.Logger
cancel context.CancelFunc
wg sync.WaitGroup
}
// Result holds resolved local paths to MMDB files.
type Result struct {
CountryPath string
ASNPath string
}
// NewManager creates a new MMDB manager and ensures files are available.
// It downloads missing files from URLs if needed.
func NewManager(cfg Config, logger *slog.Logger) (*Manager, error) {
if cfg.CacheDir == "" {
cfg.CacheDir = DefaultCacheDir
}
if err := os.MkdirAll(cfg.CacheDir, 0o755); err != nil {
return nil, fmt.Errorf("creating MMDB cache dir %s: %w", cfg.CacheDir, err)
}
m := &Manager{cfg: cfg, logger: logger}
return m, nil
}
// Resolve ensures MMDB files are available locally and returns their paths.
// For URLs, it downloads the file if not cached. For local paths, it returns them as-is.
func (m *Manager) Resolve() (Result, error) {
var res Result
var err error
if m.cfg.CountryURL != "" {
res.CountryPath, err = m.resolve(m.cfg.CountryURL, countryFile)
if err != nil {
return res, fmt.Errorf("resolving country DB: %w", err)
}
}
if m.cfg.ASNURL != "" {
res.ASNPath, err = m.resolve(m.cfg.ASNURL, asnFile)
if err != nil {
return res, fmt.Errorf("resolving ASN DB: %w", err)
}
}
return res, nil
}
// StartAutoUpdate launches a background goroutine that re-downloads MMDB files daily.
// Call Stop() to cancel.
func (m *Manager) StartAutoUpdate() {
if !m.cfg.AutoUpdate || !m.hasURLSources() {
return
}
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
m.wg.Add(1)
go func() {
defer m.wg.Done()
m.logger.Info("MMDB auto-update enabled.", "interval", updateInterval)
ticker := time.NewTicker(updateInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.update()
}
}
}()
}
// Stop cancels auto-update and waits for the goroutine to exit.
func (m *Manager) Stop() {
if m.cancel != nil {
m.cancel()
m.wg.Wait()
}
}
// resolve returns a local path for the given source.
// If source is a URL (starts with http:// or https://), it downloads to cacheDir/filename.
// Otherwise it treats source as a local file path.
func (m *Manager) resolve(source, filename string) (string, error) {
if !isURL(source) {
// Local file path — just verify it exists.
if _, err := os.Stat(source); err != nil {
return "", fmt.Errorf("local file %s: %w", source, err)
}
return source, nil
}
dest := filepath.Join(m.cfg.CacheDir, filename)
if _, err := os.Stat(dest); err == nil {
return dest, nil // Already cached.
}
if err := download(source, dest); err != nil {
return "", err
}
m.logger.Info("Downloaded MMDB file.", "url", source, "path", dest)
return dest, nil
}
func (m *Manager) update() {
if m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL) {
dest := filepath.Join(m.cfg.CacheDir, countryFile)
if err := download(m.cfg.CountryURL, dest); err != nil {
m.logger.Error("Failed to update country DB.", "err", err)
} else {
m.logger.Info("Updated country DB.", "url", m.cfg.CountryURL)
}
}
if m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL) {
dest := filepath.Join(m.cfg.CacheDir, asnFile)
if err := download(m.cfg.ASNURL, dest); err != nil {
m.logger.Error("Failed to update ASN DB.", "err", err)
} else {
m.logger.Info("Updated ASN DB.", "url", m.cfg.ASNURL)
}
}
}
func (m *Manager) hasURLSources() bool {
return (m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL)) ||
(m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL))
}
func isURL(s string) bool {
return len(s) > 8 && (s[:7] == "http://" || s[:8] == "https://")
}
func download(url, dest string) error {
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("downloading %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("downloading %s: HTTP %d", url, resp.StatusCode)
}
// Write to temp file, then rename for atomicity.
tmp := dest + ".tmp"
f, err := os.Create(tmp)
if err != nil {
return fmt.Errorf("creating temp file %s: %w", tmp, err)
}
if _, err := io.Copy(f, resp.Body); err != nil {
f.Close()
os.Remove(tmp)
return fmt.Errorf("writing %s: %w", tmp, err)
}
if err := f.Close(); err != nil {
os.Remove(tmp)
return err
}
if err := os.Rename(tmp, dest); err != nil {
os.Remove(tmp)
return fmt.Errorf("renaming %s to %s: %w", tmp, dest, err)
}
return nil
}