diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 14c31a26..b32013e8 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -271,6 +271,67 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin return inbound, nil } +// DispatchLink implements routing.Dispatcher. +func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { + if !destination.IsValid() { + return newError("Dispatcher: Invalid destination.") + } + ob := &session.Outbound{ + Target: destination, + } + ctx = session.ContextWithOutbound(ctx, ob) + content := session.ContentFromContext(ctx) + if content == nil { + content = new(session.Content) + ctx = session.ContextWithContent(ctx, content) + } + sniffingRequest := content.SniffingRequest + switch { + case !sniffingRequest.Enabled: + go d.routedDispatch(ctx, outbound, destination) + case destination.Network != net.Network_TCP: + // Only metadata sniff will be used for non tcp connection + result, err := sniffer(ctx, nil, true) + if err == nil { + content.Protocol = result.Protocol() + if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { + domain := result.Domain() + newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) + destination.Address = net.ParseAddress(domain) + if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { + ob.RouteTarget = destination + } else { + ob.Target = destination + } + } + } + go d.routedDispatch(ctx, outbound, destination) + default: + go func() { + cReader := &cachedReader{ + reader: outbound.Reader.(*pipe.Reader), + } + outbound.Reader = cReader + result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly) + if err == nil { + content.Protocol = result.Protocol() + } + if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { + domain := result.Domain() + newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) + destination.Address = net.ParseAddress(domain) + if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { + ob.RouteTarget = destination + } else { + ob.Target = destination + } + } + d.routedDispatch(ctx, outbound, destination) + }() + } + return nil +} + func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) { payload := buf.New() defer payload.Release() diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index bef40a3a..cc8b4226 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -147,7 +147,7 @@ func (w *BridgeWorker) Connections() uint32 { return w.worker.ActiveConnections() } -func (w *BridgeWorker) handleInternalConn(link transport.Link) { +func (w *BridgeWorker) handleInternalConn(link *transport.Link) { go func() { reader := link.Reader for { @@ -181,7 +181,7 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra uplinkReader, uplinkWriter := pipe.New(opt...) downlinkReader, downlinkWriter := pipe.New(opt...) - w.handleInternalConn(transport.Link{ + w.handleInternalConn(&transport.Link{ Reader: downlinkReader, Writer: uplinkWriter, }) @@ -191,3 +191,16 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra Writer: downlinkWriter, }, nil } + +func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error { + if !isInternalDomain(dest) { + ctx = session.ContextWithInbound(ctx, &session.Inbound{ + Tag: w.tag, + }) + return w.dispatcher.DispatchLink(ctx, dest, link) + } + + w.handleInternalConn(link) + + return nil +} diff --git a/common/mux/server.go b/common/mux/server.go index e21cb618..3a913098 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -56,6 +56,15 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*transport return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil } +// DispatchLink implements routing.Dispatcher +func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error { + if dest.Address != muxCoolAddress { + return s.dispatcher.DispatchLink(ctx, dest, link) + } + _, err := NewServerWorker(ctx, s.dispatcher, link) + return err +} + // Start implements common.Runnable. func (s *Server) Start() error { return nil diff --git a/features/routing/dispatcher.go b/features/routing/dispatcher.go index cfc3111a..53d3bf90 100644 --- a/features/routing/dispatcher.go +++ b/features/routing/dispatcher.go @@ -17,6 +17,7 @@ type Dispatcher interface { // Dispatch returns a Ray for transporting data for the given request. Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) + DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error } // DispatcherType returns the type of Dispatcher interface. Can be used to implement common.HasType. diff --git a/transport/internet/udp/dispatcher_test.go b/transport/internet/udp/dispatcher_test.go index a9e3c1fd..d33a47be 100644 --- a/transport/internet/udp/dispatcher_test.go +++ b/transport/internet/udp/dispatcher_test.go @@ -24,6 +24,10 @@ func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*t return d.OnDispatch(ctx, dest) } +func (d *TestDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { + return nil +} + func (d *TestDispatcher) Start() error { return nil }