From d92002ad127f64bc1e740cb350eafd693ffadd6d Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Sun, 27 Aug 2023 05:55:58 +0000 Subject: [PATCH] Dialer: Set TimeoutOnly for `gctx` and `hctx` https://github.com/XTLS/Xray-core/issues/2232#issuecomment-1694570914 Thank @cty123 for testing Fixes https://github.com/XTLS/Xray-core/issues/2232 BTW: Use `uConn.HandshakeContext(ctx)` in REALITY --- transport/internet/grpc/dial.go | 12 +++++++----- transport/internet/http/dialer.go | 12 ++++++------ transport/internet/reality/reality.go | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 8fd544b5..16af63cd 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -98,16 +98,13 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in MinConnectTimeout: 5 * time.Second, }), grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) { - gctx = session.ContextWithID(gctx, session.IDFromContext(ctx)) - gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx)) - - rawHost, rawPort, err := net.SplitHostPort(s) select { case <-gctx.Done(): return nil, gctx.Err() default: } + rawHost, rawPort, err := net.SplitHostPort(s) if err != nil { return nil, err } @@ -119,9 +116,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in return nil, err } address := net.ParseAddress(rawHost) + + gctx = session.ContextWithID(gctx, session.IDFromContext(ctx)) + gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx)) + gctx = session.ContextWithTimeoutOnly(gctx, true) + c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt) if err == nil && realityConfig != nil { - return reality.UClient(c, realityConfig, ctx, dest) + return reality.UClient(c, realityConfig, gctx, dest) } return c, err }), diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index f2e55de8..1ea3a738 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -53,7 +53,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in } transport := &http2.Transport{ - DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { + DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { rawHost, rawPort, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -67,18 +67,18 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in } address := net.ParseAddress(rawHost) - dctx := context.Background() - dctx = session.ContextWithID(dctx, session.IDFromContext(ctx)) - dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx)) + hctx = session.ContextWithID(hctx, session.IDFromContext(ctx)) + hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx)) + hctx = session.ContextWithTimeoutOnly(hctx, true) - pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt) + pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt) if err != nil { newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() return nil, err } if realityConfigs != nil { - return reality.UClient(pconn, realityConfigs, ctx, dest) + return reality.UClient(pconn, realityConfigs, hctx, dest) } var cn tls.Interface diff --git a/transport/internet/reality/reality.go b/transport/internet/reality/reality.go index b430cccc..30d4e2ae 100644 --- a/transport/internet/reality/reality.go +++ b/transport/internet/reality/reality.go @@ -156,7 +156,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati aead.Seal(hello.SessionId[:0], hello.Random[20:], hello.SessionId[:16], hello.Raw) copy(hello.Raw[39:], hello.SessionId) } - if err := uConn.Handshake(); err != nil { + if err := uConn.HandshakeContext(ctx); err != nil { return nil, err } if config.Show {