Added utls to websocket (#1256)

* Added utls to websocket

* Slightly better code

One less allocation
This commit is contained in:
Hirbod Behnam 2022-10-18 18:04:41 +03:30 committed by GitHub
parent 149e2247e8
commit 1f93cbbc5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 1 deletions

View File

@ -66,6 +66,33 @@ func (c *UConn) HandshakeAddress() net.Address {
return net.ParseAddress(state.ServerName) return net.ParseAddress(state.ServerName)
} }
// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
// http/1.1 in its ALPN.
func (c *UConn) WebsocketHandshake() error {
// Build the handshake state. This will apply every variable of the TLS of the
// fingerprint in the UConn
if err := c.BuildHandshakeState(); err != nil {
return err
}
// Iterate over extensions and check for utls.ALPNExtension
hasALPNExtension := false
for _, extension := range c.Extensions {
if alpn, ok := extension.(*utls.ALPNExtension); ok {
hasALPNExtension = true
alpn.AlpnProtocols = []string{"http/1.1"}
break
}
}
if !hasALPNExtension { // Append extension if doesn't exists
c.Extensions = append(c.Extensions, &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}})
}
// Rebuild the client hello and do the handshake
if err := c.BuildHandshakeState(); err != nil {
return err
}
return c.Handshake()
}
func (c *UConn) NegotiatedProtocol() (name string, mutual bool) { func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
state := c.ConnectionState() state := c.ConnectionState()
return state.NegotiatedProtocol, state.NegotiatedProtocolIsMutual return state.NegotiatedProtocol, state.NegotiatedProtocolIsMutual

View File

@ -6,6 +6,7 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
gonet "net"
"net/http" "net/http"
"os" "os"
"time" "time"
@ -83,7 +84,31 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
protocol = "wss" protocol = "wss"
dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
dialer.TLSClientConfig = tlsConfig
if fingerprint, exists := tls.Fingerprints[config.Fingerprint]; exists {
dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) {
// Like the NetDial in the dialer
pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
// TLS and apply the handshake
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
if err := cn.WebsocketHandshake(); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
}
return cn, nil
}
}
} }
host := dest.NetAddr() host := dest.NetAddr()