diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index b31ec707..1ce8da6b 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -11,6 +11,8 @@ import ( "sync" "time" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" @@ -233,10 +235,13 @@ func (c *httpResponseBodyWriter) Close() error { type Listener struct { sync.Mutex - server http.Server - listener net.Listener - config *Config - addConn internet.ConnHandler + server http.Server + h3server *http3.Server + listener net.Listener + h3listener *quic.EarlyListener + config *Config + addConn internet.ConnHandler + isH3 bool } func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { @@ -253,6 +258,17 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet var listener net.Listener var err error var localAddr = gonet.TCPAddr{} + handler := &requestHandler{ + host: shSettings.Host, + path: shSettings.GetNormalizedPath(), + ln: l, + sessionMu: &sync.Mutex{}, + sessions: sync.Map{}, + localAddr: localAddr, + } + tlsConfig := getTLSConfig(streamSettings) + l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3" + if port == net.Port(0) { // unix listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ @@ -263,6 +279,29 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err) } errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address) + } else if l.isH3 { // quic + Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{ + IP: address.IP(), + Port: int(port), + }, streamSettings.SocketSettings) + if err != nil { + return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err) + } + h3listener, err := quic.ListenEarly(Conn,tlsConfig, nil) + if err != nil { + return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err) + } + l.h3listener = h3listener + errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port) + + l.h3server = &http3.Server{ + Handler: handler, + } + go func() { + if err := l.h3server.ServeListener(l.h3listener); err != nil { + errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp") + } + }() } else { // tcp localAddr = gonet.TCPAddr{ IP: address.IP(), @@ -275,41 +314,29 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet if err != nil { return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err) } + l.listener = listener errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port) - } + // h2cHandler can handle both plaintext HTTP/1.1 and h2c + h2cHandler := h2c.NewHandler(handler, &http2.Server{}) + l.server = http.Server{ + Handler: h2cHandler, + ReadHeaderTimeout: time.Second * 4, + MaxHeaderBytes: 8192, + } + go func() { + if err := l.server.Serve(l.listener); err != nil { + errors.LogWarningInner(ctx, err, "failed to serve http for splithttp") + } + }() + } + l.listener = listener if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { listener = tls.NewListener(listener, tlsConfig) } } - handler := &requestHandler{ - host: shSettings.Host, - path: shSettings.GetNormalizedPath(), - ln: l, - sessionMu: &sync.Mutex{}, - sessions: sync.Map{}, - localAddr: localAddr, - } - - // h2cHandler can handle both plaintext HTTP/1.1 and h2c - h2cHandler := h2c.NewHandler(handler, &http2.Server{}) - - l.listener = listener - - l.server = http.Server{ - Handler: h2cHandler, - ReadHeaderTimeout: time.Second * 4, - MaxHeaderBytes: 8192, - } - - go func() { - if err := l.server.Serve(l.listener); err != nil { - errors.LogWarningInner(ctx, err, "failed to serve http for splithttp") - } - }() - return l, err } @@ -320,9 +347,22 @@ func (ln *Listener) Addr() net.Addr { // Close implements net.Listener.Close(). func (ln *Listener) Close() error { - return ln.listener.Close() + if ln.h3server != nil { + if err := ln.h3server.Close(); err != nil { + return err + } + } else if ln.listener != nil { + return ln.listener.Close() + } + return errors.New("listener does not have an HTTP/3 server or a net.listener") +} +func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config { + config := v2tls.ConfigFromStreamSettings(streamSettings) + if config == nil { + return &tls.Config{} + } + return config.GetTLSConfig() } - func init() { common.Must(internet.RegisterTransportListener(protocolName, ListenSH)) } diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index c1c527ef..5f59a738 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -14,6 +14,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/tls/cert" "github.com/xtls/xray-core/testing/servers/tcp" + "github.com/xtls/xray-core/testing/servers/udp" "github.com/xtls/xray-core/transport/internet" . "github.com/xtls/xray-core/transport/internet/splithttp" "github.com/xtls/xray-core/transport/internet/stat" @@ -204,3 +205,42 @@ func Test_listenSHAndDial_H2C(t *testing.T) { t.Error("Expected h2 but got:", resp.ProtoMajor) } } + +func Test_listenSHAndDial_QUIC(t *testing.T) { + if runtime.GOARCH == "arm64" { + return + } + + listenPort := udp.PickPort() + + start := time.Now() + + streamSettings := &internet.MemoryStreamConfig{ + ProtocolName: "splithttp", + ProtocolSettings: &Config{ + Path: "shs", + }, + SecurityType: "tls", + SecuritySettings: &tls.Config{ + AllowInsecure: true, + Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, + NextProtocol: []string{"h3"}, + }, + } + listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + go func() { + _ = conn.Close() + }() + }) + common.Must(err) + defer listen.Close() + + conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + common.Must(err) + _ = conn.Close() + + end := time.Now() + if !end.Before(start.Add(time.Second * 5)) { + t.Error("end: ", end, " start: ", start) + } +} \ No newline at end of file