/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. */ package wireguard import ( "context" "fmt" "net" "net/netip" "os" "github.com/sagernet/wireguard-go/tun" "github.com/xtls/xray-core/features/dns" "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event incomingPacket chan *bufferv2.View mtu int dnsClient dns.Client hasV4, hasV6 bool } type Net netTun func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, HandleLocal: true, } dev := &netTun{ ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), incomingPacket: make(chan *bufferv2.View), dnsClient: dnsClient, mtu: mtu, } dev.ep.AddNotify(dev) tcpipErr := dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { var protoNumber tcpip.NetworkProtocolNumber if ip.Is4() { protoNumber = ipv4.ProtocolNumber } else if ip.Is6() { protoNumber = ipv6.ProtocolNumber } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) } if ip.Is4() { dev.hasV4 = true } else if ip.Is6() { dev.hasV6 = true } } if dev.hasV4 { dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) } if dev.hasV6 { dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) } dev.events <- tun.EventUp return dev, (*Net)(dev), nil } func (tun *netTun) Name() (string, error) { return "go", nil } func (tun *netTun) File() *os.File { return nil } func (tun *netTun) Events() chan tun.Event { return tun.events } func (tun *netTun) Read(buf []byte, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } return view.Read(buf[offset:]) } func (tun *netTun) Write(buf []byte, offset int) (int, error) { packet := buf[offset:] if len(packet) == 0 { return 0, nil } pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) case 6: tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) } return len(buf), nil } func (tun *netTun) WriteNotify() { pkt := tun.ep.Read() if pkt == nil { return } view := pkt.ToView() pkt.DecRef() tun.incomingPacket <- view } func (tun *netTun) Flush() error { return nil } func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) if tun.events != nil { close(tun.events) } tun.ep.Close() if tun.incomingPacket != nil { close(tun.incomingPacket) } return nil } func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { var protoNumber tcpip.NetworkProtocolNumber if endpoint.Addr().Is4() { protoNumber = ipv4.ProtocolNumber } else { protoNumber = ipv6.ProtocolNumber } return tcpip.FullAddress{ NIC: 1, Addr: tcpip.Address(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { fa, pn := convertToFullAddr(addr) return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { fa, pn := convertToFullAddr(addr) return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { return net.DialTCPAddrPort(netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { fa, pn := convertToFullAddr(addr) return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { return net.ListenTCPAddrPort(netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress addr, pn = convertToFullAddr(laddr) lfa = &addr } if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { return net.DialUDPAddrPort(laddr, netip.AddrPort{}) } func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { var la, ra netip.AddrPort if laddr != nil { ip, _ := netip.AddrFromSlice(laddr.IP) la = netip.AddrPortFrom(ip, uint16(laddr.Port)) } if raddr != nil { ip, _ := netip.AddrFromSlice(raddr.IP) ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) } return net.DialUDPAddrPort(la, ra) } func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { return net.DialUDP(laddr, nil) } func (n *Net) HasV4() bool { return n.hasV4 } func (n *Net) HasV6() bool { return n.hasV6 } func IsDomainName(s string) bool { l := len(s) if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { return false } last := byte('.') nonNumeric := false partlen := 0 for i := 0; i < len(s); i++ { c := s[i] switch { default: return false case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': nonNumeric = true partlen++ case '0' <= c && c <= '9': partlen++ case c == '-': if last == '.' { return false } partlen++ nonNumeric = true case c == '.': if last == '.' || last == '-' { return false } if partlen > 63 || partlen == 0 { return false } partlen = 0 } last = c } if last == '-' || partlen > 63 { return false } return nonNumeric }