// Package mmdb handles downloading, caching, and auto-updating MaxMind MMDB files. package mmdb import ( "context" "fmt" "io" "log/slog" "net/http" "os" "path/filepath" "sync" "time" ) const ( DefaultCountryURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb" DefaultASNURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-ASN.mmdb" DefaultCacheDir = "/var/lib/shroud/mmdb" countryFile = "GeoLite2-Country.mmdb" asnFile = "GeoLite2-ASN.mmdb" updateInterval = 24 * time.Hour ) // Config holds MMDB source configuration. type Config struct { CountryURL string // URL or local path for country DB. Empty disables country lookups. ASNURL string // URL or local path for ASN DB. Empty disables ASN lookups. CacheDir string // Directory to cache downloaded files. AutoUpdate bool // Whether to auto-update from URLs daily. } // Manager handles MMDB file lifecycle: download, cache, and periodic updates. type Manager struct { cfg Config logger *slog.Logger cancel context.CancelFunc wg sync.WaitGroup } // Result holds resolved local paths to MMDB files. type Result struct { CountryPath string ASNPath string } // NewManager creates a new MMDB manager and ensures files are available. // It downloads missing files from URLs if needed. func NewManager(cfg Config, logger *slog.Logger) (*Manager, error) { if cfg.CacheDir == "" { cfg.CacheDir = DefaultCacheDir } if err := os.MkdirAll(cfg.CacheDir, 0o755); err != nil { return nil, fmt.Errorf("creating MMDB cache dir %s: %w", cfg.CacheDir, err) } m := &Manager{cfg: cfg, logger: logger} return m, nil } // Resolve ensures MMDB files are available locally and returns their paths. // For URLs, it downloads the file if not cached. For local paths, it returns them as-is. func (m *Manager) Resolve() (Result, error) { var res Result var err error if m.cfg.CountryURL != "" { res.CountryPath, err = m.resolve(m.cfg.CountryURL, countryFile) if err != nil { return res, fmt.Errorf("resolving country DB: %w", err) } } if m.cfg.ASNURL != "" { res.ASNPath, err = m.resolve(m.cfg.ASNURL, asnFile) if err != nil { return res, fmt.Errorf("resolving ASN DB: %w", err) } } return res, nil } // StartAutoUpdate launches a background goroutine that re-downloads MMDB files daily. // Call Stop() to cancel. func (m *Manager) StartAutoUpdate() { if !m.cfg.AutoUpdate || !m.hasURLSources() { return } ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel m.wg.Add(1) go func() { defer m.wg.Done() m.logger.Info("MMDB auto-update enabled.", "interval", updateInterval) ticker := time.NewTicker(updateInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: m.update() } } }() } // Stop cancels auto-update and waits for the goroutine to exit. func (m *Manager) Stop() { if m.cancel != nil { m.cancel() m.wg.Wait() } } // resolve returns a local path for the given source. // If source is a URL (starts with http:// or https://), it downloads to cacheDir/filename. // Otherwise it treats source as a local file path. func (m *Manager) resolve(source, filename string) (string, error) { if !isURL(source) { // Local file path — just verify it exists. if _, err := os.Stat(source); err != nil { return "", fmt.Errorf("local file %s: %w", source, err) } return source, nil } dest := filepath.Join(m.cfg.CacheDir, filename) if _, err := os.Stat(dest); err == nil { return dest, nil // Already cached. } if err := download(source, dest); err != nil { return "", err } m.logger.Info("Downloaded MMDB file.", "url", source, "path", dest) return dest, nil } func (m *Manager) update() { if m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL) { dest := filepath.Join(m.cfg.CacheDir, countryFile) if err := download(m.cfg.CountryURL, dest); err != nil { m.logger.Error("Failed to update country DB.", "err", err) } else { m.logger.Info("Updated country DB.", "url", m.cfg.CountryURL) } } if m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL) { dest := filepath.Join(m.cfg.CacheDir, asnFile) if err := download(m.cfg.ASNURL, dest); err != nil { m.logger.Error("Failed to update ASN DB.", "err", err) } else { m.logger.Info("Updated ASN DB.", "url", m.cfg.ASNURL) } } } func (m *Manager) hasURLSources() bool { return (m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL)) || (m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL)) } func isURL(s string) bool { return len(s) > 8 && (s[:7] == "http://" || s[:8] == "https://") } func download(url, dest string) error { resp, err := http.Get(url) if err != nil { return fmt.Errorf("downloading %s: %w", url, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("downloading %s: HTTP %d", url, resp.StatusCode) } // Write to temp file, then rename for atomicity. tmp := dest + ".tmp" f, err := os.Create(tmp) if err != nil { return fmt.Errorf("creating temp file %s: %w", tmp, err) } if _, err := io.Copy(f, resp.Body); err != nil { f.Close() os.Remove(tmp) return fmt.Errorf("writing %s: %w", tmp, err) } if err := f.Close(); err != nil { os.Remove(tmp) return err } if err := os.Rename(tmp, dest); err != nil { os.Remove(tmp) return fmt.Errorf("renaming %s to %s: %w", tmp, dest, err) } return nil }