diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index e1e21b07..4d69816b 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -39,6 +39,8 @@ type dialerConf struct { *internet.MemoryStreamConfig } +type dialerCanceller func() + var ( globalDialerMap map[dialerConf]*grpc.ClientConn globalDialerAccess sync.Mutex @@ -47,8 +49,7 @@ var ( func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) { grpcSettings := streamSettings.ProtocolSettings.(*Config) - conn, err := getGrpcClient(ctx, dest, streamSettings) - + conn, canceller, err := getGrpcClient(ctx, dest, streamSettings) if err != nil { return nil, newError("Cannot dial gRPC").Base(err) } @@ -57,6 +58,7 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne newError("using gRPC multi mode").AtDebug().WriteToLog() grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getNormalizedName()) if err != nil { + canceller() return nil, newError("Cannot dial gRPC").Base(err) } return encoding.NewMultiHunkConn(grpcService, nil), nil @@ -64,13 +66,14 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getNormalizedName()) if err != nil { + canceller() return nil, newError("Cannot dial gRPC").Base(err) } return encoding.NewHunkConn(grpcService, nil), nil } -func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, error) { +func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, dialerCanceller, error) { globalDialerAccess.Lock() defer globalDialerAccess.Unlock() @@ -81,8 +84,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in sockopt := streamSettings.SocketSettings grpcSettings := streamSettings.ProtocolSettings.(*Config) + canceller := func() { + globalDialerAccess.Lock() + defer globalDialerAccess.Unlock() + delete(globalDialerMap, dialerConf{dest, streamSettings}) + } + if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found && client.GetState() != connectivity.Shutdown { - return client, nil + return client, canceller, nil } var dialOptions = []grpc.DialOption{ @@ -147,5 +156,5 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in dialOptions..., ) globalDialerMap[dialerConf{dest, streamSettings}] = conn - return conn, err + return conn, canceller, err }