mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-26 06:39:20 +02:00
WireGuard Inbound (User-space WireGuard server) (#2477)
* feat: wireguard inbound * feat(command): generate wireguard compatible keypair * feat(wireguard): connection idle timeout * fix(wireguard): close endpoint after connection closed * fix(wireguard): resolve conflicts * feat(wireguard): set cubic as default cc algorithm in gVisor TUN * chore(wireguard): resolve conflict * chore(wireguard): remove redurant code * chore(wireguard): remove redurant code * feat: rework server for gvisor tun * feat: keep user-space tun as an option * fix: exclude android from native tun build * feat: auto kernel tun * fix: build * fix: regulate function name & fix test
This commit is contained in:
parent
f1c81557dc
commit
0ac7da2fc8
4
go.mod
4
go.mod
|
@ -27,6 +27,7 @@ require (
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb
|
golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb
|
||||||
google.golang.org/grpc v1.59.0
|
google.golang.org/grpc v1.59.0
|
||||||
google.golang.org/protobuf v1.31.0
|
google.golang.org/protobuf v1.31.0
|
||||||
|
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b
|
||||||
h12.io/socks v1.0.3
|
h12.io/socks v1.0.3
|
||||||
lukechampine.com/blake3 v1.2.1
|
lukechampine.com/blake3 v1.2.1
|
||||||
)
|
)
|
||||||
|
@ -48,7 +49,7 @@ require (
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
|
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
|
||||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect
|
||||||
go.uber.org/mock v0.3.0 // indirect
|
go.uber.org/mock v0.3.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
|
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
|
||||||
golang.org/x/mod v0.14.0 // indirect
|
golang.org/x/mod v0.14.0 // indirect
|
||||||
|
@ -59,5 +60,4 @@ require (
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20231106174013-bbf56f31fb17 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
gvisor.dev/gvisor v0.0.0-20231104011432-48a6d7d5bd0b // indirect
|
|
||||||
)
|
)
|
||||||
|
|
3
go.sum
3
go.sum
|
@ -168,9 +168,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
|
||||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||||
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U=
|
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3 h1:tkMT5pTye+1NlKIXETU78NXw0fyjnaNHmJyyLyzw8+U=
|
||||||
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
|
github.com/vishvananda/netlink v1.2.1-beta.2.0.20230316163032-ced5aaba43e3/go.mod h1:cAAsePK2e15YDAMJNyOpGYEWNe4sIghTY7gpz4cX/Ik=
|
||||||
|
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=
|
||||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
|
||||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
|
||||||
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI=
|
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19 h1:capMfFYRgH9BCLd6A3Er/cH3A9Nz3CU2KwxwOQZIePI=
|
||||||
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
|
github.com/xtls/reality v0.0.0-20231112171332-de1173cf2b19/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
|
||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
|
|
|
@ -13,7 +13,7 @@ type WireGuardPeerConfig struct {
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
PreSharedKey string `json:"preSharedKey"`
|
PreSharedKey string `json:"preSharedKey"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
KeepAlive int `json:"keepAlive"`
|
KeepAlive uint32 `json:"keepAlive"`
|
||||||
AllowedIPs []string `json:"allowedIPs,omitempty"`
|
AllowedIPs []string `json:"allowedIPs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,23 +21,23 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
|
||||||
var err error
|
var err error
|
||||||
config := new(wireguard.PeerConfig)
|
config := new(wireguard.PeerConfig)
|
||||||
|
|
||||||
|
if c.PublicKey != "" {
|
||||||
config.PublicKey, err = parseWireGuardKey(c.PublicKey)
|
config.PublicKey, err = parseWireGuardKey(c.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.PreSharedKey != "" {
|
if c.PreSharedKey != "" {
|
||||||
config.PreSharedKey, err = parseWireGuardKey(c.PreSharedKey)
|
config.PreSharedKey, err = parseWireGuardKey(c.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
config.PreSharedKey = "0000000000000000000000000000000000000000000000000000000000000000"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Endpoint = c.Endpoint
|
config.Endpoint = c.Endpoint
|
||||||
// default 0
|
// default 0
|
||||||
config.KeepAlive = int32(c.KeepAlive)
|
config.KeepAlive = c.KeepAlive
|
||||||
if c.AllowedIPs == nil {
|
if c.AllowedIPs == nil {
|
||||||
config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
|
config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
|
||||||
} else {
|
} else {
|
||||||
|
@ -48,11 +48,14 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type WireGuardConfig struct {
|
type WireGuardConfig struct {
|
||||||
|
IsClient bool `json:""`
|
||||||
|
|
||||||
|
KernelMode *bool `json:"kernelMode"`
|
||||||
SecretKey string `json:"secretKey"`
|
SecretKey string `json:"secretKey"`
|
||||||
Address []string `json:"address"`
|
Address []string `json:"address"`
|
||||||
Peers []*WireGuardPeerConfig `json:"peers"`
|
Peers []*WireGuardPeerConfig `json:"peers"`
|
||||||
MTU int `json:"mtu"`
|
MTU int32 `json:"mtu"`
|
||||||
NumWorkers int `json:"workers"`
|
NumWorkers int32 `json:"workers"`
|
||||||
Reserved []byte `json:"reserved"`
|
Reserved []byte `json:"reserved"`
|
||||||
DomainStrategy string `json:"domainStrategy"`
|
DomainStrategy string `json:"domainStrategy"`
|
||||||
}
|
}
|
||||||
|
@ -87,11 +90,11 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
||||||
if c.MTU == 0 {
|
if c.MTU == 0 {
|
||||||
config.Mtu = 1420
|
config.Mtu = 1420
|
||||||
} else {
|
} else {
|
||||||
config.Mtu = int32(c.MTU)
|
config.Mtu = c.MTU
|
||||||
}
|
}
|
||||||
// these a fallback code exists in github.com/nanoda0523/wireguard-go code,
|
// these a fallback code exists in wireguard-go code,
|
||||||
// we don't need to process fallback manually
|
// we don't need to process fallback manually
|
||||||
config.NumWorkers = int32(c.NumWorkers)
|
config.NumWorkers = c.NumWorkers
|
||||||
|
|
||||||
if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
|
if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
|
||||||
return nil, newError(`"reserved" should be empty or 3 bytes`)
|
return nil, newError(`"reserved" should be empty or 3 bytes`)
|
||||||
|
@ -113,22 +116,42 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
|
||||||
return nil, newError("unsupported domain strategy: ", c.DomainStrategy)
|
return nil, newError("unsupported domain strategy: ", c.DomainStrategy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.IsClient = c.IsClient
|
||||||
|
if c.KernelMode != nil {
|
||||||
|
config.KernelMode = *c.KernelMode
|
||||||
|
if config.KernelMode && !wireguard.KernelTunSupported() {
|
||||||
|
newError("kernel mode is not supported on your OS or permission is insufficient").AtWarning().WriteToLog()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
config.KernelMode = wireguard.KernelTunSupported()
|
||||||
|
if config.KernelMode {
|
||||||
|
newError("kernel mode is enabled as it's supported and permission is sufficient").AtDebug().WriteToLog()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseWireGuardKey(str string) (string, error) {
|
func parseWireGuardKey(str string) (string, error) {
|
||||||
if len(str) != 64 {
|
var err error
|
||||||
// may in base64 form
|
|
||||||
dat, err := base64.StdEncoding.DecodeString(str)
|
if len(str)%2 == 0 {
|
||||||
if err != nil {
|
_, err = hex.DecodeString(str)
|
||||||
return "", err
|
if err == nil {
|
||||||
}
|
|
||||||
if len(dat) != 32 {
|
|
||||||
return "", newError("key should be 32 bytes: " + str)
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(dat), err
|
|
||||||
} else {
|
|
||||||
// already hex form
|
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dat []byte
|
||||||
|
str = strings.TrimSuffix(str, "=")
|
||||||
|
if strings.ContainsRune(str, '+') || strings.ContainsRune(str, '/') {
|
||||||
|
dat, err = base64.RawStdEncoding.DecodeString(str)
|
||||||
|
} else {
|
||||||
|
dat, err = base64.RawURLEncoding.DecodeString(str)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return hex.EncodeToString(dat), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", newError("failed to deserialize key").Base(err)
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"github.com/xtls/xray-core/proxy/wireguard"
|
"github.com/xtls/xray-core/proxy/wireguard"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWireGuardOutbound(t *testing.T) {
|
func TestWireGuardConfig(t *testing.T) {
|
||||||
creator := func() Buildable {
|
creator := func() Buildable {
|
||||||
return new(WireGuardConfig)
|
return new(WireGuardConfig)
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,8 @@ func TestWireGuardOutbound(t *testing.T) {
|
||||||
],
|
],
|
||||||
"mtu": 1300,
|
"mtu": 1300,
|
||||||
"workers": 2,
|
"workers": 2,
|
||||||
"domainStrategy": "ForceIPv6v4"
|
"domainStrategy": "ForceIPv6v4",
|
||||||
|
"kernelMode": false
|
||||||
}`,
|
}`,
|
||||||
Parser: loadJSON(creator),
|
Parser: loadJSON(creator),
|
||||||
Output: &wireguard.DeviceConfig{
|
Output: &wireguard.DeviceConfig{
|
||||||
|
@ -36,7 +37,6 @@ func TestWireGuardOutbound(t *testing.T) {
|
||||||
{
|
{
|
||||||
// also can read from hex form directly
|
// also can read from hex form directly
|
||||||
PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
|
PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
|
||||||
PreSharedKey: "0000000000000000000000000000000000000000000000000000000000000000",
|
|
||||||
Endpoint: "127.0.0.1:1234",
|
Endpoint: "127.0.0.1:1234",
|
||||||
KeepAlive: 0,
|
KeepAlive: 0,
|
||||||
AllowedIps: []string{"0.0.0.0/0", "::0/0"},
|
AllowedIps: []string{"0.0.0.0/0", "::0/0"},
|
||||||
|
@ -45,6 +45,7 @@ func TestWireGuardOutbound(t *testing.T) {
|
||||||
Mtu: 1300,
|
Mtu: 1300,
|
||||||
NumWorkers: 2,
|
NumWorkers: 2,
|
||||||
DomainStrategy: wireguard.DeviceConfig_FORCE_IP64,
|
DomainStrategy: wireguard.DeviceConfig_FORCE_IP64,
|
||||||
|
KernelMode: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
|
@ -24,6 +24,7 @@ var (
|
||||||
"vless": func() interface{} { return new(VLessInboundConfig) },
|
"vless": func() interface{} { return new(VLessInboundConfig) },
|
||||||
"vmess": func() interface{} { return new(VMessInboundConfig) },
|
"vmess": func() interface{} { return new(VMessInboundConfig) },
|
||||||
"trojan": func() interface{} { return new(TrojanServerConfig) },
|
"trojan": func() interface{} { return new(TrojanServerConfig) },
|
||||||
|
"wireguard": func() interface{} { return &WireGuardConfig{IsClient: false} },
|
||||||
}, "protocol", "settings")
|
}, "protocol", "settings")
|
||||||
|
|
||||||
outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
|
outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
|
||||||
|
@ -37,7 +38,7 @@ var (
|
||||||
"vmess": func() interface{} { return new(VMessOutboundConfig) },
|
"vmess": func() interface{} { return new(VMessOutboundConfig) },
|
||||||
"trojan": func() interface{} { return new(TrojanClientConfig) },
|
"trojan": func() interface{} { return new(TrojanClientConfig) },
|
||||||
"dns": func() interface{} { return new(DNSOutboundConfig) },
|
"dns": func() interface{} { return new(DNSOutboundConfig) },
|
||||||
"wireguard": func() interface{} { return new(WireGuardConfig) },
|
"wireguard": func() interface{} { return &WireGuardConfig{IsClient: true} },
|
||||||
}, "protocol", "settings")
|
}, "protocol", "settings")
|
||||||
|
|
||||||
ctllog = log.New(os.Stderr, "xctl> ", 0)
|
ctllog = log.New(os.Stderr, "xctl> ", 0)
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var cmdX25519 = &base.Command{
|
var cmdX25519 = &base.Command{
|
||||||
UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"]`,
|
UsageLine: `{{.Exec}} x25519 [-i "private key (base64.RawURLEncoding)"] [--std-encoding]`,
|
||||||
Short: `Generate key pair for x25519 key exchange`,
|
Short: `Generate key pair for x25519 key exchange`,
|
||||||
Long: `
|
Long: `
|
||||||
Generate key pair for x25519 key exchange.
|
Generate key pair for x25519 key exchange.
|
||||||
|
@ -18,6 +18,7 @@ Generate key pair for x25519 key exchange.
|
||||||
Random: {{.Exec}} x25519
|
Random: {{.Exec}} x25519
|
||||||
|
|
||||||
From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)"
|
From private key: {{.Exec}} x25519 -i "private key (base64.RawURLEncoding)"
|
||||||
|
For Std Encoding: {{.Exec}} x25519 --std-encoding
|
||||||
`,
|
`,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,12 +27,14 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
var input_base64 = cmdX25519.Flag.String("i", "", "")
|
var input_base64 = cmdX25519.Flag.String("i", "", "")
|
||||||
|
var input_stdEncoding = cmdX25519.Flag.Bool("std-encoding", false, "")
|
||||||
|
|
||||||
func executeX25519(cmd *base.Command, args []string) {
|
func executeX25519(cmd *base.Command, args []string) {
|
||||||
var output string
|
var output string
|
||||||
var err error
|
var err error
|
||||||
var privateKey []byte
|
var privateKey []byte
|
||||||
var publicKey []byte
|
var publicKey []byte
|
||||||
|
var encoding *base64.Encoding
|
||||||
if len(*input_base64) > 0 {
|
if len(*input_base64) > 0 {
|
||||||
privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64)
|
privateKey, err = base64.RawURLEncoding.DecodeString(*input_base64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -63,9 +66,15 @@ func executeX25519(cmd *base.Command, args []string) {
|
||||||
goto out
|
goto out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *input_stdEncoding {
|
||||||
|
encoding = base64.StdEncoding
|
||||||
|
} else {
|
||||||
|
encoding = base64.RawURLEncoding
|
||||||
|
}
|
||||||
|
|
||||||
output = fmt.Sprintf("Private key: %v\nPublic key: %v",
|
output = fmt.Sprintf("Private key: %v\nPublic key: %v",
|
||||||
base64.RawURLEncoding.EncodeToString(privateKey),
|
encoding.EncodeToString(privateKey),
|
||||||
base64.RawURLEncoding.EncodeToString(publicKey))
|
encoding.EncodeToString(publicKey))
|
||||||
out:
|
out:
|
||||||
fmt.Println(output)
|
fmt.Println(output)
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,48 +27,45 @@ type netReadInfo struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
type netBindClient struct {
|
// reduce duplicated code
|
||||||
workers int
|
type netBind struct {
|
||||||
dialer internet.Dialer
|
|
||||||
dns dns.Client
|
dns dns.Client
|
||||||
dnsOption dns.IPOption
|
dnsOption dns.IPOption
|
||||||
reserved []byte
|
|
||||||
|
|
||||||
|
workers int
|
||||||
readQueue chan *netReadInfo
|
readQueue chan *netReadInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
// SetMark implements conn.Bind
|
||||||
ipStr, port, _, err := splitAddrPort(s)
|
func (bind *netBind) SetMark(mark uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseEndpoint implements conn.Bind
|
||||||
|
func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
|
ipStr, port, err := net.SplitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
portNum, err := strconv.Atoi(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var addr net.IP
|
addr := xnet.ParseAddress(ipStr)
|
||||||
if IsDomainName(ipStr) {
|
if addr.Family() == xnet.AddressFamilyDomain {
|
||||||
ips, err := bind.dns.LookupIP(ipStr, bind.dnsOption)
|
ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if len(ips) == 0 {
|
} else if len(ips) == 0 {
|
||||||
return nil, dns.ErrEmptyResponse
|
return nil, dns.ErrEmptyResponse
|
||||||
}
|
}
|
||||||
addr = ips[0]
|
addr = xnet.IPAddress(ips[0])
|
||||||
} else {
|
|
||||||
addr = net.ParseIP(ipStr)
|
|
||||||
}
|
|
||||||
if addr == nil {
|
|
||||||
return nil, errors.New("failed to parse ip: " + ipStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ip xnet.Address
|
|
||||||
if p4 := addr.To4(); len(p4) == net.IPv4len {
|
|
||||||
ip = xnet.IPAddress(p4[:])
|
|
||||||
} else {
|
|
||||||
ip = xnet.IPAddress(addr[:])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dst := xnet.Destination{
|
dst := xnet.Destination{
|
||||||
Address: ip,
|
Address: addr,
|
||||||
Port: xnet.Port(port),
|
Port: xnet.Port(portNum),
|
||||||
Network: xnet.Network_UDP,
|
Network: xnet.Network_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,7 +74,13 @@ func (bind *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
// BatchSize implements conn.Bind
|
||||||
|
func (bind *netBind) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open implements conn.Bind
|
||||||
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||||
bind.readQueue = make(chan *netReadInfo)
|
bind.readQueue = make(chan *netReadInfo)
|
||||||
|
|
||||||
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||||
|
@ -109,13 +112,21 @@ func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error
|
||||||
return arr, uint16(uport), nil
|
return arr, uint16(uport), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *netBindClient) Close() error {
|
// Close implements conn.Bind
|
||||||
|
func (bind *netBind) Close() error {
|
||||||
if bind.readQueue != nil {
|
if bind.readQueue != nil {
|
||||||
close(bind.readQueue)
|
close(bind.readQueue)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type netBindClient struct {
|
||||||
|
netBind
|
||||||
|
|
||||||
|
dialer internet.Dialer
|
||||||
|
reserved []byte
|
||||||
|
}
|
||||||
|
|
||||||
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
||||||
c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
|
c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -177,12 +188,29 @@ func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *netBindClient) SetMark(mark uint32) error {
|
type netBindServer struct {
|
||||||
return nil
|
netBind
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *netBindClient) BatchSize() int {
|
func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
||||||
return 1
|
var err error
|
||||||
|
|
||||||
|
nend, ok := endpoint.(*netEndpoint)
|
||||||
|
if !ok {
|
||||||
|
return conn.ErrWrongEndpointType
|
||||||
|
}
|
||||||
|
|
||||||
|
if nend.conn == nil {
|
||||||
|
return newError("connection not open yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, buff := range buff {
|
||||||
|
if _, err = nend.conn.Write(buff); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type netEndpoint struct {
|
type netEndpoint struct {
|
||||||
|
@ -193,7 +221,7 @@ type netEndpoint struct {
|
||||||
func (netEndpoint) ClearSrc() {}
|
func (netEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
func (e netEndpoint) DstIP() netip.Addr {
|
func (e netEndpoint) DstIP() netip.Addr {
|
||||||
return toNetIpAddr(e.dst.Address)
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e netEndpoint) SrcIP() netip.Addr {
|
func (e netEndpoint) SrcIP() netip.Addr {
|
||||||
|
@ -232,83 +260,3 @@ func toNetIpAddr(addr xnet.Address) netip.Addr {
|
||||||
return netip.AddrFrom16(arr)
|
return netip.AddrFrom16(arr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func stringsLastIndexByte(s string, b byte) int {
|
|
||||||
for i := len(s) - 1; i >= 0; i-- {
|
|
||||||
if s[i] == b {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
|
|
||||||
i := stringsLastIndexByte(s, ':')
|
|
||||||
if i == -1 {
|
|
||||||
return "", 0, false, errors.New("not an ip:port")
|
|
||||||
}
|
|
||||||
|
|
||||||
ip = s[:i]
|
|
||||||
portStr := s[i+1:]
|
|
||||||
if len(ip) == 0 {
|
|
||||||
return "", 0, false, errors.New("no IP")
|
|
||||||
}
|
|
||||||
if len(portStr) == 0 {
|
|
||||||
return "", 0, false, errors.New("no port")
|
|
||||||
}
|
|
||||||
port64, err := strconv.ParseUint(portStr, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
|
|
||||||
}
|
|
||||||
port = uint16(port64)
|
|
||||||
if ip[0] == '[' {
|
|
||||||
if len(ip) < 2 || ip[len(ip)-1] != ']' {
|
|
||||||
return "", 0, false, errors.New("missing ]")
|
|
||||||
}
|
|
||||||
ip = ip[1 : len(ip)-1]
|
|
||||||
v6 = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return ip, port, v6, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsDomainName(s string) bool {
|
|
||||||
l := len(s)
|
|
||||||
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
last := byte('.')
|
|
||||||
nonNumeric := false
|
|
||||||
partlen := 0
|
|
||||||
for i := 0; i < len(s); i++ {
|
|
||||||
c := s[i]
|
|
||||||
switch {
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
|
|
||||||
nonNumeric = true
|
|
||||||
partlen++
|
|
||||||
case '0' <= c && c <= '9':
|
|
||||||
partlen++
|
|
||||||
case c == '-':
|
|
||||||
if last == '.' {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
partlen++
|
|
||||||
nonNumeric = true
|
|
||||||
case c == '.':
|
|
||||||
if last == '.' || last == '-' {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if partlen > 63 || partlen == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
partlen = 0
|
|
||||||
}
|
|
||||||
last = c
|
|
||||||
}
|
|
||||||
if last == '-' || partlen > 63 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return nonNumeric
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,255 @@
|
||||||
|
/*
|
||||||
|
|
||||||
|
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
|
||||||
|
|
||||||
|
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
|
||||||
|
|
||||||
|
Permission to use, copy, modify, and distribute this software for any
|
||||||
|
purpose with or without fee is hereby granted, provided that the above
|
||||||
|
copyright notice and this permission notice appear in all copies.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/xtls/xray-core/common"
|
||||||
|
"github.com/xtls/xray-core/common/buf"
|
||||||
|
"github.com/xtls/xray-core/common/dice"
|
||||||
|
"github.com/xtls/xray-core/common/log"
|
||||||
|
"github.com/xtls/xray-core/common/net"
|
||||||
|
"github.com/xtls/xray-core/common/protocol"
|
||||||
|
"github.com/xtls/xray-core/common/session"
|
||||||
|
"github.com/xtls/xray-core/common/signal"
|
||||||
|
"github.com/xtls/xray-core/common/task"
|
||||||
|
"github.com/xtls/xray-core/core"
|
||||||
|
"github.com/xtls/xray-core/features/dns"
|
||||||
|
"github.com/xtls/xray-core/features/policy"
|
||||||
|
"github.com/xtls/xray-core/transport"
|
||||||
|
"github.com/xtls/xray-core/transport/internet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler is an outbound connection that silently swallow the entire payload.
|
||||||
|
type Handler struct {
|
||||||
|
conf *DeviceConfig
|
||||||
|
net Tunnel
|
||||||
|
bind *netBindClient
|
||||||
|
policyManager policy.Manager
|
||||||
|
dns dns.Client
|
||||||
|
// cached configuration
|
||||||
|
ipc string
|
||||||
|
endpoints []netip.Addr
|
||||||
|
hasIPv4, hasIPv6 bool
|
||||||
|
wgLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new wireguard handler.
|
||||||
|
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
||||||
|
v := core.MustFromContext(ctx)
|
||||||
|
|
||||||
|
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
||||||
|
return &Handler{
|
||||||
|
conf: conf,
|
||||||
|
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||||
|
dns: d,
|
||||||
|
ipc: createIPCRequest(conf),
|
||||||
|
endpoints: endpoints,
|
||||||
|
hasIPv4: hasIPv4,
|
||||||
|
hasIPv6: hasIPv6,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
|
||||||
|
h.wgLock.Lock()
|
||||||
|
defer h.wgLock.Unlock()
|
||||||
|
|
||||||
|
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Record(&log.GeneralMessage{
|
||||||
|
Severity: log.Severity_Info,
|
||||||
|
Content: "switching dialer",
|
||||||
|
})
|
||||||
|
|
||||||
|
if h.net != nil {
|
||||||
|
_ = h.net.Close()
|
||||||
|
h.net = nil
|
||||||
|
}
|
||||||
|
if h.bind != nil {
|
||||||
|
_ = h.bind.Close()
|
||||||
|
h.bind = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||||
|
bind := &netBindClient{
|
||||||
|
netBind: netBind{
|
||||||
|
dns: h.dns,
|
||||||
|
dnsOption: dns.IPOption{
|
||||||
|
IPv4Enable: h.hasIPv4,
|
||||||
|
IPv6Enable: h.hasIPv6,
|
||||||
|
},
|
||||||
|
workers: int(h.conf.NumWorkers),
|
||||||
|
},
|
||||||
|
dialer: dialer,
|
||||||
|
reserved: h.conf.Reserved,
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = bind.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
h.net, err = h.makeVirtualTun(bind)
|
||||||
|
if err != nil {
|
||||||
|
return newError("failed to create virtual tun interface").Base(err)
|
||||||
|
}
|
||||||
|
h.bind = bind
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process implements OutboundHandler.Dispatch().
|
||||||
|
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||||
|
outbound := session.OutboundFromContext(ctx)
|
||||||
|
if outbound == nil || !outbound.Target.IsValid() {
|
||||||
|
return newError("target not specified")
|
||||||
|
}
|
||||||
|
outbound.Name = "wireguard"
|
||||||
|
inbound := session.InboundFromContext(ctx)
|
||||||
|
if inbound != nil {
|
||||||
|
inbound.SetCanSpliceCopy(3)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.processWireGuard(dialer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destination of the inner request.
|
||||||
|
destination := outbound.Target
|
||||||
|
command := protocol.RequestCommandTCP
|
||||||
|
if destination.Network == net.Network_UDP {
|
||||||
|
command = protocol.RequestCommandUDP
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolve dns
|
||||||
|
addr := destination.Address
|
||||||
|
if addr.Family().IsDomain() {
|
||||||
|
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||||
|
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
|
||||||
|
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
|
||||||
|
})
|
||||||
|
{ // Resolve fallback
|
||||||
|
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
|
||||||
|
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||||
|
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
|
||||||
|
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return newError("failed to lookup DNS").Base(err)
|
||||||
|
} else if len(ips) == 0 {
|
||||||
|
return dns.ErrEmptyResponse
|
||||||
|
}
|
||||||
|
addr = net.IPAddress(ips[dice.Roll(len(ips))])
|
||||||
|
}
|
||||||
|
|
||||||
|
var newCtx context.Context
|
||||||
|
var newCancel context.CancelFunc
|
||||||
|
if session.TimeoutOnlyFromContext(ctx) {
|
||||||
|
newCtx, newCancel = context.WithCancel(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
p := h.policyManager.ForLevel(0)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, func() {
|
||||||
|
cancel()
|
||||||
|
if newCancel != nil {
|
||||||
|
newCancel()
|
||||||
|
}
|
||||||
|
}, p.Timeouts.ConnectionIdle)
|
||||||
|
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
|
||||||
|
|
||||||
|
var requestFunc func() error
|
||||||
|
var responseFunc func() error
|
||||||
|
|
||||||
|
if command == protocol.RequestCommandTCP {
|
||||||
|
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
|
||||||
|
if err != nil {
|
||||||
|
return newError("failed to create TCP connection").Base(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
requestFunc = func() error {
|
||||||
|
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||||
|
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||||
|
}
|
||||||
|
responseFunc = func() error {
|
||||||
|
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||||
|
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||||
|
}
|
||||||
|
} else if command == protocol.RequestCommandUDP {
|
||||||
|
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
|
||||||
|
if err != nil {
|
||||||
|
return newError("failed to create UDP connection").Base(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
requestFunc = func() error {
|
||||||
|
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||||
|
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||||
|
}
|
||||||
|
responseFunc = func() error {
|
||||||
|
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||||
|
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if newCtx != nil {
|
||||||
|
ctx = newCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
|
||||||
|
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
|
||||||
|
common.Interrupt(link.Reader)
|
||||||
|
common.Interrupt(link.Writer)
|
||||||
|
return newError("connection ends").Base(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// creates a tun interface on netstack given a configuration
|
||||||
|
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
|
||||||
|
t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.dnsOption.IPv4Enable = h.hasIPv4
|
||||||
|
bind.dnsOption.IPv6Enable = h.hasIPv6
|
||||||
|
|
||||||
|
if err = t.BuildDevice(h.ipc, bind); err != nil {
|
||||||
|
_ = t.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
|
@ -23,3 +23,10 @@ func (c *DeviceConfig) fallbackIP4() bool {
|
||||||
func (c *DeviceConfig) fallbackIP6() bool {
|
func (c *DeviceConfig) fallbackIP6() bool {
|
||||||
return c.DomainStrategy == DeviceConfig_FORCE_IP46
|
return c.DomainStrategy == DeviceConfig_FORCE_IP46
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *DeviceConfig) createTun() tunCreator {
|
||||||
|
if c.KernelMode {
|
||||||
|
return createKernelTun
|
||||||
|
}
|
||||||
|
return createGVisorTun
|
||||||
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.31.0
|
// protoc-gen-go v1.28.1
|
||||||
// protoc v4.23.1
|
// protoc v4.25.0
|
||||||
// source: proxy/wireguard/config.proto
|
// source: proxy/wireguard/config.proto
|
||||||
|
|
||||||
package wireguard
|
package wireguard
|
||||||
|
@ -83,7 +83,7 @@ type PeerConfig struct {
|
||||||
PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
|
PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
|
||||||
PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"`
|
PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"`
|
||||||
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
|
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
|
||||||
KeepAlive int32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
|
KeepAlive uint32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"`
|
||||||
AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"`
|
AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ func (x *PeerConfig) GetEndpoint() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *PeerConfig) GetKeepAlive() int32 {
|
func (x *PeerConfig) GetKeepAlive() uint32 {
|
||||||
if x != nil {
|
if x != nil {
|
||||||
return x.KeepAlive
|
return x.KeepAlive
|
||||||
}
|
}
|
||||||
|
@ -166,6 +166,8 @@ type DeviceConfig struct {
|
||||||
NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"`
|
NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"`
|
||||||
Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
|
Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
|
||||||
DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
|
DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
|
||||||
|
IsClient bool `protobuf:"varint,8,opt,name=is_client,json=isClient,proto3" json:"is_client,omitempty"`
|
||||||
|
KernelMode bool `protobuf:"varint,9,opt,name=kernel_mode,json=kernelMode,proto3" json:"kernel_mode,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *DeviceConfig) Reset() {
|
func (x *DeviceConfig) Reset() {
|
||||||
|
@ -249,6 +251,20 @@ func (x *DeviceConfig) GetDomainStrategy() DeviceConfig_DomainStrategy {
|
||||||
return DeviceConfig_FORCE_IP
|
return DeviceConfig_FORCE_IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *DeviceConfig) GetIsClient() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.IsClient
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *DeviceConfig) GetKernelMode() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.KernelMode
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
|
var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
var file_proxy_wireguard_config_proto_rawDesc = []byte{
|
var file_proxy_wireguard_config_proto_rawDesc = []byte{
|
||||||
|
@ -263,10 +279,10 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
|
||||||
0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70,
|
0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70,
|
||||||
0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70,
|
0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x64, 0x70,
|
||||||
0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69,
|
0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69,
|
||||||
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
|
0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c,
|
||||||
0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69,
|
0x69, 0x76, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x5f, 0x69,
|
||||||
0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65,
|
0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65,
|
||||||
0x64, 0x49, 0x70, 0x73, 0x22, 0x8a, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
|
0x64, 0x49, 0x70, 0x73, 0x22, 0xc8, 0x03, 0x0a, 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43,
|
||||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f,
|
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f,
|
||||||
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65,
|
0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x63, 0x72, 0x65,
|
||||||
0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
|
0x74, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
|
||||||
|
@ -285,19 +301,23 @@ var file_proxy_wireguard_config_proto_rawDesc = []byte{
|
||||||
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f,
|
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x6f,
|
||||||
0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
0x6e, 0x66, 0x69, 0x67, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
||||||
0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
0x65, 0x67, 0x79, 0x52, 0x0e, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
||||||
0x65, 0x67, 0x79, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72,
|
0x65, 0x67, 0x79, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x73, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74,
|
||||||
0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49,
|
0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x69, 0x73, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74,
|
||||||
0x50, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34,
|
0x12, 0x1f, 0x0a, 0x0b, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18,
|
||||||
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10,
|
0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x4d, 0x6f, 0x64,
|
||||||
0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10,
|
0x65, 0x22, 0x5c, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74,
|
||||||
0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10,
|
0x65, 0x67, 0x79, 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x10,
|
||||||
0x04, 0x42, 0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72,
|
0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x01,
|
||||||
0x6f, 0x78, 0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a,
|
0x12, 0x0d, 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x02, 0x12,
|
||||||
0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73,
|
0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x03, 0x12,
|
||||||
0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79,
|
0x0e, 0x0a, 0x0a, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x04, 0x42,
|
||||||
0x2f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61,
|
0x5e, 0x0a, 0x18, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x78,
|
||||||
0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72,
|
0x79, 0x2e, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x01, 0x5a, 0x29, 0x67,
|
||||||
0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78,
|
||||||
|
0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x77,
|
||||||
|
0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0xaa, 0x02, 0x14, 0x58, 0x72, 0x61, 0x79, 0x2e,
|
||||||
|
0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x57, 0x69, 0x72, 0x65, 0x47, 0x75, 0x61, 0x72, 0x64, 0x62,
|
||||||
|
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -10,7 +10,7 @@ message PeerConfig {
|
||||||
string public_key = 1;
|
string public_key = 1;
|
||||||
string pre_shared_key = 2;
|
string pre_shared_key = 2;
|
||||||
string endpoint = 3;
|
string endpoint = 3;
|
||||||
int32 keep_alive = 4;
|
uint32 keep_alive = 4;
|
||||||
repeated string allowed_ips = 5;
|
repeated string allowed_ips = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,4 +29,6 @@ message DeviceConfig {
|
||||||
int32 num_workers = 5;
|
int32 num_workers = 5;
|
||||||
bytes reserved = 6;
|
bytes reserved = 6;
|
||||||
DomainStrategy domain_strategy = 7;
|
DomainStrategy domain_strategy = 7;
|
||||||
|
bool is_client = 8;
|
||||||
|
bool kernel_mode = 9;
|
||||||
}
|
}
|
|
@ -0,0 +1,230 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package gvisortun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type netTun struct {
|
||||||
|
ep *channel.Endpoint
|
||||||
|
stack *stack.Stack
|
||||||
|
events chan tun.Event
|
||||||
|
incomingPacket chan *buffer.View
|
||||||
|
mtu int
|
||||||
|
hasV4, hasV6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Net netTun
|
||||||
|
|
||||||
|
func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
|
||||||
|
opts := stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
|
||||||
|
HandleLocal: !promiscuousMode,
|
||||||
|
}
|
||||||
|
dev := &netTun{
|
||||||
|
ep: channel.New(1024, uint32(mtu), ""),
|
||||||
|
stack: stack.New(opts),
|
||||||
|
events: make(chan tun.Event, 1),
|
||||||
|
incomingPacket: make(chan *buffer.View),
|
||||||
|
mtu: mtu,
|
||||||
|
}
|
||||||
|
dev.ep.AddNotify(dev)
|
||||||
|
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
|
||||||
|
if tcpipErr != nil {
|
||||||
|
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||||
|
}
|
||||||
|
for _, ip := range localAddresses {
|
||||||
|
var protoNumber tcpip.NetworkProtocolNumber
|
||||||
|
if ip.Is4() {
|
||||||
|
protoNumber = ipv4.ProtocolNumber
|
||||||
|
} else if ip.Is6() {
|
||||||
|
protoNumber = ipv6.ProtocolNumber
|
||||||
|
}
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: protoNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
|
||||||
|
}
|
||||||
|
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||||
|
if tcpipErr != nil {
|
||||||
|
return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
||||||
|
}
|
||||||
|
if ip.Is4() {
|
||||||
|
dev.hasV4 = true
|
||||||
|
} else if ip.Is6() {
|
||||||
|
dev.hasV6 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dev.hasV4 {
|
||||||
|
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
|
||||||
|
}
|
||||||
|
if dev.hasV6 {
|
||||||
|
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
|
||||||
|
}
|
||||||
|
if promiscuousMode {
|
||||||
|
// enable promiscuous mode to handle all packets processed by netstack
|
||||||
|
dev.stack.SetPromiscuousMode(1, true)
|
||||||
|
dev.stack.SetSpoofing(1, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
opt := tcpip.CongestionControlOption("cubic")
|
||||||
|
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
|
||||||
|
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dev.events <- tun.EventUp
|
||||||
|
return dev, (*Net)(dev), dev.stack, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize implements tun.Device
|
||||||
|
func (tun *netTun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name implements tun.Device
|
||||||
|
func (tun *netTun) Name() (string, error) {
|
||||||
|
return "go", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// File implements tun.Device
|
||||||
|
func (tun *netTun) File() *os.File {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events implements tun.Device
|
||||||
|
func (tun *netTun) Events() <-chan tun.Event {
|
||||||
|
return tun.events
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements tun.Device
|
||||||
|
|
||||||
|
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
||||||
|
view, ok := <-tun.incomingPacket
|
||||||
|
if !ok {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := view.Read(buf[0][offset:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sizes[0] = n
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements tun.Device
|
||||||
|
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
||||||
|
for _, buf := range buf {
|
||||||
|
packet := buf[offset:]
|
||||||
|
if len(packet) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
||||||
|
switch packet[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||||
|
case 6:
|
||||||
|
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||||
|
default:
|
||||||
|
return 0, syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteNotify implements channel.Notification
|
||||||
|
func (tun *netTun) WriteNotify() {
|
||||||
|
pkt := tun.ep.Read()
|
||||||
|
if pkt.IsNil() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
view := pkt.ToView()
|
||||||
|
pkt.DecRef()
|
||||||
|
|
||||||
|
tun.incomingPacket <- view
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush implements tun.Device
|
||||||
|
func (tun *netTun) Flush() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements tun.Device
|
||||||
|
func (tun *netTun) Close() error {
|
||||||
|
tun.stack.RemoveNIC(1)
|
||||||
|
|
||||||
|
if tun.events != nil {
|
||||||
|
close(tun.events)
|
||||||
|
}
|
||||||
|
|
||||||
|
tun.ep.Close()
|
||||||
|
|
||||||
|
if tun.incomingPacket != nil {
|
||||||
|
close(tun.incomingPacket)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTU implements tun.Device
|
||||||
|
func (tun *netTun) MTU() (int, error) {
|
||||||
|
return tun.mtu, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||||
|
var protoNumber tcpip.NetworkProtocolNumber
|
||||||
|
if endpoint.Addr().Is4() {
|
||||||
|
protoNumber = ipv4.ProtocolNumber
|
||||||
|
} else {
|
||||||
|
protoNumber = ipv6.ProtocolNumber
|
||||||
|
}
|
||||||
|
return tcpip.FullAddress{
|
||||||
|
NIC: 1,
|
||||||
|
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
|
||||||
|
Port: endpoint.Port(),
|
||||||
|
}, protoNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||||
|
fa, pn := convertToFullAddr(addr)
|
||||||
|
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
|
||||||
|
var lfa, rfa *tcpip.FullAddress
|
||||||
|
var pn tcpip.NetworkProtocolNumber
|
||||||
|
if laddr.IsValid() || laddr.Port() > 0 {
|
||||||
|
var addr tcpip.FullAddress
|
||||||
|
addr, pn = convertToFullAddr(laddr)
|
||||||
|
lfa = &addr
|
||||||
|
}
|
||||||
|
if raddr.IsValid() || raddr.Port() > 0 {
|
||||||
|
var addr tcpip.FullAddress
|
||||||
|
addr, pn = convertToFullAddr(raddr)
|
||||||
|
rfa = &addr
|
||||||
|
}
|
||||||
|
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||||
|
}
|
|
@ -0,0 +1,181 @@
|
||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/xtls/xray-core/common"
|
||||||
|
"github.com/xtls/xray-core/common/buf"
|
||||||
|
"github.com/xtls/xray-core/common/log"
|
||||||
|
"github.com/xtls/xray-core/common/net"
|
||||||
|
"github.com/xtls/xray-core/common/session"
|
||||||
|
"github.com/xtls/xray-core/common/signal"
|
||||||
|
"github.com/xtls/xray-core/common/task"
|
||||||
|
"github.com/xtls/xray-core/core"
|
||||||
|
"github.com/xtls/xray-core/features/dns"
|
||||||
|
"github.com/xtls/xray-core/features/policy"
|
||||||
|
"github.com/xtls/xray-core/features/routing"
|
||||||
|
"github.com/xtls/xray-core/transport/internet/stat"
|
||||||
|
)
|
||||||
|
|
||||||
|
var nullDestination = net.TCPDestination(net.AnyIP, 0)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
bindServer *netBindServer
|
||||||
|
|
||||||
|
info routingInfo
|
||||||
|
policyManager policy.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
type routingInfo struct {
|
||||||
|
ctx context.Context
|
||||||
|
dispatcher routing.Dispatcher
|
||||||
|
inboundTag *session.Inbound
|
||||||
|
outboundTag *session.Outbound
|
||||||
|
contentTag *session.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||||
|
v := core.MustFromContext(ctx)
|
||||||
|
|
||||||
|
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &Server{
|
||||||
|
bindServer: &netBindServer{
|
||||||
|
netBind: netBind{
|
||||||
|
dns: v.GetFeature(dns.ClientType()).(dns.Client),
|
||||||
|
dnsOption: dns.IPOption{
|
||||||
|
IPv4Enable: hasIPv4,
|
||||||
|
IPv6Enable: hasIPv6,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||||
|
}
|
||||||
|
|
||||||
|
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
|
||||||
|
_ = tun.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Network implements proxy.Inbound.
|
||||||
|
func (*Server) Network() []net.Network {
|
||||||
|
return []net.Network{net.Network_UDP}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process implements proxy.Inbound.
|
||||||
|
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
|
||||||
|
s.info = routingInfo{
|
||||||
|
ctx: core.ToBackgroundDetachedContext(ctx),
|
||||||
|
dispatcher: dispatcher,
|
||||||
|
inboundTag: session.InboundFromContext(ctx),
|
||||||
|
outboundTag: session.OutboundFromContext(ctx),
|
||||||
|
contentTag: session.ContentFromContext(ctx),
|
||||||
|
}
|
||||||
|
|
||||||
|
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nep := ep.(*netEndpoint)
|
||||||
|
nep.conn = conn
|
||||||
|
|
||||||
|
reader := buf.NewPacketReader(conn)
|
||||||
|
for {
|
||||||
|
mpayload, err := reader.ReadMultiBuffer()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, payload := range mpayload {
|
||||||
|
v, ok := <-s.bindServer.readQueue
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
i, err := payload.Read(v.buff)
|
||||||
|
|
||||||
|
v.bytes = i
|
||||||
|
v.endpoint = nep
|
||||||
|
v.err = err
|
||||||
|
v.waiter.Done()
|
||||||
|
if err != nil && errors.Is(err, io.EOF) {
|
||||||
|
nep.conn = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
||||||
|
if s.info.dispatcher == nil {
|
||||||
|
newError("unexpected: dispatcher == nil").AtError().WriteToLog()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
|
||||||
|
plcy := s.policyManager.ForLevel(0)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
|
||||||
|
|
||||||
|
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||||
|
From: nullDestination,
|
||||||
|
To: dest,
|
||||||
|
Status: log.AccessAccepted,
|
||||||
|
Reason: "",
|
||||||
|
})
|
||||||
|
|
||||||
|
if s.info.inboundTag != nil {
|
||||||
|
ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
|
||||||
|
}
|
||||||
|
if s.info.outboundTag != nil {
|
||||||
|
ctx = session.ContextWithOutbound(ctx, s.info.outboundTag)
|
||||||
|
}
|
||||||
|
if s.info.contentTag != nil {
|
||||||
|
ctx = session.ContextWithContent(ctx, s.info.contentTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := s.info.dispatcher.Dispatch(ctx, dest)
|
||||||
|
if err != nil {
|
||||||
|
newError("dispatch connection").Base(err).AtError().WriteToLog()
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
requestDone := func() error {
|
||||||
|
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
||||||
|
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||||||
|
return newError("failed to transport all TCP request").Base(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
responseDone := func() error {
|
||||||
|
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
|
||||||
|
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
|
||||||
|
return newError("failed to transport all TCP response").Base(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||||
|
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||||
|
common.Interrupt(link.Reader)
|
||||||
|
common.Interrupt(link.Writer)
|
||||||
|
newError("connection ends").Base(err).AtDebug().WriteToLog()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,14 +10,26 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common/log"
|
"github.com/xtls/xray-core/common/log"
|
||||||
|
xnet "github.com/xtls/xray-core/common/net"
|
||||||
|
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
|
||||||
|
|
||||||
|
type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
|
||||||
|
|
||||||
type Tunnel interface {
|
type Tunnel interface {
|
||||||
BuildDevice(ipc string, bind conn.Bind) error
|
BuildDevice(ipc string, bind conn.Bind) error
|
||||||
DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
|
DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
|
||||||
|
@ -103,3 +115,91 @@ func CalculateInterfaceName(name string) (tunName string) {
|
||||||
tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
|
tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Tunnel = (*gvisorNet)(nil)
|
||||||
|
|
||||||
|
type gvisorNet struct {
|
||||||
|
tunnel
|
||||||
|
net *gvisortun.Net
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gvisorNet) Close() error {
|
||||||
|
return g.tunnel.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
|
||||||
|
net.Conn, error,
|
||||||
|
) {
|
||||||
|
return g.net.DialContextTCPAddrPort(ctx, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
||||||
|
return g.net.DialUDPAddrPort(laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
|
||||||
|
out := &gvisorNet{}
|
||||||
|
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler != nil {
|
||||||
|
// handler is only used for promiscuous mode
|
||||||
|
// capture all packets and send to handler
|
||||||
|
|
||||||
|
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
||||||
|
go func(r *tcp.ForwarderRequest) {
|
||||||
|
var (
|
||||||
|
wq waiter.Queue
|
||||||
|
id = r.ID()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Perform a TCP three-way handshake.
|
||||||
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
|
if err != nil {
|
||||||
|
newError(err.String()).AtError().WriteToLog()
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Complete(false)
|
||||||
|
defer ep.Close()
|
||||||
|
|
||||||
|
// enable tcp keep-alive to prevent hanging connections
|
||||||
|
ep.SocketOptions().SetKeepAlive(true)
|
||||||
|
|
||||||
|
// local address is actually destination
|
||||||
|
handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
|
||||||
|
}(r)
|
||||||
|
})
|
||||||
|
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
|
||||||
|
go func(r *udp.ForwarderRequest) {
|
||||||
|
var (
|
||||||
|
wq waiter.Queue
|
||||||
|
id = r.ID()
|
||||||
|
)
|
||||||
|
|
||||||
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
|
if err != nil {
|
||||||
|
newError(err.String()).AtError().WriteToLog()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer ep.Close()
|
||||||
|
|
||||||
|
// prevents hanging connections and ensure timely release
|
||||||
|
ep.SocketOptions().SetLinger(tcpip.LingerOption{
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 15 * time.Second,
|
||||||
|
})
|
||||||
|
|
||||||
|
handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
|
||||||
|
}(r)
|
||||||
|
})
|
||||||
|
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
}
|
||||||
|
|
||||||
|
out.tun, out.net = tun, n
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,42 +1,16 @@
|
||||||
//go:build !linux
|
//go:build !linux || android
|
||||||
|
|
||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"errors"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Tunnel = (*gvisorNet)(nil)
|
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
type gvisorNet struct {
|
|
||||||
tunnel
|
|
||||||
net *netstack.Net
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *gvisorNet) Close() error {
|
func KernelTunSupported() bool {
|
||||||
return g.tunnel.Close()
|
return false
|
||||||
}
|
|
||||||
|
|
||||||
func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
|
|
||||||
net.Conn, error,
|
|
||||||
) {
|
|
||||||
return g.net.DialContextTCPAddrPort(ctx, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
|
||||||
return g.net.DialUDPAddrPort(laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateTun(localAddresses []netip.Addr, mtu int) (Tunnel, error) {
|
|
||||||
out := &gvisorNet{}
|
|
||||||
tun, n, err := netstack.CreateNetTUN(localAddresses, nil, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out.tun, out.net = tun, n
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -69,7 +71,11 @@ func (d *deviceNet) Close() (err error) {
|
||||||
return errors.Join(errs...)
|
return errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
|
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
|
||||||
|
if handler != nil {
|
||||||
|
return nil, newError("TODO: support promiscuous mode")
|
||||||
|
}
|
||||||
|
|
||||||
var v4, v6 *netip.Addr
|
var v4, v6 *netip.Addr
|
||||||
for _, prefixes := range localAddresses {
|
for _, prefixes := range localAddresses {
|
||||||
if v4 == nil && prefixes.Is4() {
|
if v4 == nil && prefixes.Is4() {
|
||||||
|
@ -221,3 +227,11 @@ func CreateTun(localAddresses []netip.Addr, mtu int) (t Tunnel, err error) {
|
||||||
out.tun = wgt
|
out.tun = wgt
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func KernelTunSupported() bool {
|
||||||
|
// run a superuser permission check to check
|
||||||
|
// if the current user has the sufficient permission
|
||||||
|
// to create a tun device.
|
||||||
|
|
||||||
|
return unix.Geteuid() == 0 // 0 means root
|
||||||
|
}
|
||||||
|
|
|
@ -1,326 +1,111 @@
|
||||||
/*
|
|
||||||
|
|
||||||
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
|
|
||||||
|
|
||||||
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
|
|
||||||
|
|
||||||
Permission to use, copy, modify, and distribute this software for any
|
|
||||||
purpose with or without fee is hereby granted, provided that the above
|
|
||||||
copyright notice and this permission notice appear in all copies.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
|
||||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
|
||||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
|
||||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
||||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
|
||||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
|
||||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
stdnet "net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common"
|
"github.com/xtls/xray-core/common"
|
||||||
"github.com/xtls/xray-core/common/buf"
|
|
||||||
"github.com/xtls/xray-core/common/dice"
|
|
||||||
"github.com/xtls/xray-core/common/log"
|
"github.com/xtls/xray-core/common/log"
|
||||||
"github.com/xtls/xray-core/common/net"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"github.com/xtls/xray-core/common/protocol"
|
|
||||||
"github.com/xtls/xray-core/common/session"
|
|
||||||
"github.com/xtls/xray-core/common/signal"
|
|
||||||
"github.com/xtls/xray-core/common/task"
|
|
||||||
"github.com/xtls/xray-core/core"
|
|
||||||
"github.com/xtls/xray-core/features/dns"
|
|
||||||
"github.com/xtls/xray-core/features/policy"
|
|
||||||
"github.com/xtls/xray-core/transport"
|
|
||||||
"github.com/xtls/xray-core/transport/internet"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler is an outbound connection that silently swallow the entire payload.
|
//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
|
||||||
type Handler struct {
|
|
||||||
conf *DeviceConfig
|
|
||||||
net Tunnel
|
|
||||||
bind *netBindClient
|
|
||||||
policyManager policy.Manager
|
|
||||||
dns dns.Client
|
|
||||||
// cached configuration
|
|
||||||
ipc string
|
|
||||||
endpoints []netip.Addr
|
|
||||||
hasIPv4, hasIPv6 bool
|
|
||||||
wgLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new wireguard handler.
|
|
||||||
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
|
||||||
v := core.MustFromContext(ctx)
|
|
||||||
|
|
||||||
endpoints, err := parseEndpoints(conf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hasIPv4, hasIPv6 := false, false
|
|
||||||
for _, e := range endpoints {
|
|
||||||
if e.Is4() {
|
|
||||||
hasIPv4 = true
|
|
||||||
}
|
|
||||||
if e.Is6() {
|
|
||||||
hasIPv6 = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
|
||||||
return &Handler{
|
|
||||||
conf: conf,
|
|
||||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
|
||||||
dns: d,
|
|
||||||
ipc: createIPCRequest(conf, d, hasIPv6),
|
|
||||||
endpoints: endpoints,
|
|
||||||
hasIPv4: hasIPv4,
|
|
||||||
hasIPv6: hasIPv6,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
|
|
||||||
h.wgLock.Lock()
|
|
||||||
defer h.wgLock.Unlock()
|
|
||||||
|
|
||||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
|
var wgLogger = &device.Logger{
|
||||||
|
Verbosef: func(format string, args ...any) {
|
||||||
log.Record(&log.GeneralMessage{
|
log.Record(&log.GeneralMessage{
|
||||||
Severity: log.Severity_Info,
|
Severity: log.Severity_Debug,
|
||||||
Content: "switching dialer",
|
Content: fmt.Sprintf(format, args...),
|
||||||
})
|
})
|
||||||
|
},
|
||||||
if h.net != nil {
|
Errorf: func(format string, args ...any) {
|
||||||
_ = h.net.Close()
|
log.Record(&log.GeneralMessage{
|
||||||
h.net = nil
|
Severity: log.Severity_Error,
|
||||||
}
|
Content: fmt.Sprintf(format, args...),
|
||||||
if h.bind != nil {
|
|
||||||
_ = h.bind.Close()
|
|
||||||
h.bind = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
|
||||||
bind := &netBindClient{
|
|
||||||
dialer: dialer,
|
|
||||||
workers: int(h.conf.NumWorkers),
|
|
||||||
dns: h.dns,
|
|
||||||
reserved: h.conf.Reserved,
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
_ = bind.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
h.net, err = h.makeVirtualTun(bind)
|
|
||||||
if err != nil {
|
|
||||||
return newError("failed to create virtual tun interface").Base(err)
|
|
||||||
}
|
|
||||||
h.bind = bind
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process implements OutboundHandler.Dispatch().
|
|
||||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
|
||||||
outbound := session.OutboundFromContext(ctx)
|
|
||||||
if outbound == nil || !outbound.Target.IsValid() {
|
|
||||||
return newError("target not specified")
|
|
||||||
}
|
|
||||||
outbound.Name = "wireguard"
|
|
||||||
inbound := session.InboundFromContext(ctx)
|
|
||||||
if inbound != nil {
|
|
||||||
inbound.SetCanSpliceCopy(3)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.processWireGuard(dialer); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destination of the inner request.
|
|
||||||
destination := outbound.Target
|
|
||||||
command := protocol.RequestCommandTCP
|
|
||||||
if destination.Network == net.Network_UDP {
|
|
||||||
command = protocol.RequestCommandUDP
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolve dns
|
|
||||||
addr := destination.Address
|
|
||||||
if addr.Family().IsDomain() {
|
|
||||||
ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
|
||||||
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
|
|
||||||
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
|
|
||||||
})
|
})
|
||||||
{ // Resolve fallback
|
},
|
||||||
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
|
|
||||||
ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
|
||||||
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
|
|
||||||
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return newError("failed to lookup DNS").Base(err)
|
|
||||||
} else if len(ips) == 0 {
|
|
||||||
return dns.ErrEmptyResponse
|
|
||||||
}
|
|
||||||
addr = net.IPAddress(ips[dice.Roll(len(ips))])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var newCtx context.Context
|
func init() {
|
||||||
var newCancel context.CancelFunc
|
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||||
if session.TimeoutOnlyFromContext(ctx) {
|
deviceConfig := config.(*DeviceConfig)
|
||||||
newCtx, newCancel = context.WithCancel(context.Background())
|
if deviceConfig.IsClient {
|
||||||
|
return New(ctx, deviceConfig)
|
||||||
|
} else {
|
||||||
|
return NewServer(ctx, deviceConfig)
|
||||||
}
|
}
|
||||||
|
}))
|
||||||
p := h.policyManager.ForLevel(0)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
timer := signal.CancelAfterInactivity(ctx, func() {
|
|
||||||
cancel()
|
|
||||||
if newCancel != nil {
|
|
||||||
newCancel()
|
|
||||||
}
|
|
||||||
}, p.Timeouts.ConnectionIdle)
|
|
||||||
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
|
|
||||||
|
|
||||||
var requestFunc func() error
|
|
||||||
var responseFunc func() error
|
|
||||||
|
|
||||||
if command == protocol.RequestCommandTCP {
|
|
||||||
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
|
|
||||||
if err != nil {
|
|
||||||
return newError("failed to create TCP connection").Base(err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
requestFunc = func() error {
|
|
||||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
|
||||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
|
||||||
}
|
|
||||||
responseFunc = func() error {
|
|
||||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
|
||||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
|
||||||
}
|
|
||||||
} else if command == protocol.RequestCommandUDP {
|
|
||||||
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
|
|
||||||
if err != nil {
|
|
||||||
return newError("failed to create UDP connection").Base(err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
requestFunc = func() error {
|
|
||||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
|
||||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
|
||||||
}
|
|
||||||
responseFunc = func() error {
|
|
||||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
|
||||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if newCtx != nil {
|
|
||||||
ctx = newCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
|
|
||||||
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
|
|
||||||
common.Interrupt(link.Reader)
|
|
||||||
common.Interrupt(link.Writer)
|
|
||||||
return newError("connection ends").Base(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// serialize the config into an IPC request
|
|
||||||
func createIPCRequest(conf *DeviceConfig, d dns.Client, resolveEndPointToV4 bool) string {
|
|
||||||
var request bytes.Buffer
|
|
||||||
|
|
||||||
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
|
|
||||||
|
|
||||||
for _, peer := range conf.Peers {
|
|
||||||
endpoint := peer.Endpoint
|
|
||||||
host, port, err := net.SplitHostPort(endpoint)
|
|
||||||
if resolveEndPointToV4 && err == nil {
|
|
||||||
_, err = netip.ParseAddr(host)
|
|
||||||
if err != nil {
|
|
||||||
ipList, err := d.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: false})
|
|
||||||
if err == nil && len(ipList) > 0 {
|
|
||||||
endpoint = stdnet.JoinHostPort(ipList[0].String(), port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
|
|
||||||
peer.PublicKey, endpoint, peer.KeepAlive, peer.PreSharedKey))
|
|
||||||
|
|
||||||
for _, ip := range peer.AllowedIps {
|
|
||||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return request.String()[:request.Len()]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert endpoint string to netip.Addr
|
// convert endpoint string to netip.Addr
|
||||||
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
|
func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) {
|
||||||
|
var hasIPv4, hasIPv6 bool
|
||||||
|
|
||||||
endpoints := make([]netip.Addr, len(conf.Endpoint))
|
endpoints := make([]netip.Addr, len(conf.Endpoint))
|
||||||
for i, str := range conf.Endpoint {
|
for i, str := range conf.Endpoint {
|
||||||
var addr netip.Addr
|
var addr netip.Addr
|
||||||
if strings.Contains(str, "/") {
|
if strings.Contains(str, "/") {
|
||||||
prefix, err := netip.ParsePrefix(str)
|
prefix, err := netip.ParsePrefix(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, false, err
|
||||||
}
|
}
|
||||||
addr = prefix.Addr()
|
addr = prefix.Addr()
|
||||||
if prefix.Bits() != addr.BitLen() {
|
if prefix.Bits() != addr.BitLen() {
|
||||||
return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
|
return nil, false, false, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
addr, err = netip.ParseAddr(str)
|
addr, err = netip.ParseAddr(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
endpoints[i] = addr
|
endpoints[i] = addr
|
||||||
|
|
||||||
|
if addr.Is4() {
|
||||||
|
hasIPv4 = true
|
||||||
|
} else if addr.Is6() {
|
||||||
|
hasIPv6 = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return endpoints, nil
|
return endpoints, hasIPv4, hasIPv6, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// creates a tun interface on netstack given a configuration
|
// serialize the config into an IPC request
|
||||||
func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
|
func createIPCRequest(conf *DeviceConfig) string {
|
||||||
t, err := CreateTun(h.endpoints, int(h.conf.Mtu))
|
var request strings.Builder
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
|
||||||
|
|
||||||
|
if !conf.IsClient {
|
||||||
|
// placeholder, we'll handle actual port listening on Xray
|
||||||
|
request.WriteString("listen_port=1337\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
bind.dnsOption.IPv4Enable = h.hasIPv4
|
for _, peer := range conf.Peers {
|
||||||
bind.dnsOption.IPv6Enable = h.hasIPv6
|
if peer.PublicKey != "" {
|
||||||
|
request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
|
||||||
if err = t.BuildDevice(h.ipc, bind); err != nil {
|
|
||||||
_ = t.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
if peer.PreSharedKey != "" {
|
||||||
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
|
||||||
return New(ctx, config.(*DeviceConfig))
|
}
|
||||||
}))
|
|
||||||
|
if peer.Endpoint != "" {
|
||||||
|
request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range peer.AllowedIps {
|
||||||
|
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.KeepAlive != 0 {
|
||||||
|
request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return request.String()[:request.Len()]
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue