From 1ced7985d59f89673f7e17409180c93ac1885afd Mon Sep 17 00:00:00 2001 From: hmol233 <82594500+hmol233@users.noreply.github.com> Date: Sun, 11 Jul 2021 22:02:55 +0800 Subject: [PATCH] Refine `PrioritizedDomain`, should fix https://github.com/XTLS/Xray-core/issues/638 --- app/dns/dns.go | 50 ++++++++++------------------- app/dns/nameserver.go | 74 ++++++++++++++++++++----------------------- 2 files changed, 52 insertions(+), 72 deletions(-) diff --git a/app/dns/dns.go b/app/dns/dns.go index c638735d..68cb29b0 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -12,7 +12,6 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/matcher/geoip" - "github.com/xtls/xray-core/common/matcher/str" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/features" @@ -29,8 +28,6 @@ type DNS struct { hosts *StaticHosts clients []*Client ctx context.Context - domainMatcher str.IndexMatcher - matcherInfos []DomainMatcherInfo } // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher @@ -89,9 +86,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) { domainRuleCount += len(ns.PrioritizedDomain) } - // MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1 - matcherInfos := make([]DomainMatcherInfo, domainRuleCount+1) - domainMatcher := &str.MatcherGroup{} geoipContainer := geoip.GeoIPMatcherContainer{} for _, endpoint := range config.NameServers { @@ -104,22 +98,13 @@ func New(ctx context.Context, config *Config) (*DNS, error) { } for _, ns := range config.NameServer { - clientIdx := len(clients) - updateDomain := func(domainRule str.Matcher, originalRuleIdx int, matcherInfos []DomainMatcherInfo) error { - midx := domainMatcher.Add(domainRule) - matcherInfos[midx] = DomainMatcherInfo{ - clientIdx: uint16(clientIdx), - domainRuleIdx: uint16(originalRuleIdx), - } - return nil - } myClientIP := clientIP switch len(ns.ClientIp) { case net.IPv4len, net.IPv6len: - myClientIP = net.IP(ns.ClientIp) + myClientIP = ns.ClientIp } - client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain) + client, err := NewClient(ctx, ns, myClientIP, geoipContainer) if err != nil { return nil, newError("failed to create client").Base(err) } @@ -137,8 +122,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) { ipOption: ipOption, clients: clients, ctx: ctx, - domainMatcher: domainMatcher, - matcherInfos: matcherInfos, cacheStrategy: config.CacheStrategy, disableFallback: config.DisableFallback, }, nil @@ -268,21 +251,22 @@ func (s *DNS) sortClients(domain string, option *dns.IPOption) []*Client { }() // Priority domain matching - for _, match := range s.domainMatcher.Match(domain) { - info := s.matcherInfos[match] - client := s.clients[info.clientIdx] - domainRule := client.domains[info.domainRuleIdx] - if !canQueryOnClient(option, client) { - newError("skipping the client " + client.Name()).AtDebug().WriteToLog() - continue + for clientIdx, client := range s.clients { + if ids := client.domainMatcher.Match(domain); len(ids) > 0 { + if !canQueryOnClient(option, client) { + newError("skipping the client " + client.Name()).AtDebug().WriteToLog() + continue + } + for _, id := range ids { + rule := client.findRule(id) + domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", rule, clientIdx)) + } + if clientUsed[clientIdx] { + continue + } + clients = append(clients, client) + clientNames = append(clientNames, client.Name()) } - domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx)) - if clientUsed[info.clientIdx] { - continue - } - clientUsed[info.clientIdx] = true - clients = append(clients, client) - clientNames = append(clientNames, client.Name()) } if !s.disableFallback { diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index ac0e405a..acc842b3 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -25,11 +25,23 @@ type Server interface { // Client is the interface for DNS client. type Client struct { - server Server - clientIP net.IP - skipFallback bool - domains []string - expectIPs []*geoip.GeoIPMatcher + server Server + clientIP net.IP + skipFallback bool + expectIPs []*geoip.GeoIPMatcher + domainMatcher str.MatcherGroup + originRules []*NameServer_OriginalRule +} + +func (c Client) findRule(idx uint32) string { + for _, r := range c.originRules { + if idx <= r.Size { + return r.Rule + } + idx -= r.Size + } + + return "unknown rule" } var errExpectedIPNonMatch = errors.New("expectIPs not match") @@ -64,7 +76,7 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err } // NewClient creates a DNS client managing a name server with client IP, domain rules and expected IPs. -func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container geoip.GeoIPMatcherContainer, matcherInfos *[]DomainMatcherInfo, updateDomainRule func(str.Matcher, int, []DomainMatcherInfo) error) (*Client, error) { +func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container geoip.GeoIPMatcherContainer) (*Client, error) { client := &Client{} err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error { @@ -79,55 +91,38 @@ func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container g ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...) ns.OriginalRules = append(ns.OriginalRules, localTLDsAndDotlessDomainsRule) // The following lines is a solution to avoid core panics(rule index out of range) when setting `localhost` DNS client in config. - // Because the `localhost` DNS client will apend len(localTLDsAndDotlessDomains) rules into matcherInfos to match `geosite:private` default rule. + // Because the `localhost` DNS client will append len(localTLDsAndDotlessDomains) rules into matcherInfos to match `geosite:private` default rule. // But `matcherInfos` has no enough length to add rules, which leads to core panics (rule index out of range). // To avoid this, the length of `matcherInfos` must be equal to the expected, so manually append it with Golang default zero value first for later modification. - for i := 0; i < len(localTLDsAndDotlessDomains); i++ { - *matcherInfos = append(*matcherInfos, DomainMatcherInfo{ - clientIdx: uint16(0), - domainRuleIdx: uint16(0), - }) - } + // ;) + /* + for i := 0; i < len(localTLDsAndDotlessDomains); i++ { + *matcherInfos = append(*matcherInfos, DomainMatcherInfo{ + clientIdx: uint16(0), + domainRuleIdx: uint16(0), + }) + } + */ } // Establish domain rules - var rules []string - ruleCurr := 0 - ruleIter := 0 + var domainMatcher = str.MatcherGroup{} for _, domain := range ns.PrioritizedDomain { domainRule, err := toStrMatcher(domain.Type, domain.Value) if err != nil { return newError("failed to create prioritized domain").Base(err).AtWarning() } - originalRuleIdx := ruleCurr - if ruleCurr < len(ns.OriginalRules) { - rule := ns.OriginalRules[ruleCurr] - if ruleCurr >= len(rules) { - rules = append(rules, rule.Rule) - } - ruleIter++ - if ruleIter >= int(rule.Size) { - ruleIter = 0 - ruleCurr++ - } - } else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests) - rules = append(rules, domainRule.String()) - ruleCurr++ - } - err = updateDomainRule(domainRule, originalRuleIdx, *matcherInfos) - if err != nil { - return newError("failed to create prioritized domain").Base(err).AtWarning() - } + domainMatcher.Add(domainRule) } // Establish expected IPs - var matchers []*geoip.GeoIPMatcher + var ipMatchers []*geoip.GeoIPMatcher for _, geoip := range ns.Geoip { matcher, err := container.Add(geoip) if err != nil { return newError("failed to create ip matcher").Base(err).AtWarning() } - matchers = append(matchers, matcher) + ipMatchers = append(ipMatchers, matcher) } if len(clientIP) > 0 { @@ -141,8 +136,9 @@ func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container g client.server = server client.clientIP = clientIP - client.domains = rules - client.expectIPs = matchers + client.expectIPs = ipMatchers + client.originRules = ns.OriginalRules + client.domainMatcher = domainMatcher return nil }) return client, err