From efd32b0fb2a4f1b2d2e98b21916cfd9aa0e0b497 Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Wed, 3 May 2023 22:21:45 -0400 Subject: [PATCH] Enable splice for freedom outbound (downlink only) - Add outbound name - Add outbound conn in ctx - Refactor splice: it can be turn on from all inbounds and outbounds - Refactor splice: Add splice copy to vless inbound - Fix http error test - Add freedom splice toggle via env var - Populate outbound obj in context - Use CanSpliceCopy to mark a connection - Turn off splice by default --- app/dispatcher/default.go | 20 +++--- app/proxyman/inbound/worker.go | 6 +- app/proxyman/outbound/handler.go | 7 +- common/buf/copy.go | 12 ++++ common/session/session.go | 14 ++++ proxy/blackhole/blackhole.go | 6 ++ proxy/dns/dns.go | 1 + proxy/dokodemo/dokodemo.go | 9 ++- proxy/errors.generated.go | 9 +++ proxy/freedom/freedom.go | 27 ++++++-- proxy/http/client.go | 5 ++ proxy/http/server.go | 9 ++- proxy/loopback/loopback.go | 1 + proxy/proxy.go | 86 +++++++++++++++++++++++++ proxy/shadowsocks/client.go | 5 ++ proxy/shadowsocks/server.go | 10 ++- proxy/shadowsocks_2022/inbound.go | 1 + proxy/shadowsocks_2022/inbound_multi.go | 1 + proxy/shadowsocks_2022/inbound_relay.go | 1 + proxy/shadowsocks_2022/outbound.go | 2 + proxy/socks/client.go | 5 ++ proxy/socks/server.go | 10 +-- proxy/trojan/client.go | 5 ++ proxy/trojan/server.go | 4 +- proxy/vless/encoding/encoding.go | 68 ++++++------------- proxy/vless/inbound/inbound.go | 35 ++-------- proxy/vless/outbound/outbound.go | 47 ++++++-------- proxy/vmess/inbound/inbound.go | 4 +- proxy/vmess/outbound/outbound.go | 16 +++-- proxy/wireguard/wireguard.go | 14 ++-- testing/scenarios/http_test.go | 6 +- testing/scenarios/vmess_test.go | 4 +- 32 files changed, 282 insertions(+), 168 deletions(-) create mode 100644 proxy/errors.generated.go diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index aaa9b410..bfc43608 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -218,11 +218,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin if !destination.IsValid() { panic("Dispatcher: Invalid destination.") } - ob := &session.Outbound{ - OriginalTarget: destination, - Target: destination, + ob := session.OutboundFromContext(ctx) + if ob == nil { + ob = &session.Outbound{} + ctx = session.ContextWithOutbound(ctx, ob) } - ctx = session.ContextWithOutbound(ctx, ob) + ob.OriginalTarget = destination + ob.Target = destination content := session.ContentFromContext(ctx) if content == nil { content = new(session.Content) @@ -271,11 +273,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De if !destination.IsValid() { return newError("Dispatcher: Invalid destination.") } - ob := &session.Outbound{ - OriginalTarget: destination, - Target: destination, + ob := session.OutboundFromContext(ctx) + if ob == nil { + ob = &session.Outbound{} + ctx = session.ContextWithOutbound(ctx, ob) } - ctx = session.ContextWithOutbound(ctx, ob) + ob.OriginalTarget = destination + ob.Target = destination content := session.ContentFromContext(ctx) if content == nil { content = new(session.Content) diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 8ed4090a..1fe86655 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -60,6 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) { sid := session.NewID() ctx = session.ContextWithID(ctx, sid) + var outbound = &session.Outbound{} if w.recvOrigDest { var dest net.Destination switch getTProxyType(w.stream) { @@ -74,11 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) { dest = net.DestinationFromAddr(conn.LocalAddr()) } if dest.IsValid() { - ctx = session.ContextWithOutbound(ctx, &session.Outbound{ - Target: dest, - }) + outbound.Target = dest } } + ctx = session.ContextWithOutbound(ctx, outbound) if w.uplinkCounter != nil || w.downlinkCounter != nil { conn = &stat.CounterConnection{ diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index adf6537a..d290b016 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -274,7 +274,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti } conn, err := internet.Dial(ctx, dest, h.streamSettings) - return h.getStatCouterConnection(conn), err + conn = h.getStatCouterConnection(conn) + outbound := session.OutboundFromContext(ctx) + if outbound != nil { + outbound.Conn = conn + } + return conn, err } func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection { diff --git a/common/buf/copy.go b/common/buf/copy.go index 601771be..3096dc57 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -6,6 +6,7 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/signal" + "github.com/xtls/xray-core/features/stats" ) type dataHandler func(MultiBuffer) @@ -40,6 +41,17 @@ func CountSize(sc *SizeCounter) CopyOption { } } +// AddToStatCounter a CopyOption add to stat counter +func AddToStatCounter(sc stats.Counter) CopyOption { + return func(handler *copyHandler) { + handler.onData = append(handler.onData, func(b MultiBuffer) { + if sc != nil { + sc.Add(int64(b.Len())) + } + }) + } +} + type readError struct { error } diff --git a/common/session/session.go b/common/session/session.go index b9609e86..4af61015 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -50,6 +50,16 @@ type Inbound struct { Conn net.Conn // Timer of the inbound buf copier. May be nil. Timer *signal.ActivityTimer + // CanSpliceCopy is a property for this connection, set by both inbound and outbound + // 1 = can, 2 = after processing protocol info should be able to, 3 = cannot + CanSpliceCopy int +} + +func(i *Inbound) SetCanSpliceCopy(canSpliceCopy int) int { + if canSpliceCopy > i.CanSpliceCopy { + i.CanSpliceCopy = canSpliceCopy + } + return i.CanSpliceCopy } // Outbound is the metadata of an outbound connection. @@ -60,6 +70,10 @@ type Outbound struct { RouteTarget net.Destination // Gateway address Gateway net.Address + // Name of the outbound proxy that handles the connection. + Name string + // Conn is actually internet.Connection. May be nil. It is currently nil for outbound with proxySettings + Conn net.Conn } // SniffingRequest controls the behavior of content sniffing. diff --git a/proxy/blackhole/blackhole.go b/proxy/blackhole/blackhole.go index b17c60c4..4b819417 100644 --- a/proxy/blackhole/blackhole.go +++ b/proxy/blackhole/blackhole.go @@ -8,6 +8,7 @@ import ( "time" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" ) @@ -30,6 +31,11 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + outbound := session.OutboundFromContext(ctx) + if outbound != nil { + outbound.Name = "blackhole" + } + nBytes := h.response.WriteTo(link.Writer) if nBytes > 0 { // Sleep a little here to make sure the response is sent to client. diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index d8a3244d..415fe991 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -96,6 +96,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. if outbound == nil || !outbound.Target.IsValid() { return newError("invalid outbound") } + outbound.Name = "dns" srcNetwork := outbound.Target.Network diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 42d8256f..4a4735e8 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -102,11 +102,10 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st } inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.Name = "dokodemo-door" - inbound.User = &protocol.MemoryUser{ - Level: d.config.UserLevel, - } + inbound.Name = "dokodemo-door" + inbound.SetCanSpliceCopy(1) + inbound.User = &protocol.MemoryUser{ + Level: d.config.UserLevel, } ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ diff --git a/proxy/errors.generated.go b/proxy/errors.generated.go new file mode 100644 index 00000000..1a643896 --- /dev/null +++ b/proxy/errors.generated.go @@ -0,0 +1,9 @@ +package proxy + +import "github.com/xtls/xray-core/common/errors" + +type errPathObjHolder struct{} + +func newError(values ...interface{}) *errors.Error { + return errors.New(values...).WithPathObj(errPathObjHolder{}) +} diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index c6907b4c..808f837f 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -13,6 +13,7 @@ import ( "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/common/retry" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal" @@ -21,11 +22,14 @@ import ( "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/stats" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" ) +var useSplice bool + func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { h := new(Handler) @@ -36,6 +40,12 @@ func init() { } return h, nil })) + const defaultFlagValue = "NOT_DEFINED_AT_ALL" + value := platform.NewEnvFlag("xray.buf.splice").GetValue(func() string { return defaultFlagValue }) + switch value { + case "auto", "enable": + useSplice = true + } } // Handler handles Freedom connections. @@ -107,6 +117,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified.") } + outbound.Name = "freedom" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(1) + } destination := outbound.Target UDPOverride := net.UDPDestination(nil, 0) if h.config.DestinationOverride != nil { @@ -195,17 +210,17 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte responseDone := func() error { defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) - - var reader buf.Reader if destination.Network == net.Network_TCP { - reader = buf.NewReader(conn) - } else { - reader = NewPacketReader(conn, UDPOverride) + var writeConn net.Conn + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && useSplice { + writeConn = inbound.Conn + } + return proxy.CopyRawConnIfExist(ctx, conn, writeConn, link.Writer, timer) } + reader := NewPacketReader(conn, UDPOverride) if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil { return newError("failed to process response").Base(err) } - return nil } diff --git a/proxy/http/client.go b/proxy/http/client.go index f597a502..302e521d 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -73,6 +73,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified.") } + outbound.Name = "http" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(2) + } target := outbound.Target targetAddr := target.NetAddr() diff --git a/proxy/http/server.go b/proxy/http/server.go index 6b00fe2b..511d9b08 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -84,11 +84,10 @@ type readerOnly struct { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.Name = "http" - inbound.User = &protocol.MemoryUser{ - Level: s.config.UserLevel, - } + inbound.Name = "http" + inbound.SetCanSpliceCopy(2) + inbound.User = &protocol.MemoryUser{ + Level: s.config.UserLevel, } reader := bufio.NewReaderSize(readerOnly{conn}, buf.Size) diff --git a/proxy/loopback/loopback.go b/proxy/loopback/loopback.go index 946847f3..30c39bd9 100644 --- a/proxy/loopback/loopback.go +++ b/proxy/loopback/loopback.go @@ -26,6 +26,7 @@ func (l *Loopback) Process(ctx context.Context, link *transport.Link, _ internet if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified.") } + outbound.Name = "loopback" destination := outbound.Target newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/proxy.go b/proxy/proxy.go index fb52605c..12b9631b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -7,13 +7,24 @@ package proxy import ( "context" + gotls "crypto/tls" + "io" + "runtime" + "github.com/pires/go-proxyproto" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/transport/internet/tls" ) // An Inbound processes inbound connections. @@ -47,3 +58,78 @@ type GetInbound interface { type GetOutbound interface { GetOutbound() Outbound } + +// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it +func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) { + var readCounter, writerCounter stats.Counter + if conn != nil { + statConn, ok := conn.(*stat.CounterConnection) + if ok { + conn = statConn.Connection + readCounter = statConn.ReadCounter + writerCounter = statConn.WriteCounter + } + if xc, ok := conn.(*gotls.Conn); ok { + conn = xc.NetConn() + } else if utlsConn, ok := conn.(*tls.UConn); ok { + conn = utlsConn.NetConn() + } else if realityConn, ok := conn.(*reality.Conn); ok { + conn = realityConn.NetConn() + } else if realityUConn, ok := conn.(*reality.UConn); ok { + conn = realityUConn.NetConn() + } + if pc, ok := conn.(*proxyproto.Conn); ok { + conn = pc.Raw() + // 8192 > 4096, there is no need to process pc's bufReader + } + } + return conn, readCounter, writerCounter +} + +// CopyRawConnIfExist use the most efficient copy method. +// - If caller don't want to turn on splice, do not pass in both reader conn and writer conn +// - writer are from *transport.Link +func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net.Conn, writer buf.Writer, timer signal.ActivityUpdater) error { + readerConn, readCounter, _ := UnwrapRawConn(readerConn) + writerConn, _, writeCounter := UnwrapRawConn(writerConn) + reader := buf.NewReader(readerConn) + if inbound := session.InboundFromContext(ctx); inbound != nil { + if tc, ok := writerConn.(*net.TCPConn); ok && readerConn != nil && writerConn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { + for inbound.CanSpliceCopy != 3 { + if inbound.CanSpliceCopy == 1 { + newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx)) + runtime.Gosched() // necessary + w, err := tc.ReadFrom(readerConn) + if readCounter != nil { + readCounter.Add(w) + } + if writeCounter != nil { + writeCounter.Add(w) + } + if err != nil && errors.Cause(err) != io.EOF { + return err + } + return nil + } + buffer, err := reader.ReadMultiBuffer() + if !buffer.IsEmpty() { + if readCounter != nil { + readCounter.Add(int64(buffer.Len())) + } + timer.Update() + if werr := writer.WriteMultiBuffer(buffer); werr != nil { + return werr + } + } + if err != nil { + return err + } + } + } + } + newError("CopyRawConn readv").WriteToLog(session.ExportIDToError(ctx)) + if err := buf.Copy(reader, writer, buf.UpdateActivity(timer), buf.AddToStatCounter(readCounter)); err != nil { + return newError("failed to process response").Base(err) + } + return nil +} diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index e22b11c7..57d8f81c 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -53,6 +53,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified") } + outbound.Name = "shadowsocks" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(3) + } destination := outbound.Target network := destination.Network diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 1d89db5e..2975ba70 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -71,6 +71,10 @@ func (s *Server) Network() []net.Network { } func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { + inbound := session.InboundFromContext(ctx) + inbound.Name = "shadowsocks" + inbound.SetCanSpliceCopy(3) + switch network { case net.Network_TCP: return s.handleConnection(ctx, conn, dispatcher) @@ -110,13 +114,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis }) inbound := session.InboundFromContext(ctx) - if inbound == nil { - panic("no inbound metadata") - } - inbound.Name = "shadowsocks" - var dest *net.Destination - reader := buf.NewPacketReader(conn) for { mpayload, err := reader.ReadMultiBuffer() diff --git a/proxy/shadowsocks_2022/inbound.go b/proxy/shadowsocks_2022/inbound.go index bb298c09..246fc7f1 100644 --- a/proxy/shadowsocks_2022/inbound.go +++ b/proxy/shadowsocks_2022/inbound.go @@ -66,6 +66,7 @@ func (i *Inbound) Network() []net.Network { func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022" + inbound.SetCanSpliceCopy(3) var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_multi.go b/proxy/shadowsocks_2022/inbound_multi.go index c9927476..c3832a91 100644 --- a/proxy/shadowsocks_2022/inbound_multi.go +++ b/proxy/shadowsocks_2022/inbound_multi.go @@ -155,6 +155,7 @@ func (i *MultiUserInbound) Network() []net.Network { func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-multi" + inbound.SetCanSpliceCopy(3) var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_relay.go b/proxy/shadowsocks_2022/inbound_relay.go index c3f8e675..e2cb7d50 100644 --- a/proxy/shadowsocks_2022/inbound_relay.go +++ b/proxy/shadowsocks_2022/inbound_relay.go @@ -87,6 +87,7 @@ func (i *RelayInbound) Network() []net.Network { func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) inbound.Name = "shadowsocks-2022-relay" + inbound.SetCanSpliceCopy(3) var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/outbound.go b/proxy/shadowsocks_2022/outbound.go index 151ea0e2..a06daac7 100644 --- a/proxy/shadowsocks_2022/outbound.go +++ b/proxy/shadowsocks_2022/outbound.go @@ -66,12 +66,14 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int inbound := session.InboundFromContext(ctx) if inbound != nil { inboundConn = inbound.Conn + inbound.SetCanSpliceCopy(3) } outbound := session.OutboundFromContext(ctx) if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified") } + outbound.Name = "shadowsocks-2022" destination := outbound.Target network := destination.Network diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 1993aa0b..82591be4 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -61,6 +61,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified.") } + outbound.Name = "socks" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(2) + } // Destination of the inner request. destination := outbound.Target diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 184ecd08..6964fdf2 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -63,11 +63,11 @@ func (s *Server) Network() []net.Network { // Process implements proxy.Inbound. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { - if inbound := session.InboundFromContext(ctx); inbound != nil { - inbound.Name = "socks" - inbound.User = &protocol.MemoryUser{ - Level: s.config.UserLevel, - } + inbound := session.InboundFromContext(ctx) + inbound.Name = "socks" + inbound.SetCanSpliceCopy(2) + inbound.User = &protocol.MemoryUser{ + Level: s.config.UserLevel, } switch network { diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index 0c6f16d3..d6b95fc0 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -54,6 +54,11 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified") } + outbound.Name = "trojan" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(3) + } destination := outbound.Target network := destination.Network diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 41245ba4..5c3fcd91 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -214,10 +214,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con } inbound := session.InboundFromContext(ctx) - if inbound == nil { - panic("no inbound metadata") - } inbound.Name = "trojan" + inbound.SetCanSpliceCopy(3) inbound.User = user sessionPolicy = s.policyManager.ForLevel(user.Level) diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index cf962492..48bda497 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -8,9 +8,7 @@ import ( "crypto/rand" "io" "math/big" - "runtime" "strconv" - "syscall" "time" "github.com/xtls/xray-core/common/buf" @@ -20,10 +18,8 @@ import ( "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/features/stats" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy/vless" - "github.com/xtls/xray-core/transport/internet/reality" - "github.com/xtls/xray-core/transport/internet/stat" - "github.com/xtls/xray-core/transport/internet/tls" ) const ( @@ -206,13 +202,11 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn, - input *bytes.Reader, rawInput *bytes.Buffer, - counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, +func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, + ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32, ) error { err := func() error { - var ct stats.Counter withinPaddingBuffers := true shouldSwitchToDirectCopy := false var remainingContent int32 = -1 @@ -220,40 +214,14 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater currentCommand := 0 for { if shouldSwitchToDirectCopy { - shouldSwitchToDirectCopy = false - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - if _, ok := inbound.User.Account.(*vless.MemoryAccount); inbound.User.Account == nil || ok { - iConn := inbound.Conn - statConn, ok := iConn.(*stat.CounterConnection) - if ok { - iConn = statConn.Connection - } - if tlsConn, ok := iConn.(*tls.Conn); ok { - iConn = tlsConn.NetConn() - } else if realityConn, ok := iConn.(*reality.Conn); ok { - iConn = realityConn.NetConn() - } - if tc, ok := iConn.(*net.TCPConn); ok { - newError("XtlsRead splice").WriteToLog(session.ExportIDToError(ctx)) - runtime.Gosched() // necessary - w, err := tc.ReadFrom(conn) - if counter != nil { - counter.Add(w) - } - if statConn != nil && statConn.WriteCounter != nil { - statConn.WriteCounter.Add(w) - } - return err - } + var writerConn net.Conn + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil { + writerConn = inbound.Conn + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter } } - if rawConn != nil { - reader = buf.NewReadVReader(conn, rawConn, nil) - } else { - reader = buf.NewReader(conn) - } - ct = counter - newError("XtlsRead readV").WriteToLog(session.ExportIDToError(ctx)) + return proxy.CopyRawConnIfExist(ctx, conn, writerConn, writer, timer) } buffer, err := reader.ReadMultiBuffer() if !buffer.IsEmpty() { @@ -292,9 +260,6 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater if *numberOfPacketToFilter > 0 { XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx) } - if ct != nil { - ct.Add(int64(buffer.Len())) - } timer.Update() if werr := writer.WriteMultiBuffer(buffer); werr != nil { return werr @@ -312,7 +277,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter, +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, ctx context.Context, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32, ) error { @@ -349,18 +314,21 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate } if shouldSwitchToDirectCopy { encryptBuffer, directBuffer := buf.SplitMulti(buffer, xtlsSpecIndex+1) - length := encryptBuffer.Len() if !encryptBuffer.IsEmpty() { timer.Update() if werr := writer.WriteMultiBuffer(encryptBuffer); werr != nil { return werr } } - buffer = directBuffer - writer = buf.NewWriter(conn) - ct = counter - newError("XtlsWrite writeV ", xtlsSpecIndex, " ", length, " ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) time.Sleep(5 * time.Millisecond) // for some device, the first xtls direct packet fails without this delay + + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 // force the value to 1, don't use setter + } + buffer = directBuffer + rawConn, _, writerCounter := proxy.UnwrapRawConn(conn) + writer = buf.NewWriter(rawConn) + ct = writerCounter } } if !buffer.IsEmpty() { diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 8653e1e3..388aeecb 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -10,11 +10,9 @@ import ( "reflect" "strconv" "strings" - "syscall" "time" "unsafe" - "github.com/pires/go-proxyproto" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -30,7 +28,6 @@ import ( feature_inbound "github.com/xtls/xray-core/features/inbound" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" - "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/transport/internet/reality" @@ -182,8 +179,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s sid := session.ExportIDToError(ctx) iConn := connection - statConn, ok := iConn.(*stat.CounterConnection) - if ok { + if statConn, ok := iConn.(*stat.CounterConnection); ok { iConn = statConn.Connection } @@ -447,14 +443,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s // Flow: requestAddons.Flow, } - var netConn net.Conn - var rawConn syscall.RawConn var input *bytes.Reader var rawInput *bytes.Buffer - switch requestAddons.Flow { case vless.XRV: if account.Flow == requestAddons.Flow { + inbound.SetCanSpliceCopy(2) switch request.Command { case protocol.RequestCommandUDP: return newError(requestAddons.Flow + " doesn't support UDP").AtWarning() @@ -467,23 +461,14 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if tlsConn.ConnectionState().Version != gotls.VersionTLS13 { return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, tlsConn.ConnectionState().Version).AtWarning() } - netConn = tlsConn.NetConn() t = reflect.TypeOf(tlsConn.Conn).Elem() p = uintptr(unsafe.Pointer(tlsConn.Conn)) } else if realityConn, ok := iConn.(*reality.Conn); ok { - netConn = realityConn.NetConn() t = reflect.TypeOf(realityConn.Conn).Elem() p = uintptr(unsafe.Pointer(realityConn.Conn)) } else { return newError("XTLS only supports TLS and REALITY directly for now.").AtWarning() } - if pc, ok := netConn.(*proxyproto.Conn); ok { - netConn = pc.Raw() - // 8192 > 4096, there is no need to process pc's bufReader - } - if sc, ok := netConn.(syscall.Conn); ok { - rawConn, _ = sc.SyscallConn() - } i, _ := t.FieldByName("input") r, _ := t.FieldByName("rawInput") input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) @@ -493,6 +478,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning() } case "": + inbound.SetCanSpliceCopy(3) if account.Flow == vless.XRV && (request.Command == protocol.RequestCommandTCP || isMuxAndNotXUDP(request, first)) { return newError(account.ID.String() + " is not able to use \"\". Note that the pure TLS proxy has certain TLS in TLS characters.").AtWarning() } @@ -540,13 +526,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - var counter stats.Counter - if statConn != nil { - counter = statConn.ReadCounter - } - // TODO enable splice - ctx = session.ContextWithInbound(ctx, nil) - err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, input, rawInput, counter, ctx, account.ID.Bytes(), + ctx1 := session.ContextWithInbound(ctx, nil) // TODO enable splice + err = encoding.XtlsRead(clientReader, serverWriter, timer, connection, input, rawInput, ctx1, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer @@ -592,11 +573,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s var err error if requestAddons.Flow == vless.XRV { - var counter stats.Counter - if statConn != nil { - counter = statConn.WriteCounter - } - err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &numberOfPacketToFilter, + err = encoding.XtlsWrite(serverReader, clientWriter, timer, connection, ctx, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 12962a47..bc2e6625 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -7,7 +7,6 @@ import ( "context" gotls "crypto/tls" "reflect" - "syscall" "time" "unsafe" @@ -23,7 +22,6 @@ import ( "github.com/xtls/xray-core/common/xudp" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" - "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless/encoding" "github.com/xtls/xray-core/transport" @@ -71,9 +69,15 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { + return newError("target not specified").AtError() + } + outbound.Name = "vless" + inbound := session.InboundFromContext(ctx) + var rec *protocol.ServerSpec var conn stat.Connection - if err := retry.ExponentialBackoff(5, 200).On(func() error { rec = h.serverPicker.PickServer() var err error @@ -88,16 +92,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte defer conn.Close() iConn := conn - statConn, ok := iConn.(*stat.CounterConnection) - if ok { + if statConn, ok := iConn.(*stat.CounterConnection); ok { iConn = statConn.Connection } - - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { - return newError("target not specified").AtError() - } - target := outbound.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx)) @@ -123,8 +120,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte Flow: account.Flow, } - var netConn net.Conn - var rawConn syscall.RawConn var input *bytes.Reader var rawInput *bytes.Buffer allowUDP443 := false @@ -134,6 +129,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte requestAddons.Flow = requestAddons.Flow[:16] fallthrough case vless.XRV: + if inbound != nil { + inbound.SetCanSpliceCopy(2) + } switch request.Command { case protocol.RequestCommandUDP: if !allowUDP443 && request.Port == 443 { @@ -146,28 +144,26 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte var t reflect.Type var p uintptr if tlsConn, ok := iConn.(*tls.Conn); ok { - netConn = tlsConn.NetConn() t = reflect.TypeOf(tlsConn.Conn).Elem() p = uintptr(unsafe.Pointer(tlsConn.Conn)) } else if utlsConn, ok := iConn.(*tls.UConn); ok { - netConn = utlsConn.NetConn() t = reflect.TypeOf(utlsConn.Conn).Elem() p = uintptr(unsafe.Pointer(utlsConn.Conn)) } else if realityConn, ok := iConn.(*reality.UConn); ok { - netConn = realityConn.NetConn() t = reflect.TypeOf(realityConn.Conn).Elem() p = uintptr(unsafe.Pointer(realityConn.Conn)) } else { return newError("XTLS only supports TLS and REALITY directly for now.").AtWarning() } - if sc, ok := netConn.(syscall.Conn); ok { - rawConn, _ = sc.SyscallConn() - } i, _ := t.FieldByName("input") r, _ := t.FieldByName("rawInput") input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) } + default: + if inbound != nil { + inbound.SetCanSpliceCopy(3) + } } var newCtx context.Context @@ -257,11 +253,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning() } } - var counter stats.Counter - if statConn != nil { - counter = statConn.WriteCounter - } - err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &numberOfPacketToFilter, + ctx1 := session.ContextWithOutbound(ctx, nil) // TODO enable splice + err = encoding.XtlsWrite(clientReader, serverWriter, timer, conn, ctx1, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer @@ -293,11 +286,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } if requestAddons.Flow == vless.XRV { - var counter stats.Counter - if statConn != nil { - counter = statConn.ReadCounter - } - err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, input, rawInput, counter, ctx, account.ID.Bytes(), + err = encoding.XtlsRead(serverReader, clientWriter, timer, conn, input, rawInput, ctx, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index f48a26e1..679ea5da 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -256,10 +256,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s } inbound := session.InboundFromContext(ctx) - if inbound == nil { - panic("no inbound metadata") - } inbound.Name = "vmess" + inbound.SetCanSpliceCopy(3) inbound.User = request.User sessionPolicy = h.policyManager.ForLevel(request.User.Level) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index fc77f07f..5e228d68 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -60,9 +60,18 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Process implements proxy.Outbound.Process(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { + return newError("target not specified").AtError() + } + outbound.Name = "vmess" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(3) + } + var rec *protocol.ServerSpec var conn stat.Connection - err := retry.ExponentialBackoff(5, 200).On(func() error { rec = h.serverPicker.PickServer() rawConn, err := dialer.Dial(ctx, rec.Destination()) @@ -78,11 +87,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } defer conn.Close() - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { - return newError("target not specified").AtError() - } - target := outbound.Target newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/wireguard/wireguard.go b/proxy/wireguard/wireguard.go index 53e7dcd5..899dcac5 100644 --- a/proxy/wireguard/wireguard.go +++ b/proxy/wireguard/wireguard.go @@ -75,6 +75,16 @@ func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { // Process implements OutboundHandler.Dispatch(). func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { + return newError("target not specified") + } + outbound.Name = "wireguard" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(3) + } + if h.bind == nil || h.bind.dialer != dialer || h.net == nil { log.Record(&log.GeneralMessage{ Severity: log.Severity_Info, @@ -101,10 +111,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte h.bind = bind } - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { - return newError("target not specified") - } // Destination of the inner request. destination := outbound.Target command := protocol.RequestCommandTCP diff --git a/testing/scenarios/http_test.go b/testing/scenarios/http_test.go index d6a765bb..b9b112ff 100644 --- a/testing/scenarios/http_test.go +++ b/testing/scenarios/http_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "strings" "testing" "time" @@ -128,9 +129,8 @@ func TestHttpError(t *testing.T) { } resp, err := client.Get("http://127.0.0.1:" + dest.Port.String()) - common.Must(err) - if resp.StatusCode != 503 { - t.Error("status: ", resp.StatusCode) + if resp != nil && resp.StatusCode != 503 || err != nil && !strings.Contains(err.Error(), "malformed HTTP status code") { + t.Error("should not receive http response", err) } } } diff --git a/testing/scenarios/vmess_test.go b/testing/scenarios/vmess_test.go index 9f2b0abc..2239b13c 100644 --- a/testing/scenarios/vmess_test.go +++ b/testing/scenarios/vmess_test.go @@ -1174,10 +1174,10 @@ func TestVMessGCMMuxUDP(t *testing.T) { servers, err := InitializeServerConfigs(serverConfig, clientConfig) common.Must(err) - for range "abcd" { + for range "ab" { var errg errgroup.Group for i := 0; i < 16; i++ { - errg.Go(testTCPConn(clientPort, 10240, time.Second*20)) + errg.Go(testTCPConn(clientPort, 1024, time.Second*10)) errg.Go(testUDPConn(clientUDPPort, 1024, time.Second*10)) } if err := errg.Wait(); err != nil {