From 526c6789ed5f56aaf8bad665f12692b8aae1565a Mon Sep 17 00:00:00 2001 From: Hirbod Behnam Date: Sun, 26 Mar 2023 09:28:19 +0330 Subject: [PATCH] Add custom path to gRPC (#1815) --- transport/internet/grpc/config.go | 40 ++++++- transport/internet/grpc/config_test.go | 111 ++++++++++++++++++ transport/internet/grpc/dial.go | 7 +- .../grpc/encoding/customSeviceName.go | 22 ++-- transport/internet/grpc/hub.go | 3 +- 5 files changed, 166 insertions(+), 17 deletions(-) create mode 100644 transport/internet/grpc/config_test.go diff --git a/transport/internet/grpc/config.go b/transport/internet/grpc/config.go index d87722a4..39eadf31 100644 --- a/transport/internet/grpc/config.go +++ b/transport/internet/grpc/config.go @@ -2,6 +2,7 @@ package grpc import ( "net/url" + "strings" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/transport/internet" @@ -15,6 +16,41 @@ func init() { })) } -func (c *Config) getNormalizedName() string { - return url.PathEscape(c.ServiceName) +func (c *Config) getServiceName() string { + // Normal old school config + if !strings.HasPrefix(c.ServiceName, "/") { + return url.PathEscape(c.ServiceName) + } + // Otherwise new custom paths + rawServiceName := c.ServiceName[1:strings.LastIndex(c.ServiceName, "/")] // trim from first to last '/' + serviceNameParts := strings.Split(rawServiceName, "/") + for i := range serviceNameParts { + serviceNameParts[i] = url.PathEscape(serviceNameParts[i]) + } + return strings.Join(serviceNameParts, "/") +} + +func (c *Config) getTunStreamName() string { + // Normal old school config + if !strings.HasPrefix(c.ServiceName, "/") { + return "Tun" + } + // Otherwise new custom paths + endingPath := c.ServiceName[strings.LastIndex(c.ServiceName, "/")+1:] // from the last '/' to end of string + return url.PathEscape(strings.Split(endingPath, "|")[0]) +} + +func (c *Config) getTunMultiStreamName() string { + // Normal old school config + if !strings.HasPrefix(c.ServiceName, "/") { + return "TunMulti" + } + // Otherwise new custom paths + endingPath := c.ServiceName[strings.LastIndex(c.ServiceName, "/")+1:] // from the last '/' to end of string + streamNames := strings.Split(endingPath, "|") + if len(streamNames) == 1 { // client side. Service name is the full path to multi tun + return url.PathEscape(streamNames[0]) + } else { // server side. The second part is the path to multi tun + return url.PathEscape(streamNames[1]) + } } diff --git a/transport/internet/grpc/config_test.go b/transport/internet/grpc/config_test.go new file mode 100644 index 00000000..fbc549b4 --- /dev/null +++ b/transport/internet/grpc/config_test.go @@ -0,0 +1,111 @@ +package grpc + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestConfig_GetServiceName(t *testing.T) { + tests := []struct { + TestName string + ServiceName string + Expected string + }{ + { + TestName: "simple no absolute path", + ServiceName: "hello", + Expected: "hello", + }, + { + TestName: "escape no absolute path", + ServiceName: "hello/world!", + Expected: "hello%2Fworld%21", + }, + { + TestName: "absolute path", + ServiceName: "/my/sample/path/a|b", + Expected: "my/sample/path", + }, + { + TestName: "escape absolute path", + ServiceName: "/hello /world!/a|b", + Expected: "hello%20/world%21", + }, + } + for _, test := range tests { + t.Run(test.TestName, func(t *testing.T) { + config := Config{ServiceName: test.ServiceName} + assert.Equal(t, test.Expected, config.getServiceName()) + }) + } +} + +func TestConfig_GetTunStreamName(t *testing.T) { + tests := []struct { + TestName string + ServiceName string + Expected string + }{ + { + TestName: "no absolute path", + ServiceName: "hello", + Expected: "Tun", + }, + { + TestName: "absolute path server", + ServiceName: "/my/sample/path/tun_service|multi_service", + Expected: "tun_service", + }, + { + TestName: "absolute path client", + ServiceName: "/my/sample/path/tun_service", + Expected: "tun_service", + }, + { + TestName: "escape absolute path client", + ServiceName: "/m y/sa !mple/pa\\th/tun\\_serv!ice", + Expected: "tun%5C_serv%21ice", + }, + } + for _, test := range tests { + t.Run(test.TestName, func(t *testing.T) { + config := Config{ServiceName: test.ServiceName} + assert.Equal(t, test.Expected, config.getTunStreamName()) + }) + } +} + +func TestConfig_GetTunMultiStreamName(t *testing.T) { + tests := []struct { + TestName string + ServiceName string + Expected string + }{ + { + TestName: "no absolute path", + ServiceName: "hello", + Expected: "TunMulti", + }, + { + TestName: "absolute path server", + ServiceName: "/my/sample/path/tun_service|multi_service", + Expected: "multi_service", + }, + { + TestName: "absolute path client", + ServiceName: "/my/sample/path/multi_service", + Expected: "multi_service", + }, + { + TestName: "escape absolute path client", + ServiceName: "/m y/sa !mple/pa\\th/mu%lti\\_serv!ice", + Expected: "mu%25lti%5C_serv%21ice", + }, + } + for _, test := range tests { + t.Run(test.TestName, func(t *testing.T) { + config := Config{ServiceName: test.ServiceName} + assert.Equal(t, test.Expected, config.getTunMultiStreamName()) + }) + } +} diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 8c58ff14..8fd544b5 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -54,15 +54,16 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne } client := encoding.NewGRPCServiceClient(conn) if grpcSettings.MultiMode { - newError("using gRPC multi mode").AtDebug().WriteToLog() - grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getNormalizedName()) + newError("using gRPC multi mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunMultiStreamName() + "`").AtDebug().WriteToLog() + grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunMultiStreamName()) if err != nil { return nil, newError("Cannot dial gRPC").Base(err) } return encoding.NewMultiHunkConn(grpcService, nil), nil } - grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getNormalizedName()) + newError("using gRPC tun mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunStreamName() + "`").AtDebug().WriteToLog() + grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunStreamName()) if err != nil { return nil, newError("Cannot dial gRPC").Base(err) } diff --git a/transport/internet/grpc/encoding/customSeviceName.go b/transport/internet/grpc/encoding/customSeviceName.go index aa098835..dd990755 100644 --- a/transport/internet/grpc/encoding/customSeviceName.go +++ b/transport/internet/grpc/encoding/customSeviceName.go @@ -6,20 +6,20 @@ import ( "google.golang.org/grpc" ) -func ServerDesc(name string) grpc.ServiceDesc { +func ServerDesc(name, tun, tunMulti string) grpc.ServiceDesc { return grpc.ServiceDesc{ ServiceName: name, HandlerType: (*GRPCServiceServer)(nil), Methods: []grpc.MethodDesc{}, Streams: []grpc.StreamDesc{ { - StreamName: "Tun", + StreamName: tun, Handler: _GRPCService_Tun_Handler, ServerStreams: true, ClientStreams: true, }, { - StreamName: "TunMulti", + StreamName: tunMulti, Handler: _GRPCService_TunMulti_Handler, ServerStreams: true, ClientStreams: true, @@ -29,8 +29,8 @@ func ServerDesc(name string) grpc.ServiceDesc { } } -func (c *gRPCServiceClient) TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunClient, error) { - stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[0], "/"+name+"/Tun", opts...) +func (c *gRPCServiceClient) TunCustomName(ctx context.Context, name, tun string, opts ...grpc.CallOption) (GRPCService_TunClient, error) { + stream, err := c.cc.NewStream(ctx, &ServerDesc(name, tun, "").Streams[0], "/"+name+"/"+tun, opts...) if err != nil { return nil, err } @@ -38,8 +38,8 @@ func (c *gRPCServiceClient) TunCustomName(ctx context.Context, name string, opts return x, nil } -func (c *gRPCServiceClient) TunMultiCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error) { - stream, err := c.cc.NewStream(ctx, &ServerDesc(name).Streams[1], "/"+name+"/TunMulti", opts...) +func (c *gRPCServiceClient) TunMultiCustomName(ctx context.Context, name, tunMulti string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error) { + stream, err := c.cc.NewStream(ctx, &ServerDesc(name, "", tunMulti).Streams[1], "/"+name+"/"+tunMulti, opts...) if err != nil { return nil, err } @@ -48,13 +48,13 @@ func (c *gRPCServiceClient) TunMultiCustomName(ctx context.Context, name string, } type GRPCServiceClientX interface { - TunCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunClient, error) - TunMultiCustomName(ctx context.Context, name string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error) + TunCustomName(ctx context.Context, name, tun string, opts ...grpc.CallOption) (GRPCService_TunClient, error) + TunMultiCustomName(ctx context.Context, name, tunMulti string, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error) Tun(ctx context.Context, opts ...grpc.CallOption) (GRPCService_TunClient, error) TunMulti(ctx context.Context, opts ...grpc.CallOption) (GRPCService_TunMultiClient, error) } -func RegisterGRPCServiceServerX(s *grpc.Server, srv GRPCServiceServer, name string) { - desc := ServerDesc(name) +func RegisterGRPCServiceServerX(s *grpc.Server, srv GRPCServiceServer, name, tun, tunMulti string) { + desc := ServerDesc(name, tun, tunMulti) s.RegisterService(&desc, srv) } diff --git a/transport/internet/grpc/hub.go b/transport/internet/grpc/hub.go index 9bce2274..d3dd6da5 100644 --- a/transport/internet/grpc/hub.go +++ b/transport/internet/grpc/hub.go @@ -125,7 +125,8 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i } } - encoding.RegisterGRPCServiceServerX(s, listener, grpcSettings.getNormalizedName()) + newError("gRPC listen for service name `" + grpcSettings.getServiceName() + "` tun `" + grpcSettings.getTunStreamName() + "` multi tun `" + grpcSettings.getTunMultiStreamName() + "`").AtDebug().WriteToLog() + encoding.RegisterGRPCServiceServerX(s, listener, grpcSettings.getServiceName(), grpcSettings.getTunStreamName(), grpcSettings.getTunMultiStreamName()) if config := reality.ConfigFromStreamSettings(settings); config != nil { streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig())