mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-22 20:59:19 +02:00
Disable VMess drain when not pure connection
This commit is contained in:
parent
ff9bb2d8df
commit
f390047b37
|
@ -57,14 +57,14 @@ func TestRequestSerialization(t *testing.T) {
|
||||||
defer common.Close(userValidator)
|
defer common.Close(userValidator)
|
||||||
|
|
||||||
server := NewServerSession(userValidator, sessionHistory)
|
server := NewServerSession(userValidator, sessionHistory)
|
||||||
actualRequest, err := server.DecodeRequestHeader(buffer)
|
actualRequest, err := server.DecodeRequestHeader(buffer, false)
|
||||||
common.Must(err)
|
common.Must(err)
|
||||||
|
|
||||||
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
|
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
|
||||||
t.Error(r)
|
t.Error(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = server.DecodeRequestHeader(buffer2)
|
_, err = server.DecodeRequestHeader(buffer2, false)
|
||||||
// anti replay attack
|
// anti replay attack
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("nil error")
|
t.Error("nil error")
|
||||||
|
@ -107,7 +107,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||||
defer common.Close(userValidator)
|
defer common.Close(userValidator)
|
||||||
|
|
||||||
server := NewServerSession(userValidator, sessionHistory)
|
server := NewServerSession(userValidator, sessionHistory)
|
||||||
_, err := server.DecodeRequestHeader(buffer)
|
_, err := server.DecodeRequestHeader(buffer, false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("nil error")
|
t.Error("nil error")
|
||||||
}
|
}
|
||||||
|
@ -148,7 +148,7 @@ func TestMuxRequest(t *testing.T) {
|
||||||
defer common.Close(userValidator)
|
defer common.Close(userValidator)
|
||||||
|
|
||||||
server := NewServerSession(userValidator, sessionHistory)
|
server := NewServerSession(userValidator, sessionHistory)
|
||||||
actualRequest, err := server.DecodeRequestHeader(buffer)
|
actualRequest, err := server.DecodeRequestHeader(buffer, false)
|
||||||
common.Must(err)
|
common.Must(err)
|
||||||
|
|
||||||
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
|
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
|
||||||
|
|
|
@ -131,7 +131,7 @@ func parseSecurityType(b byte) protocol.SecurityType {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
|
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
|
||||||
func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
|
func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) {
|
||||||
buffer := buf.New()
|
buffer := buf.New()
|
||||||
behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed()))
|
behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed()))
|
||||||
BaseDrainSize := behaviorRand.Roll(3266)
|
BaseDrainSize := behaviorRand.Roll(3266)
|
||||||
|
@ -143,7 +143,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
||||||
drainConnection := func(e error) error {
|
drainConnection := func(e error) error {
|
||||||
// We read a deterministic generated length of data before closing the connection to offset padding read pattern
|
// We read a deterministic generated length of data before closing the connection to offset padding read pattern
|
||||||
readSizeRemain -= int(buffer.Len())
|
readSizeRemain -= int(buffer.Len())
|
||||||
if readSizeRemain > 0 {
|
if readSizeRemain > 0 && isDrain {
|
||||||
err := s.DrainConnN(reader, readSizeRemain)
|
err := s.DrainConnN(reader, readSizeRemain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e)
|
return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e)
|
||||||
|
|
|
@ -220,9 +220,18 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
|
||||||
return newError("unable to set read deadline").Base(err).AtWarning()
|
return newError("unable to set read deadline").Base(err).AtWarning()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
iConn := connection
|
||||||
|
if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
|
||||||
|
iConn = statConn.Connection
|
||||||
|
}
|
||||||
|
_, isDrain := iConn.(*net.TCPConn)
|
||||||
|
if !isDrain {
|
||||||
|
_, isDrain = iConn.(*net.UnixConn)
|
||||||
|
}
|
||||||
|
|
||||||
reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
|
reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
|
||||||
svrSession := encoding.NewServerSession(h.clients, h.sessionHistory)
|
svrSession := encoding.NewServerSession(h.clients, h.sessionHistory)
|
||||||
request, err := svrSession.DecodeRequestHeader(reader)
|
request, err := svrSession.DecodeRequestHeader(reader, isDrain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Cause(err) != io.EOF {
|
if errors.Cause(err) != io.EOF {
|
||||||
log.Record(&log.AccessMessage{
|
log.Record(&log.AccessMessage{
|
||||||
|
|
Loading…
Reference in New Issue