Refactor: Shadowsocks & Trojan UDP FullCone NAT

https://t.me/projectXray/95704
This commit is contained in:
RPRX 2020-12-23 13:06:21 +00:00 committed by GitHub
parent 4140ed7ab0
commit 8f8f7dd66f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 234 additions and 44 deletions

View File

@ -2,6 +2,7 @@ package buf
import ( import (
"io" "io"
"net"
"github.com/xtls/xray-core/common/bytespool" "github.com/xtls/xray-core/common/bytespool"
) )
@ -20,6 +21,7 @@ type Buffer struct {
v []byte v []byte
start int32 start int32
end int32 end int32
UDP *net.UDPAddr
} }
// New creates a Buffer with 0 length and 2K capacity. // New creates a Buffer with 0 length and 2K capacity.

View File

@ -17,6 +17,7 @@ import (
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
) )
@ -148,7 +149,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if destination.Network == net.Network_TCP { if destination.Network == net.Network_TCP {
writer = buf.NewWriter(conn) writer = buf.NewWriter(conn)
} else { } else {
writer = &buf.SequentialWriter{Writer: conn} writer = NewPacketWriter(conn)
} }
if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
@ -165,7 +166,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if destination.Network == net.Network_TCP { if destination.Network == net.Network_TCP {
reader = buf.NewReader(conn) reader = buf.NewReader(conn)
} else { } else {
reader = buf.NewPacketReader(conn) reader = NewPacketReader(conn)
} }
if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
return newError("failed to process response").Base(err) return newError("failed to process response").Base(err)
@ -180,3 +181,93 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return nil return nil
} }
func NewPacketReader(conn net.Conn) buf.Reader {
iConn := conn
statConn, ok := iConn.(*internet.StatCouterConnection)
if ok {
iConn = statConn.Connection
}
var counter stats.Counter
if statConn != nil {
counter = statConn.ReadCounter
}
if c, ok := iConn.(*internet.PacketConnWrapper); ok {
return &PacketReader{
PacketConnWrapper: c,
Counter: counter,
}
}
return &buf.PacketReader{Reader: conn}
}
type PacketReader struct {
*internet.PacketConnWrapper
stats.Counter
}
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b := buf.New()
b.Resize(0, buf.Size)
n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
if err != nil {
b.Release()
return nil, err
}
b.Resize(0, int32(n))
b.UDP = d.(*net.UDPAddr)
if r.Counter != nil {
r.Counter.Add(int64(n))
}
return buf.MultiBuffer{b}, nil
}
func NewPacketWriter(conn net.Conn) buf.Writer {
iConn := conn
statConn, ok := iConn.(*internet.StatCouterConnection)
if ok {
iConn = statConn.Connection
}
var counter stats.Counter
if statConn != nil {
counter = statConn.WriteCounter
}
if c, ok := iConn.(*internet.PacketConnWrapper); ok {
return &PacketWriter{
PacketConnWrapper: c,
Counter: counter,
}
}
return &buf.SequentialWriter{Writer: conn}
}
type PacketWriter struct {
*internet.PacketConnWrapper
stats.Counter
}
func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for {
mb2, b := buf.SplitFirst(mb)
mb = mb2
if b == nil {
break
}
var n int
var err error
if b.UDP != nil {
n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), b.UDP)
} else {
n, err = w.PacketConnWrapper.Write(b.Bytes())
}
b.Release()
if err != nil {
buf.ReleaseMulti(mb)
return err
}
if w.Counter != nil {
w.Counter.Add(int64(n))
}
}
return nil
}

View File

@ -134,14 +134,15 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
} }
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
writer := &buf.SequentialWriter{Writer: &UDPWriter{
Writer: conn,
Request: request,
}}
requestDone := func() error { requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
writer := &UDPWriter{
Writer: conn,
Request: request,
}
if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all UDP request").Base(err) return newError("failed to transport all UDP request").Base(err)
} }

View File

@ -230,11 +230,15 @@ func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer.Release() buffer.Release()
return nil, err return nil, err
} }
_, payload, err := DecodeUDPPacket(v.User, buffer) u, payload, err := DecodeUDPPacket(v.User, buffer)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return nil, err return nil, err
} }
payload.UDP = &net.UDPAddr{
IP: u.Address.IP(),
Port: int(u.Port),
}
return buf.MultiBuffer{payload}, nil return buf.MultiBuffer{payload}, nil
} }
@ -243,13 +247,36 @@ type UDPWriter struct {
Request *protocol.RequestHeader Request *protocol.RequestHeader
} }
// Write implements io.Writer. func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
func (w *UDPWriter) Write(payload []byte) (int, error) { for {
packet, err := EncodeUDPPacket(w.Request, payload) mb2, b := buf.SplitFirst(mb)
mb = mb2
if b == nil {
break
}
var packet *buf.Buffer
var err error
if b.UDP != nil {
request := &protocol.RequestHeader{
User: w.Request.User,
Address: net.IPAddress(b.UDP.IP),
Port: net.Port(b.UDP.Port),
}
packet, err = EncodeUDPPacket(request, b.Bytes())
} else {
packet, err = EncodeUDPPacket(w.Request, b.Bytes())
}
b.Release()
if err != nil { if err != nil {
return 0, err buf.ReleaseMulti(mb)
return err
} }
_, err = w.Writer.Write(packet.Bytes()) _, err = w.Writer.Write(packet.Bytes())
packet.Release() packet.Release()
return len(payload), err if err != nil {
buf.ReleaseMulti(mb)
return err
}
}
return nil
} }

View File

@ -145,7 +145,7 @@ func TestUDPReaderWriter(t *testing.T) {
cache := buf.New() cache := buf.New()
defer cache.Release() defer cache.Release()
writer := &buf.SequentialWriter{Writer: &UDPWriter{ writer := &UDPWriter{
Writer: cache, Writer: cache,
Request: &protocol.RequestHeader{ Request: &protocol.RequestHeader{
Version: Version, Version: Version,
@ -153,7 +153,7 @@ func TestUDPReaderWriter(t *testing.T) {
Port: 123, Port: 123,
User: user, User: user,
}, },
}} }
reader := &UDPReader{ reader := &UDPReader{
Reader: cache, Reader: cache,

View File

@ -77,6 +77,15 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
} }
payload := packet.Payload payload := packet.Payload
if payload.UDP != nil {
request = &protocol.RequestHeader{
User: request.User,
Address: net.IPAddress(payload.UDP.IP),
Port: net.Port(payload.UDP.Port),
}
}
data, err := EncodeUDPPacket(request, payload.Bytes()) data, err := EncodeUDPPacket(request, payload.Bytes())
payload.Release() payload.Release()
if err != nil { if err != nil {
@ -94,6 +103,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
} }
inbound.User = s.user inbound.User = s.user
var dest net.Destination
reader := buf.NewPacketReader(conn) reader := buf.NewPacketReader(conn)
for { for {
mpayload, err := reader.ReadMultiBuffer() mpayload, err := reader.ReadMultiBuffer()
@ -118,17 +129,25 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
} }
currentPacketCtx := ctx currentPacketCtx := ctx
dest := request.Destination()
if inbound.Source.IsValid() { if inbound.Source.IsValid() {
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source, From: inbound.Source,
To: dest, To: request.Destination(),
Status: log.AccessAccepted, Status: log.AccessAccepted,
Reason: "", Reason: "",
Email: request.User.Email, Email: request.User.Email,
}) })
} }
newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(currentPacketCtx)) newError("tunnelling request to ", request.Destination()).WriteToLog(session.ExportIDToError(currentPacketCtx))
data.UDP = &net.UDPAddr{
IP: request.Address.IP(),
Port: int(request.Port),
}
if dest.Network == 0 {
dest = request.Destination() // JUST FOLLOW THE FIREST PACKET
}
currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
udpServer.Dispatch(currentPacketCtx, dest, data) udpServer.Dispatch(currentPacketCtx, dest, data)

View File

@ -196,6 +196,15 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
if request == nil { if request == nil {
return return
} }
if payload.UDP != nil {
request = &protocol.RequestHeader{
User: request.User,
Address: net.IPAddress(payload.UDP.IP),
Port: net.Port(payload.UDP.Port),
}
}
udpMessage, err := EncodeUDPPacket(request, payload.Bytes()) udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
payload.Release() payload.Release()
@ -211,6 +220,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx)) newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
} }
var dest net.Destination
reader := buf.NewPacketReader(conn) reader := buf.NewPacketReader(conn)
for { for {
mpayload, err := reader.ReadMultiBuffer() mpayload, err := reader.ReadMultiBuffer()
@ -242,8 +253,17 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
}) })
} }
payload.UDP = &net.UDPAddr{
IP: request.Address.IP(),
Port: int(request.Port),
}
if dest.Network == 0 {
dest = request.Destination() // JUST FOLLOW THE FIREST PACKET
}
currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
udpServer.Dispatch(currentPacketCtx, request.Destination(), payload) udpServer.Dispatch(currentPacketCtx, dest, payload)
} }
} }
} }

View File

@ -128,31 +128,43 @@ type PacketWriter struct {
// WriteMultiBuffer implements buf.Writer // WriteMultiBuffer implements buf.Writer
func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
b := make([]byte, maxLength) for {
for !mb.IsEmpty() { mb2, b := buf.SplitFirst(mb)
var length int mb = mb2
mb, length = buf.SplitBytes(mb, b) if b == nil {
if _, err := w.writePacket(b[:length], w.Target); err != nil { break
}
target := w.Target
if b.UDP != nil {
target.Address = net.IPAddress(b.UDP.IP)
target.Port = net.Port(b.UDP.Port)
}
if _, err := w.writePacket(b.Bytes(), target); err != nil {
buf.ReleaseMulti(mb) buf.ReleaseMulti(mb)
return err return err
} }
} }
return nil return nil
} }
// WriteMultiBufferWithMetadata writes udp packet with destination specified // WriteMultiBufferWithMetadata writes udp packet with destination specified
func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error { func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
b := make([]byte, maxLength) for {
for !mb.IsEmpty() { mb2, b := buf.SplitFirst(mb)
var length int mb = mb2
mb, length = buf.SplitBytes(mb, b) if b == nil {
if _, err := w.writePacket(b[:length], dest); err != nil { break
}
source := dest
if b.UDP != nil {
source.Address = net.IPAddress(b.UDP.IP)
source.Port = net.Port(b.UDP.Port)
}
if _, err := w.writePacket(b.Bytes(), source); err != nil {
buf.ReleaseMulti(mb) buf.ReleaseMulti(mb)
return err return err
} }
} }
return nil return nil
} }
@ -300,6 +312,10 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
} }
b := buf.New() b := buf.New()
b.UDP = &net.UDPAddr{
IP: addr.IP(),
Port: int(port.Value()),
}
mb = append(mb, b) mb = append(mb, b)
n, err := b.ReadFullFrom(r, int32(length)) n, err := b.ReadFullFrom(r, int32(length))
if err != nil { if err != nil {

View File

@ -256,6 +256,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
user := inbound.User user := inbound.User
var dest net.Destination
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -278,8 +280,12 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
}) })
newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx)) newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx))
if dest.Network == 0 {
dest = p.Target // JUST FOLLOW THE FIREST PACKET
}
for _, b := range p.Buffer { for _, b := range p.Buffer {
udpServer.Dispatch(ctx, p.Target, b) udpServer.Dispatch(ctx, dest, b)
} }
} }
} }

View File

@ -60,7 +60,7 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &packetConnWrapper{ return &PacketConnWrapper{
conn: packetConn, conn: packetConn,
dest: destAddr, dest: destAddr,
}, nil }, nil
@ -98,41 +98,49 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne
return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr()) return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr())
} }
type packetConnWrapper struct { type PacketConnWrapper struct {
conn net.PacketConn conn net.PacketConn
dest net.Addr dest net.Addr
} }
func (c *packetConnWrapper) Close() error { func (c *PacketConnWrapper) Close() error {
return c.conn.Close() return c.conn.Close()
} }
func (c *packetConnWrapper) LocalAddr() net.Addr { func (c *PacketConnWrapper) LocalAddr() net.Addr {
return c.conn.LocalAddr() return c.conn.LocalAddr()
} }
func (c *packetConnWrapper) RemoteAddr() net.Addr { func (c *PacketConnWrapper) RemoteAddr() net.Addr {
return c.dest return c.dest
} }
func (c *packetConnWrapper) Write(p []byte) (int, error) { func (c *PacketConnWrapper) Write(p []byte) (int, error) {
return c.conn.WriteTo(p, c.dest) return c.conn.WriteTo(p, c.dest)
} }
func (c *packetConnWrapper) Read(p []byte) (int, error) { func (c *PacketConnWrapper) Read(p []byte) (int, error) {
n, _, err := c.conn.ReadFrom(p) n, _, err := c.conn.ReadFrom(p)
return n, err return n, err
} }
func (c *packetConnWrapper) SetDeadline(t time.Time) error { func (c *PacketConnWrapper) WriteTo(p []byte, d net.Addr) (int, error) {
return c.conn.WriteTo(p, d)
}
func (c *PacketConnWrapper) ReadFrom(p []byte) (int, net.Addr, error) {
return c.conn.ReadFrom(p)
}
func (c *PacketConnWrapper) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t) return c.conn.SetDeadline(t)
} }
func (c *packetConnWrapper) SetReadDeadline(t time.Time) error { func (c *PacketConnWrapper) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t) return c.conn.SetReadDeadline(t)
} }
func (c *packetConnWrapper) SetWriteDeadline(t time.Time) error { func (c *PacketConnWrapper) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t) return c.conn.SetWriteDeadline(t)
} }