From 3ed14c2fcd0dd52d1e67953d5eecb7c2e45c07f9 Mon Sep 17 00:00:00 2001 From: Jim Han <50871214+JimhHan@users.noreply.github.com> Date: Wed, 31 Mar 2021 00:43:31 +0800 Subject: [PATCH] Fix: gRPC & HTTP/2 dialer (#445) --- transport/internet/grpc/dial.go | 22 ++++++++++++---------- transport/internet/http/dialer.go | 5 +++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 457f67df..ed7ad3b5 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -36,6 +36,7 @@ func init() { type dialerConf struct { net.Destination *internet.SocketConfig + *tls.Config } var ( @@ -46,14 +47,9 @@ var ( func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) { grpcSettings := streamSettings.ProtocolSettings.(*Config) - config := tls.ConfigFromStreamSettings(streamSettings) - var dialOption = grpc.WithInsecure() + tlsConfig := tls.ConfigFromStreamSettings(streamSettings) - if config != nil { - dialOption = grpc.WithTransportCredentials(credentials.NewTLS(config.GetTLSConfig())) - } - - conn, err := getGrpcClient(ctx, dest, dialOption, streamSettings.SocketSettings) + conn, err := getGrpcClient(ctx, dest, tlsConfig, streamSettings.SocketSettings) if err != nil { return nil, newError("Cannot dial gRPC").Base(err) @@ -76,7 +72,7 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne return encoding.NewHunkConn(grpcService, nil), nil } -func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) { +func getGrpcClient(ctx context.Context, dest net.Destination, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) { globalDialerAccess.Lock() defer globalDialerAccess.Unlock() @@ -84,10 +80,16 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di globalDialerMap = make(map[dialerConf]*grpc.ClientConn) } - if client, found := globalDialerMap[dialerConf{dest, sockopt}]; found && client.GetState() != connectivity.Shutdown { + if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsConfig}]; found && client.GetState() != connectivity.Shutdown { return client, nil } + dialOption := grpc.WithInsecure() + + if tlsConfig != nil { + dialOption = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig.GetTLSConfig())) + } + conn, err := grpc.Dial( gonet.JoinHostPort(dest.Address.String(), dest.Port.String()), dialOption, @@ -125,6 +127,6 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di return internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt) }), ) - globalDialerMap[dialerConf{dest, sockopt}] = conn + globalDialerMap[dialerConf{dest, sockopt, tlsConfig}] = conn return conn, err } diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index dc2cd8ab..ae3ba6d2 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -21,6 +21,7 @@ import ( type dialerConf struct { net.Destination *internet.SocketConfig + *tls.Config } var ( @@ -36,7 +37,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C globalDialerMap = make(map[dialerConf]*http.Client) } - if client, found := globalDialerMap[dialerConf{dest, sockopt}]; found { + if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsSettings}]; found { return client, nil } @@ -92,7 +93,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.C Transport: transport, } - globalDialerMap[dialerConf{dest, sockopt}] = client + globalDialerMap[dialerConf{dest, sockopt, tlsSettings}] = client return client, nil }