From d87758d46ff54a74efb82affc941a47ea2d3832b Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Mon, 21 Nov 2022 22:37:22 -0500 Subject: [PATCH] Parse big server hello properly --- proxy/vless/encoding/encoding.go | 67 +++++++++++++++++++------------- proxy/vless/inbound/inbound.go | 10 +++-- proxy/vless/outbound/outbound.go | 10 +++-- 3 files changed, 55 insertions(+), 32 deletions(-) diff --git a/proxy/vless/encoding/encoding.go b/proxy/vless/encoding/encoding.go index 836007bb..1817fa27 100644 --- a/proxy/vless/encoding/encoding.go +++ b/proxy/vless/encoding/encoding.go @@ -247,7 +247,9 @@ func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, c } // XtlsRead filter and read xtls protocol -func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn, counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool) error { +func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, rawConn syscall.RawConn, + counter stats.Counter, ctx context.Context, userUUID []byte, numberOfPacketToFilter *int, enableXtls *bool, + isTLS12orAbove *bool, isTLS *bool, cipher *uint16, remainingServerHello *int32) error { err := func() error { var ct stats.Counter filterUUID := true @@ -306,7 +308,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } } if *numberOfPacketToFilter > 0 { - XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, ctx) + XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx) } if ct != nil { ct.Add(int64(buffer.Len())) @@ -328,7 +330,9 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater } // XtlsWrite filter and write xtls protocol -func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter, ctx context.Context, userUUID *[]byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool) error { +func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn net.Conn, counter stats.Counter, + ctx context.Context, userUUID *[]byte, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, + cipher *uint16, remainingServerHello *int32) error { err := func() error { var ct stats.Counter filterTlsApplicationData := true @@ -337,7 +341,7 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate buffer, err := reader.ReadMultiBuffer() if !buffer.IsEmpty() { if *numberOfPacketToFilter > 0 { - XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, ctx) + XtlsFilterTls(buffer, numberOfPacketToFilter, enableXtls, isTLS12orAbove, isTLS, cipher, remainingServerHello, ctx) } if filterTlsApplicationData && *isTLS { buffer = ReshapeMultiBuffer(ctx, buffer) @@ -399,40 +403,51 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate } // XtlsFilterTls filter and recognize tls 1.3 and other info -func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, ctx context.Context) { +func XtlsFilterTls(buffer buf.MultiBuffer, numberOfPacketToFilter *int, enableXtls *bool, isTLS12orAbove *bool, isTLS *bool, + cipher *uint16, remainingServerHello *int32, ctx context.Context) { for _, b := range buffer { *numberOfPacketToFilter-- if b.Len() >= 6 { startsBytes := b.BytesTo(6) if bytes.Equal(tlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == 0x02 { - total := (int(startsBytes[3])<<8 | int(startsBytes[4])) + 5 - if b.Len() >= 74 && total >= 74 { - if bytes.Contains(b.BytesTo(int32(total)), tls13SupportedVersions) { - sessionIdLen := int32(b.Byte(43)) - cipherSuite := b.BytesRange(43 + sessionIdLen + 1, 43 + sessionIdLen + 3) - cipherNum := uint16(cipherSuite[0]) << 8 | uint16(cipherSuite[1]) - v, ok := Tls13CipherSuiteDic[cipherNum] - if !ok { - v = "Unknown cipher!" - } else if (v != "TLS_AES_128_CCM_8_SHA256") { - *enableXtls = true - } - newError("XtlsFilterTls found tls 1.3! ", buffer.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx)) - } else { - newError("XtlsFilterTls found tls 1.2! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) - } - *isTLS12orAbove = true - *isTLS = true - *numberOfPacketToFilter = 0 - return + *remainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5 + *isTLS12orAbove = true + *isTLS = true + if b.Len() >= 79 && *remainingServerHello >= 79 { + sessionIdLen := int32(b.Byte(43)) + cipherSuite := b.BytesRange(43 + sessionIdLen + 1, 43 + sessionIdLen + 3) + *cipher = uint16(cipherSuite[0]) << 8 | uint16(cipherSuite[1]) } else { - newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", total).WriteToLog(session.ExportIDToError(ctx)) + newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx)) } } else if bytes.Equal(tlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == 0x01 { *isTLS = true newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) } } + if *remainingServerHello > 0 { + end := *remainingServerHello + if end > b.Len() { + end = b.Len() + } + *remainingServerHello -= b.Len() + if bytes.Contains(b.BytesTo(end), tls13SupportedVersions) { + v, ok := Tls13CipherSuiteDic[*cipher] + if !ok { + v = "Old cipher: " + strconv.FormatUint(uint64(*cipher), 16) + } else if (v != "TLS_AES_128_CCM_8_SHA256") { + *enableXtls = true + } + newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx)) + *numberOfPacketToFilter = 0 + return + } else if *remainingServerHello <= 0 { + newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx)) + *numberOfPacketToFilter = 0 + return + } + newError("XtlsFilterTls inclusive server hello ", b.Len(), " ", *remainingServerHello).WriteToLog(session.ExportIDToError(ctx)) + } if *numberOfPacketToFilter <= 0 { newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx)) } diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index b11830ad..daa6cde9 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -511,6 +511,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s enableXtls := false isTLS12orAbove := false isTLS := false + var cipher uint16 = 0 + var remainingServerHello int32 = -1 numberOfPacketToFilter := 8 postRequest := func() error { @@ -529,7 +531,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s //TODO enable splice ctx = session.ContextWithInbound(ctx, nil) if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) + err = encoding.XtlsRead(clientReader, serverWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), + &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { err = encoding.ReadV(clientReader, serverWriter, timer, iConn.(*xtls.Conn), rawConn, counter, ctx) } @@ -561,7 +564,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s return err1 // ... } if requestAddons.Flow == vless.XRV { - encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, ctx) + encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx) if isTLS { multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer) for i, b := range multiBuffer { @@ -583,7 +586,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if statConn != nil { counter = statConn.WriteCounter } - err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) + err = encoding.XtlsWrite(serverReader, clientWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, + &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { // from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 1075f588..7cfbbfd0 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -193,6 +193,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte enableXtls := false isTLS12orAbove := false isTLS := false + var cipher uint16 = 0 + var remainingServerHello int32 = -1 numberOfPacketToFilter := 8 if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 { @@ -220,7 +222,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return err1 // ... } if requestAddons.Flow == vless.XRV { - encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, ctx) + encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx) if isTLS { for i, b := range multiBuffer { multiBuffer[i] = encoding.XtlsPadding(b, 0x00, &userUUID, ctx) @@ -241,7 +243,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if statConn != nil { counter = statConn.WriteCounter } - err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) + err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &userUUID, &numberOfPacketToFilter, + &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { // from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)) @@ -277,7 +280,8 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte counter = statConn.ReadCounter } if requestAddons.Flow == vless.XRV { - err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS) + err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, counter, ctx, account.ID.Bytes(), + &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello) } else { if requestAddons.Flow != vless.XRS { ctx = session.ContextWithInbound(ctx, nil)