mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-25 06:09:19 +02:00
Fix buffer.UDP destination override (#2356)
This commit is contained in:
parent
e013dce1df
commit
b8bd243df5
|
@ -4,7 +4,6 @@ package dispatcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -135,77 +134,10 @@ func (*DefaultDispatcher) Start() error {
|
||||||
// Close implements common.Closable.
|
// Close implements common.Closable.
|
||||||
func (*DefaultDispatcher) Close() error { return nil }
|
func (*DefaultDispatcher) Close() error { return nil }
|
||||||
|
|
||||||
func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
|
func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
|
||||||
downOpt := pipe.OptionsFromContext(ctx)
|
opt := pipe.OptionsFromContext(ctx)
|
||||||
upOpt := downOpt
|
uplinkReader, uplinkWriter := pipe.New(opt...)
|
||||||
|
downlinkReader, downlinkWriter := pipe.New(opt...)
|
||||||
if network == net.Network_UDP {
|
|
||||||
var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
|
|
||||||
// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
|
|
||||||
// When target replies, server will restore the domain and send back to client.
|
|
||||||
// Note: this map is not global but per connection context
|
|
||||||
upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
|
|
||||||
for i, buffer := range mb {
|
|
||||||
if buffer.UDP == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addr := buffer.UDP.Address
|
|
||||||
if addr.Family().IsIP() {
|
|
||||||
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
|
|
||||||
domain := fkr0.GetDomainFromFakeDNS(addr)
|
|
||||||
if len(domain) > 0 {
|
|
||||||
buffer.UDP.Address = net.DomainAddress(domain)
|
|
||||||
newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
||||||
} else {
|
|
||||||
newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ip2domain == nil {
|
|
||||||
ip2domain = new(sync.Map)
|
|
||||||
newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
|
|
||||||
}
|
|
||||||
domain := addr.Domain()
|
|
||||||
ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
|
|
||||||
if err == nil {
|
|
||||||
for _, ip := range ips {
|
|
||||||
ip2domain.Store(ip.String(), domain)
|
|
||||||
}
|
|
||||||
newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
||||||
} else {
|
|
||||||
newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mb
|
|
||||||
}))
|
|
||||||
downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
|
|
||||||
for i, buffer := range mb {
|
|
||||||
if buffer.UDP == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addr := buffer.UDP.Address
|
|
||||||
if addr.Family().IsIP() {
|
|
||||||
if ip2domain == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if domain, found := ip2domain.Load(addr.IP().String()); found {
|
|
||||||
buffer.UDP.Address = net.DomainAddress(domain.(string))
|
|
||||||
newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
|
|
||||||
fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
|
|
||||||
buffer.UDP.Address = fakeIp[0]
|
|
||||||
newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mb
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
uplinkReader, uplinkWriter := pipe.New(upOpt...)
|
|
||||||
downlinkReader, downlinkWriter := pipe.New(downOpt...)
|
|
||||||
|
|
||||||
inboundLink := &transport.Link{
|
inboundLink := &transport.Link{
|
||||||
Reader: downlinkReader,
|
Reader: downlinkReader,
|
||||||
|
@ -263,7 +195,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu
|
||||||
protocolString = resComp.ProtocolForDomainResult()
|
protocolString = resComp.ProtocolForDomainResult()
|
||||||
}
|
}
|
||||||
for _, p := range request.OverrideDestinationForProtocol {
|
for _, p := range request.OverrideDestinationForProtocol {
|
||||||
if strings.HasPrefix(protocolString, p) {
|
if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
|
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
|
||||||
|
@ -287,7 +219,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
|
||||||
panic("Dispatcher: Invalid destination.")
|
panic("Dispatcher: Invalid destination.")
|
||||||
}
|
}
|
||||||
ob := &session.Outbound{
|
ob := &session.Outbound{
|
||||||
Target: destination,
|
OriginalTarget: destination,
|
||||||
|
Target: destination,
|
||||||
}
|
}
|
||||||
ctx = session.ContextWithOutbound(ctx, ob)
|
ctx = session.ContextWithOutbound(ctx, ob)
|
||||||
content := session.ContentFromContext(ctx)
|
content := session.ContentFromContext(ctx)
|
||||||
|
@ -295,9 +228,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
|
||||||
content = new(session.Content)
|
content = new(session.Content)
|
||||||
ctx = session.ContextWithContent(ctx, content)
|
ctx = session.ContextWithContent(ctx, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
sniffingRequest := content.SniffingRequest
|
sniffingRequest := content.SniffingRequest
|
||||||
inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
|
inbound, outbound := d.getLink(ctx)
|
||||||
if !sniffingRequest.Enabled {
|
if !sniffingRequest.Enabled {
|
||||||
go d.routedDispatch(ctx, outbound, destination)
|
go d.routedDispatch(ctx, outbound, destination)
|
||||||
} else {
|
} else {
|
||||||
|
@ -314,7 +246,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
|
||||||
domain := result.Domain()
|
domain := result.Domain()
|
||||||
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
||||||
destination.Address = net.ParseAddress(domain)
|
destination.Address = net.ParseAddress(domain)
|
||||||
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
protocol := result.Protocol()
|
||||||
|
if resComp, ok := result.(SnifferResultComposite); ok {
|
||||||
|
protocol = resComp.ProtocolForDomainResult()
|
||||||
|
}
|
||||||
|
isFakeIP := false
|
||||||
|
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
|
||||||
|
isFakeIP = true
|
||||||
|
}
|
||||||
|
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
|
||||||
ob.RouteTarget = destination
|
ob.RouteTarget = destination
|
||||||
} else {
|
} else {
|
||||||
ob.Target = destination
|
ob.Target = destination
|
||||||
|
@ -332,7 +272,8 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
|
||||||
return newError("Dispatcher: Invalid destination.")
|
return newError("Dispatcher: Invalid destination.")
|
||||||
}
|
}
|
||||||
ob := &session.Outbound{
|
ob := &session.Outbound{
|
||||||
Target: destination,
|
OriginalTarget: destination,
|
||||||
|
Target: destination,
|
||||||
}
|
}
|
||||||
ctx = session.ContextWithOutbound(ctx, ob)
|
ctx = session.ContextWithOutbound(ctx, ob)
|
||||||
content := session.ContentFromContext(ctx)
|
content := session.ContentFromContext(ctx)
|
||||||
|
@ -356,7 +297,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
|
||||||
domain := result.Domain()
|
domain := result.Domain()
|
||||||
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
|
||||||
destination.Address = net.ParseAddress(domain)
|
destination.Address = net.ParseAddress(domain)
|
||||||
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
|
protocol := result.Protocol()
|
||||||
|
if resComp, ok := result.(SnifferResultComposite); ok {
|
||||||
|
protocol = resComp.ProtocolForDomainResult()
|
||||||
|
}
|
||||||
|
isFakeIP := false
|
||||||
|
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
|
||||||
|
isFakeIP = true
|
||||||
|
}
|
||||||
|
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
|
||||||
ob.RouteTarget = destination
|
ob.RouteTarget = destination
|
||||||
} else {
|
} else {
|
||||||
ob.Target = destination
|
ob.Target = destination
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/xtls/xray-core/app/proxyman"
|
"github.com/xtls/xray-core/app/proxyman"
|
||||||
"github.com/xtls/xray-core/common"
|
"github.com/xtls/xray-core/common"
|
||||||
|
"github.com/xtls/xray-core/common/buf"
|
||||||
"github.com/xtls/xray-core/common/mux"
|
"github.com/xtls/xray-core/common/mux"
|
||||||
"github.com/xtls/xray-core/common/net"
|
"github.com/xtls/xray-core/common/net"
|
||||||
"github.com/xtls/xray-core/common/net/cnc"
|
"github.com/xtls/xray-core/common/net/cnc"
|
||||||
|
@ -166,6 +167,11 @@ func (h *Handler) Tag() string {
|
||||||
|
|
||||||
// Dispatch implements proxy.Outbound.Dispatch.
|
// Dispatch implements proxy.Outbound.Dispatch.
|
||||||
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
||||||
|
outbound := session.OutboundFromContext(ctx)
|
||||||
|
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
|
||||||
|
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
|
||||||
|
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
|
||||||
|
}
|
||||||
if h.mux != nil {
|
if h.mux != nil {
|
||||||
test := func(err error) {
|
test := func(err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -175,7 +181,6 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
||||||
common.Interrupt(link.Writer)
|
common.Interrupt(link.Writer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outbound := session.OutboundFromContext(ctx)
|
|
||||||
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
|
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
|
||||||
switch h.udp443 {
|
switch h.udp443 {
|
||||||
case "reject":
|
case "reject":
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package buf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/xtls/xray-core/common/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EndpointOverrideReader struct {
|
||||||
|
Reader
|
||||||
|
Dest net.Address
|
||||||
|
OriginalDest net.Address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) {
|
||||||
|
mb, err := r.Reader.ReadMultiBuffer()
|
||||||
|
if err == nil {
|
||||||
|
for _, b := range mb {
|
||||||
|
if b.UDP != nil && b.UDP.Address == r.OriginalDest {
|
||||||
|
b.UDP.Address = r.Dest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mb, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type EndpointOverrideWriter struct {
|
||||||
|
Writer
|
||||||
|
Dest net.Address
|
||||||
|
OriginalDest net.Address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error {
|
||||||
|
for _, b := range mb {
|
||||||
|
if b.UDP != nil && b.UDP.Address == w.Dest {
|
||||||
|
b.UDP.Address = w.OriginalDest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w.Writer.WriteMultiBuffer(mb)
|
||||||
|
}
|
|
@ -55,8 +55,9 @@ type Inbound struct {
|
||||||
// Outbound is the metadata of an outbound connection.
|
// Outbound is the metadata of an outbound connection.
|
||||||
type Outbound struct {
|
type Outbound struct {
|
||||||
// Target address of the outbound connection.
|
// Target address of the outbound connection.
|
||||||
Target net.Destination
|
OriginalTarget net.Destination
|
||||||
RouteTarget net.Destination
|
Target net.Destination
|
||||||
|
RouteTarget net.Destination
|
||||||
// Gateway address
|
// Gateway address
|
||||||
Gateway net.Address
|
Gateway net.Address
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,6 @@ const (
|
||||||
type pipeOption struct {
|
type pipeOption struct {
|
||||||
limit int32 // maximum buffer size in bytes
|
limit int32 // maximum buffer size in bytes
|
||||||
discardOverflow bool
|
discardOverflow bool
|
||||||
onTransmission func(buffer buf.MultiBuffer) buf.MultiBuffer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *pipeOption) isFull(curSize int32) bool {
|
func (o *pipeOption) isFull(curSize int32) bool {
|
||||||
|
@ -141,10 +140,6 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.option.onTransmission != nil {
|
|
||||||
mb = p.option.onTransmission(mb)
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
err := p.writeMultiBufferInternal(mb)
|
err := p.writeMultiBufferInternal(mb)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
@ -3,7 +3,6 @@ package pipe
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common/buf"
|
|
||||||
"github.com/xtls/xray-core/common/signal"
|
"github.com/xtls/xray-core/common/signal"
|
||||||
"github.com/xtls/xray-core/common/signal/done"
|
"github.com/xtls/xray-core/common/signal/done"
|
||||||
"github.com/xtls/xray-core/features/policy"
|
"github.com/xtls/xray-core/features/policy"
|
||||||
|
@ -26,12 +25,6 @@ func WithSizeLimit(limit int32) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option {
|
|
||||||
return func(option *pipeOption) {
|
|
||||||
option.onTransmission = hook
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DiscardOverflow returns an Option for Pipe to discard writes if full.
|
// DiscardOverflow returns an Option for Pipe to discard writes if full.
|
||||||
func DiscardOverflow() Option {
|
func DiscardOverflow() Option {
|
||||||
return func(opt *pipeOption) {
|
return func(opt *pipeOption) {
|
||||||
|
|
Loading…
Reference in New Issue