Refine `PrioritizedDomain`, should fix https://github.com/XTLS/Xray-core/issues/638

This commit is contained in:
hmol233 2021-07-11 22:02:55 +08:00
parent b85eef0131
commit 1ced7985d5
No known key found for this signature in database
GPG Key ID: D617A9DAB0C992D5
2 changed files with 52 additions and 72 deletions

View File

@ -12,7 +12,6 @@ import (
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/matcher/geoip" "github.com/xtls/xray-core/common/matcher/geoip"
"github.com/xtls/xray-core/common/matcher/str"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/features" "github.com/xtls/xray-core/features"
@ -29,8 +28,6 @@ type DNS struct {
hosts *StaticHosts hosts *StaticHosts
clients []*Client clients []*Client
ctx context.Context ctx context.Context
domainMatcher str.IndexMatcher
matcherInfos []DomainMatcherInfo
} }
// DomainMatcherInfo contains information attached to index returned by Server.domainMatcher // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher
@ -89,9 +86,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
domainRuleCount += len(ns.PrioritizedDomain) domainRuleCount += len(ns.PrioritizedDomain)
} }
// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
matcherInfos := make([]DomainMatcherInfo, domainRuleCount+1)
domainMatcher := &str.MatcherGroup{}
geoipContainer := geoip.GeoIPMatcherContainer{} geoipContainer := geoip.GeoIPMatcherContainer{}
for _, endpoint := range config.NameServers { for _, endpoint := range config.NameServers {
@ -104,22 +98,13 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
} }
for _, ns := range config.NameServer { for _, ns := range config.NameServer {
clientIdx := len(clients)
updateDomain := func(domainRule str.Matcher, originalRuleIdx int, matcherInfos []DomainMatcherInfo) error {
midx := domainMatcher.Add(domainRule)
matcherInfos[midx] = DomainMatcherInfo{
clientIdx: uint16(clientIdx),
domainRuleIdx: uint16(originalRuleIdx),
}
return nil
}
myClientIP := clientIP myClientIP := clientIP
switch len(ns.ClientIp) { switch len(ns.ClientIp) {
case net.IPv4len, net.IPv6len: case net.IPv4len, net.IPv6len:
myClientIP = net.IP(ns.ClientIp) myClientIP = ns.ClientIp
} }
client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain) client, err := NewClient(ctx, ns, myClientIP, geoipContainer)
if err != nil { if err != nil {
return nil, newError("failed to create client").Base(err) return nil, newError("failed to create client").Base(err)
} }
@ -137,8 +122,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
ipOption: ipOption, ipOption: ipOption,
clients: clients, clients: clients,
ctx: ctx, ctx: ctx,
domainMatcher: domainMatcher,
matcherInfos: matcherInfos,
cacheStrategy: config.CacheStrategy, cacheStrategy: config.CacheStrategy,
disableFallback: config.DisableFallback, disableFallback: config.DisableFallback,
}, nil }, nil
@ -268,22 +251,23 @@ func (s *DNS) sortClients(domain string, option *dns.IPOption) []*Client {
}() }()
// Priority domain matching // Priority domain matching
for _, match := range s.domainMatcher.Match(domain) { for clientIdx, client := range s.clients {
info := s.matcherInfos[match] if ids := client.domainMatcher.Match(domain); len(ids) > 0 {
client := s.clients[info.clientIdx]
domainRule := client.domains[info.domainRuleIdx]
if !canQueryOnClient(option, client) { if !canQueryOnClient(option, client) {
newError("skipping the client " + client.Name()).AtDebug().WriteToLog() newError("skipping the client " + client.Name()).AtDebug().WriteToLog()
continue continue
} }
domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx)) for _, id := range ids {
if clientUsed[info.clientIdx] { rule := client.findRule(id)
domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", rule, clientIdx))
}
if clientUsed[clientIdx] {
continue continue
} }
clientUsed[info.clientIdx] = true
clients = append(clients, client) clients = append(clients, client)
clientNames = append(clientNames, client.Name()) clientNames = append(clientNames, client.Name())
} }
}
if !s.disableFallback { if !s.disableFallback {
// Default round-robin query // Default round-robin query

View File

@ -28,8 +28,20 @@ type Client struct {
server Server server Server
clientIP net.IP clientIP net.IP
skipFallback bool skipFallback bool
domains []string
expectIPs []*geoip.GeoIPMatcher expectIPs []*geoip.GeoIPMatcher
domainMatcher str.MatcherGroup
originRules []*NameServer_OriginalRule
}
func (c Client) findRule(idx uint32) string {
for _, r := range c.originRules {
if idx <= r.Size {
return r.Rule
}
idx -= r.Size
}
return "unknown rule"
} }
var errExpectedIPNonMatch = errors.New("expectIPs not match") var errExpectedIPNonMatch = errors.New("expectIPs not match")
@ -64,7 +76,7 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err
} }
// NewClient creates a DNS client managing a name server with client IP, domain rules and expected IPs. // NewClient creates a DNS client managing a name server with client IP, domain rules and expected IPs.
func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container geoip.GeoIPMatcherContainer, matcherInfos *[]DomainMatcherInfo, updateDomainRule func(str.Matcher, int, []DomainMatcherInfo) error) (*Client, error) { func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container geoip.GeoIPMatcherContainer) (*Client, error) {
client := &Client{} client := &Client{}
err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error { err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
@ -79,55 +91,38 @@ func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container g
ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...) ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...)
ns.OriginalRules = append(ns.OriginalRules, localTLDsAndDotlessDomainsRule) ns.OriginalRules = append(ns.OriginalRules, localTLDsAndDotlessDomainsRule)
// The following lines is a solution to avoid core panicsrule index out of range when setting `localhost` DNS client in config. // The following lines is a solution to avoid core panicsrule index out of range when setting `localhost` DNS client in config.
// Because the `localhost` DNS client will apend len(localTLDsAndDotlessDomains) rules into matcherInfos to match `geosite:private` default rule. // Because the `localhost` DNS client will append len(localTLDsAndDotlessDomains) rules into matcherInfos to match `geosite:private` default rule.
// But `matcherInfos` has no enough length to add rules, which leads to core panics (rule index out of range). // But `matcherInfos` has no enough length to add rules, which leads to core panics (rule index out of range).
// To avoid this, the length of `matcherInfos` must be equal to the expected, so manually append it with Golang default zero value first for later modification. // To avoid this, the length of `matcherInfos` must be equal to the expected, so manually append it with Golang default zero value first for later modification.
// ;)
/*
for i := 0; i < len(localTLDsAndDotlessDomains); i++ { for i := 0; i < len(localTLDsAndDotlessDomains); i++ {
*matcherInfos = append(*matcherInfos, DomainMatcherInfo{ *matcherInfos = append(*matcherInfos, DomainMatcherInfo{
clientIdx: uint16(0), clientIdx: uint16(0),
domainRuleIdx: uint16(0), domainRuleIdx: uint16(0),
}) })
} }
*/
} }
// Establish domain rules // Establish domain rules
var rules []string var domainMatcher = str.MatcherGroup{}
ruleCurr := 0
ruleIter := 0
for _, domain := range ns.PrioritizedDomain { for _, domain := range ns.PrioritizedDomain {
domainRule, err := toStrMatcher(domain.Type, domain.Value) domainRule, err := toStrMatcher(domain.Type, domain.Value)
if err != nil { if err != nil {
return newError("failed to create prioritized domain").Base(err).AtWarning() return newError("failed to create prioritized domain").Base(err).AtWarning()
} }
originalRuleIdx := ruleCurr domainMatcher.Add(domainRule)
if ruleCurr < len(ns.OriginalRules) {
rule := ns.OriginalRules[ruleCurr]
if ruleCurr >= len(rules) {
rules = append(rules, rule.Rule)
}
ruleIter++
if ruleIter >= int(rule.Size) {
ruleIter = 0
ruleCurr++
}
} else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests)
rules = append(rules, domainRule.String())
ruleCurr++
}
err = updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
if err != nil {
return newError("failed to create prioritized domain").Base(err).AtWarning()
}
} }
// Establish expected IPs // Establish expected IPs
var matchers []*geoip.GeoIPMatcher var ipMatchers []*geoip.GeoIPMatcher
for _, geoip := range ns.Geoip { for _, geoip := range ns.Geoip {
matcher, err := container.Add(geoip) matcher, err := container.Add(geoip)
if err != nil { if err != nil {
return newError("failed to create ip matcher").Base(err).AtWarning() return newError("failed to create ip matcher").Base(err).AtWarning()
} }
matchers = append(matchers, matcher) ipMatchers = append(ipMatchers, matcher)
} }
if len(clientIP) > 0 { if len(clientIP) > 0 {
@ -141,8 +136,9 @@ func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container g
client.server = server client.server = server
client.clientIP = clientIP client.clientIP = clientIP
client.domains = rules client.expectIPs = ipMatchers
client.expectIPs = matchers client.originRules = ns.OriginalRules
client.domainMatcher = domainMatcher
return nil return nil
}) })
return client, err return client, err