diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index b064cc56..48a29c9c 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -58,6 +58,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in } tlsConfig := tls.ConfigFromStreamSettings(streamSettings) + isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1") var gotlsConfig *gotls.Config @@ -88,7 +89,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in var uploadTransport http.RoundTripper var downloadTransport http.RoundTripper - if tlsConfig != nil { + if isH2 { downloadTransport = &http2.Transport{ DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) { return dialContext(ctxInner) @@ -121,7 +122,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in upload: &http.Client{ Transport: uploadTransport, }, - isH2: tlsConfig != nil, + isH2: isH2, uploadRawPool: &sync.Pool{}, dialUploadConn: dialContext, } diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 412f686a..5181b19b 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -19,6 +19,8 @@ import ( "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" v2tls "github.com/xtls/xray-core/transport/internet/tls" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) type requestHandler struct { @@ -268,16 +270,21 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet } } + handler := &requestHandler{ + host: shSettings.Host, + path: shSettings.GetNormalizedPath(), + ln: l, + sessions: sync.Map{}, + localAddr: localAddr, + } + + // h2cHandler can handle both plaintext HTTP/1.1 and h2c + h2cHandler := h2c.NewHandler(handler, &http2.Server{}) + l.listener = listener l.server = http.Server{ - Handler: &requestHandler{ - host: shSettings.Host, - path: shSettings.GetNormalizedPath(), - ln: l, - sessions: sync.Map{}, - localAddr: localAddr, - }, + Handler: h2cHandler, ReadHeaderTimeout: time.Second * 4, MaxHeaderBytes: 8192, } diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index 3d3d387c..52ced9a0 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -2,7 +2,10 @@ package splithttp_test import ( "context" + gotls "crypto/tls" "fmt" + gonet "net" + "net/http" "runtime" "testing" "time" @@ -15,6 +18,7 @@ import ( . "github.com/xtls/xray-core/transport/internet/splithttp" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" + "golang.org/x/net/http2" ) func Test_listenSHAndDial(t *testing.T) { @@ -152,3 +156,50 @@ func Test_listenSHAndDial_TLS(t *testing.T) { t.Error("end: ", end, " start: ", start) } } + +func Test_listenSHAndDial_H2C(t *testing.T) { + if runtime.GOARCH == "arm64" { + return + } + + listenPort := tcp.PickPort() + + streamSettings := &internet.MemoryStreamConfig{ + ProtocolName: "splithttp", + ProtocolSettings: &Config{ + Path: "shs", + }, + } + listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + go func() { + _ = conn.Close() + }() + }) + common.Must(err) + defer listen.Close() + + client := http.Client{ + Transport: &http2.Transport{ + // So http2.Transport doesn't complain the URL scheme isn't 'https' + AllowHTTP: true, + // even with AllowHTTP, http2.Transport will attempt to establish + // the connection using DialTLSContext. Disable TLS with custom + // dial context. + DialTLSContext: func(ctx context.Context, network, addr string, cfg *gotls.Config) (gonet.Conn, error) { + var d gonet.Dialer + return d.DialContext(ctx, network, addr) + }, + }, + } + + resp, err := client.Get("http://" + net.LocalHostIP.String() + ":" + listenPort.String()) + common.Must(err) + + if resp.StatusCode != 404 { + t.Error("Expected 404 but got:", resp.StatusCode) + } + + if resp.ProtoMajor != 2 { + t.Error("Expected h2 but got:", resp.ProtoMajor) + } +}