mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-12-22 19:33:32 +02:00
Disable VMess drain when not pure connection
This commit is contained in:
parent
ff9bb2d8df
commit
f390047b37
3 changed files with 16 additions and 7 deletions
|
@ -57,14 +57,14 @@ func TestRequestSerialization(t *testing.T) {
|
|||
defer common.Close(userValidator)
|
||||
|
||||
server := NewServerSession(userValidator, sessionHistory)
|
||||
actualRequest, err := server.DecodeRequestHeader(buffer)
|
||||
actualRequest, err := server.DecodeRequestHeader(buffer, false)
|
||||
common.Must(err)
|
||||
|
||||
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
|
||||
t.Error(r)
|
||||
}
|
||||
|
||||
_, err = server.DecodeRequestHeader(buffer2)
|
||||
_, err = server.DecodeRequestHeader(buffer2, false)
|
||||
// anti replay attack
|
||||
if err == nil {
|
||||
t.Error("nil error")
|
||||
|
@ -107,7 +107,7 @@ func TestInvalidRequest(t *testing.T) {
|
|||
defer common.Close(userValidator)
|
||||
|
||||
server := NewServerSession(userValidator, sessionHistory)
|
||||
_, err := server.DecodeRequestHeader(buffer)
|
||||
_, err := server.DecodeRequestHeader(buffer, false)
|
||||
if err == nil {
|
||||
t.Error("nil error")
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ func TestMuxRequest(t *testing.T) {
|
|||
defer common.Close(userValidator)
|
||||
|
||||
server := NewServerSession(userValidator, sessionHistory)
|
||||
actualRequest, err := server.DecodeRequestHeader(buffer)
|
||||
actualRequest, err := server.DecodeRequestHeader(buffer, false)
|
||||
common.Must(err)
|
||||
|
||||
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
|
||||
|
|
|
@ -131,7 +131,7 @@ func parseSecurityType(b byte) protocol.SecurityType {
|
|||
}
|
||||
|
||||
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
|
||||
func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
|
||||
func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) {
|
||||
buffer := buf.New()
|
||||
behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed()))
|
||||
BaseDrainSize := behaviorRand.Roll(3266)
|
||||
|
@ -143,7 +143,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
|||
drainConnection := func(e error) error {
|
||||
// We read a deterministic generated length of data before closing the connection to offset padding read pattern
|
||||
readSizeRemain -= int(buffer.Len())
|
||||
if readSizeRemain > 0 {
|
||||
if readSizeRemain > 0 && isDrain {
|
||||
err := s.DrainConnN(reader, readSizeRemain)
|
||||
if err != nil {
|
||||
return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e)
|
||||
|
|
|
@ -220,9 +220,18 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
|
|||
return newError("unable to set read deadline").Base(err).AtWarning()
|
||||
}
|
||||
|
||||
iConn := connection
|
||||
if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
|
||||
iConn = statConn.Connection
|
||||
}
|
||||
_, isDrain := iConn.(*net.TCPConn)
|
||||
if !isDrain {
|
||||
_, isDrain = iConn.(*net.UnixConn)
|
||||
}
|
||||
|
||||
reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
|
||||
svrSession := encoding.NewServerSession(h.clients, h.sessionHistory)
|
||||
request, err := svrSession.DecodeRequestHeader(reader)
|
||||
request, err := svrSession.DecodeRequestHeader(reader, isDrain)
|
||||
if err != nil {
|
||||
if errors.Cause(err) != io.EOF {
|
||||
log.Record(&log.AccessMessage{
|
||||
|
|
Loading…
Reference in a new issue