From b41306601269c12b96c5d688a7b3b33a9a7dacb0 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Sat, 23 Apr 2022 19:24:46 -0400 Subject: [PATCH] Fakedns fix xUDP destination override (#1011) * Fix UDP destination override * Fix code style * Fix fakedns object init Do type convertion at runtime in case if user don't use fakedns in config. Since dispatcher now depend on fakedns object, move the injection order of fakedns to top (As a temporary solution) * Amend logic for handing fakedns client A map is used by server side when client turn on fakedns Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs. When target replies, server will restore the domain and send back to client. Co-authored-by: hmol233 <82594500+hmol233@users.noreply.github.com> --- app/dispatcher/default.go | 112 ++++++++++++++++++++++++++++++-------- infra/conf/xray.go | 2 +- transport/pipe/impl.go | 5 ++ transport/pipe/pipe.go | 7 +++ 4 files changed, 102 insertions(+), 24 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index f4ceba9d..7143ccfc 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -4,6 +4,7 @@ package dispatcher import ( "context" + "fmt" "strings" "sync" "time" @@ -92,13 +93,17 @@ type DefaultDispatcher struct { router routing.Router policy policy.Manager stats stats.Manager - hosts dns.HostsLookup + dns dns.Client + fdns dns.FakeDNSEngine } func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { d := new(DefaultDispatcher) if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { + core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { + d.fdns = fdns + }) return d.Init(config.(*Config), om, router, pm, sm, dc) }); err != nil { return nil, err @@ -108,14 +113,12 @@ func init() { } // Init initializes DefaultDispatcher. -func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { +func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error { d.ohm = om d.router = router d.policy = pm d.stats = sm - if hosts, ok := dc.(dns.HostsLookup); ok { - d.hosts = hosts - } + d.dns = dns return nil } @@ -132,10 +135,77 @@ func (*DefaultDispatcher) Start() error { // Close implements common.Closable. func (*DefaultDispatcher) Close() error { return nil } -func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { - opt := pipe.OptionsFromContext(ctx) - uplinkReader, uplinkWriter := pipe.New(opt...) - downlinkReader, downlinkWriter := pipe.New(opt...) +func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) { + downOpt := pipe.OptionsFromContext(ctx) + upOpt := downOpt + + if network == net.Network_UDP { + var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns + // Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs. + // When target replies, server will restore the domain and send back to client. + // Note: this map is not global but per connection context + upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer { + for i, buffer := range mb { + if buffer.UDP == nil { + continue + } + addr := buffer.UDP.Address + if addr.Family().IsIP() { + if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled { + domain := fkr0.GetDomainFromFakeDNS(addr) + if len(domain) > 0 { + buffer.UDP.Address = net.DomainAddress(domain) + newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) + } else { + newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx)) + } + } + } else { + if (ip2domain == nil) { + ip2domain = new(sync.Map) + newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx)) + } + domain := addr.Domain() + ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false}) + if err == nil { + for _, ip := range ips { + ip2domain.Store(ip.String(), domain) + } + newError("[fakedns client] candidate ip: " + fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) + } else { + newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx)) + } + } + } + return mb + })) + downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer { + for i, buffer := range mb { + if buffer.UDP == nil { + continue + } + addr := buffer.UDP.Address + if addr.Family().IsIP() { + if ip2domain == nil { + continue + } + if domain, found := ip2domain.Load(addr.IP().String()); found { + buffer.UDP.Address = net.DomainAddress(domain.(string)) + newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) + } + } else { + if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok { + fakeIp := fkr0.GetFakeIPForDomain(addr.Domain()) + buffer.UDP.Address = fakeIp[0] + newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) + } + } + } + return mb + })) + } + uplinkReader, uplinkWriter := pipe.New(upOpt...) + downlinkReader, downlinkWriter := pipe.New(downOpt...) inboundLink := &transport.Link{ Reader: downlinkReader, @@ -178,17 +248,13 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran return inboundLink, outboundLink } -func shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { +func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { domain := result.Domain() for _, d := range request.ExcludeForDomain { if strings.ToLower(domain) == d { return false } } - var fakeDNSEngine dns.FakeDNSEngine - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { - fakeDNSEngine = fdns - }) protocolString := result.Protocol() if resComp, ok := result.(SnifferResultComposite); ok { protocolString = resComp.ProtocolForDomainResult() @@ -197,7 +263,7 @@ func shouldOverride(ctx context.Context, result SniffResult, request session.Sni if strings.HasPrefix(protocolString, p) { return true } - if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && + if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) { newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx)) return true @@ -221,14 +287,14 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin Target: destination, } ctx = session.ContextWithOutbound(ctx, ob) - - inbound, outbound := d.getLink(ctx) content := session.ContentFromContext(ctx) if content == nil { content = new(session.Content) ctx = session.ContextWithContent(ctx, content) } + sniffingRequest := content.SniffingRequest + inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest) switch { case !sniffingRequest.Enabled: go d.routedDispatch(ctx, outbound, destination) @@ -237,7 +303,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin result, err := sniffer(ctx, nil, true) if err == nil { content.Protocol = result.Protocol() - if shouldOverride(ctx, result, sniffingRequest, destination) { + if d.shouldOverride(ctx, result, sniffingRequest, destination) { domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) @@ -259,7 +325,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin if err == nil { content.Protocol = result.Protocol() } - if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) { + if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) @@ -298,7 +364,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De result, err := sniffer(ctx, nil, true) if err == nil { content.Protocol = result.Protocol() - if shouldOverride(ctx, result, sniffingRequest, destination) { + if d.shouldOverride(ctx, result, sniffingRequest, destination) { domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) @@ -320,7 +386,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De if err == nil { content.Protocol = result.Protocol() } - if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) { + if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) @@ -384,8 +450,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { ob := session.OutboundFromContext(ctx) - if d.hosts != nil && destination.Address.Family().IsDomain() { - proxied := d.hosts.LookupHosts(ob.Target.String()) + if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { + proxied := hosts.LookupHosts(ob.Target.String()) if proxied != nil { ro := ob.RouteTarget == destination destination.Address = *proxied diff --git a/infra/conf/xray.go b/infra/conf/xray.go index c67e6e09..d4895d84 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -632,7 +632,7 @@ func (c *Config) Build() (*core.Config, error) { if err != nil { return nil, err } - config.App = append(config.App, serial.ToTypedMessage(r)) + config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...) } if c.Observatory != nil { diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index 903c5fdf..14a18e63 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -24,6 +24,7 @@ const ( type pipeOption struct { limit int32 // maximum buffer size in bytes discardOverflow bool + onTransmission func(buffer buf.MultiBuffer) buf.MultiBuffer } func (o *pipeOption) isFull(curSize int32) bool { @@ -137,6 +138,10 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error { return nil } + if p.option.onTransmission != nil { + mb = p.option.onTransmission(mb) + } + for { err := p.writeMultiBufferInternal(mb) if err == nil { diff --git a/transport/pipe/pipe.go b/transport/pipe/pipe.go index 1706fdab..0b22c2db 100644 --- a/transport/pipe/pipe.go +++ b/transport/pipe/pipe.go @@ -3,6 +3,7 @@ package pipe import ( "context" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/features/policy" @@ -25,6 +26,12 @@ func WithSizeLimit(limit int32) Option { } } +func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option { + return func(option *pipeOption) { + option.onTransmission = hook + } +} + // DiscardOverflow returns an Option for Pipe to discard writes if full. func DiscardOverflow() Option { return func(opt *pipeOption) {