From f390047b37ee1dc68a9b0af749816db4670ebb1d Mon Sep 17 00:00:00 2001 From: RPRX <63339210+rprx@users.noreply.github.com> Date: Fri, 18 Dec 2020 12:45:47 +0000 Subject: [PATCH] Disable VMess drain when not pure connection --- proxy/vmess/encoding/encoding_test.go | 8 ++++---- proxy/vmess/encoding/server.go | 4 ++-- proxy/vmess/inbound/inbound.go | 11 ++++++++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index c136731e..b8c34cf8 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -57,14 +57,14 @@ func TestRequestSerialization(t *testing.T) { defer common.Close(userValidator) server := NewServerSession(userValidator, sessionHistory) - actualRequest, err := server.DecodeRequestHeader(buffer) + actualRequest, err := server.DecodeRequestHeader(buffer, false) common.Must(err) if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { t.Error(r) } - _, err = server.DecodeRequestHeader(buffer2) + _, err = server.DecodeRequestHeader(buffer2, false) // anti replay attack if err == nil { t.Error("nil error") @@ -107,7 +107,7 @@ func TestInvalidRequest(t *testing.T) { defer common.Close(userValidator) server := NewServerSession(userValidator, sessionHistory) - _, err := server.DecodeRequestHeader(buffer) + _, err := server.DecodeRequestHeader(buffer, false) if err == nil { t.Error("nil error") } @@ -148,7 +148,7 @@ func TestMuxRequest(t *testing.T) { defer common.Close(userValidator) server := NewServerSession(userValidator, sessionHistory) - actualRequest, err := server.DecodeRequestHeader(buffer) + actualRequest, err := server.DecodeRequestHeader(buffer, false) common.Must(err) if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 9951919f..66711fd1 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -131,7 +131,7 @@ func parseSecurityType(b byte) protocol.SecurityType { } // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream. -func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) { +func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) { buffer := buf.New() behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed())) BaseDrainSize := behaviorRand.Roll(3266) @@ -143,7 +143,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request drainConnection := func(e error) error { // We read a deterministic generated length of data before closing the connection to offset padding read pattern readSizeRemain -= int(buffer.Len()) - if readSizeRemain > 0 { + if readSizeRemain > 0 && isDrain { err := s.DrainConnN(reader, readSizeRemain) if err != nil { return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index ff962e6a..5c01ac57 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -220,9 +220,18 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i return newError("unable to set read deadline").Base(err).AtWarning() } + iConn := connection + if statConn, ok := iConn.(*internet.StatCouterConnection); ok { + iConn = statConn.Connection + } + _, isDrain := iConn.(*net.TCPConn) + if !isDrain { + _, isDrain = iConn.(*net.UnixConn) + } + reader := &buf.BufferedReader{Reader: buf.NewReader(connection)} svrSession := encoding.NewServerSession(h.clients, h.sessionHistory) - request, err := svrSession.DecodeRequestHeader(reader) + request, err := svrSession.DecodeRequestHeader(reader, isDrain) if err != nil { if errors.Cause(err) != io.EOF { log.Record(&log.AccessMessage{