mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-12-22 11:23:32 +02:00
Add DispatchLink
This commit is contained in:
parent
625cf7361a
commit
50e576081e
5 changed files with 90 additions and 2 deletions
|
@ -271,6 +271,67 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
// DispatchLink implements routing.Dispatcher.
|
||||
func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
|
||||
if !destination.IsValid() {
|
||||
return newError("Dispatcher: Invalid destination.")
|
||||
}
|
||||
ob := &session.Outbound{
|
||||
Target: destination,
|
||||
}
|
||||
ctx = session.ContextWithOutbound(ctx, ob)
|
||||
content := session.ContentFromContext(ctx)
|
||||
if content == nil {
|
||||
content = new(session.Content)
|
||||
ctx = session.ContextWithContent(ctx, content)
|
||||
}
|
||||
sniffingRequest := content.SniffingRequest
|
||||
switch {
|
||||
case !sniffingRequest.Enabled:
|
||||
go d.routedDispatch(ctx, outbound, destination)
|
||||
case destination.Network != net.Network_TCP:
|
||||
// Only metadata sniff will be used for non tcp connection
|
||||
result, err := sniffer(ctx, nil, true)
|
||||
if err == nil {
|
||||
content.Protocol = result.Protocol()
|
||||
if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
|
||||
domain := result.Domain()
|
||||
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
||||
destination.Address = net.ParseAddress(domain)
|
||||
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
||||
ob.RouteTarget = destination
|
||||
} else {
|
||||
ob.Target = destination
|
||||
}
|
||||
}
|
||||
}
|
||||
go d.routedDispatch(ctx, outbound, destination)
|
||||
default:
|
||||
go func() {
|
||||
cReader := &cachedReader{
|
||||
reader: outbound.Reader.(*pipe.Reader),
|
||||
}
|
||||
outbound.Reader = cReader
|
||||
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly)
|
||||
if err == nil {
|
||||
content.Protocol = result.Protocol()
|
||||
}
|
||||
if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
|
||||
domain := result.Domain()
|
||||
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
||||
destination.Address = net.ParseAddress(domain)
|
||||
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
||||
ob.RouteTarget = destination
|
||||
} else {
|
||||
ob.Target = destination
|
||||
}
|
||||
}
|
||||
d.routedDispatch(ctx, outbound, destination)
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) {
|
||||
payload := buf.New()
|
||||
defer payload.Release()
|
||||
|
|
|
@ -147,7 +147,7 @@ func (w *BridgeWorker) Connections() uint32 {
|
|||
return w.worker.ActiveConnections()
|
||||
}
|
||||
|
||||
func (w *BridgeWorker) handleInternalConn(link transport.Link) {
|
||||
func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
|
||||
go func() {
|
||||
reader := link.Reader
|
||||
for {
|
||||
|
@ -181,7 +181,7 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra
|
|||
uplinkReader, uplinkWriter := pipe.New(opt...)
|
||||
downlinkReader, downlinkWriter := pipe.New(opt...)
|
||||
|
||||
w.handleInternalConn(transport.Link{
|
||||
w.handleInternalConn(&transport.Link{
|
||||
Reader: downlinkReader,
|
||||
Writer: uplinkWriter,
|
||||
})
|
||||
|
@ -191,3 +191,16 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra
|
|||
Writer: downlinkWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
|
||||
if !isInternalDomain(dest) {
|
||||
ctx = session.ContextWithInbound(ctx, &session.Inbound{
|
||||
Tag: w.tag,
|
||||
})
|
||||
return w.dispatcher.DispatchLink(ctx, dest, link)
|
||||
}
|
||||
|
||||
w.handleInternalConn(link)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -56,6 +56,15 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*transport
|
|||
return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil
|
||||
}
|
||||
|
||||
// DispatchLink implements routing.Dispatcher
|
||||
func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
|
||||
if dest.Address != muxCoolAddress {
|
||||
return s.dispatcher.DispatchLink(ctx, dest, link)
|
||||
}
|
||||
_, err := NewServerWorker(ctx, s.dispatcher, link)
|
||||
return err
|
||||
}
|
||||
|
||||
// Start implements common.Runnable.
|
||||
func (s *Server) Start() error {
|
||||
return nil
|
||||
|
|
|
@ -17,6 +17,7 @@ type Dispatcher interface {
|
|||
|
||||
// Dispatch returns a Ray for transporting data for the given request.
|
||||
Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error)
|
||||
DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error
|
||||
}
|
||||
|
||||
// DispatcherType returns the type of Dispatcher interface. Can be used to implement common.HasType.
|
||||
|
|
|
@ -24,6 +24,10 @@ func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*t
|
|||
return d.OnDispatch(ctx, dest)
|
||||
}
|
||||
|
||||
func (d *TestDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *TestDispatcher) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue