package mtproto

import (
	"context"
	"crypto/rand"
	"crypto/sha256"
	"io"
	"sync"

	"github.com/xtls/xray-core/common"
)

const (
	HeaderSize = 64
)

type SessionContext struct {
	ConnectionType [4]byte
	DataCenterID   uint16
}

func DefaultSessionContext() SessionContext {
	return SessionContext{
		ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
		DataCenterID:   0,
	}
}

type contextKey int32

const (
	sessionContextKey contextKey = iota
)

func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
	return context.WithValue(ctx, sessionContextKey, c)
}

func SessionContextFromContext(ctx context.Context) SessionContext {
	if c := ctx.Value(sessionContextKey); c != nil {
		return c.(SessionContext)
	}
	return DefaultSessionContext()
}

type Authentication struct {
	Header        [HeaderSize]byte
	DecodingKey   [32]byte
	EncodingKey   [32]byte
	DecodingNonce [16]byte
	EncodingNonce [16]byte
}

func (a *Authentication) DataCenterID() uint16 {
	x := ((int16(a.Header[61]) << 8) | int16(a.Header[60]))
	if x < 0 {
		x = -x
	}
	return uint16(x) - 1
}

func (a *Authentication) ConnectionType() [4]byte {
	var x [4]byte
	copy(x[:], a.Header[56:60])
	return x
}

func (a *Authentication) ApplySecret(b []byte) {
	a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
	a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
}

func generateRandomBytes(random []byte, connType [4]byte) {
	for {
		common.Must2(rand.Read(random))

		if random[0] == 0xef {
			continue
		}

		val := (uint32(random[3]) << 24) | (uint32(random[2]) << 16) | (uint32(random[1]) << 8) | uint32(random[0])
		if val == 0x44414548 || val == 0x54534f50 || val == 0x20544547 || val == 0x4954504f || val == 0xeeeeeeee {
			continue
		}

		if (uint32(random[7])<<24)|(uint32(random[6])<<16)|(uint32(random[5])<<8)|uint32(random[4]) == 0x00000000 {
			continue
		}

		copy(random[56:60], connType[:])

		return
	}
}

func NewAuthentication(sc SessionContext) *Authentication {
	auth := getAuthenticationObject()
	random := auth.Header[:]
	generateRandomBytes(random, sc.ConnectionType)
	copy(auth.EncodingKey[:], random[8:])
	copy(auth.EncodingNonce[:], random[8+32:])
	keyivInverse := Inverse(random[8 : 8+32+16])
	copy(auth.DecodingKey[:], keyivInverse)
	copy(auth.DecodingNonce[:], keyivInverse[32:])
	return auth
}

func ReadAuthentication(reader io.Reader) (*Authentication, error) {
	auth := getAuthenticationObject()

	if _, err := io.ReadFull(reader, auth.Header[:]); err != nil {
		putAuthenticationObject(auth)
		return nil, err
	}

	copy(auth.DecodingKey[:], auth.Header[8:])
	copy(auth.DecodingNonce[:], auth.Header[8+32:])
	keyivInverse := Inverse(auth.Header[8 : 8+32+16])
	copy(auth.EncodingKey[:], keyivInverse)
	copy(auth.EncodingNonce[:], keyivInverse[32:])

	return auth, nil
}

// Inverse returns a new byte array. It is a sequence of bytes when the input is read from end to beginning.Inverse
// Visible for testing only.
func Inverse(b []byte) []byte {
	lenb := len(b)
	b2 := make([]byte, lenb)
	for i, v := range b {
		b2[lenb-i-1] = v
	}
	return b2
}

var (
	authPool = sync.Pool{
		New: func() interface{} {
			return new(Authentication)
		},
	}
)

func getAuthenticationObject() *Authentication {
	return authPool.Get().(*Authentication)
}

func putAuthenticationObject(auth *Authentication) {
	authPool.Put(auth)
}