From ff35118af50d8eba8b5a3233ae0bd9dba6849642 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Thu, 29 Apr 2021 06:29:42 +0800 Subject: [PATCH] VMess AEAD based packet length (cherry picked from commit 08221600082a79376bdc262f2ffec1a3129ae98d) --- common/protocol/headers.go | 2 + infra/conf/vmess.go | 8 ++- proxy/vmess/account.go | 15 ++++- proxy/vmess/encoding/auth.go | 10 +++ proxy/vmess/encoding/client.go | 46 +++++++++++++ proxy/vmess/encoding/server.go | 46 +++++++++++++ proxy/vmess/outbound/outbound.go | 4 ++ testing/scenarios/vmess_test.go | 107 +++++++++++++++++++++++++++++++ 8 files changed, 232 insertions(+), 6 deletions(-) diff --git a/common/protocol/headers.go b/common/protocol/headers.go index 921b4d32..1dcc467e 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -38,6 +38,8 @@ const ( RequestOptionChunkMasking bitmask.Byte = 0x04 RequestOptionGlobalPadding bitmask.Byte = 0x08 + + RequestOptionAuthenticatedLength bitmask.Byte = 0x10 ) type RequestHeader struct { diff --git a/infra/conf/vmess.go b/infra/conf/vmess.go index 37557c17..9759890d 100644 --- a/infra/conf/vmess.go +++ b/infra/conf/vmess.go @@ -15,9 +15,10 @@ import ( ) type VMessAccount struct { - ID string `json:"id"` - AlterIds uint16 `json:"alterId"` - Security string `json:"security"` + ID string `json:"id"` + AlterIds uint16 `json:"alterId"` + Security string `json:"security"` + Experiments string `json:"experiments"` } // Build implements Buildable @@ -43,6 +44,7 @@ func (a *VMessAccount) Build() *vmess.Account { SecuritySettings: &protocol.SecurityConfig{ Type: st, }, + TestsEnabled: a.Experiments, } } diff --git a/proxy/vmess/account.go b/proxy/vmess/account.go index a95576b9..8f5e6004 100644 --- a/proxy/vmess/account.go +++ b/proxy/vmess/account.go @@ -1,6 +1,8 @@ package vmess import ( + "strings" + "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/uuid" @@ -14,6 +16,8 @@ type MemoryAccount struct { AlterIDs []*protocol.ID // Security type of the account. Used for client connections. Security protocol.SecurityType + + AuthenticatedLengthExperiment bool } // AnyValidID returns an ID that is either the main ID or one of the alternative IDs if any. @@ -41,9 +45,14 @@ func (a *Account) AsAccount() (protocol.Account, error) { return nil, newError("failed to parse ID").Base(err).AtError() } protoID := protocol.NewID(id) + var AuthenticatedLength bool + if strings.Contains(a.TestsEnabled, "AuthenticatedLength") { + AuthenticatedLength = true + } return &MemoryAccount{ - ID: protoID, - AlterIDs: protocol.NewAlterIDs(protoID, uint16(a.AlterId)), - Security: a.SecuritySettings.GetSecurityType(), + ID: protoID, + AlterIDs: protocol.NewAlterIDs(protoID, uint16(a.AlterId)), + Security: a.SecuritySettings.GetSecurityType(), + AuthenticatedLengthExperiment: AuthenticatedLength, }, nil } diff --git a/proxy/vmess/encoding/auth.go b/proxy/vmess/encoding/auth.go index 6233643f..23536de6 100644 --- a/proxy/vmess/encoding/auth.go +++ b/proxy/vmess/encoding/auth.go @@ -7,6 +7,8 @@ import ( "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/crypto" + "golang.org/x/crypto/sha3" ) @@ -116,3 +118,11 @@ func (s *ShakeSizeParser) NextPaddingLen() uint16 { func (s *ShakeSizeParser) MaxPaddingLen() uint16 { return 64 } + +type AEADSizeParser struct { + crypto.AEADChunkSizeParser +} + +func NewAEADSizeParser(auth *crypto.AEADAuthenticator) *AEADSizeParser { + return &AEADSizeParser{crypto.AEADChunkSizeParser{Auth: auth}} +} diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index d46e8fb6..82ab5dee 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -171,6 +171,17 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write NonceGenerator: GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) case protocol.SecurityType_CHACHA20_POLY1305: aead, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.requestBodyKey[:])) @@ -181,6 +192,18 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write NonceGenerator: GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey)) + common.Must(err) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) default: panic("Unknown security type.") @@ -312,6 +335,17 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read NonceGenerator: GenerateChunkNonce(c.responseBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) case protocol.SecurityType_CHACHA20_POLY1305: aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.responseBodyKey[:])) @@ -321,6 +355,18 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read NonceGenerator: GenerateChunkNonce(c.responseBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey)) + common.Must(err) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) default: panic("Unknown security type.") diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index f4773894..428d0e69 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -362,6 +362,17 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) case protocol.SecurityType_CHACHA20_POLY1305: @@ -372,6 +383,18 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey)) + common.Must(err) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) default: @@ -480,6 +503,17 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ NonceGenerator: GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) case protocol.SecurityType_CHACHA20_POLY1305: @@ -490,6 +524,18 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ NonceGenerator: GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } + if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { + AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") + AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey)) + common.Must(err) + + lengthAuth := &crypto.AEADAuthenticator{ + AEAD: AuthenticatedLengthKeyAEAD, + NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), + AdditionalDataGenerator: crypto.GenerateEmptyBytes(), + } + sizeParser = NewAEADSizeParser(lengthAuth) + } return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) default: diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 7265e9f3..a618daea 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -119,6 +119,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte request.Option.Clear(protocol.RequestOptionChunkMasking) } + if account.AuthenticatedLengthExperiment { + request.Option.Set(protocol.RequestOptionAuthenticatedLength) + } + input := link.Reader output := link.Writer diff --git a/testing/scenarios/vmess_test.go b/testing/scenarios/vmess_test.go index 45641deb..2f174cd0 100644 --- a/testing/scenarios/vmess_test.go +++ b/testing/scenarios/vmess_test.go @@ -1330,3 +1330,110 @@ func TestVMessZero(t *testing.T) { t.Error(err) } } + +func TestVMessGCMLengthAuth(t *testing.T) { + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + common.Must(err) + defer tcpServer.Close() + + userID := protocol.NewID(uuid.New()) + serverPort := tcp.PickPort() + serverConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(serverPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&inbound.Config{ + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + }), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + clientPort := tcp.PickPort() + clientConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(clientPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + NetworkList: &net.NetworkList{ + Network: []net.Network{net.Network_TCP}, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&outbound.Config{ + Receiver: []*protocol.ServerEndpoint{ + { + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(serverPort), + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + SecuritySettings: &protocol.SecurityConfig{ + Type: protocol.SecurityType_AES128_GCM, + }, + TestsEnabled: "AuthenticatedLength", + }), + }, + }, + }, + }, + }), + }, + }, + } + + servers, err := InitializeServerConfigs(serverConfig, clientConfig) + if err != nil { + t.Fatal("Failed to initialize all servers: ", err.Error()) + } + defer CloseAllServers(servers) + + var errg errgroup.Group + for i := 0; i < 10; i++ { + errg.Go(testTCPConn(clientPort, 10240*1024, time.Second*40)) + } + + if err := errg.Wait(); err != nil { + t.Error(err) + } +}