From df39991bb328638503b52bac7f68c993f71e3b24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A7=8B=E3=81=AE=E3=81=8B=E3=81=88=E3=81=A7?= Date: Fri, 12 Feb 2021 23:17:31 +0800 Subject: [PATCH] Refactor: Add Shadowsocks Validator (#233) --- proxy/shadowsocks/config.go | 27 ++++++ proxy/shadowsocks/protocol.go | 129 ++++++++++++----------------- proxy/shadowsocks/protocol_test.go | 8 +- proxy/shadowsocks/server.go | 46 ++++++---- proxy/shadowsocks/validator.go | 113 +++++++++++++++++++++++++ 5 files changed, 230 insertions(+), 93 deletions(-) create mode 100644 proxy/shadowsocks/validator.go diff --git a/proxy/shadowsocks/config.go b/proxy/shadowsocks/config.go index dfc3ce3d..607894b1 100644 --- a/proxy/shadowsocks/config.go +++ b/proxy/shadowsocks/config.go @@ -7,6 +7,8 @@ import ( "crypto/md5" "crypto/sha1" "io" + "reflect" + "strconv" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/hkdf" @@ -31,6 +33,31 @@ func (a *MemoryAccount) Equals(another protocol.Account) bool { return false } +func (a *MemoryAccount) GetCipherName() string { + switch a.Cipher.(type) { + case *AesCfb: + keyBytes := a.Cipher.(*AesCfb).KeyBytes + return "AES_" + strconv.FormatInt(int64(keyBytes*8), 10) + "_CFB" + case *ChaCha20: + if a.Cipher.(*ChaCha20).IVBytes == 8 { + return "CHACHA20" + } + return "CHACHA20_IETF" + case *AEADCipher: + switch reflect.ValueOf(a.Cipher.(*AEADCipher).AEADAuthCreator).Pointer() { + case reflect.ValueOf(createAesGcm).Pointer(): + keyBytes := a.Cipher.(*AEADCipher).KeyBytes + return "AES_" + strconv.FormatInt(int64(keyBytes*8), 10) + "_GCM" + case reflect.ValueOf(createChacha20Poly1305).Pointer(): + return "CHACHA20_POLY1305" + } + case *NoneCipher: + return "NONE" + } + + return "" +} + func createAesGcm(key []byte) cipher.AEAD { block, err := aes.NewCipher(key) common.Must(err) diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index d85fabfc..69db3ba2 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -54,12 +54,9 @@ func (r *FullReader) Read(p []byte) (n int, err error) { } // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts. -func ReadTCPSession(users []*protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) { - user := users[0] - account := user.Account.(*MemoryAccount) +func ReadTCPSession(validator *Validator, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) { hashkdf := hmac.New(sha256.New, []byte("SSBSKDF")) - hashkdf.Write(account.Key) behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil)) @@ -71,10 +68,20 @@ func ReadTCPSession(users []*protocol.MemoryUser, reader io.Reader) (*protocol.R readSizeRemain := DrainSize var r2 buf.Reader + buffer := buf.New() + defer buffer.Release() - if len(users) > 1 { - buffer := buf.New() - defer buffer.Release() + var user *protocol.MemoryUser + var ivLen int32 + var err error + + count := validator.Count() + if count == 0 { + readSizeRemain -= int(buffer.Len()) + DrainConnN(reader, readSizeRemain) + return nil, nil, newError("invalid user") + } else if count > 1 { + var aead cipher.AEAD if _, err := buffer.ReadFullFrom(reader, 50); err != nil { readSizeRemain -= int(buffer.Len()) @@ -83,45 +90,26 @@ func ReadTCPSession(users []*protocol.MemoryUser, reader io.Reader) (*protocol.R } bs := buffer.Bytes() + user, aead, _, ivLen, err = validator.Get(bs, protocol.RequestCommandTCP) - var aeadCipher *AEADCipher - var ivLen int32 - subkey := make([]byte, 32) - length := make([]byte, 16) - var aead cipher.AEAD - var err error - for _, user = range users { - account = user.Account.(*MemoryAccount) - aeadCipher = account.Cipher.(*AEADCipher) - ivLen = aeadCipher.IVSize() - subkey = subkey[:aeadCipher.KeyBytes] - hkdfSHA1(account.Key, bs[:ivLen], subkey) - aead = aeadCipher.AEADAuthCreator(subkey) - _, err = aead.Open(length[:0], length[4:16], bs[ivLen:ivLen+18], nil) - if err == nil { - reader = &FullReader{reader, bs[ivLen:]} - auth := &crypto.AEADAuthenticator{ - AEAD: aead, - NonceGenerator: crypto.GenerateInitialAEADNonce(), - } - r2 = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{ - Auth: auth, - }, reader, protocol.TransferTypeStream, nil) - break + if user != nil { + reader = &FullReader{reader, bs[ivLen:]} + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: crypto.GenerateInitialAEADNonce(), } - } - if err != nil { + r2 = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{ + Auth: auth, + }, reader, protocol.TransferTypeStream, nil) + } else { readSizeRemain -= int(buffer.Len()) DrainConnN(reader, readSizeRemain) return nil, nil, newError("failed to match an user").Base(err) } - } - - buffer := buf.New() - defer buffer.Release() - - if r2 == nil { - ivLen := account.Cipher.IVSize() + } else { + user, ivLen = validator.GetOnlyUser() + account := user.Account.(*MemoryAccount) + hashkdf.Write(account.Key) var iv []byte if ivLen > 0 { if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil { @@ -261,40 +249,31 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff return buffer, nil } -func DecodeUDPPacket(users []*protocol.MemoryUser, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { +func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { + bs := payload.Bytes() + if len(bs) <= 32 { + return nil, nil, newError("len(bs) <= 32") + } + var user *protocol.MemoryUser - var account *MemoryAccount var err error - if len(users) > 1 { - bs := payload.Bytes() - if len(bs) <= 32 { - return nil, nil, newError("len(bs) <= 32") - } - - var aeadCipher *AEADCipher - var ivLen int32 - subkey := make([]byte, 32) - data := make([]byte, 8192) - var aead cipher.AEAD + count := validator.Count() + if count == 0 { + return nil, nil, newError("invalid user") + } else if count > 1 { var d []byte - for _, user = range users { - account = user.Account.(*MemoryAccount) - aeadCipher = account.Cipher.(*AEADCipher) - ivLen = aeadCipher.IVSize() - subkey = subkey[:aeadCipher.KeyBytes] - hkdfSHA1(account.Key, bs[:ivLen], subkey) - aead = aeadCipher.AEADAuthCreator(subkey) - d, err = aead.Open(data[:0], data[8180:8192], bs[ivLen:], nil) - if err == nil { - payload.Clear() - payload.Write(d) - break - } + user, _, d, _, err = validator.Get(bs, protocol.RequestCommandUDP) + + if user != nil { + payload.Clear() + payload.Write(d) + } else { + return nil, nil, newError("failed to decrypt UDP payload").Base(err) } } else { - user = users[0] - account = user.Account.(*MemoryAccount) + user, _ = validator.GetOnlyUser() + account := user.Account.(*MemoryAccount) var iv []byte if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 { @@ -302,12 +281,9 @@ func DecodeUDPPacket(users []*protocol.MemoryUser, payload *buf.Buffer) (*protoc iv = make([]byte, account.Cipher.IVSize()) copy(iv, payload.BytesTo(account.Cipher.IVSize())) } - - err = account.Cipher.DecodePacket(account.Key, payload) - } - - if err != nil { - return nil, nil, newError("failed to decrypt UDP payload").Base(err) + if err = account.Cipher.DecodePacket(account.Key, payload); err != nil { + return nil, nil, newError("failed to decrypt UDP payload").Base(err) + } } request := &protocol.RequestHeader{ @@ -341,7 +317,10 @@ func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { buffer.Release() return nil, err } - u, payload, err := DecodeUDPPacket([]*protocol.MemoryUser{v.User}, buffer) + validator := new(Validator) + validator.Add(v.User) + + u, payload, err := DecodeUDPPacket(validator, buffer) if err != nil { buffer.Release() return nil, err diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index 3183a39f..5663a5d9 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -38,7 +38,9 @@ func TestUDPEncoding(t *testing.T) { encodedData, err := EncodeUDPPacket(request, data.Bytes()) common.Must(err) - decodedRequest, decodedData, err := DecodeUDPPacket([]*protocol.MemoryUser{request.User}, encodedData) + validator := new(Validator) + validator.Add(request.User) + decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData) common.Must(err) if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" { @@ -117,7 +119,9 @@ func TestTCPRequest(t *testing.T) { common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{data})) - decodedRequest, reader, err := ReadTCPSession([]*protocol.MemoryUser{request.User}, cache) + validator := new(Validator) + validator.Add(request.User) + decodedRequest, reader, err := ReadTCPSession(validator, cache) common.Must(err) if r := cmp.Diff(decodedRequest, request, cmp.Comparer(func(a1, a2 protocol.Account) bool { return a1.Equals(a2) })); r != "" { t.Error("request: ", r) diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 62e2a5e4..c34798f6 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -22,35 +22,46 @@ import ( type Server struct { config *ServerConfig - users []*protocol.MemoryUser + validator *Validator policyManager policy.Manager cone bool } // NewServer create a new Shadowsocks server. func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { - if config.Users == nil { - return nil, newError("empty users") + validator := new(Validator) + for _, user := range config.Users { + u, err := user.ToMemoryUser() + if err != nil { + return nil, newError("failed to get shadowsocks user").Base(err).AtError() + } + + if err := validator.Add(u); err != nil { + return nil, newError("failed to add user").Base(err).AtError() + } } v := core.MustFromContext(ctx) s := &Server{ config: config, + validator: validator, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), cone: ctx.Value("cone").(bool), } - for _, user := range config.Users { - u, err := user.ToMemoryUser() - if err != nil { - return nil, newError("failed to parse user account").Base(err) - } - s.users = append(s.users, u) - } - return s, nil } +// AddUser implements proxy.UserManager.AddUser(). +func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { + return s.validator.Add(u) +} + +// RemoveUser implements proxy.UserManager.RemoveUser(). +func (s *Server) RemoveUser(ctx context.Context, e string) error { + return s.validator.Del(e) +} + func (s *Server) Network() []net.Network { list := s.config.Network if len(list) == 0 { @@ -102,8 +113,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, if inbound == nil { panic("no inbound metadata") } - if len(s.users) == 1 { - inbound.User = s.users[0] + + if s.validator.Count() == 1 { + inbound.User, _ = s.validator.GetOnlyUser() } var dest *net.Destination @@ -121,9 +133,11 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, var err error if inbound.User != nil { - request, data, err = DecodeUDPPacket([]*protocol.MemoryUser{inbound.User}, payload) + validator := new(Validator) + validator.Add(inbound.User) + request, data, err = DecodeUDPPacket(validator, payload) } else { - request, data, err = DecodeUDPPacket(s.users, payload) + request, data, err = DecodeUDPPacket(s.validator, payload) if err == nil { inbound.User = request.User } @@ -178,7 +192,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, } bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)} - request, bodyReader, err := ReadTCPSession(s.users, &bufferedReader) + request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader) if err != nil { log.Record(&log.AccessMessage{ From: conn.RemoteAddr(), diff --git a/proxy/shadowsocks/validator.go b/proxy/shadowsocks/validator.go new file mode 100644 index 00000000..7d796bc2 --- /dev/null +++ b/proxy/shadowsocks/validator.go @@ -0,0 +1,113 @@ +package shadowsocks + +import ( + "crypto/cipher" + "strings" + "sync" + + "github.com/xtls/xray-core/common/protocol" +) + +// Validator stores valid Shadowsocks users. +type Validator struct { + // Considering email's usage here, map + sync.Mutex/RWMutex may have better performance. + email sync.Map + users sync.Map +} + +// Add a Shadowsocks user, Email must be empty or unique. +func (v *Validator) Add(u *protocol.MemoryUser) error { + account := u.Account.(*MemoryAccount) + + if !account.Cipher.IsAEAD() && v.Count() > 0 { + return newError("The cipher do not support Single-port Multi-user") + } + + if u.Email != "" { + _, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u) + if loaded { + return newError("User ", u.Email, " already exists.") + } + } + + v.users.Store(string(account.Key)+"&"+account.GetCipherName(), u) + return nil +} + +// Del a Shadowsocks user with a non-empty Email. +func (v *Validator) Del(e string) error { + if e == "" { + return newError("Email must not be empty.") + } + le := strings.ToLower(e) + u, _ := v.email.Load(le) + if u == nil { + return newError("User ", e, " not found.") + } + account := u.(*protocol.MemoryUser).Account.(*MemoryAccount) + v.email.Delete(le) + v.users.Delete(string(account.Key) + "&" + account.GetCipherName()) + return nil +} + +// Count the number of Shadowsocks users +func (v *Validator) Count() int { + length := 0 + v.users.Range(func(_, _ interface{}) bool { + length++ + + return true + }) + return length +} + +// Get a Shadowsocks user and the user's cipher. +func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol.MemoryUser, aead cipher.AEAD, ret []byte, ivLen int32, err error) { + var dataSize int + + switch command { + case protocol.RequestCommandTCP: + dataSize = 16 + case protocol.RequestCommandUDP: + dataSize = 8192 + } + + var aeadCipher *AEADCipher + subkey := make([]byte, 32) + data := make([]byte, dataSize) + + v.users.Range(func(key, user interface{}) bool { + account := user.(*protocol.MemoryUser).Account.(*MemoryAccount) + aeadCipher = account.Cipher.(*AEADCipher) + ivLen = aeadCipher.IVSize() + subkey = subkey[:aeadCipher.KeyBytes] + hkdfSHA1(account.Key, bs[:ivLen], subkey) + aead = aeadCipher.AEADAuthCreator(subkey) + + switch command { + case protocol.RequestCommandTCP: + ret, err = aead.Open(data[:0], data[4:16], bs[ivLen:ivLen+18], nil) + case protocol.RequestCommandUDP: + ret, err = aead.Open(data[:0], data[8180:8192], bs[ivLen:], nil) + } + + if err == nil { + u = user.(*protocol.MemoryUser) + return false + } + return true + }) + + return +} + +// Get the only user without authentication +func (v *Validator) GetOnlyUser() (u *protocol.MemoryUser, ivLen int32) { + v.users.Range(func(_, user interface{}) bool { + u = user.(*protocol.MemoryUser) + return false + }) + ivLen = u.Account.(*MemoryAccount).Cipher.IVSize() + + return +}