From b8bd243df5f6ff4929681c82ed945b27f1597d25 Mon Sep 17 00:00:00 2001 From: dyhkwong <50692134+dyhkwong@users.noreply.github.com> Date: Tue, 29 Aug 2023 15:12:36 +0800 Subject: [PATCH] Fix buffer.UDP destination override (#2356) --- app/dispatcher/default.go | 107 ++++++++----------------------- app/proxyman/outbound/handler.go | 7 +- common/buf/override.go | 38 +++++++++++ common/session/session.go | 5 +- transport/pipe/impl.go | 5 -- transport/pipe/pipe.go | 7 -- 6 files changed, 75 insertions(+), 94 deletions(-) create mode 100644 common/buf/override.go diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 35307cef..5a71ad41 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -4,7 +4,6 @@ package dispatcher import ( "context" - "fmt" "strings" "sync" "time" @@ -135,77 +134,10 @@ func (*DefaultDispatcher) Start() error { // Close implements common.Closable. func (*DefaultDispatcher) Close() error { return nil } -func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) { - downOpt := pipe.OptionsFromContext(ctx) - upOpt := downOpt - - if network == net.Network_UDP { - var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns - // Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs. - // When target replies, server will restore the domain and send back to client. - // Note: this map is not global but per connection context - upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer { - for i, buffer := range mb { - if buffer.UDP == nil { - continue - } - addr := buffer.UDP.Address - if addr.Family().IsIP() { - if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled { - domain := fkr0.GetDomainFromFakeDNS(addr) - if len(domain) > 0 { - buffer.UDP.Address = net.DomainAddress(domain) - newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) - } else { - newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx)) - } - } - } else { - if ip2domain == nil { - ip2domain = new(sync.Map) - newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx)) - } - domain := addr.Domain() - ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false}) - if err == nil { - for _, ip := range ips { - ip2domain.Store(ip.String(), domain) - } - newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) - } else { - newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx)) - } - } - } - return mb - })) - downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer { - for i, buffer := range mb { - if buffer.UDP == nil { - continue - } - addr := buffer.UDP.Address - if addr.Family().IsIP() { - if ip2domain == nil { - continue - } - if domain, found := ip2domain.Load(addr.IP().String()); found { - buffer.UDP.Address = net.DomainAddress(domain.(string)) - newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) - } - } else { - if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok { - fakeIp := fkr0.GetFakeIPForDomain(addr.Domain()) - buffer.UDP.Address = fakeIp[0] - newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) - } - } - } - return mb - })) - } - uplinkReader, uplinkWriter := pipe.New(upOpt...) - downlinkReader, downlinkWriter := pipe.New(downOpt...) +func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { + opt := pipe.OptionsFromContext(ctx) + uplinkReader, uplinkWriter := pipe.New(opt...) + downlinkReader, downlinkWriter := pipe.New(opt...) inboundLink := &transport.Link{ Reader: downlinkReader, @@ -263,7 +195,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu protocolString = resComp.ProtocolForDomainResult() } for _, p := range request.OverrideDestinationForProtocol { - if strings.HasPrefix(protocolString, p) { + if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) { return true } if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && @@ -287,7 +219,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin panic("Dispatcher: Invalid destination.") } ob := &session.Outbound{ - Target: destination, + OriginalTarget: destination, + Target: destination, } ctx = session.ContextWithOutbound(ctx, ob) content := session.ContentFromContext(ctx) @@ -295,9 +228,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin content = new(session.Content) ctx = session.ContextWithContent(ctx, content) } - sniffingRequest := content.SniffingRequest - inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest) + inbound, outbound := d.getLink(ctx) if !sniffingRequest.Enabled { go d.routedDispatch(ctx, outbound, destination) } else { @@ -314,7 +246,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) - if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { + protocol := result.Protocol() + if resComp, ok := result.(SnifferResultComposite); ok { + protocol = resComp.ProtocolForDomainResult() + } + isFakeIP := false + if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) { + isFakeIP = true + } + if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { ob.RouteTarget = destination } else { ob.Target = destination @@ -332,7 +272,8 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De return newError("Dispatcher: Invalid destination.") } ob := &session.Outbound{ - Target: destination, + OriginalTarget: destination, + Target: destination, } ctx = session.ContextWithOutbound(ctx, ob) content := session.ContentFromContext(ctx) @@ -356,7 +297,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) - if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { + protocol := result.Protocol() + if resComp, ok := result.(SnifferResultComposite); ok { + protocol = resComp.ProtocolForDomainResult() + } + isFakeIP := false + if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) { + isFakeIP = true + } + if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { ob.RouteTarget = destination } else { ob.Target = destination diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index b477dd6b..adf6537a 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -8,6 +8,7 @@ import ( "github.com/xtls/xray-core/app/proxyman" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/mux" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net/cnc" @@ -166,6 +167,11 @@ func (h *Handler) Tag() string { // Dispatch implements proxy.Outbound.Dispatch. func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { + outbound := session.OutboundFromContext(ctx) + if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address { + link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} + link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address} + } if h.mux != nil { test := func(err error) { if err != nil { @@ -175,7 +181,6 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { common.Interrupt(link.Writer) } } - outbound := session.OutboundFromContext(ctx) if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 { switch h.udp443 { case "reject": diff --git a/common/buf/override.go b/common/buf/override.go new file mode 100644 index 00000000..7b2f1554 --- /dev/null +++ b/common/buf/override.go @@ -0,0 +1,38 @@ +package buf + +import ( + "github.com/xtls/xray-core/common/net" +) + +type EndpointOverrideReader struct { + Reader + Dest net.Address + OriginalDest net.Address +} + +func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) { + mb, err := r.Reader.ReadMultiBuffer() + if err == nil { + for _, b := range mb { + if b.UDP != nil && b.UDP.Address == r.OriginalDest { + b.UDP.Address = r.Dest + } + } + } + return mb, err +} + +type EndpointOverrideWriter struct { + Writer + Dest net.Address + OriginalDest net.Address +} + +func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error { + for _, b := range mb { + if b.UDP != nil && b.UDP.Address == w.Dest { + b.UDP.Address = w.OriginalDest + } + } + return w.Writer.WriteMultiBuffer(mb) +} diff --git a/common/session/session.go b/common/session/session.go index 83c48fde..b9609e86 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -55,8 +55,9 @@ type Inbound struct { // Outbound is the metadata of an outbound connection. type Outbound struct { // Target address of the outbound connection. - Target net.Destination - RouteTarget net.Destination + OriginalTarget net.Destination + Target net.Destination + RouteTarget net.Destination // Gateway address Gateway net.Address } diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index a60bc485..dbdb050e 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -24,7 +24,6 @@ const ( type pipeOption struct { limit int32 // maximum buffer size in bytes discardOverflow bool - onTransmission func(buffer buf.MultiBuffer) buf.MultiBuffer } func (o *pipeOption) isFull(curSize int32) bool { @@ -141,10 +140,6 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error { return nil } - if p.option.onTransmission != nil { - mb = p.option.onTransmission(mb) - } - for { err := p.writeMultiBufferInternal(mb) if err == nil { diff --git a/transport/pipe/pipe.go b/transport/pipe/pipe.go index 735cc091..f4b78303 100644 --- a/transport/pipe/pipe.go +++ b/transport/pipe/pipe.go @@ -3,7 +3,6 @@ package pipe import ( "context" - "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/features/policy" @@ -26,12 +25,6 @@ func WithSizeLimit(limit int32) Option { } } -func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option { - return func(option *pipeOption) { - option.onTransmission = hook - } -} - // DiscardOverflow returns an Option for Pipe to discard writes if full. func DiscardOverflow() Option { return func(opt *pipeOption) {