From 1f93cbbc5d37959cf03ae3784f544bc13f6b162a Mon Sep 17 00:00:00 2001 From: Hirbod Behnam Date: Tue, 18 Oct 2022 18:04:41 +0330 Subject: [PATCH] Added utls to websocket (#1256) * Added utls to websocket * Slightly better code One less allocation --- transport/internet/tls/tls.go | 27 ++++++++++++++++++++++++++ transport/internet/websocket/dialer.go | 27 +++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index ea86c0ce..f1291e81 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -66,6 +66,33 @@ func (c *UConn) HandshakeAddress() net.Address { return net.ParseAddress(state.ServerName) } +// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send +// http/1.1 in its ALPN. +func (c *UConn) WebsocketHandshake() error { + // Build the handshake state. This will apply every variable of the TLS of the + // fingerprint in the UConn + if err := c.BuildHandshakeState(); err != nil { + return err + } + // Iterate over extensions and check for utls.ALPNExtension + hasALPNExtension := false + for _, extension := range c.Extensions { + if alpn, ok := extension.(*utls.ALPNExtension); ok { + hasALPNExtension = true + alpn.AlpnProtocols = []string{"http/1.1"} + break + } + } + if !hasALPNExtension { // Append extension if doesn't exists + c.Extensions = append(c.Extensions, &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}) + } + // Rebuild the client hello and do the handshake + if err := c.BuildHandshakeState(); err != nil { + return err + } + return c.Handshake() +} + func (c *UConn) NegotiatedProtocol() (name string, mutual bool) { state := c.ConnectionState() return state.NegotiatedProtocol, state.NegotiatedProtocolIsMutual diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 284d8dee..a8f71264 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "fmt" "io" + gonet "net" "net/http" "os" "time" @@ -83,7 +84,31 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { protocol = "wss" - dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) + tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) + dialer.TLSClientConfig = tlsConfig + if fingerprint, exists := tls.Fingerprints[config.Fingerprint]; exists { + dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) { + // Like the NetDial in the dialer + pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + if err != nil { + newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() + return nil, err + } + // TLS and apply the handshake + cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) + if err := cn.WebsocketHandshake(); err != nil { + newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() + return nil, err + } + if !tlsConfig.InsecureSkipVerify { + if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil { + newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() + return nil, err + } + } + return cn, nil + } + } } host := dest.NetAddr()