From 60b06877bfd6aae94600c29679591924b430da78 Mon Sep 17 00:00:00 2001 From: RPRX <63339210+rprx@users.noreply.github.com> Date: Sun, 14 Mar 2021 07:10:10 +0000 Subject: [PATCH] Add WebSocket 0-RTT support (#375) --- infra/conf/transport_internet.go | 13 ++++ transport/internet/websocket/config.pb.go | 15 +++- transport/internet/websocket/config.proto | 2 + transport/internet/websocket/connection.go | 3 +- transport/internet/websocket/dialer.go | 88 ++++++++++++++++++++-- transport/internet/websocket/hub.go | 11 ++- transport/internet/websocket/ws.go | 4 +- 7 files changed, 122 insertions(+), 14 deletions(-) diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 6a7e7694..c8f5c18f 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -3,6 +3,8 @@ package conf import ( "encoding/json" "math" + "net/url" + "strconv" "strings" "github.com/golang/protobuf/proto" @@ -155,9 +157,20 @@ func (c *WebSocketConfig) Build() (proto.Message, error) { Value: value, }) } + var ed uint32 + if u, err := url.Parse(path); err == nil { + if q := u.Query(); q.Get("ed") != "" { + Ed, _ := strconv.Atoi(q.Get("ed")) + ed = uint32(Ed) + q.Del("ed") + u.RawQuery = q.Encode() + path = u.String() + } + } config := &websocket.Config{ Path: path, Header: header, + Ed: ed, } if c.AcceptProxyProtocol { config.AcceptProxyProtocol = c.AcceptProxyProtocol diff --git a/transport/internet/websocket/config.pb.go b/transport/internet/websocket/config.pb.go index 1d3f8e22..9db30c3b 100644 --- a/transport/internet/websocket/config.pb.go +++ b/transport/internet/websocket/config.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.25.0 -// protoc v3.14.0 +// protoc v3.15.6 // source: transport/internet/websocket/config.proto package websocket @@ -89,6 +89,7 @@ type Config struct { Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"` Header []*Header `protobuf:"bytes,3,rep,name=header,proto3" json:"header,omitempty"` AcceptProxyProtocol bool `protobuf:"varint,4,opt,name=accept_proxy_protocol,json=acceptProxyProtocol,proto3" json:"accept_proxy_protocol,omitempty"` + Ed uint32 `protobuf:"varint,5,opt,name=ed,proto3" json:"ed,omitempty"` } func (x *Config) Reset() { @@ -144,6 +145,13 @@ func (x *Config) GetAcceptProxyProtocol() bool { return false } +func (x *Config) GetEd() uint32 { + if x != nil { + return x.Ed + } + return 0 +} + var File_transport_internet_websocket_config_proto protoreflect.FileDescriptor var file_transport_internet_websocket_config_proto_rawDesc = []byte{ @@ -155,7 +163,7 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{ 0x0a, 0x06, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, - 0x22, 0x99, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, + 0x22, 0xa9, 0x01, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x41, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, @@ -164,7 +172,8 @@ var file_transport_internet_websocket_config_proto_rawDesc = []byte{ 0x65, 0x72, 0x12, 0x32, 0x0a, 0x15, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x42, 0x85, 0x01, 0x0a, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0e, 0x0a, 0x02, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x0d, 0x52, 0x02, 0x65, 0x64, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x42, 0x85, 0x01, 0x0a, 0x25, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, diff --git a/transport/internet/websocket/config.proto b/transport/internet/websocket/config.proto index 0795c326..85365c33 100644 --- a/transport/internet/websocket/config.proto +++ b/transport/internet/websocket/config.proto @@ -20,4 +20,6 @@ message Config { repeated Header header = 3; bool accept_proxy_protocol = 4; + + uint32 ed = 5; } diff --git a/transport/internet/websocket/connection.go b/transport/internet/websocket/connection.go index a9b60e8e..7cc2ad9b 100644 --- a/transport/internet/websocket/connection.go +++ b/transport/internet/websocket/connection.go @@ -22,10 +22,11 @@ type connection struct { remoteAddr net.Addr } -func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection { +func newConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection { return &connection{ conn: conn, remoteAddr: remoteAddr, + reader: extraReader, } } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 5ee0b6cd..849beb5e 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -2,9 +2,12 @@ package websocket import ( "context" + "encoding/base64" + "io" "time" "github.com/gorilla/websocket" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/session" @@ -15,10 +18,21 @@ import ( // Dial dials a WebSocket connection to the given destination. func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) - - conn, err := dialWebsocket(ctx, dest, streamSettings) - if err != nil { - return nil, newError("failed to dial WebSocket").Base(err) + var conn net.Conn + if streamSettings.ProtocolSettings.(*Config).Ed > 0 { + ctx, cancel := context.WithCancel(ctx) + conn = &delayDialConn{ + dialed: make(chan bool, 1), + cancel: cancel, + ctx: ctx, + dest: dest, + streamSettings: streamSettings, + } + } else { + var err error + if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil { + return nil, newError("failed to dial WebSocket").Base(err) + } } return internet.Connection(conn), nil } @@ -27,7 +41,7 @@ func init() { common.Must(internet.RegisterTransportDialer(protocolName, Dial)) } -func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) { +func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) { wsSettings := streamSettings.ProtocolSettings.(*Config) dialer := &websocket.Dialer{ @@ -52,7 +66,12 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in } uri := protocol + "://" + host + wsSettings.GetNormalizedPath() - conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) + header := wsSettings.GetRequestHeader() + if ed != nil { + header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed)) + } + + conn, resp, err := dialer.Dial(uri, header) if err != nil { var reason string if resp != nil { @@ -61,5 +80,60 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in return nil, newError("failed to dial to (", uri, "): ", reason).Base(err) } - return newConnection(conn, conn.RemoteAddr()), nil + return newConnection(conn, conn.RemoteAddr(), nil), nil +} + +type delayDialConn struct { + net.Conn + closed bool + dialed chan bool + cancel context.CancelFunc + ctx context.Context + dest net.Destination + streamSettings *internet.MemoryStreamConfig +} + +func (d *delayDialConn) Write(b []byte) (int, error) { + if d.closed { + return 0, io.ErrClosedPipe + } + if d.Conn == nil { + ed := b + if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) { + ed = nil + } + var err error + if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil { + d.Close() + return 0, newError("failed to dial WebSocket").Base(err) + } + d.dialed <- true + if ed != nil { + return len(ed), nil + } + } + return d.Conn.Write(b) +} + +func (d *delayDialConn) Read(b []byte) (int, error) { + if d.closed { + return 0, io.ErrClosedPipe + } + if d.Conn == nil { + select { + case <-d.ctx.Done(): + return 0, io.ErrUnexpectedEOF + case <-d.dialed: + } + } + return d.Conn.Read(b) +} + +func (d *delayDialConn) Close() error { + d.closed = true + d.cancel() + if d.Conn == nil { + return nil + } + return d.Conn.Close() } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index f01e91aa..0fd84a52 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -1,8 +1,11 @@ package websocket import ( + "bytes" "context" "crypto/tls" + "encoding/base64" + "io" "net/http" "sync" "time" @@ -51,7 +54,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } } - h.ln.addConn(newConnection(conn, remoteAddr)) + var extraReader io.Reader + if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { + if ed, err := base64.StdEncoding.DecodeString(str); err == nil && len(ed) > 0 { + extraReader = bytes.NewReader(ed) + } + } + h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) } type Listener struct { diff --git a/transport/internet/websocket/ws.go b/transport/internet/websocket/ws.go index 661bb4c0..55387100 100644 --- a/transport/internet/websocket/ws.go +++ b/transport/internet/websocket/ws.go @@ -1,6 +1,6 @@ -/*Package websocket implements Websocket transport +/*Package websocket implements WebSocket transport -Websocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability. +WebSocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability. */ package websocket