From a343d689444e3ace08252b17ae1cf7f939a283ef Mon Sep 17 00:00:00 2001 From: cty123 Date: Sat, 19 Aug 2023 22:10:59 +0200 Subject: [PATCH] fix(proxy): removed the udp payload length check when encryption is disabled --- proxy/shadowsocks/protocol.go | 64 +++++++++++---------- proxy/shadowsocks/protocol_test.go | 91 ++++++++++++++++++++++-------- proxy/shadowsocks/validator.go | 5 ++ 3 files changed, 107 insertions(+), 53 deletions(-) diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 3176d118..3a0c7e22 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -4,6 +4,7 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" + "errors" "hash/crc32" "io" @@ -236,37 +237,37 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff } func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { - bs := payload.Bytes() - if len(bs) <= 32 { - return nil, nil, newError("len(bs) <= 32") - } + rawPayload := payload.Bytes() + user, _, d, _, err := validator.Get(rawPayload, protocol.RequestCommandUDP) - user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP) - switch err { - case ErrIVNotUnique: + if errors.Is(err, ErrIVNotUnique) { return nil, nil, newError("failed iv check").Base(err) - case ErrNotFound: - return nil, nil, newError("failed to match an user").Base(err) - default: - account := user.Account.(*MemoryAccount) - if account.Cipher.IsAEAD() { - payload.Clear() - payload.Write(d) - } else { - if account.Cipher.IVSize() > 0 { - iv := make([]byte, account.Cipher.IVSize()) - copy(iv, payload.BytesTo(account.Cipher.IVSize())) - } - if err = account.Cipher.DecodePacket(account.Key, payload); err != nil { - return nil, nil, newError("failed to decrypt UDP payload").Base(err) - } - } } - request := &protocol.RequestHeader{ - Version: Version, - User: user, - Command: protocol.RequestCommandUDP, + if errors.Is(err, ErrNotFound) { + return nil, nil, newError("failed to match an user").Base(err) + } + + if err != nil { + return nil, nil, newError("unexpected error").Base(err) + } + + account, ok := user.Account.(*MemoryAccount) + if !ok { + return nil, nil, newError("expected MemoryAccount returned from validator") + } + + if account.Cipher.IsAEAD() { + payload.Clear() + payload.Write(d) + } else { + if account.Cipher.IVSize() > 0 { + iv := make([]byte, account.Cipher.IVSize()) + copy(iv, payload.BytesTo(account.Cipher.IVSize())) + } + if err = account.Cipher.DecodePacket(account.Key, payload); err != nil { + return nil, nil, newError("failed to decrypt UDP payload").Base(err) + } } payload.SetByte(0, payload.Byte(0)&0x0F) @@ -276,8 +277,13 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque return nil, nil, newError("failed to parse address").Base(err) } - request.Address = addr - request.Port = port + request := &protocol.RequestHeader{ + Version: Version, + User: user, + Command: protocol.RequestCommandUDP, + Address: addr, + Port: port, + } return request, payload, nil } diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index e1b6495e..4083905d 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -23,37 +23,80 @@ func equalRequestHeader(x, y *protocol.RequestHeader) bool { })) } -func TestUDPEncoding(t *testing.T) { - request := &protocol.RequestHeader{ - Version: Version, - Command: protocol.RequestCommandUDP, - Address: net.LocalHostIP, - Port: 1234, - User: &protocol.MemoryUser{ - Email: "love@example.com", - Account: toAccount(&Account{ - Password: "password", - CipherType: CipherType_AES_128_GCM, - }), +func TestUDPEncodingDecoding(t *testing.T) { + testRequests := []protocol.RequestHeader{ + { + Version: Version, + Command: protocol.RequestCommandUDP, + Address: net.LocalHostIP, + Port: 1234, + User: &protocol.MemoryUser{ + Email: "love@example.com", + Account: toAccount(&Account{ + Password: "password", + CipherType: CipherType_AES_128_GCM, + }), + }, + }, + { + Version: Version, + Command: protocol.RequestCommandUDP, + Address: net.LocalHostIP, + Port: 1234, + User: &protocol.MemoryUser{ + Email: "love@example.com", + Account: toAccount(&Account{ + Password: "123", + CipherType: CipherType_NONE, + }), + }, }, } - data := buf.New() - common.Must2(data.WriteString("test string")) - encodedData, err := EncodeUDPPacket(request, data.Bytes()) - common.Must(err) + for _, request := range testRequests { + data := buf.New() + common.Must2(data.WriteString("test string")) + encodedData, err := EncodeUDPPacket(&request, data.Bytes()) + common.Must(err) - validator := new(Validator) - validator.Add(request.User) - decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData) - common.Must(err) + validator := new(Validator) + validator.Add(request.User) + decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData) + common.Must(err) - if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" { - t.Error("data: ", r) + if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" { + t.Error("data: ", r) + } + + if equalRequestHeader(decodedRequest, &request) == false { + t.Error("different request") + } + } +} + +func TestUDPDecodingWithPayloadTooShort(t *testing.T) { + testAccounts := []protocol.Account{ + toAccount(&Account{ + Password: "password", + CipherType: CipherType_AES_128_GCM, + }), + toAccount(&Account{ + Password: "password", + CipherType: CipherType_NONE, + }), } - if equalRequestHeader(decodedRequest, request) == false { - t.Error("different request") + for _, account := range testAccounts { + data := buf.New() + data.WriteString("short payload") + validator := new(Validator) + validator.Add(&protocol.MemoryUser{ + Account: account, + }) + _, _, err := DecodeUDPPacket(validator, data) + if err == nil { + t.Fatal("expected error") + } } } diff --git a/proxy/shadowsocks/validator.go b/proxy/shadowsocks/validator.go index 2aa62e06..8888a1c0 100644 --- a/proxy/shadowsocks/validator.go +++ b/proxy/shadowsocks/validator.go @@ -80,6 +80,11 @@ func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol for _, user := range v.users { if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() { + // AEAD payload decoding requires the payload to be over 32 bytes + if len(bs) < 32 { + continue + } + aeadCipher := account.Cipher.(*AEADCipher) ivLen = aeadCipher.IVSize() iv := bs[:ivLen]