~bigbes/shroud

ref: 321879085cefec9f799abe941b2b76b17ad4a1db shroud/internal/mmdb/mmdb.go -rw-r--r-- 5.3 KiB
32187908 — Eugene Blikh refactor: rename Go module to go.bigb.es/shroud a month ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
// Package mmdb handles downloading, caching, and auto-updating MaxMind MMDB files.
package mmdb

import (
	"context"
	"fmt"
	"io"
	"log/slog"
	"net/http"
	"os"
	"path/filepath"
	"sync"
	"time"
)

const (
	DefaultCountryURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb"
	DefaultASNURL     = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-ASN.mmdb"
	DefaultCacheDir   = "/var/lib/shroud/mmdb"

	countryFile = "GeoLite2-Country.mmdb"
	asnFile     = "GeoLite2-ASN.mmdb"

	updateInterval = 24 * time.Hour
)

// Config holds MMDB source configuration.
type Config struct {
	CountryURL string // URL or local path for country DB. Empty disables country lookups.
	ASNURL     string // URL or local path for ASN DB. Empty disables ASN lookups.
	CacheDir   string // Directory to cache downloaded files.
	AutoUpdate bool   // Whether to auto-update from URLs daily.
}

// Manager handles MMDB file lifecycle: download, cache, and periodic updates.
type Manager struct {
	cfg    Config
	logger *slog.Logger

	cancel context.CancelFunc
	wg     sync.WaitGroup
}

// Result holds resolved local paths to MMDB files.
type Result struct {
	CountryPath string
	ASNPath     string
}

// NewManager creates a new MMDB manager and ensures files are available.
// It downloads missing files from URLs if needed.
func NewManager(cfg Config, logger *slog.Logger) (*Manager, error) {
	if cfg.CacheDir == "" {
		cfg.CacheDir = DefaultCacheDir
	}
	if err := os.MkdirAll(cfg.CacheDir, 0o755); err != nil {
		return nil, fmt.Errorf("creating MMDB cache dir %s: %w", cfg.CacheDir, err)
	}
	m := &Manager{cfg: cfg, logger: logger}
	return m, nil
}

// Resolve ensures MMDB files are available locally and returns their paths.
// For URLs, it downloads the file if not cached. For local paths, it returns them as-is.
func (m *Manager) Resolve() (Result, error) {
	var res Result
	var err error

	if m.cfg.CountryURL != "" {
		res.CountryPath, err = m.resolve(m.cfg.CountryURL, countryFile)
		if err != nil {
			return res, fmt.Errorf("resolving country DB: %w", err)
		}
	}
	if m.cfg.ASNURL != "" {
		res.ASNPath, err = m.resolve(m.cfg.ASNURL, asnFile)
		if err != nil {
			return res, fmt.Errorf("resolving ASN DB: %w", err)
		}
	}
	return res, nil
}

// StartAutoUpdate launches a background goroutine that re-downloads MMDB files daily.
// Call Stop() to cancel.
func (m *Manager) StartAutoUpdate() {
	if !m.cfg.AutoUpdate || !m.hasURLSources() {
		return
	}
	ctx, cancel := context.WithCancel(context.Background())
	m.cancel = cancel
	m.wg.Add(1)
	go func() {
		defer m.wg.Done()
		m.logger.Info("MMDB auto-update enabled.", "interval", updateInterval)
		ticker := time.NewTicker(updateInterval)
		defer ticker.Stop()
		for {
			select {
			case <-ctx.Done():
				return
			case <-ticker.C:
				m.update()
			}
		}
	}()
}

// Stop cancels auto-update and waits for the goroutine to exit.
func (m *Manager) Stop() {
	if m.cancel != nil {
		m.cancel()
		m.wg.Wait()
	}
}

// resolve returns a local path for the given source.
// If source is a URL (starts with http:// or https://), it downloads to cacheDir/filename.
// Otherwise it treats source as a local file path.
func (m *Manager) resolve(source, filename string) (string, error) {
	if !isURL(source) {
		// Local file path — just verify it exists.
		if _, err := os.Stat(source); err != nil {
			return "", fmt.Errorf("local file %s: %w", source, err)
		}
		return source, nil
	}
	dest := filepath.Join(m.cfg.CacheDir, filename)
	if _, err := os.Stat(dest); err == nil {
		return dest, nil // Already cached.
	}
	if err := download(source, dest); err != nil {
		return "", err
	}
	m.logger.Info("Downloaded MMDB file.", "url", source, "path", dest)
	return dest, nil
}

func (m *Manager) update() {
	if m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL) {
		dest := filepath.Join(m.cfg.CacheDir, countryFile)
		if err := download(m.cfg.CountryURL, dest); err != nil {
			m.logger.Error("Failed to update country DB.", "err", err)
		} else {
			m.logger.Info("Updated country DB.", "url", m.cfg.CountryURL)
		}
	}
	if m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL) {
		dest := filepath.Join(m.cfg.CacheDir, asnFile)
		if err := download(m.cfg.ASNURL, dest); err != nil {
			m.logger.Error("Failed to update ASN DB.", "err", err)
		} else {
			m.logger.Info("Updated ASN DB.", "url", m.cfg.ASNURL)
		}
	}
}

func (m *Manager) hasURLSources() bool {
	return (m.cfg.CountryURL != "" && isURL(m.cfg.CountryURL)) ||
		(m.cfg.ASNURL != "" && isURL(m.cfg.ASNURL))
}

func isURL(s string) bool {
	return len(s) > 8 && (s[:7] == "http://" || s[:8] == "https://")
}

func download(url, dest string) error {
	resp, err := http.Get(url)
	if err != nil {
		return fmt.Errorf("downloading %s: %w", url, err)
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		return fmt.Errorf("downloading %s: HTTP %d", url, resp.StatusCode)
	}

	// Write to temp file, then rename for atomicity.
	tmp := dest + ".tmp"
	f, err := os.Create(tmp)
	if err != nil {
		return fmt.Errorf("creating temp file %s: %w", tmp, err)
	}
	if _, err := io.Copy(f, resp.Body); err != nil {
		f.Close()
		os.Remove(tmp)
		return fmt.Errorf("writing %s: %w", tmp, err)
	}
	if err := f.Close(); err != nil {
		os.Remove(tmp)
		return err
	}
	if err := os.Rename(tmp, dest); err != nil {
		os.Remove(tmp)
		return fmt.Errorf("renaming %s to %s: %w", tmp, dest, err)
	}
	return nil
}