package crypto_test import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/rand" "io" "testing" "github.com/google/go-cmp/cmp" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" . "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/protocol" ) func TestAuthenticationReaderWriter(t *testing.T) { key := make([]byte, 16) rand.Read(key) block, err := aes.NewCipher(key) common.Must(err) aead, err := cipher.NewGCM(block) common.Must(err) const payloadSize = 1024 * 80 rawPayload := make([]byte, payloadSize) rand.Read(rawPayload) payload := buf.MergeBytes(nil, rawPayload) cache := bytes.NewBuffer(nil) iv := make([]byte, 12) rand.Read(iv) writer := NewAuthenticationWriter(&AEADAuthenticator{ AEAD: aead, NonceGenerator: GenerateStaticBytes(iv), AdditionalDataGenerator: GenerateEmptyBytes(), }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil) common.Must(writer.WriteMultiBuffer(payload)) if cache.Len() <= 1024*80 { t.Error("cache len: ", cache.Len()) } common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{})) reader := NewAuthenticationReader(&AEADAuthenticator{ AEAD: aead, NonceGenerator: GenerateStaticBytes(iv), AdditionalDataGenerator: GenerateEmptyBytes(), }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil) var mb buf.MultiBuffer for mb.Len() < payloadSize { mb2, err := reader.ReadMultiBuffer() common.Must(err) mb, _ = buf.MergeMulti(mb, mb2) } if mb.Len() != payloadSize { t.Error("mb len: ", mb.Len()) } mbContent := make([]byte, payloadSize) buf.SplitBytes(mb, mbContent) if r := cmp.Diff(mbContent, rawPayload); r != "" { t.Error(r) } _, err = reader.ReadMultiBuffer() if err != io.EOF { t.Error("error: ", err) } } func TestAuthenticationReaderWriterPacket(t *testing.T) { key := make([]byte, 16) common.Must2(rand.Read(key)) block, err := aes.NewCipher(key) common.Must(err) aead, err := cipher.NewGCM(block) common.Must(err) cache := buf.New() iv := make([]byte, 12) rand.Read(iv) writer := NewAuthenticationWriter(&AEADAuthenticator{ AEAD: aead, NonceGenerator: GenerateStaticBytes(iv), AdditionalDataGenerator: GenerateEmptyBytes(), }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil) var payload buf.MultiBuffer pb1 := buf.New() pb1.Write([]byte("abcd")) payload = append(payload, pb1) pb2 := buf.New() pb2.Write([]byte("efgh")) payload = append(payload, pb2) common.Must(writer.WriteMultiBuffer(payload)) if cache.Len() == 0 { t.Error("cache len: ", cache.Len()) } common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{})) reader := NewAuthenticationReader(&AEADAuthenticator{ AEAD: aead, NonceGenerator: GenerateStaticBytes(iv), AdditionalDataGenerator: GenerateEmptyBytes(), }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket, nil) mb, err := reader.ReadMultiBuffer() common.Must(err) mb, b1 := buf.SplitFirst(mb) if b1.String() != "abcd" { t.Error("b1: ", b1.String()) } mb, b2 := buf.SplitFirst(mb) if b2.String() != "efgh" { t.Error("b2: ", b2.String()) } if !mb.IsEmpty() { t.Error("not empty") } _, err = reader.ReadMultiBuffer() if err != io.EOF { t.Error("error: ", err) } }