cancel failed grpc connection (#707)

Co-authored-by: Shelikhoo <xiaokangwang@outlook.com>
This commit is contained in:
yuhan6665 2021-09-14 13:40:38 -04:00 committed by GitHub
parent 7246001029
commit 0f79126379
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -39,6 +39,8 @@ type dialerConf struct {
*internet.MemoryStreamConfig *internet.MemoryStreamConfig
} }
type dialerCanceller func()
var ( var (
globalDialerMap map[dialerConf]*grpc.ClientConn globalDialerMap map[dialerConf]*grpc.ClientConn
globalDialerAccess sync.Mutex globalDialerAccess sync.Mutex
@ -47,8 +49,7 @@ var (
func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) { func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
grpcSettings := streamSettings.ProtocolSettings.(*Config) grpcSettings := streamSettings.ProtocolSettings.(*Config)
conn, err := getGrpcClient(ctx, dest, streamSettings) conn, canceller, err := getGrpcClient(ctx, dest, streamSettings)
if err != nil { if err != nil {
return nil, newError("Cannot dial gRPC").Base(err) return nil, newError("Cannot dial gRPC").Base(err)
} }
@ -57,6 +58,7 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
newError("using gRPC multi mode").AtDebug().WriteToLog() newError("using gRPC multi mode").AtDebug().WriteToLog()
grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getNormalizedName()) grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getNormalizedName())
if err != nil { if err != nil {
canceller()
return nil, newError("Cannot dial gRPC").Base(err) return nil, newError("Cannot dial gRPC").Base(err)
} }
return encoding.NewMultiHunkConn(grpcService, nil), nil return encoding.NewMultiHunkConn(grpcService, nil), nil
@ -64,13 +66,14 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getNormalizedName()) grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getNormalizedName())
if err != nil { if err != nil {
canceller()
return nil, newError("Cannot dial gRPC").Base(err) return nil, newError("Cannot dial gRPC").Base(err)
} }
return encoding.NewHunkConn(grpcService, nil), nil return encoding.NewHunkConn(grpcService, nil), nil
} }
func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, error) { func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, dialerCanceller, error) {
globalDialerAccess.Lock() globalDialerAccess.Lock()
defer globalDialerAccess.Unlock() defer globalDialerAccess.Unlock()
@ -81,8 +84,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
sockopt := streamSettings.SocketSettings sockopt := streamSettings.SocketSettings
grpcSettings := streamSettings.ProtocolSettings.(*Config) grpcSettings := streamSettings.ProtocolSettings.(*Config)
canceller := func() {
globalDialerAccess.Lock()
defer globalDialerAccess.Unlock()
delete(globalDialerMap, dialerConf{dest, streamSettings})
}
if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found && client.GetState() != connectivity.Shutdown { if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found && client.GetState() != connectivity.Shutdown {
return client, nil return client, canceller, nil
} }
var dialOptions = []grpc.DialOption{ var dialOptions = []grpc.DialOption{
@ -147,5 +156,5 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in
dialOptions..., dialOptions...,
) )
globalDialerMap[dialerConf{dest, streamSettings}] = conn globalDialerMap[dialerConf{dest, streamSettings}] = conn
return conn, err return conn, canceller, err
} }