diff --git a/app/dns/dohdns.go b/app/dns/dohdns.go index b53ebe76..25d35074 100644 --- a/app/dns/dohdns.go +++ b/app/dns/dohdns.go @@ -55,19 +55,13 @@ func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net. ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { dispatcherCtx := context.Background() - if inbound := session.InboundFromContext(ctx); inbound != nil { - dispatcherCtx = session.ContextWithInbound(dispatcherCtx, inbound) - } - if content := session.ContentFromContext(ctx); content != nil { - dispatcherCtx = session.ContextWithContent(dispatcherCtx, content) - } - dispatcherCtx = internet.ContextWithLookupDomain(dispatcherCtx, internet.LookupDomainFromContext(ctx)) dest, err := net.ParseDestination(network + ":" + addr) if err != nil { return nil, err } + dispatcherCtx = session.ContextWithContent(dispatcherCtx, &session.Content{Protocol: "tls"}) dispatcherCtx = log.ContextWithAccessMessage(dispatcherCtx, &log.AccessMessage{ From: "DoH", To: s.dohURL, @@ -76,6 +70,12 @@ func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net. }) link, err := s.dispatcher.Dispatch(dispatcherCtx, dest) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + + } if err != nil { return nil, err } diff --git a/app/dns/server.go b/app/dns/server.go index 65aa7b7b..0090c998 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -30,7 +30,7 @@ type Server struct { sync.Mutex hosts *StaticHosts clientIP net.IP - clients []Client // clientIdx -> Client + clients []Client // clientIdx -> Client ctx context.Context ipIndexMap []*MultiGeoIPMatcher // clientIdx -> *MultiGeoIPMatcher domainRules [][]string // clientIdx -> domainRuleIdx -> DomainRule @@ -307,7 +307,7 @@ func (s *Server) queryIPTimeout(idx int, client Client, domain string, option dn Tag: s.tag, }) } - ctx = internet.ContextWithLookupDomain(ctx, Fqdn(domain)) + ctx = internet.ContextWithLookupDomain(ctx, domain) ips, err := client.QueryIP(ctx, domain, option) cancel() diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 3f0095d3..7256df9f 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -22,12 +22,12 @@ var ( // InitSystemDialer: It's private method and you are NOT supposed to use this function. func InitSystemDialer(dc dns.Client, om outbound.Manager) { - effectiveSystemDialer.init(dc, om) + effectiveSystemDialer.Init(dc, om) } type SystemDialer interface { Dial(ctx context.Context, source net.Address, destination net.Destination, sockopt *SocketConfig) (net.Conn, error) - init(dc dns.Client, om outbound.Manager) + Init(dc dns.Client, om outbound.Manager) } type DefaultSystemDialer struct { @@ -63,22 +63,30 @@ func (d *DefaultSystemDialer) lookupIP(domain string, strategy DomainStrategy, l return nil, nil } - var lookup = d.dns.LookupIP + var option = dns.IPOption{ + IPv4Enable: true, + IPv6Enable: true, + FakeEnable: false, + } switch { case strategy == DomainStrategy_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()): - if lookupIPv4, ok := d.dns.(dns.IPv4Lookup); ok { - lookup = lookupIPv4.LookupIPv4 + option = dns.IPOption{ + IPv4Enable: true, + IPv6Enable: false, + FakeEnable: false, } case strategy == DomainStrategy_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()): - if lookupIPv4, ok := d.dns.(dns.IPv4Lookup); ok { - lookup = lookupIPv4.LookupIPv4 + option = dns.IPOption{ + IPv4Enable: false, + IPv6Enable: true, + FakeEnable: false, } case strategy == DomainStrategy_AS_IS: return nil, nil } - return lookup(domain) + return d.dns.LookupIP(domain, option) } func (d *DefaultSystemDialer) canLookupIP(ctx context.Context, dst net.Destination, sockopt *SocketConfig) bool { @@ -184,7 +192,7 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr()) } -func (d *DefaultSystemDialer) init(dc dns.Client, om outbound.Manager) { +func (d *DefaultSystemDialer) Init(dc dns.Client, om outbound.Manager) { d.dns = dc d.obm = om } @@ -249,7 +257,7 @@ func WithAdapter(dialer SystemDialerAdapter) SystemDialer { } } -func (v *SimpleSystemDialer) init(_ dns.Client, _ outbound.Manager) {} +func (v *SimpleSystemDialer) Init(_ dns.Client, _ outbound.Manager) {} func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr())