diff --git a/proxy/wireguard/gvisortun/tun.go b/proxy/wireguard/gvisortun/tun.go index 65677c48..f809c5f8 100644 --- a/proxy/wireguard/gvisortun/tun.go +++ b/proxy/wireguard/gvisortun/tun.go @@ -1,3 +1,6 @@ +//go:build go1.23 +// +build go1.23 + /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. diff --git a/proxy/wireguard/gvisortun/tun_go121.go b/proxy/wireguard/gvisortun/tun_go121.go new file mode 100644 index 00000000..1fc5f10b --- /dev/null +++ b/proxy/wireguard/gvisortun/tun_go121.go @@ -0,0 +1,233 @@ +//go:build !go1.23 +// +build !go1.23 + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. + */ + +package gvisortun + +import ( + "context" + "fmt" + "net/netip" + "os" + "syscall" + + "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + incomingPacket chan *buffer.View + mtu int + hasV4, hasV6 bool +} + +type Net netTun + +func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: !promiscuousMode, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + events: make(chan tun.Event, 1), + incomingPacket: make(chan *buffer.View), + mtu: mtu, + } + dev.ep.AddNotify(dev) + tcpipErr := dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { + dev.hasV4 = true + } else if ip.Is6() { + dev.hasV6 = true + } + } + if dev.hasV4 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + } + if dev.hasV6 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + } + if promiscuousMode { + // enable promiscuous mode to handle all packets processed by netstack + dev.stack.SetPromiscuousMode(1, true) + dev.stack.SetSpoofing(1, true) + } + + opt := tcpip.CongestionControlOption("cubic") + if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err) + } + + dev.events <- tun.EventUp + return dev, (*Net)(dev), dev.stack, nil +} + +// BatchSize implements tun.Device +func (tun *netTun) BatchSize() int { + return 1 +} + +// Name implements tun.Device +func (tun *netTun) Name() (string, error) { + return "go", nil +} + +// File implements tun.Device +func (tun *netTun) File() *os.File { + return nil +} + +// Events implements tun.Device +func (tun *netTun) Events() <-chan tun.Event { + return tun.events +} + +// Read implements tun.Device + +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// Write implements tun.Device +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } + + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + } + return len(buf), nil +} + +// WriteNotify implements channel.Notification +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt.IsNil() { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + +// Flush implements tun.Device +func (tun *netTun) Flush() error { + return nil +} + +// Close implements tun.Device +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + + if tun.events != nil { + close(tun.events) + } + + tun.ep.Close() + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +// MTU implements tun.Device +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) +} + +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { + var lfa, rfa *tcpip.FullAddress + var pn tcpip.NetworkProtocolNumber + if laddr.IsValid() || laddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(laddr) + lfa = &addr + } + if raddr.IsValid() || raddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(raddr) + rfa = &addr + } + return gonet.DialUDP(net.stack, lfa, rfa, pn) +} diff --git a/proxy/wireguard/tun.go b/proxy/wireguard/tun.go index 74a3b71d..af91d012 100644 --- a/proxy/wireguard/tun.go +++ b/proxy/wireguard/tun.go @@ -1,3 +1,6 @@ +//go:build go1.23 +// +build go1.23 + package wireguard import ( diff --git a/proxy/wireguard/tun_go121.go b/proxy/wireguard/tun_go121.go new file mode 100644 index 00000000..9c670433 --- /dev/null +++ b/proxy/wireguard/tun_go121.go @@ -0,0 +1,208 @@ +//go:build !go1.23 +// +build !go1.23 + +package wireguard + +import ( + "context" + "fmt" + "net" + "net/netip" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/log" + xnet "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/proxy/wireguard/gvisortun" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) + +type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn) + +type Tunnel interface { + BuildDevice(ipc string, bind conn.Bind) error + DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) + DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) + Close() error +} + +type tunnel struct { + tun tun.Device + device *device.Device + rw sync.Mutex +} + +func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) { + t.rw.Lock() + defer t.rw.Unlock() + + if t.device != nil { + return errors.New("device is already initialized") + } + + logger := &device.Logger{ + Verbosef: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Debug, + Content: fmt.Sprintf(format, args...), + }) + }, + Errorf: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Error, + Content: fmt.Sprintf(format, args...), + }) + }, + } + + t.device = device.NewDevice(t.tun, bind, logger) + if err = t.device.IpcSet(ipc); err != nil { + return err + } + if err = t.device.Up(); err != nil { + return err + } + return nil +} + +func (t *tunnel) Close() (err error) { + t.rw.Lock() + defer t.rw.Unlock() + + if t.device == nil { + return nil + } + + t.device.Close() + t.device = nil + err = t.tun.Close() + t.tun = nil + return nil +} + +func CalculateInterfaceName(name string) (tunName string) { + if runtime.GOOS == "darwin" { + tunName = "utun" + } else if name != "" { + tunName = name + } else { + tunName = "tun" + } + interfaces, err := net.Interfaces() + if err != nil { + return + } + var tunIndex int + for _, netInterface := range interfaces { + if strings.HasPrefix(netInterface.Name, tunName) { + index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16) + if parseErr == nil { + tunIndex = int(index) + 1 + } + } + } + tunName = fmt.Sprintf("%s%d", tunName, tunIndex) + return +} + +var _ Tunnel = (*gvisorNet)(nil) + +type gvisorNet struct { + tunnel + net *gvisortun.Net +} + +func (g *gvisorNet) Close() error { + return g.tunnel.Close() +} + +func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( + net.Conn, error, +) { + return g.net.DialContextTCPAddrPort(ctx, addr) +} + +func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { + return g.net.DialUDPAddrPort(laddr, raddr) +} + +func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) { + out := &gvisorNet{} + tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil) + if err != nil { + return nil, err + } + + if handler != nil { + // handler is only used for promiscuous mode + // capture all packets and send to handler + + tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) { + go func(r *tcp.ForwarderRequest) { + var ( + wq waiter.Queue + id = r.ID() + ) + + // Perform a TCP three-way handshake. + ep, err := r.CreateEndpoint(&wq) + if err != nil { + errors.LogError(context.Background(), err.String()) + r.Complete(true) + return + } + r.Complete(false) + defer ep.Close() + + // enable tcp keep-alive to prevent hanging connections + ep.SocketOptions().SetKeepAlive(true) + + // local address is actually destination + handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep)) + }(r) + }) + stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + + udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) { + go func(r *udp.ForwarderRequest) { + var ( + wq waiter.Queue + id = r.ID() + ) + + ep, err := r.CreateEndpoint(&wq) + if err != nil { + errors.LogError(context.Background(), err.String()) + return + } + defer ep.Close() + + // prevents hanging connections and ensure timely release + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: true, + Timeout: 15 * time.Second, + }) + + handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep)) + }(r) + }) + stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + } + + out.tun, out.net = tun, n + return out, nil +}