fix(proxy): removed the udp payload length check when encryption is disabled

This commit is contained in:
cty123 2023-08-19 22:10:59 +02:00 committed by yuhan6665
parent f67167bb3b
commit a343d68944
3 changed files with 107 additions and 53 deletions

View File

@ -4,6 +4,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"errors"
"hash/crc32"
"io"
@ -236,19 +237,26 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
}
func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
bs := payload.Bytes()
if len(bs) <= 32 {
return nil, nil, newError("len(bs) <= 32")
rawPayload := payload.Bytes()
user, _, d, _, err := validator.Get(rawPayload, protocol.RequestCommandUDP)
if errors.Is(err, ErrIVNotUnique) {
return nil, nil, newError("failed iv check").Base(err)
}
user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP)
switch err {
case ErrIVNotUnique:
return nil, nil, newError("failed iv check").Base(err)
case ErrNotFound:
if errors.Is(err, ErrNotFound) {
return nil, nil, newError("failed to match an user").Base(err)
default:
account := user.Account.(*MemoryAccount)
}
if err != nil {
return nil, nil, newError("unexpected error").Base(err)
}
account, ok := user.Account.(*MemoryAccount)
if !ok {
return nil, nil, newError("expected MemoryAccount returned from validator")
}
if account.Cipher.IsAEAD() {
payload.Clear()
payload.Write(d)
@ -261,13 +269,6 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
}
}
}
request := &protocol.RequestHeader{
Version: Version,
User: user,
Command: protocol.RequestCommandUDP,
}
payload.SetByte(0, payload.Byte(0)&0x0F)
@ -276,8 +277,13 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
return nil, nil, newError("failed to parse address").Base(err)
}
request.Address = addr
request.Port = port
request := &protocol.RequestHeader{
Version: Version,
User: user,
Command: protocol.RequestCommandUDP,
Address: addr,
Port: port,
}
return request, payload, nil
}

View File

@ -23,8 +23,9 @@ func equalRequestHeader(x, y *protocol.RequestHeader) bool {
}))
}
func TestUDPEncoding(t *testing.T) {
request := &protocol.RequestHeader{
func TestUDPEncodingDecoding(t *testing.T) {
testRequests := []protocol.RequestHeader{
{
Version: Version,
Command: protocol.RequestCommandUDP,
Address: net.LocalHostIP,
@ -36,11 +37,26 @@ func TestUDPEncoding(t *testing.T) {
CipherType: CipherType_AES_128_GCM,
}),
},
},
{
Version: Version,
Command: protocol.RequestCommandUDP,
Address: net.LocalHostIP,
Port: 1234,
User: &protocol.MemoryUser{
Email: "love@example.com",
Account: toAccount(&Account{
Password: "123",
CipherType: CipherType_NONE,
}),
},
},
}
for _, request := range testRequests {
data := buf.New()
common.Must2(data.WriteString("test string"))
encodedData, err := EncodeUDPPacket(request, data.Bytes())
encodedData, err := EncodeUDPPacket(&request, data.Bytes())
common.Must(err)
validator := new(Validator)
@ -52,10 +68,37 @@ func TestUDPEncoding(t *testing.T) {
t.Error("data: ", r)
}
if equalRequestHeader(decodedRequest, request) == false {
if equalRequestHeader(decodedRequest, &request) == false {
t.Error("different request")
}
}
}
func TestUDPDecodingWithPayloadTooShort(t *testing.T) {
testAccounts := []protocol.Account{
toAccount(&Account{
Password: "password",
CipherType: CipherType_AES_128_GCM,
}),
toAccount(&Account{
Password: "password",
CipherType: CipherType_NONE,
}),
}
for _, account := range testAccounts {
data := buf.New()
data.WriteString("short payload")
validator := new(Validator)
validator.Add(&protocol.MemoryUser{
Account: account,
})
_, _, err := DecodeUDPPacket(validator, data)
if err == nil {
t.Fatal("expected error")
}
}
}
func TestTCPRequest(t *testing.T) {
cases := []struct {

View File

@ -80,6 +80,11 @@ func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol
for _, user := range v.users {
if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() {
// AEAD payload decoding requires the payload to be over 32 bytes
if len(bs) < 32 {
continue
}
aeadCipher := account.Cipher.(*AEADCipher)
ivLen = aeadCipher.IVSize()
iv := bs[:ivLen]