diff --git a/app/dns/dns.go b/app/dns/dns.go index 1ad75ada..b76db087 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -167,33 +167,36 @@ func (s *DNS) IsOwnLink(ctx context.Context) bool { // LookupIP implements dns.Client. func (s *DNS) LookupIP(domain string) ([]net.IP, error) { - return s.lookupIPInternal(domain, *s.ipOption) + return s.lookupIPInternal(domain, s.ipOption.Copy()) } // LookupOptions implements dns.Client. -func (s *DNS) LookupOptions(domain string, opt dns.IPOption) ([]net.IP, error) { +func (s *DNS) LookupOptions(domain string, opts ...dns.Option) ([]net.IP, error) { + opt := s.ipOption.Copy() + for _, o := range opts { + if o != nil { + o(opt) + } + } + return s.lookupIPInternal(domain, opt) } // LookupIPv4 implements dns.IPv4Lookup. func (s *DNS) LookupIPv4(domain string) ([]net.IP, error) { - return s.lookupIPInternal(domain, dns.IPOption{ + return s.lookupIPInternal(domain, &dns.IPOption{ IPv4Enable: true, - IPv6Enable: false, - FakeEnable: false, }) } // LookupIPv6 implements dns.IPv6Lookup. func (s *DNS) LookupIPv6(domain string) ([]net.IP, error) { - return s.lookupIPInternal(domain, dns.IPOption{ - IPv4Enable: false, + return s.lookupIPInternal(domain, &dns.IPOption{ IPv6Enable: true, - FakeEnable: false, }) } -func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, error) { +func (s *DNS) lookupIPInternal(domain string, option *dns.IPOption) ([]net.IP, error) { if domain == "" { return nil, newError("empty domain name") } @@ -228,7 +231,7 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er errs := []error{} ctx := session.ContextWithInbound(s.ctx, &session.Inbound{Tag: s.tag}) for _, client := range s.sortClients(domain, option) { - ips, err := client.QueryIP(ctx, domain, option, s.cs) + ips, err := client.QueryIP(ctx, domain, *option, s.cs) if len(ips) > 0 { return ips, nil } @@ -244,7 +247,7 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...)) } -func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client { +func (s *DNS) sortClients(domain string, option *dns.IPOption) []*Client { clients := make([]*Client, 0, len(s.clients)) clientUsed := make([]bool, len(s.clients)) clientNames := make([]string, 0, len(s.clients)) diff --git a/app/dns/hosts.go b/app/dns/hosts.go index af315803..44b79c47 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -74,7 +74,7 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma return sh, nil } -func filterIP(ips []net.Address, option dns.IPOption) []net.Address { +func filterIP(ips []net.Address, option *dns.IPOption) []net.Address { filtered := make([]net.Address, 0, len(ips)) for _, ip := range ips { if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) { @@ -95,7 +95,7 @@ func (h *StaticHosts) lookupInternal(domain string) []net.Address { return ips } -func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address { +func (h *StaticHosts) lookup(domain string, option *dns.IPOption, maxDepth int) []net.Address { switch addrs := h.lookupInternal(domain); { case len(addrs) == 0: // Not recorded in static hosts, return nil return nil @@ -113,6 +113,6 @@ func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) [ } // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts. -func (h *StaticHosts) Lookup(domain string, option dns.IPOption) []net.Address { +func (h *StaticHosts) Lookup(domain string, option *dns.IPOption) []net.Address { return h.lookup(domain, option, 5) } diff --git a/app/dns/hosts_test.go b/app/dns/hosts_test.go index 87e487bc..786aeca3 100644 --- a/app/dns/hosts_test.go +++ b/app/dns/hosts_test.go @@ -40,7 +40,7 @@ func TestStaticHosts(t *testing.T) { common.Must(err) { - ips := hosts.Lookup("example.com", dns.IPOption{ + ips := hosts.Lookup("example.com", &dns.IPOption{ IPv4Enable: true, IPv6Enable: true, }) @@ -53,7 +53,7 @@ func TestStaticHosts(t *testing.T) { } { - ips := hosts.Lookup("www.example.cn", dns.IPOption{ + ips := hosts.Lookup("www.example.cn", &dns.IPOption{ IPv4Enable: true, IPv6Enable: true, }) @@ -66,7 +66,7 @@ func TestStaticHosts(t *testing.T) { } { - ips := hosts.Lookup("baidu.com", dns.IPOption{ + ips := hosts.Lookup("baidu.com", &dns.IPOption{ IPv4Enable: false, IPv6Enable: true, }) diff --git a/app/dns/options.go b/app/dns/options.go index 2bedbe9f..06d93e29 100644 --- a/app/dns/options.go +++ b/app/dns/options.go @@ -2,23 +2,15 @@ package dns import "github.com/xtls/xray-core/features/dns" -type Option interface { - queryIPv4() bool - queryIPv6() bool - queryIP() bool - queryFake() bool - canDoQuery(c *Client) bool -} - -func isIPQuery(o dns.IPOption) bool { +func isIPQuery(o *dns.IPOption) bool { return o.IPv4Enable || o.IPv6Enable } -func canQueryOnClient(o dns.IPOption, c *Client) bool { +func canQueryOnClient(o *dns.IPOption, c *Client) bool { isIPClient := !(c.Name() == FakeDNSName) return isIPClient && isIPQuery(o) } -func isQuery(o dns.IPOption) bool { +func isQuery(o *dns.IPOption) bool { return !(o.IPv4Enable || o.IPv6Enable || o.FakeEnable) } diff --git a/features/dns/client.go b/features/dns/client.go index 93f8a68b..3fc9dce0 100644 --- a/features/dns/client.go +++ b/features/dns/client.go @@ -14,6 +14,12 @@ type IPOption struct { FakeEnable bool } +func (p *IPOption) Copy() *IPOption { + return &IPOption{p.IPv4Enable, p.IPv6Enable, p.FakeEnable} +} + +type Option func(dopt *IPOption) *IPOption + // Client is a Xray feature for querying DNS information. // // xray:api:stable @@ -23,8 +29,8 @@ type Client interface { // LookupIP returns IP address for the given domain. IPs may contain IPv4 and/or IPv6 addresses. LookupIP(domain string) ([]net.IP, error) - // LookupOptions query IP address for domain with IPOption. - LookupOptions(domain string, opt IPOption) ([]net.IP, error) + // LookupOptions query IP address for domain with *IPOption. + LookupOptions(domain string, opt ...Option) ([]net.IP, error) } // IPv4Lookup is an optional feature for querying IPv4 addresses only. @@ -69,9 +75,33 @@ func RCodeFromError(err error) uint16 { } var ( - LookupIPv4 = IPOption{IPv4Enable: true} - LookupIPv6 = IPOption{IPv6Enable: true} - LookupIP = IPOption{IPv4Enable: true, IPv6Enable: true} - LookupFake = IPOption{FakeEnable: true} - LookupAll = IPOption{true, true, true} + LookupIPv4Only = func(d *IPOption) *IPOption { + d.IPv4Enable = true + d.IPv6Enable = false + return d + } + LookupIPv6Only = func(d *IPOption) *IPOption { + d.IPv4Enable = false + d.IPv6Enable = true + return d + } + LookupIP = func(d *IPOption) *IPOption { + d.IPv4Enable = true + d.IPv6Enable = true + return d + } + LookupFake = func(d *IPOption) *IPOption { + d.FakeEnable = true + return d + } + LookupNoFake = func(d *IPOption) *IPOption { + d.FakeEnable = false + return d + } + + LookupAll = func(d *IPOption) *IPOption { + LookupIP(d) + LookupFake(d) + return d + } ) diff --git a/features/dns/localdns/client.go b/features/dns/localdns/client.go index 47d3590e..d47d29c0 100644 --- a/features/dns/localdns/client.go +++ b/features/dns/localdns/client.go @@ -39,7 +39,7 @@ func (*Client) LookupIP(host string) ([]net.IP, error) { } // LookupOptions implements Client. -func (c *Client) LookupOptions(host string, _ dns.IPOption) ([]net.IP, error) { +func (c *Client) LookupOptions(host string, _ ...dns.Option) ([]net.IP, error) { return c.LookupIP(host) } diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 109af830..8b9fd957 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -199,18 +199,16 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, var err error var ttl uint32 = 600 - var opt = dns.LookupIP + var opt dns.Option switch qType { case dnsmessage.TypeA: - opt = dns.LookupIPv4 + opt = dns.LookupIPv4Only case dnsmessage.TypeAAAA: - opt = dns.LookupIPv6 + opt = dns.LookupIPv6Only } - opt.FakeEnable = true - - ips, err = h.client.LookupOptions(domain, opt) + ips, err = h.client.LookupOptions(domain, opt, dns.LookupFake) rcode := dns.RCodeFromError(err) if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { newError("ip query").Base(err).WriteToLog() diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index fe6db062..de01e072 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -59,15 +59,14 @@ func (h *Handler) policy() policy.Session { } func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address { - var opt = dns.LookupIP + var opt dns.Option if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) { - opt = dns.LookupIPv4 + opt = dns.LookupIPv4Only } else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) { - opt = dns.LookupIPv6 + opt = dns.LookupIPv6Only } - opt.FakeEnable = true - ips, err := h.dns.LookupOptions(domain, opt) + ips, err := h.dns.LookupOptions(domain, opt, dns.LookupNoFake) if err != nil { newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx)) } diff --git a/testing/mocks/dns.go b/testing/mocks/dns.go index 21ccaeae..1178790c 100644 --- a/testing/mocks/dns.go +++ b/testing/mocks/dns.go @@ -65,18 +65,23 @@ func (mr *DNSClientMockRecorder) LookupIP(arg0 interface{}) *gomock.Call { } // LookupOptions mocks base method. -func (m *DNSClient) LookupOptions(arg0 string, arg1 dns.IPOption) ([]net.IP, error) { +func (m *DNSClient) LookupOptions(arg0 string, arg1 ...dns.Option) ([]net.IP, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LookupOptions", arg0, arg1) + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "LookupOptions", varargs...) ret0, _ := ret[0].([]net.IP) ret1, _ := ret[1].(error) return ret0, ret1 } // LookupOptions indicates an expected call of LookupOptions. -func (mr *DNSClientMockRecorder) LookupOptions(arg0, arg1 interface{}) *gomock.Call { +func (mr *DNSClientMockRecorder) LookupOptions(arg0 interface{}, arg1 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupOptions", reflect.TypeOf((*DNSClient)(nil).LookupOptions), arg0, arg1) + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupOptions", reflect.TypeOf((*DNSClient)(nil).LookupOptions), varargs...) } // Start mocks base method. diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 7cc48645..22a9de8d 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -63,17 +63,17 @@ func (d *DefaultSystemDialer) lookupIP(domain string, strategy DomainStrategy, l return nil, nil } - var opt = dns.LookupIP + var opt dns.Option switch { case strategy == DomainStrategy_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()): - opt = dns.LookupIPv4 + opt = dns.LookupIPv4Only case strategy == DomainStrategy_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()): - opt = dns.LookupIPv6 + opt = dns.LookupIPv6Only case strategy == DomainStrategy_AS_IS: return nil, nil } - return d.dns.LookupOptions(domain, opt) + return d.dns.LookupOptions(domain, opt, dns.LookupNoFake) } func (d *DefaultSystemDialer) canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig) bool {