184 lines
3.5 KiB
Go
184 lines
3.5 KiB
Go
package cache
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.ktf.ninja/tabledevil/gists/projects/go-tools/go/goipgrep/internal/ipinfo"
|
|
)
|
|
|
|
type entry struct {
|
|
Info ipinfo.Info `json:"info"`
|
|
FetchedAt time.Time `json:"fetched_at"`
|
|
}
|
|
|
|
type fileFormat struct {
|
|
Version int `json:"version"`
|
|
Entries map[string]entry `json:"entries"`
|
|
}
|
|
|
|
type Store struct {
|
|
path string
|
|
ttl time.Duration
|
|
max int
|
|
changed bool
|
|
data fileFormat
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func DefaultPath() string {
|
|
dir, err := os.UserCacheDir()
|
|
if err != nil {
|
|
home, _ := os.UserHomeDir()
|
|
if home == "" {
|
|
return ".ipgrep.ipinfo.json"
|
|
}
|
|
return filepath.Join(home, ".cache", "ipgrep", "ipinfo.json")
|
|
}
|
|
return filepath.Join(dir, "ipgrep", "ipinfo.json")
|
|
}
|
|
|
|
func Clear(path string) error {
|
|
return os.Remove(path)
|
|
}
|
|
|
|
func Load(path string, ttl time.Duration, maxEntries int) (*Store, error) {
|
|
s := &Store{
|
|
path: path,
|
|
ttl: ttl,
|
|
max: maxEntries,
|
|
data: fileFormat{Version: 1, Entries: make(map[string]entry)},
|
|
}
|
|
b, err := os.ReadFile(path)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return s, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Try v1 format.
|
|
var ff fileFormat
|
|
if err := json.Unmarshal(b, &ff); err == nil && ff.Version == 1 && ff.Entries != nil {
|
|
s.data = ff
|
|
return s, nil
|
|
}
|
|
|
|
// Try legacy format: map[ip]ipinfo.Info.
|
|
var legacy map[string]ipinfo.Info
|
|
if err := json.Unmarshal(b, &legacy); err == nil && legacy != nil {
|
|
now := time.Time{}
|
|
for ip, info := range legacy {
|
|
s.data.Entries[ip] = entry{Info: info, FetchedAt: now}
|
|
}
|
|
// Don't mark changed just because we loaded legacy.
|
|
return s, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("unrecognized cache format: %s", path)
|
|
}
|
|
|
|
func (s *Store) Get(ip string) (ipinfo.Info, bool) {
|
|
s.mu.RLock()
|
|
e, ok := s.data.Entries[ip]
|
|
s.mu.RUnlock()
|
|
if !ok {
|
|
return ipinfo.Info{}, false
|
|
}
|
|
if s.ttl > 0 && !e.FetchedAt.IsZero() && time.Since(e.FetchedAt) > s.ttl {
|
|
return ipinfo.Info{}, false
|
|
}
|
|
return e.Info, true
|
|
}
|
|
|
|
func (s *Store) Put(ip string, info ipinfo.Info) {
|
|
s.mu.Lock()
|
|
s.data.Entries[ip] = entry{Info: info, FetchedAt: time.Now().UTC()}
|
|
s.changed = true
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *Store) Changed() bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.changed
|
|
}
|
|
|
|
func (s *Store) Save() error {
|
|
if s.path == "" {
|
|
return errors.New("cache path is empty")
|
|
}
|
|
|
|
s.prune()
|
|
|
|
s.mu.RLock()
|
|
payload := s.data
|
|
s.mu.RUnlock()
|
|
|
|
dir := filepath.Dir(s.path)
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
|
|
tmp, err := os.CreateTemp(dir, "ipgrep-cache-*.json")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tmpName := tmp.Name()
|
|
enc := json.NewEncoder(tmp)
|
|
enc.SetIndent("", " ")
|
|
if err := enc.Encode(payload); err != nil {
|
|
_ = tmp.Close()
|
|
_ = os.Remove(tmpName)
|
|
return err
|
|
}
|
|
if err := tmp.Close(); err != nil {
|
|
_ = os.Remove(tmpName)
|
|
return err
|
|
}
|
|
if err := os.Rename(tmpName, s.path); err != nil {
|
|
_ = os.Remove(tmpName)
|
|
return err
|
|
}
|
|
s.mu.Lock()
|
|
s.changed = false
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) prune() {
|
|
if s.max <= 0 {
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if len(s.data.Entries) <= s.max {
|
|
return
|
|
}
|
|
|
|
type kv struct {
|
|
k string
|
|
t time.Time
|
|
}
|
|
all := make([]kv, 0, len(s.data.Entries))
|
|
for k, v := range s.data.Entries {
|
|
all = append(all, kv{k: k, t: v.FetchedAt})
|
|
}
|
|
sort.Slice(all, func(i, j int) bool {
|
|
// zero timestamps (legacy) sort oldest.
|
|
return all[i].t.Before(all[j].t)
|
|
})
|
|
for len(s.data.Entries) > s.max {
|
|
k := all[0].k
|
|
all = all[1:]
|
|
delete(s.data.Entries, k)
|
|
s.changed = true
|
|
}
|
|
}
|