mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-12-22 11:23:32 +02:00
Fixing tcp connestions leak
- always use HandshakeContext instead of Handshake - pickup dailer dropped ctx - rename HandshakeContextAddress to HandshakeAddressContext
This commit is contained in:
parent
5ea1315b85
commit
cae94570df
7 changed files with 38 additions and 18 deletions
|
@ -71,8 +71,8 @@ func (d *DokodemoDoor) policy() policy.Session {
|
|||
return p
|
||||
}
|
||||
|
||||
type hasHandshakeAddress interface {
|
||||
HandshakeAddress() net.Address
|
||||
type hasHandshakeAddressContext interface {
|
||||
HandshakeAddressContext(ctx context.Context) net.Address
|
||||
}
|
||||
|
||||
// Process implements proxy.Inbound.
|
||||
|
@ -89,8 +89,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
|
|||
if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
|
||||
dest = outbound.Target
|
||||
destinationOverridden = true
|
||||
} else if handshake, ok := conn.(hasHandshakeAddress); ok {
|
||||
addr := handshake.HandshakeAddress()
|
||||
} else if handshake, ok := conn.(hasHandshakeAddressContext); ok {
|
||||
addr := handshake.HandshakeAddressContext(ctx)
|
||||
if addr != nil {
|
||||
dest.Address = addr
|
||||
destinationOverridden = true
|
||||
|
|
|
@ -308,7 +308,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
|
|||
|
||||
nextProto := ""
|
||||
if tlsConn, ok := iConn.(*tls.Conn); ok {
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
|
|||
} else {
|
||||
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
|
||||
}
|
||||
if err := cn.Handshake(); err != nil {
|
||||
if err := cn.HandshakeContext(ctx); err != nil {
|
||||
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
|
|||
tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
|
||||
if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
|
||||
conn = tls.UClient(conn, tlsConfig, fingerprint)
|
||||
if err := conn.(*tls.UConn).Handshake(); err != nil {
|
||||
if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -65,7 +65,7 @@ func (c *grpcUtls) ClientHandshake(ctx context.Context, authority string, rawCon
|
|||
conn := UClient(rawConn, cfg, c.fingerprint).(*UConn)
|
||||
errChannel := make(chan error, 1)
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
errChannel <- conn.HandshakeContext(ctx)
|
||||
close(errChannel)
|
||||
}()
|
||||
select {
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
utls "github.com/refraction-networking/utls"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
|
@ -14,7 +16,7 @@ import (
|
|||
|
||||
type Interface interface {
|
||||
net.Conn
|
||||
Handshake() error
|
||||
HandshakeContext(ctx context.Context) error
|
||||
VerifyHostname(host string) error
|
||||
NegotiatedProtocol() (name string, mutual bool)
|
||||
}
|
||||
|
@ -25,6 +27,16 @@ type Conn struct {
|
|||
*tls.Conn
|
||||
}
|
||||
|
||||
const tlsCloseTimeout = 250 * time.Millisecond
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
timer := time.AfterFunc(tlsCloseTimeout, func() {
|
||||
c.Conn.NetConn().Close()
|
||||
})
|
||||
defer timer.Stop()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
mb = buf.Compact(mb)
|
||||
mb, err := buf.WriteMultiBuffer(c, mb)
|
||||
|
@ -32,8 +44,8 @@ func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) HandshakeAddress() net.Address {
|
||||
if err := c.Handshake(); err != nil {
|
||||
func (c *Conn) HandshakeAddressContext(ctx context.Context) net.Address {
|
||||
if err := c.HandshakeContext(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
state := c.ConnectionState()
|
||||
|
@ -64,8 +76,16 @@ type UConn struct {
|
|||
*utls.UConn
|
||||
}
|
||||
|
||||
func (c *UConn) HandshakeAddress() net.Address {
|
||||
if err := c.Handshake(); err != nil {
|
||||
func (c *UConn) Close() error {
|
||||
timer := time.AfterFunc(tlsCloseTimeout, func() {
|
||||
c.Conn.NetConn().Close()
|
||||
})
|
||||
defer timer.Stop()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *UConn) HandshakeAddressContext(ctx context.Context) net.Address {
|
||||
if err := c.HandshakeContext(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
state := c.ConnectionState()
|
||||
|
@ -77,7 +97,7 @@ func (c *UConn) HandshakeAddress() net.Address {
|
|||
|
||||
// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
|
||||
// http/1.1 in its ALPN.
|
||||
func (c *UConn) WebsocketHandshake() error {
|
||||
func (c *UConn) WebsocketHandshakeContext(ctx context.Context) error {
|
||||
// Build the handshake state. This will apply every variable of the TLS of the
|
||||
// fingerprint in the UConn
|
||||
if err := c.BuildHandshakeState(); err != nil {
|
||||
|
@ -99,7 +119,7 @@ func (c *UConn) WebsocketHandshake() error {
|
|||
if err := c.BuildHandshakeState(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Handshake()
|
||||
return c.HandshakeContext(ctx)
|
||||
}
|
||||
|
||||
func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
|
||||
|
@ -118,7 +138,7 @@ func copyConfig(c *tls.Config) *utls.Config {
|
|||
ServerName: c.ServerName,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
KeyLogWriter: c.KeyLogWriter,
|
||||
KeyLogWriter: c.KeyLogWriter,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
|
|||
}
|
||||
// TLS and apply the handshake
|
||||
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
|
||||
if err := cn.WebsocketHandshake(); err != nil {
|
||||
if err := cn.WebsocketHandshakeContext(ctx); err != nil {
|
||||
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
|
||||
return nil, err
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
|
|||
header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
|
||||
}
|
||||
|
||||
conn, resp, err := dialer.Dial(uri, header)
|
||||
conn, resp, err := dialer.DialContext(ctx, uri, header)
|
||||
if err != nil {
|
||||
var reason string
|
||||
if resp != nil {
|
||||
|
|
Loading…
Reference in a new issue