From 0ac7da2fc8e26017c694c16bd2c2c2f1734c301c Mon Sep 17 00:00:00 2001 From: hax0r31337 <65506006+hax0r31337@users.noreply.github.com> Date: Sat, 18 Nov 2023 11:27:17 +0800 Subject: [PATCH] WireGuard Inbound (User-space WireGuard server) (#2477) * feat: wireguard inbound * feat(command): generate wireguard compatible keypair * feat(wireguard): connection idle timeout * fix(wireguard): close endpoint after connection closed * fix(wireguard): resolve conflicts * feat(wireguard): set cubic as default cc algorithm in gVisor TUN * chore(wireguard): resolve conflict * chore(wireguard): remove redurant code * chore(wireguard): remove redurant code * feat: rework server for gvisor tun * feat: keep user-space tun as an option * fix: exclude android from native tun build * feat: auto kernel tun * fix: build * fix: regulate function name & fix test --- go.mod | 4 +- go.sum | 3 +- infra/conf/wireguard.go | 71 ++++--- infra/conf/wireguard_test.go | 15 +- infra/conf/xray.go | 3 +- main/commands/all/x25519.go | 15 +- proxy/wireguard/bind.go | 170 ++++++--------- proxy/wireguard/client.go | 255 +++++++++++++++++++++++ proxy/wireguard/config.go | 7 + proxy/wireguard/config.pb.go | 58 ++++-- proxy/wireguard/config.proto | 40 ++-- proxy/wireguard/gvisortun/tun.go | 230 +++++++++++++++++++++ proxy/wireguard/server.go | 181 ++++++++++++++++ proxy/wireguard/tun.go | 100 +++++++++ proxy/wireguard/tun_default.go | 38 +--- proxy/wireguard/tun_linux.go | 16 +- proxy/wireguard/wireguard.go | 343 ++++++------------------------- 17 files changed, 1049 insertions(+), 500 deletions(-) create mode 100644 proxy/wireguard/client.go create mode 100644 proxy/wireguard/gvisortun/tun.go create mode 100644 proxy/wireguard/server.go diff --git a/go.mod b/go.mod index f0e15d8f..d7f43f89 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 + gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b h12.io/socks v1.0.3 lukechampine.com/blake3 v1.2.1 ) @@ -48,7 +49,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect - github.com/vishvananda/netns v0.0.4 // indirect + github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect go.uber.org/mock v0.3.0 // indirect golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect golang.org/x/mod v0.14.0 // indirect @@ -59,5 +60,4 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b // indirect ) diff --git a/go.sum b/go.sum index 36a5e317..22473052 100644 --- a/go.sum +++ b/go.sum @@ -168,9 +168,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U= github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI= github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/infra/conf/wireguard.go b/infra/conf/wireguard.go index 7b2b6bbf..a4f0eda6 100644 --- a/infra/conf/wireguard.go +++ b/infra/conf/wireguard.go @@ -13,7 +13,7 @@ type WireGuardPeerConfig struct { PublicKey string `json:"publicKey"` PreSharedKey string `json:"preSharedKey"` Endpoint string `json:"endpoint"` - KeepAlive int `json:"keepAlive"` + KeepAlive uint32 `json:"keepAlive"` AllowedIPs []string `json:"allowedIPs,omitempty"` } @@ -21,9 +21,11 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) { var err error config := new(wireguard.PeerConfig) - config.PublicKey, err = parseWireGuardKey(c.PublicKey) - if err != nil { - return nil, err + if c.PublicKey != "" { + config.PublicKey, err = parseWireGuardKey(c.PublicKey) + if err != nil { + return nil, err + } } if c.PreSharedKey != "" { @@ -31,13 +33,11 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) { if err != nil { return nil, err } - } else { - config.PreSharedKey = "0000000000000000000000000000000000000000000000000000000000000000" } config.Endpoint = c.Endpoint // default 0 - config.KeepAlive = int32(c.KeepAlive) + config.KeepAlive = c.KeepAlive if c.AllowedIPs == nil { config.AllowedIps = []string{"0.0.0.0/0", "::0/0"} } else { @@ -48,11 +48,14 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) { } type WireGuardConfig struct { + IsClient bool `json:""` + + KernelMode *bool `json:"kernelMode"` SecretKey string `json:"secretKey"` Address []string `json:"address"` Peers []*WireGuardPeerConfig `json:"peers"` - MTU int `json:"mtu"` - NumWorkers int `json:"workers"` + MTU int32 `json:"mtu"` + NumWorkers int32 `json:"workers"` Reserved []byte `json:"reserved"` DomainStrategy string `json:"domainStrategy"` } @@ -87,11 +90,11 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { if c.MTU == 0 { config.Mtu = 1420 } else { - config.Mtu = int32(c.MTU) + config.Mtu = c.MTU } - // these a fallback code exists in github.com/nanoda0523/wireguard-go code, + // these a fallback code exists in wireguard-go code, // we don't need to process fallback manually - config.NumWorkers = int32(c.NumWorkers) + config.NumWorkers = c.NumWorkers if len(c.Reserved) != 0 && len(c.Reserved) != 3 { return nil, newError(`"reserved" should be empty or 3 bytes`) @@ -113,22 +116,42 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { return nil, newError("unsupported domain strategy: ", c.DomainStrategy) } + config.IsClient = c.IsClient + if c.KernelMode != nil { + config.KernelMode = *c.KernelMode + if config.KernelMode && !wireguard.KernelTunSupported() { + newError("kernel mode is not supported on your OS or permission is insufficient").AtWarning().WriteToLog() + } + } else { + config.KernelMode = wireguard.KernelTunSupported() + if config.KernelMode { + newError("kernel mode is enabled as it's supported and permission is sufficient").AtDebug().WriteToLog() + } + } + return config, nil } func parseWireGuardKey(str string) (string, error) { - if len(str) != 64 { - // may in base64 form - dat, err := base64.StdEncoding.DecodeString(str) - if err != nil { - return "", err + var err error + + if len(str)%2 == 0 { + _, err = hex.DecodeString(str) + if err == nil { + return str, nil } - if len(dat) != 32 { - return "", newError("key should be 32 bytes: " + str) - } - return hex.EncodeToString(dat), err - } else { - // already hex form - return str, nil } + + var dat []byte + str = strings.TrimSuffix(str, "=") + if strings.ContainsRune(str, '+') || strings.ContainsRune(str, '/') { + dat, err = base64.RawStdEncoding.DecodeString(str) + } else { + dat, err = base64.RawURLEncoding.DecodeString(str) + } + if err == nil { + return hex.EncodeToString(dat), nil + } + + return "", newError("failed to deserialize key").Base(err) } diff --git a/infra/conf/wireguard_test.go b/infra/conf/wireguard_test.go index 7a4adf36..57951105 100644 --- a/infra/conf/wireguard_test.go +++ b/infra/conf/wireguard_test.go @@ -7,7 +7,7 @@ import ( "github.com/xtls/xray-core/proxy/wireguard" ) -func TestWireGuardOutbound(t *testing.T) { +func TestWireGuardConfig(t *testing.T) { creator := func() Buildable { return new(WireGuardConfig) } @@ -25,7 +25,8 @@ func TestWireGuardOutbound(t *testing.T) { ], "mtu": 1300, "workers": 2, - "domainStrategy": "ForceIPv6v4" + "domainStrategy": "ForceIPv6v4", + "kernelMode": false }`, Parser: loadJSON(creator), Output: &wireguard.DeviceConfig{ @@ -35,16 +36,16 @@ func TestWireGuardOutbound(t *testing.T) { Peers: []*wireguard.PeerConfig{ { // also can read from hex form directly - PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a", - PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000", - Endpoint: "127.0.0.1:1234", - KeepAlive: 0, - AllowedIps: []string{"0.0.0.0/0", "::0/0"}, + PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a", + Endpoint: "127.0.0.1:1234", + KeepAlive: 0, + AllowedIps: []string{"0.0.0.0/0", "::0/0"}, }, }, Mtu: 1300, NumWorkers: 2, DomainStrategy: wireguard.DeviceConfig_FORCE_IP64, + KernelMode: false, }, }, }) diff --git a/infra/conf/xray.go b/infra/conf/xray.go index dfc34a8c..0935b1b0 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -24,6 +24,7 @@ var ( "vless": func() interface{} { return new(VLessInboundConfig) }, "vmess": func() interface{} { return new(VMessInboundConfig) }, "trojan": func() interface{} { return new(TrojanServerConfig) }, + "wireguard": func() interface{} { return &WireGuardConfig{IsClient: false} }, }, "protocol", "settings") outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{ @@ -37,7 +38,7 @@ var ( "vmess": func() interface{} { return new(VMessOutboundConfig) }, "trojan": func() interface{} { return new(TrojanClientConfig) }, "dns": func() interface{} { return new(DNSOutboundConfig) }, - "wireguard": func() interface{} { return new(WireGuardConfig) }, + "wireguard": func() interface{} { return &WireGuardConfig{IsClient: true} }, }, "protocol", "settings") ctllog = log.New(os.Stderr, "xctl> ", 0) diff --git a/main/commands/all/x25519.go b/main/commands/all/x25519.go index e7909d9b..814cca72 100644 --- a/main/commands/all/x25519.go +++ b/main/commands/all/x25519.go @@ -10,7 +10,7 @@ import ( ) var cmdX25519 = &base.Command{ - UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"]`, + UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"] [--std-encoding]`, Short: `Generate key pair for x25519 key exchange`, Long: ` Generate key pair for x25519 key exchange. @@ -18,6 +18,7 @@ Generate key pair for x25519 key exchange. Random: {{.Exec}} x25519 From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)" +For Std Encoding: {{.Exec}} x25519 --std-encoding `, } @@ -26,12 +27,14 @@ func init() { } var input_base64 = cmdX25519.Flag.String("i", "", "") +var input_stdEncoding = cmdX25519.Flag.Bool("std-encoding", false, "") func executeX25519(cmd *base.Command, args []string) { var output string var err error var privateKey []byte var publicKey []byte + var encoding *base64.Encoding if len(*input_base64) > 0 { privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64) if err != nil { @@ -63,9 +66,15 @@ func executeX25519(cmd *base.Command, args []string) { goto out } + if *input_stdEncoding { + encoding = base64.StdEncoding + } else { + encoding = base64.RawURLEncoding + } + output = fmt.Sprintf("Private key: %v\nPublic key: %v", - base64.RawURLEncoding.EncodeToString(privateKey), - base64.RawURLEncoding.EncodeToString(publicKey)) + encoding.EncodeToString(privateKey), + encoding.EncodeToString(publicKey)) out: fmt.Println(output) } diff --git a/proxy/wireguard/bind.go b/proxy/wireguard/bind.go index c224dc56..1fbcbc98 100644 --- a/proxy/wireguard/bind.go +++ b/proxy/wireguard/bind.go @@ -27,48 +27,45 @@ type netReadInfo struct { err error } -type netBindClient struct { - workers int - dialer internet.Dialer +// reduce duplicated code +type netBind struct { dns dns.Client dnsOption dns.IPOption - reserved []byte + workers int readQueue chan *netReadInfo } -func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { - ipStr, port, _, err := splitAddrPort(s) +// SetMark implements conn.Bind +func (bind *netBind) SetMark(mark uint32) error { + return nil +} + +// ParseEndpoint implements conn.Bind +func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) { + ipStr, port, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + portNum, err := strconv.Atoi(port) if err != nil { return nil, err } - var addr net.IP - if IsDomainName(ipStr) { - ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption) + addr := xnet.ParseAddress(ipStr) + if addr.Family() == xnet.AddressFamilyDomain { + ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption) if err != nil { return nil, err } else if len(ips) == 0 { return nil, dns.ErrEmptyResponse } - addr = ips[0] - } else { - addr = net.ParseIP(ipStr) - } - if addr == nil { - return nil, errors.New("failed to parse ip: " + ipStr) - } - - var ip xnet.Address - if p4 := addr.To4(); len(p4) == net.IPv4len { - ip = xnet.IPAddress(p4[:]) - } else { - ip = xnet.IPAddress(addr[:]) + addr = xnet.IPAddress(ips[0]) } dst := xnet.Destination{ - Address: ip, - Port: xnet.Port(port), + Address: addr, + Port: xnet.Port(portNum), Network: xnet.Network_UDP, } @@ -77,7 +74,13 @@ func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { }, nil } -func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { +// BatchSize implements conn.Bind +func (bind *netBind) BatchSize() int { + return 1 +} + +// Open implements conn.Bind +func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { bind.readQueue = make(chan *netReadInfo) fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { @@ -109,13 +112,21 @@ func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error return arr, uint16(uport), nil } -func (bind *netBindClient) Close() error { +// Close implements conn.Bind +func (bind *netBind) Close() error { if bind.readQueue != nil { close(bind.readQueue) } return nil } +type netBindClient struct { + netBind + + dialer internet.Dialer + reserved []byte +} + func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { c, err := bind.dialer.Dial(context.Background(), endpoint.dst) if err != nil { @@ -177,12 +188,29 @@ func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error { return nil } -func (bind *netBindClient) SetMark(mark uint32) error { - return nil +type netBindServer struct { + netBind } -func (bind *netBindClient) BatchSize() int { - return 1 +func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error { + var err error + + nend, ok := endpoint.(*netEndpoint) + if !ok { + return conn.ErrWrongEndpointType + } + + if nend.conn == nil { + return newError("connection not open yet") + } + + for _, buff := range buff { + if _, err = nend.conn.Write(buff); err != nil { + return err + } + } + + return err } type netEndpoint struct { @@ -193,7 +221,7 @@ type netEndpoint struct { func (netEndpoint) ClearSrc() {} func (e netEndpoint) DstIP() netip.Addr { - return toNetIpAddr(e.dst.Address) + return netip.Addr{} } func (e netEndpoint) SrcIP() netip.Addr { @@ -232,83 +260,3 @@ func toNetIpAddr(addr xnet.Address) netip.Addr { return netip.AddrFrom16(arr) } } - -func stringsLastIndexByte(s string, b byte) int { - for i := len(s) - 1; i >= 0; i-- { - if s[i] == b { - return i - } - } - return -1 -} - -func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) { - i := stringsLastIndexByte(s, ':') - if i == -1 { - return "", 0, false, errors.New("not an ip:port") - } - - ip = s[:i] - portStr := s[i+1:] - if len(ip) == 0 { - return "", 0, false, errors.New("no IP") - } - if len(portStr) == 0 { - return "", 0, false, errors.New("no port") - } - port64, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s)) - } - port = uint16(port64) - if ip[0] == '[' { - if len(ip) < 2 || ip[len(ip)-1] != ']' { - return "", 0, false, errors.New("missing ]") - } - ip = ip[1 : len(ip)-1] - v6 = true - } - - return ip, port, v6, nil -} - -func IsDomainName(s string) bool { - l := len(s) - if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { - return false - } - last := byte('.') - nonNumeric := false - partlen := 0 - for i := 0; i < len(s); i++ { - c := s[i] - switch { - default: - return false - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': - nonNumeric = true - partlen++ - case '0' <= c && c <= '9': - partlen++ - case c == '-': - if last == '.' { - return false - } - partlen++ - nonNumeric = true - case c == '.': - if last == '.' || last == '-' { - return false - } - if partlen > 63 || partlen == 0 { - return false - } - partlen = 0 - } - last = c - } - if last == '-' || partlen > 63 { - return false - } - return nonNumeric -} diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go new file mode 100644 index 00000000..def07878 --- /dev/null +++ b/proxy/wireguard/client.go @@ -0,0 +1,255 @@ +/* + +Some of codes are copied from https://github.com/octeep/wireproxy, license below. + +Copyright (c) 2022 Wind T.F. Wong + +Permission to use, copy, modify, and distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +*/ + +package wireguard + +import ( + "context" + "net/netip" + "sync" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/dice" + "github.com/xtls/xray-core/common/log" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal" + "github.com/xtls/xray-core/common/task" + "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/features/dns" + "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/internet" +) + +// Handler is an outbound connection that silently swallow the entire payload. +type Handler struct { + conf *DeviceConfig + net Tunnel + bind *netBindClient + policyManager policy.Manager + dns dns.Client + // cached configuration + ipc string + endpoints []netip.Addr + hasIPv4, hasIPv6 bool + wgLock sync.Mutex +} + +// New creates a new wireguard handler. +func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { + v := core.MustFromContext(ctx) + + endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) + if err != nil { + return nil, err + } + + d := v.GetFeature(dns.ClientType()).(dns.Client) + return &Handler{ + conf: conf, + policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + dns: d, + ipc: createIPCRequest(conf), + endpoints: endpoints, + hasIPv4: hasIPv4, + hasIPv6: hasIPv6, + }, nil +} + +func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { + h.wgLock.Lock() + defer h.wgLock.Unlock() + + if h.bind != nil && h.bind.dialer == dialer && h.net != nil { + return nil + } + + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Info, + Content: "switching dialer", + }) + + if h.net != nil { + _ = h.net.Close() + h.net = nil + } + if h.bind != nil { + _ = h.bind.Close() + h.bind = nil + } + + // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer + bind := &netBindClient{ + netBind: netBind{ + dns: h.dns, + dnsOption: dns.IPOption{ + IPv4Enable: h.hasIPv4, + IPv6Enable: h.hasIPv6, + }, + workers: int(h.conf.NumWorkers), + }, + dialer: dialer, + reserved: h.conf.Reserved, + } + defer func() { + if err != nil { + _ = bind.Close() + } + }() + + h.net, err = h.makeVirtualTun(bind) + if err != nil { + return newError("failed to create virtual tun interface").Base(err) + } + h.bind = bind + return nil +} + +// Process implements OutboundHandler.Dispatch(). +func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { + return newError("target not specified") + } + outbound.Name = "wireguard" + inbound := session.InboundFromContext(ctx) + if inbound != nil { + inbound.SetCanSpliceCopy(3) + } + + if err := h.processWireGuard(dialer); err != nil { + return err + } + + // Destination of the inner request. + destination := outbound.Target + command := protocol.RequestCommandTCP + if destination.Network == net.Network_UDP { + command = protocol.RequestCommandUDP + } + + // resolve dns + addr := destination.Address + if addr.Family().IsDomain() { + ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ + IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), + IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), + }) + { // Resolve fallback + if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { + ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ + IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), + IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), + }) + } + } + if err != nil { + return newError("failed to lookup DNS").Base(err) + } else if len(ips) == 0 { + return dns.ErrEmptyResponse + } + addr = net.IPAddress(ips[dice.Roll(len(ips))]) + } + + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + + p := h.policyManager.ForLevel(0) + + ctx, cancel := context.WithCancel(ctx) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, p.Timeouts.ConnectionIdle) + addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) + + var requestFunc func() error + var responseFunc func() error + + if command == protocol.RequestCommandTCP { + conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) + if err != nil { + return newError("failed to create TCP connection").Base(err) + } + defer conn.Close() + + requestFunc = func() error { + defer timer.SetTimeout(p.Timeouts.DownlinkOnly) + return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) + } + responseFunc = func() error { + defer timer.SetTimeout(p.Timeouts.UplinkOnly) + return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + } + } else if command == protocol.RequestCommandUDP { + conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) + if err != nil { + return newError("failed to create UDP connection").Base(err) + } + defer conn.Close() + + requestFunc = func() error { + defer timer.SetTimeout(p.Timeouts.DownlinkOnly) + return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) + } + responseFunc = func() error { + defer timer.SetTimeout(p.Timeouts.UplinkOnly) + return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + } + } + + if newCtx != nil { + ctx = newCtx + } + + responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) + if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { + common.Interrupt(link.Reader) + common.Interrupt(link.Writer) + return newError("connection ends").Base(err) + } + + return nil +} + +// creates a tun interface on netstack given a configuration +func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { + t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil) + if err != nil { + return nil, err + } + + bind.dnsOption.IPv4Enable = h.hasIPv4 + bind.dnsOption.IPv6Enable = h.hasIPv6 + + if err = t.BuildDevice(h.ipc, bind); err != nil { + _ = t.Close() + return nil, err + } + return t, nil +} diff --git a/proxy/wireguard/config.go b/proxy/wireguard/config.go index 75622753..2a316cdd 100644 --- a/proxy/wireguard/config.go +++ b/proxy/wireguard/config.go @@ -23,3 +23,10 @@ func (c *DeviceConfig) fallbackIP4() bool { func (c *DeviceConfig) fallbackIP6() bool { return c.DomainStrategy == DeviceConfig_FORCE_IP46 } + +func (c *DeviceConfig) createTun() tunCreator { + if c.KernelMode { + return createKernelTun + } + return createGVisorTun +} diff --git a/proxy/wireguard/config.pb.go b/proxy/wireguard/config.pb.go index dfe7dab5..ed8b135e 100644 --- a/proxy/wireguard/config.pb.go +++ b/proxy/wireguard/config.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 -// protoc v4.23.1 +// protoc-gen-go v1.28.1 +// protoc v4.25.0 // source: proxy/wireguard/config.proto package wireguard @@ -83,7 +83,7 @@ type PeerConfig struct { PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"` Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"` - KeepAlive int32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` + KeepAlive uint32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"` } @@ -140,7 +140,7 @@ func (x *PeerConfig) GetEndpoint() string { return "" } -func (x *PeerConfig) GetKeepAlive() int32 { +func (x *PeerConfig) GetKeepAlive() uint32 { if x != nil { return x.KeepAlive } @@ -166,6 +166,8 @@ type DeviceConfig struct { NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"` Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"` DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"` + IsClient bool `protobuf:"varint,8,opt,name=is_client,json=isClient,proto3" json:"is_client,omitempty"` + KernelMode bool `protobuf:"varint,9,opt,name=kernel_mode,json=kernelMode,proto3" json:"kernel_mode,omitempty"` } func (x *DeviceConfig) Reset() { @@ -249,6 +251,20 @@ func (x *DeviceConfig) GetDomainStrategy() DeviceConfig_DomainStrategy { return DeviceConfig_FORCE_IP } +func (x *DeviceConfig) GetIsClient() bool { + if x != nil { + return x.IsClient + } + return false +} + +func (x *DeviceConfig) GetKernelMode() bool { + if x != nil { + return x.KernelMode + } + return false +} + var File_proxy_wireguard_config_proto protoreflect.FileDescriptor var file_proxy_wireguard_config_proto_rawDesc = []byte{ @@ -263,10 +279,10 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{ 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69, - 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c, + 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c, 0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, - 0x64, 0x49, 0x70, 0x73, 0x22, 0x8a, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, + 0x64, 0x49, 0x70, 0x73, 0x22, 0xc8, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, @@ -285,19 +301,23 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{ 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, - 0x65, 0x67, 0x79, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, - 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, - 0x50, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, - 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, - 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, - 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, - 0x04, 0x42, 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, - 0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, - 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, - 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, - 0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, - 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, - 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x67, 0x79, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x73, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x73, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x12, 0x1f, 0x0a, 0x0b, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x4d, 0x6f, 0x64, + 0x65, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, + 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x10, + 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x01, + 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x02, 0x12, + 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x03, 0x12, + 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x04, 0x42, + 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78, + 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, + 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77, + 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e, + 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/proxy/wireguard/config.proto b/proxy/wireguard/config.proto index 0a12c009..e7fd66f4 100644 --- a/proxy/wireguard/config.proto +++ b/proxy/wireguard/config.proto @@ -7,26 +7,28 @@ option java_package = "com.xray.proxy.wireguard"; option java_multiple_files = true; message PeerConfig { - string public_key = 1; - string pre_shared_key = 2; - string endpoint = 3; - int32 keep_alive = 4; - repeated string allowed_ips = 5; + string public_key = 1; + string pre_shared_key = 2; + string endpoint = 3; + uint32 keep_alive = 4; + repeated string allowed_ips = 5; } message DeviceConfig { - enum DomainStrategy { - FORCE_IP = 0; - FORCE_IP4 = 1; - FORCE_IP6 = 2; - FORCE_IP46 = 3; - FORCE_IP64 = 4; - } - string secret_key = 1; - repeated string endpoint = 2; - repeated PeerConfig peers = 3; - int32 mtu = 4; - int32 num_workers = 5; - bytes reserved = 6; - DomainStrategy domain_strategy = 7; + enum DomainStrategy { + FORCE_IP = 0; + FORCE_IP4 = 1; + FORCE_IP6 = 2; + FORCE_IP46 = 3; + FORCE_IP64 = 4; + } + string secret_key = 1; + repeated string endpoint = 2; + repeated PeerConfig peers = 3; + int32 mtu = 4; + int32 num_workers = 5; + bytes reserved = 6; + DomainStrategy domain_strategy = 7; + bool is_client = 8; + bool kernel_mode = 9; } \ No newline at end of file diff --git a/proxy/wireguard/gvisortun/tun.go b/proxy/wireguard/gvisortun/tun.go new file mode 100644 index 00000000..9e9a0b2b --- /dev/null +++ b/proxy/wireguard/gvisortun/tun.go @@ -0,0 +1,230 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. + */ + +package gvisortun + +import ( + "context" + "fmt" + "net/netip" + "os" + "syscall" + + "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + incomingPacket chan *buffer.View + mtu int + hasV4, hasV6 bool +} + +type Net netTun + +func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: !promiscuousMode, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + events: make(chan tun.Event, 1), + incomingPacket: make(chan *buffer.View), + mtu: mtu, + } + dev.ep.AddNotify(dev) + tcpipErr := dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { + dev.hasV4 = true + } else if ip.Is6() { + dev.hasV6 = true + } + } + if dev.hasV4 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + } + if dev.hasV6 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + } + if promiscuousMode { + // enable promiscuous mode to handle all packets processed by netstack + dev.stack.SetPromiscuousMode(1, true) + dev.stack.SetSpoofing(1, true) + } + + opt := tcpip.CongestionControlOption("cubic") + if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err) + } + + dev.events <- tun.EventUp + return dev, (*Net)(dev), dev.stack, nil +} + +// BatchSize implements tun.Device +func (tun *netTun) BatchSize() int { + return 1 +} + +// Name implements tun.Device +func (tun *netTun) Name() (string, error) { + return "go", nil +} + +// File implements tun.Device +func (tun *netTun) File() *os.File { + return nil +} + +// Events implements tun.Device +func (tun *netTun) Events() <-chan tun.Event { + return tun.events +} + +// Read implements tun.Device + +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +// Write implements tun.Device +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } + + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + } + return len(buf), nil +} + +// WriteNotify implements channel.Notification +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt.IsNil() { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + +// Flush implements tun.Device +func (tun *netTun) Flush() error { + return nil +} + +// Close implements tun.Device +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + + if tun.events != nil { + close(tun.events) + } + + tun.ep.Close() + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +// MTU implements tun.Device +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) +} + +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { + var lfa, rfa *tcpip.FullAddress + var pn tcpip.NetworkProtocolNumber + if laddr.IsValid() || laddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(laddr) + lfa = &addr + } + if raddr.IsValid() || raddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(raddr) + rfa = &addr + } + return gonet.DialUDP(net.stack, lfa, rfa, pn) +} diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go new file mode 100644 index 00000000..6cd2d7ad --- /dev/null +++ b/proxy/wireguard/server.go @@ -0,0 +1,181 @@ +package wireguard + +import ( + "context" + "errors" + "io" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/log" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal" + "github.com/xtls/xray-core/common/task" + "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/features/dns" + "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/transport/internet/stat" +) + +var nullDestination = net.TCPDestination(net.AnyIP, 0) + +type Server struct { + bindServer *netBindServer + + info routingInfo + policyManager policy.Manager +} + +type routingInfo struct { + ctx context.Context + dispatcher routing.Dispatcher + inboundTag *session.Inbound + outboundTag *session.Outbound + contentTag *session.Content +} + +func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { + v := core.MustFromContext(ctx) + + endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) + if err != nil { + return nil, err + } + + server := &Server{ + bindServer: &netBindServer{ + netBind: netBind{ + dns: v.GetFeature(dns.ClientType()).(dns.Client), + dnsOption: dns.IPOption{ + IPv4Enable: hasIPv4, + IPv6Enable: hasIPv6, + }, + }, + }, + policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + } + + tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) + if err != nil { + return nil, err + } + + if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil { + _ = tun.Close() + return nil, err + } + + return server, nil +} + +// Network implements proxy.Inbound. +func (*Server) Network() []net.Network { + return []net.Network{net.Network_UDP} +} + +// Process implements proxy.Inbound. +func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { + s.info = routingInfo{ + ctx: core.ToBackgroundDetachedContext(ctx), + dispatcher: dispatcher, + inboundTag: session.InboundFromContext(ctx), + outboundTag: session.OutboundFromContext(ctx), + contentTag: session.ContentFromContext(ctx), + } + + ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) + if err != nil { + return err + } + + nep := ep.(*netEndpoint) + nep.conn = conn + + reader := buf.NewPacketReader(conn) + for { + mpayload, err := reader.ReadMultiBuffer() + if err != nil { + return err + } + + for _, payload := range mpayload { + v, ok := <-s.bindServer.readQueue + if !ok { + return nil + } + i, err := payload.Read(v.buff) + + v.bytes = i + v.endpoint = nep + v.err = err + v.waiter.Done() + if err != nil && errors.Is(err, io.EOF) { + nep.conn = nil + return nil + } + } + } +} + +func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { + if s.info.dispatcher == nil { + newError("unexpected: dispatcher == nil").AtError().WriteToLog() + return + } + defer conn.Close() + + ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) + plcy := s.policyManager.ForLevel(0) + timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) + + ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ + From: nullDestination, + To: dest, + Status: log.AccessAccepted, + Reason: "", + }) + + if s.info.inboundTag != nil { + ctx = session.ContextWithInbound(ctx, s.info.inboundTag) + } + if s.info.outboundTag != nil { + ctx = session.ContextWithOutbound(ctx, s.info.outboundTag) + } + if s.info.contentTag != nil { + ctx = session.ContextWithContent(ctx, s.info.contentTag) + } + + link, err := s.info.dispatcher.Dispatch(ctx, dest) + if err != nil { + newError("dispatch connection").Base(err).AtError().WriteToLog() + } + defer cancel() + + requestDone := func() error { + defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) + if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil { + return newError("failed to transport all TCP request").Base(err) + } + + return nil + } + + responseDone := func() error { + defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) + if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil { + return newError("failed to transport all TCP response").Base(err) + } + + return nil + } + + requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { + common.Interrupt(link.Reader) + common.Interrupt(link.Writer) + newError("connection ends").Base(err).AtDebug().WriteToLog() + return + } +} diff --git a/proxy/wireguard/tun.go b/proxy/wireguard/tun.go index c320d0d0..c2d30323 100644 --- a/proxy/wireguard/tun.go +++ b/proxy/wireguard/tun.go @@ -10,14 +10,26 @@ import ( "strconv" "strings" "sync" + "time" "github.com/xtls/xray-core/common/log" + xnet "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/proxy/wireguard/gvisortun" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" ) +type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) + +type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn) + type Tunnel interface { BuildDevice(ipc string, bind conn.Bind) error DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) @@ -103,3 +115,91 @@ func CalculateInterfaceName(name string) (tunName string) { tunName = fmt.Sprintf("%s%d", tunName, tunIndex) return } + +var _ Tunnel = (*gvisorNet)(nil) + +type gvisorNet struct { + tunnel + net *gvisortun.Net +} + +func (g *gvisorNet) Close() error { + return g.tunnel.Close() +} + +func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( + net.Conn, error, +) { + return g.net.DialContextTCPAddrPort(ctx, addr) +} + +func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { + return g.net.DialUDPAddrPort(laddr, raddr) +} + +func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) { + out := &gvisorNet{} + tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil) + if err != nil { + return nil, err + } + + if handler != nil { + // handler is only used for promiscuous mode + // capture all packets and send to handler + + tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) { + go func(r *tcp.ForwarderRequest) { + var ( + wq waiter.Queue + id = r.ID() + ) + + // Perform a TCP three-way handshake. + ep, err := r.CreateEndpoint(&wq) + if err != nil { + newError(err.String()).AtError().WriteToLog() + r.Complete(true) + return + } + r.Complete(false) + defer ep.Close() + + // enable tcp keep-alive to prevent hanging connections + ep.SocketOptions().SetKeepAlive(true) + + // local address is actually destination + handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep)) + }(r) + }) + stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + + udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) { + go func(r *udp.ForwarderRequest) { + var ( + wq waiter.Queue + id = r.ID() + ) + + ep, err := r.CreateEndpoint(&wq) + if err != nil { + newError(err.String()).AtError().WriteToLog() + return + } + defer ep.Close() + + // prevents hanging connections and ensure timely release + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: true, + Timeout: 15 * time.Second, + }) + + handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep)) + }(r) + }) + stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + } + + out.tun, out.net = tun, n + return out, nil +} diff --git a/proxy/wireguard/tun_default.go b/proxy/wireguard/tun_default.go index 07f21272..4d0567af 100644 --- a/proxy/wireguard/tun_default.go +++ b/proxy/wireguard/tun_default.go @@ -1,42 +1,16 @@ -//go:build !linux +//go:build !linux || android package wireguard import ( - "context" - "net" + "errors" "net/netip" - - "golang.zx2c4.com/wireguard/tun/netstack" ) -var _ Tunnel = (*gvisorNet)(nil) - -type gvisorNet struct { - tunnel - net *netstack.Net +func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) { + return nil, errors.New("not implemented") } -func (g *gvisorNet) Close() error { - return g.tunnel.Close() -} - -func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( - net.Conn, error, -) { - return g.net.DialContextTCPAddrPort(ctx, addr) -} - -func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { - return g.net.DialUDPAddrPort(laddr, raddr) -} - -func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) { - out := &gvisorNet{} - tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu) - if err != nil { - return nil, err - } - out.tun, out.net = tun, n - return out, nil +func KernelTunSupported() bool { + return false } diff --git a/proxy/wireguard/tun_linux.go b/proxy/wireguard/tun_linux.go index ec940c56..b85a9d09 100644 --- a/proxy/wireguard/tun_linux.go +++ b/proxy/wireguard/tun_linux.go @@ -1,3 +1,5 @@ +//go:build linux && !android + package wireguard import ( @@ -69,7 +71,11 @@ func (d *deviceNet) Close() (err error) { return errors.Join(errs...) } -func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) { +func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) { + if handler != nil { + return nil, newError("TODO: support promiscuous mode") + } + var v4, v6 *netip.Addr for _, prefixes := range localAddresses { if v4 == nil && prefixes.Is4() { @@ -221,3 +227,11 @@ func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) { out.tun = wgt return out, nil } + +func KernelTunSupported() bool { + // run a superuser permission check to check + // if the current user has the sufficient permission + // to create a tun device. + + return unix.Geteuid() == 0 // 0 means root +} diff --git a/proxy/wireguard/wireguard.go b/proxy/wireguard/wireguard.go index 48e2ace3..2b3c3007 100644 --- a/proxy/wireguard/wireguard.go +++ b/proxy/wireguard/wireguard.go @@ -1,326 +1,111 @@ -/* - -Some of codes are copied from https://github.com/octeep/wireproxy, license below. - -Copyright (c) 2022 Wind T.F. Wong - -Permission to use, copy, modify, and distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -*/ - package wireguard import ( - "bytes" "context" "fmt" - stdnet "net" "net/netip" "strings" - "sync" "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/buf" - "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/log" - "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/protocol" - "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal" - "github.com/xtls/xray-core/common/task" - "github.com/xtls/xray-core/core" - "github.com/xtls/xray-core/features/dns" - "github.com/xtls/xray-core/features/policy" - "github.com/xtls/xray-core/transport" - "github.com/xtls/xray-core/transport/internet" + "golang.zx2c4.com/wireguard/device" ) -// Handler is an outbound connection that silently swallow the entire payload. -type Handler struct { - conf *DeviceConfig - net Tunnel - bind *netBindClient - policyManager policy.Manager - dns dns.Client - // cached configuration - ipc string - endpoints []netip.Addr - hasIPv4, hasIPv6 bool - wgLock sync.Mutex -} +//go:generate go run github.com/xtls/xray-core/common/errors/errorgen -// New creates a new wireguard handler. -func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { - v := core.MustFromContext(ctx) - - endpoints, err := parseEndpoints(conf) - if err != nil { - return nil, err - } - - hasIPv4, hasIPv6 := false, false - for _, e := range endpoints { - if e.Is4() { - hasIPv4 = true - } - if e.Is6() { - hasIPv6 = true - } - } - - d := v.GetFeature(dns.ClientType()).(dns.Client) - return &Handler{ - conf: conf, - policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), - dns: d, - ipc: createIPCRequest(conf, d, hasIPv6), - endpoints: endpoints, - hasIPv4: hasIPv4, - hasIPv6: hasIPv6, - }, nil -} - -func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { - h.wgLock.Lock() - defer h.wgLock.Unlock() - - if h.bind != nil && h.bind.dialer == dialer && h.net != nil { - return nil - } - - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Info, - Content: "switching dialer", - }) - - if h.net != nil { - _ = h.net.Close() - h.net = nil - } - if h.bind != nil { - _ = h.bind.Close() - h.bind = nil - } - - // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer - bind := &netBindClient{ - dialer: dialer, - workers: int(h.conf.NumWorkers), - dns: h.dns, - reserved: h.conf.Reserved, - } - defer func() { - if err != nil { - _ = bind.Close() - } - }() - - h.net, err = h.makeVirtualTun(bind) - if err != nil { - return newError("failed to create virtual tun interface").Base(err) - } - h.bind = bind - return nil -} - -// Process implements OutboundHandler.Dispatch(). -func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { - return newError("target not specified") - } - outbound.Name = "wireguard" - inbound := session.InboundFromContext(ctx) - if inbound != nil { - inbound.SetCanSpliceCopy(3) - } - - if err := h.processWireGuard(dialer); err != nil { - return err - } - - // Destination of the inner request. - destination := outbound.Target - command := protocol.RequestCommandTCP - if destination.Network == net.Network_UDP { - command = protocol.RequestCommandUDP - } - - // resolve dns - addr := destination.Address - if addr.Family().IsDomain() { - ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ - IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), - IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), +var wgLogger = &device.Logger{ + Verbosef: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Debug, + Content: fmt.Sprintf(format, args...), }) - { // Resolve fallback - if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { - ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ - IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), - IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), - }) - } - } - if err != nil { - return newError("failed to lookup DNS").Base(err) - } else if len(ips) == 0 { - return dns.ErrEmptyResponse - } - addr = net.IPAddress(ips[dice.Roll(len(ips))]) - } - - var newCtx context.Context - var newCancel context.CancelFunc - if session.TimeoutOnlyFromContext(ctx) { - newCtx, newCancel = context.WithCancel(context.Background()) - } - - p := h.policyManager.ForLevel(0) - - ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, func() { - cancel() - if newCancel != nil { - newCancel() - } - }, p.Timeouts.ConnectionIdle) - addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) - - var requestFunc func() error - var responseFunc func() error - - if command == protocol.RequestCommandTCP { - conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) - if err != nil { - return newError("failed to create TCP connection").Base(err) - } - defer conn.Close() - - requestFunc = func() error { - defer timer.SetTimeout(p.Timeouts.DownlinkOnly) - return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) - } - responseFunc = func() error { - defer timer.SetTimeout(p.Timeouts.UplinkOnly) - return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) - } - } else if command == protocol.RequestCommandUDP { - conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) - if err != nil { - return newError("failed to create UDP connection").Base(err) - } - defer conn.Close() - - requestFunc = func() error { - defer timer.SetTimeout(p.Timeouts.DownlinkOnly) - return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) - } - responseFunc = func() error { - defer timer.SetTimeout(p.Timeouts.UplinkOnly) - return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) - } - } - - if newCtx != nil { - ctx = newCtx - } - - responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) - if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { - common.Interrupt(link.Reader) - common.Interrupt(link.Writer) - return newError("connection ends").Base(err) - } - - return nil + }, + Errorf: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Error, + Content: fmt.Sprintf(format, args...), + }) + }, } -// serialize the config into an IPC request -func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string { - var request bytes.Buffer - - request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) - - for _, peer := range conf.Peers { - endpoint := peer.Endpoint - host, port, err := net.SplitHostPort(endpoint) - if resolveEndPointToV4 && err == nil { - _, err = netip.ParseAddr(host) - if err != nil { - ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false}) - if err == nil && len(ipList) > 0 { - endpoint = stdnet.JoinHostPort(ipList[0].String(), port) - } - } +func init() { + common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { + deviceConfig := config.(*DeviceConfig) + if deviceConfig.IsClient { + return New(ctx, deviceConfig) + } else { + return NewServer(ctx, deviceConfig) } - - request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n", - peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey)) - - for _, ip := range peer.AllowedIps { - request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) - } - } - - return request.String()[:request.Len()] + })) } // convert endpoint string to netip.Addr -func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { +func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) { + var hasIPv4, hasIPv6 bool + endpoints := make([]netip.Addr, len(conf.Endpoint)) for i, str := range conf.Endpoint { var addr netip.Addr if strings.Contains(str, "/") { prefix, err := netip.ParsePrefix(str) if err != nil { - return nil, err + return nil, false, false, err } addr = prefix.Addr() if prefix.Bits() != addr.BitLen() { - return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") + return nil, false, false, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") } } else { var err error addr, err = netip.ParseAddr(str) if err != nil { - return nil, err + return nil, false, false, err } } endpoints[i] = addr + + if addr.Is4() { + hasIPv4 = true + } else if addr.Is6() { + hasIPv6 = true + } } - return endpoints, nil + return endpoints, hasIPv4, hasIPv6, nil } -// creates a tun interface on netstack given a configuration -func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { - t, err := CreateTun(h.endpoints, int(h.conf.Mtu)) - if err != nil { - return nil, err +// serialize the config into an IPC request +func createIPCRequest(conf *DeviceConfig) string { + var request strings.Builder + + request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) + + if !conf.IsClient { + // placeholder, we'll handle actual port listening on Xray + request.WriteString("listen_port=1337\n") } - bind.dnsOption.IPv4Enable = h.hasIPv4 - bind.dnsOption.IPv6Enable = h.hasIPv6 + for _, peer := range conf.Peers { + if peer.PublicKey != "" { + request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey)) + } - if err = t.BuildDevice(h.ipc, bind); err != nil { - _ = t.Close() - return nil, err + if peer.PreSharedKey != "" { + request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey)) + } + + if peer.Endpoint != "" { + request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint)) + } + + for _, ip := range peer.AllowedIps { + request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) + } + + if peer.KeepAlive != 0 { + request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) + } } - return t, nil -} -func init() { - common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { - return New(ctx, config.(*DeviceConfig)) - })) + return request.String()[:request.Len()] }