From 6b25c7e500e14075cbd3c15c40e7f393ba3c3682 Mon Sep 17 00:00:00 2001 From: Allo Date: Sun, 21 Apr 2024 21:12:00 +0800 Subject: [PATCH] refactor: temporarily cache geo data --- infra/conf/dns.go | 14 ++- infra/conf/router.go | 181 ++++++++++++++++++++++++++------------ infra/conf/router_test.go | 5 +- 3 files changed, 138 insertions(+), 62 deletions(-) diff --git a/infra/conf/dns.go b/infra/conf/dns.go index a0f3155c..cee414c8 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -73,8 +73,11 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) { var domains []*dns.NameServer_PriorityDomain var originalRules []*dns.NameServer_OriginalRule + cache := NewGeoCache() + defer cache.Clear() + for _, rule := range c.Domains { - parsedDomain, err := parseDomainRule(rule) + parsedDomain, err := cache.ParseDomainRule(rule) if err != nil { return nil, newError("invalid domain rule: ", rule).Base(err) } @@ -91,7 +94,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) { }) } - geoipList, err := ToCidrList(c.ExpectIPs) + geoipList, err := cache.ToCidrList(c.ExpectIPs) if err != nil { return nil, newError("invalid IP rule: ", c.ExpectIPs).Base(err) } @@ -209,6 +212,9 @@ func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) { } sort.Strings(domains) + cache := newSiteCache() + defer cache.reset() + for _, domain := range domains { switch { case strings.HasPrefix(domain, "domain:"): @@ -226,7 +232,7 @@ func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) { if len(listName) == 0 { return nil, newError("empty geosite rule: ", domain) } - geositeList, err := loadGeositeWithAttr("geosite.dat", listName) + geositeList, err := cache.loadGeositeWithAttr("geosite.dat", listName) if err != nil { return nil, newError("failed to load geosite: ", listName).Base(err) } @@ -287,7 +293,7 @@ func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) { } filename := kv[0] list := kv[1] - geositeList, err := loadGeositeWithAttr(filename, list) + geositeList, err := cache.loadGeositeWithAttr(filename, list) if err != nil { return nil, newError("failed to load domain list: ", list, " from ", filename).Base(err) } diff --git a/infra/conf/router.go b/infra/conf/router.go index 6ae39d45..6d051105 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -116,8 +116,11 @@ func (c *RouterConfig) Build() (*router.Config, error) { } } + cache := NewGeoCache() + defer cache.Clear() + for _, rawRule := range rawRuleList { - rule, err := ParseRule(rawRule) + rule, err := cache.ParseRule(rawRule) if err != nil { return nil, err } @@ -195,39 +198,80 @@ func ParseIP(s string) (*router.CIDR, error) { } } -func loadGeoIP(code string) ([]*router.CIDR, error) { - return loadIP("geoip.dat", code) -} - -var ( - FileCache = make(map[string][]byte) - IPCache = make(map[string]*router.GeoIP) - SiteCache = make(map[string]*router.GeoSite) -) - func loadFile(file string) ([]byte, error) { - if FileCache[file] == nil { - bs, err := filesystem.ReadAsset(file) - if err != nil { - return nil, newError("failed to open file: ", file).Base(err) - } - if len(bs) == 0 { - return nil, newError("empty file: ", file) - } - // Do not cache file, may save RAM when there - // are many files, but consume CPU each time. - return bs, nil - FileCache[file] = bs + bs, err := filesystem.ReadAsset(file) + if err != nil { + return nil, newError("failed to open file: ", file).Base(err) } - return FileCache[file], nil + if len(bs) == 0 { + return nil, newError("empty file: ", file) + } + return bs, nil } -func loadIP(file, code string) ([]*router.CIDR, error) { +type ipCache struct { + fileContent map[string][]byte + geoIPs map[string]*router.GeoIP +} + +type siteCache struct { + fileContent map[string][]byte + geoSites map[string]*router.GeoSite +} + +// GeoCache is a cache for GeoIP and GeoSite. +type GeoCache struct { + ipCache *ipCache + siteCache *siteCache +} + +// NewGeoCache creates a new GeoCache. +func NewGeoCache() *GeoCache { + return &GeoCache{ + ipCache: newIPCache(), + siteCache: newSiteCache(), + } +} + +// Clear clears the cache. +func (c *GeoCache) Clear() { + if c.ipCache.reset() || c.siteCache.reset() { + // only trigger GC if there is something to release + runtime.GC() + } +} + +// newIPCache creates a new ipCache. +// after use, call close() to release memory. +func newIPCache() *ipCache { + return &ipCache{ + fileContent: make(map[string][]byte), + geoIPs: make(map[string]*router.GeoIP), + } +} + +// reset remove the content of ipCache. +// returns true if the cache is not empty. +func (c *ipCache) reset() bool { + if len(c.fileContent) == 0 && len(c.geoIPs) == 0 { + return false + } + c.fileContent = make(map[string][]byte) + c.geoIPs = make(map[string]*router.GeoIP) + return true +} + +func (c *ipCache) loadIP(file, code string) ([]*router.CIDR, error) { index := file + ":" + code - if IPCache[index] == nil { - bs, err := loadFile(file) - if err != nil { - return nil, newError("failed to load file: ", file).Base(err) + if c.geoIPs[index] == nil { + bs := c.fileContent[file] + if bs == nil { + var err error + bs, err = loadFile(file) + if err != nil { + return nil, newError("failed to load file: ", file).Base(err) + } + c.fileContent[file] = bs } bs = find(bs, []byte(code)) if bs == nil { @@ -237,19 +281,43 @@ func loadIP(file, code string) ([]*router.CIDR, error) { if err := proto.Unmarshal(bs, &geoip); err != nil { return nil, newError("error unmarshal IP in ", file, ": ", code).Base(err) } - defer runtime.GC() // or debug.FreeOSMemory() - return geoip.Cidr, nil // do not cache geoip - IPCache[index] = &geoip + c.geoIPs[index] = &geoip + return geoip.Cidr, nil } - return IPCache[index].Cidr, nil + return c.geoIPs[index].Cidr, nil } -func loadSite(file, code string) ([]*router.Domain, error) { +// newSiteCache creates a new siteCache. +// after use, call close() to release memory. +func newSiteCache() *siteCache { + return &siteCache{ + fileContent: make(map[string][]byte), + geoSites: make(map[string]*router.GeoSite), + } +} + +// reset remove the content of siteCache. +// returns true if the cache is not empty. +func (c *siteCache) reset() bool { + if len(c.fileContent) == 0 && len(c.geoSites) == 0 { + return false + } + c.fileContent = make(map[string][]byte) + c.geoSites = make(map[string]*router.GeoSite) + return true +} + +func (c *siteCache) loadSite(file, code string) ([]*router.Domain, error) { index := file + ":" + code - if SiteCache[index] == nil { - bs, err := loadFile(file) - if err != nil { - return nil, newError("failed to load file: ", file).Base(err) + if c.geoSites[index] == nil { + bs := c.fileContent[file] + if bs == nil { + var err error + bs, err = loadFile(file) + if err != nil { + return nil, newError("failed to load file: ", file).Base(err) + } + c.fileContent[file] = bs } bs = find(bs, []byte(code)) if bs == nil { @@ -259,11 +327,10 @@ func loadSite(file, code string) ([]*router.Domain, error) { if err := proto.Unmarshal(bs, &geosite); err != nil { return nil, newError("error unmarshal Site in ", file, ": ", code).Base(err) } - defer runtime.GC() // or debug.FreeOSMemory() - return geosite.Domain, nil // do not cache geosite - SiteCache[index] = &geosite + c.geoSites[index] = &geosite + return geosite.Domain, nil } - return SiteCache[index].Domain, nil + return c.geoSites[index].Domain, nil } func DecodeVarint(buf []byte) (x uint64, n int) { @@ -358,14 +425,14 @@ func parseAttrs(attrs []string) *AttributeList { return al } -func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, error) { +func (c *siteCache) loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, error) { parts := strings.Split(siteWithAttr, "@") if len(parts) == 0 { return nil, newError("empty site") } country := strings.ToUpper(parts[0]) attrs := parseAttrs(parts[1:]) - domains, err := loadSite(file, country) + domains, err := c.loadSite(file, country) if err != nil { return nil, err } @@ -384,10 +451,10 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er return filteredDomains, nil } -func parseDomainRule(domain string) ([]*router.Domain, error) { +func (c *GeoCache) ParseDomainRule(domain string) ([]*router.Domain, error) { if strings.HasPrefix(domain, "geosite:") { country := strings.ToUpper(domain[8:]) - domains, err := loadGeositeWithAttr("geosite.dat", country) + domains, err := c.siteCache.loadGeositeWithAttr("geosite.dat", country) if err != nil { return nil, newError("failed to load geosite: ", country).Base(err) } @@ -411,7 +478,7 @@ func parseDomainRule(domain string) ([]*router.Domain, error) { } filename := kv[0] country := kv[1] - domains, err := loadGeositeWithAttr(filename, country) + domains, err := c.siteCache.loadGeositeWithAttr(filename, country) if err != nil { return nil, newError("failed to load external sites: ", country, " from ", filename).Base(err) } @@ -454,7 +521,7 @@ func parseDomainRule(domain string) ([]*router.Domain, error) { return []*router.Domain{domainRule}, nil } -func ToCidrList(ips StringList) ([]*router.GeoIP, error) { +func (c *GeoCache) ToCidrList(ips StringList) ([]*router.GeoIP, error) { var geoipList []*router.GeoIP var customCidrs []*router.CIDR @@ -469,7 +536,7 @@ func ToCidrList(ips StringList) ([]*router.GeoIP, error) { if len(country) == 0 { return nil, newError("empty country name in rule") } - geoip, err := loadGeoIP(strings.ToUpper(country)) + geoip, err := c.ipCache.loadIP("geoip.dat", strings.ToUpper(country)) if err != nil { return nil, newError("failed to load GeoIP: ", country).Base(err) } @@ -509,7 +576,7 @@ func ToCidrList(ips StringList) ([]*router.GeoIP, error) { country = country[1:] isReverseMatch = true } - geoip, err := loadIP(filename, strings.ToUpper(country)) + geoip, err := c.ipCache.loadIP(filename, strings.ToUpper(country)) if err != nil { return nil, newError("failed to load IPs: ", country, " from ", filename).Base(err) } @@ -539,7 +606,7 @@ func ToCidrList(ips StringList) ([]*router.GeoIP, error) { return geoipList, nil } -func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { +func (c *GeoCache) parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { type RawFieldRule struct { RouterRule Domain *StringList `json:"domain"` @@ -581,7 +648,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { if rawFieldRule.Domain != nil { for _, domain := range *rawFieldRule.Domain { - rules, err := parseDomainRule(domain) + rules, err := c.ParseDomainRule(domain) if err != nil { return nil, newError("failed to parse domain rule: ", domain).Base(err) } @@ -591,7 +658,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { if rawFieldRule.Domains != nil { for _, domain := range *rawFieldRule.Domains { - rules, err := parseDomainRule(domain) + rules, err := c.ParseDomainRule(domain) if err != nil { return nil, newError("failed to parse domain rule: ", domain).Base(err) } @@ -600,7 +667,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { } if rawFieldRule.IP != nil { - geoipList, err := ToCidrList(*rawFieldRule.IP) + geoipList, err := c.ToCidrList(*rawFieldRule.IP) if err != nil { return nil, err } @@ -616,7 +683,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { } if rawFieldRule.SourceIP != nil { - geoipList, err := ToCidrList(*rawFieldRule.SourceIP) + geoipList, err := c.ToCidrList(*rawFieldRule.SourceIP) if err != nil { return nil, err } @@ -652,14 +719,14 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { return rule, nil } -func ParseRule(msg json.RawMessage) (*router.RoutingRule, error) { +func (c *GeoCache) ParseRule(msg json.RawMessage) (*router.RoutingRule, error) { rawRule := new(RouterRule) err := json.Unmarshal(msg, rawRule) if err != nil { return nil, newError("invalid router rule").Base(err) } if rawRule.Type == "" || strings.EqualFold(rawRule.Type, "field") { - fieldrule, err := parseFieldRule(msg) + fieldrule, err := c.parseFieldRule(msg) if err != nil { return nil, newError("invalid field rule").Base(err) } diff --git a/infra/conf/router_test.go b/infra/conf/router_test.go index 0af1b3e3..27aec0ae 100644 --- a/infra/conf/router_test.go +++ b/infra/conf/router_test.go @@ -44,7 +44,10 @@ func TestToCidrList(t *testing.T) { "ext-ip:geoiptestrouter.dat:!ca", }) - _, err := ToCidrList(ips) + cache := NewGeoCache() + defer cache.Clear() + + _, err := cache.ToCidrList(ips) if err != nil { t.Fatalf("Failed to parse geoip list, got %s", err) }