From bb63429079e8cda5654b25f04a2ba5ee173955e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 10 Apr 2023 13:00:57 +0800 Subject: [PATCH] Update cancel context usage --- ntp/service.go | 8 +++++--- outbound/dns.go | 10 ++++++++-- transport/v2raygrpc/client.go | 4 ++-- transport/v2raygrpc/conn.go | 10 ++++++---- transport/v2raygrpc/server.go | 3 ++- 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ntp/service.go b/ntp/service.go index b7ef999b..a8dc6ea2 100644 --- a/ntp/service.go +++ b/ntp/service.go @@ -2,6 +2,7 @@ package ntp import ( "context" + "os" "time" "github.com/sagernet/sing-box/adapter" @@ -9,6 +10,7 @@ import ( "github.com/sagernet/sing-box/common/settings" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -20,7 +22,7 @@ var _ adapter.TimeService = (*Service)(nil) type Service struct { ctx context.Context - cancel context.CancelFunc + cancel common.ContextCancelCauseFunc server M.Socksaddr writeToSystem bool dialer N.Dialer @@ -30,7 +32,7 @@ type Service struct { } func NewService(ctx context.Context, router adapter.Router, logger logger.Logger, options option.NTPOptions) *Service { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := common.ContextWithCancelCause(ctx) server := options.ServerOptions.Build() if server.Port == 0 { server.Port = 123 @@ -64,7 +66,7 @@ func (s *Service) Start() error { func (s *Service) Close() error { s.ticker.Stop() - s.cancel() + s.cancel(os.ErrClosed) return nil } diff --git a/outbound/dns.go b/outbound/dns.go index 075e2849..5af64173 100644 --- a/outbound/dns.go +++ b/outbound/dns.go @@ -102,11 +102,10 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { ctx = adapter.WithContext(ctx, &metadata) - fastClose, cancel := context.WithCancel(ctx) + fastClose, cancel := common.ContextWithCancelCause(ctx) timeout := canceler.New(fastClose, cancel, C.DNSTimeout) var group task.Group group.Append0(func(ctx context.Context) error { - defer cancel() _buffer := buf.StackNewSize(dns.FixedPacketSize) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) @@ -115,11 +114,13 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada buffer.FullReset() destination, err := conn.ReadPacket(buffer) if err != nil { + cancel(err) return err } var message mDNS.Msg err = message.Unpack(buffer.Bytes()) if err != nil { + cancel(err) return err } timeout.Update() @@ -127,17 +128,22 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada go func() error { response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) if err != nil { + cancel(err) return err } timeout.Update() responseBuffer := buf.NewPacket() n, err := response.PackBuffer(responseBuffer.FreeBytes()) if err != nil { + cancel(err) responseBuffer.Release() return err } responseBuffer.Truncate(len(n)) err = conn.WritePacket(responseBuffer, destination) + if err != nil { + cancel(err) + } return err }() } diff --git a/transport/v2raygrpc/client.go b/transport/v2raygrpc/client.go index 17b3b5cd..d4f24987 100644 --- a/transport/v2raygrpc/client.go +++ b/transport/v2raygrpc/client.go @@ -101,10 +101,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { return nil, err } client := NewGunServiceClient(clientConn).(GunServiceCustomNameClient) - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := common.ContextWithCancelCause(ctx) stream, err := client.TunCustomName(ctx, c.serviceName) if err != nil { - cancel() + cancel(err) return nil, err } return NewGRPCConn(stream, cancel), nil diff --git a/transport/v2raygrpc/conn.go b/transport/v2raygrpc/conn.go index 1ef3f391..ec31e298 100644 --- a/transport/v2raygrpc/conn.go +++ b/transport/v2raygrpc/conn.go @@ -1,12 +1,12 @@ package v2raygrpc import ( - "context" "net" "os" "time" "github.com/sagernet/sing-box/common/baderror" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/rw" ) @@ -14,11 +14,11 @@ var _ net.Conn = (*GRPCConn)(nil) type GRPCConn struct { GunService - cancel context.CancelFunc + cancel common.ContextCancelCauseFunc cache []byte } -func NewGRPCConn(service GunService, cancel context.CancelFunc) *GRPCConn { +func NewGRPCConn(service GunService, cancel common.ContextCancelCauseFunc) *GRPCConn { if client, isClient := service.(GunService_TunClient); isClient { service = &clientConnWrapper{client} } @@ -37,6 +37,7 @@ func (c *GRPCConn) Read(b []byte) (n int, err error) { hunk, err := c.Recv() err = baderror.WrapGRPC(err) if err != nil { + c.cancel(err) return } n = copy(b, hunk.Data) @@ -49,13 +50,14 @@ func (c *GRPCConn) Read(b []byte) (n int, err error) { func (c *GRPCConn) Write(b []byte) (n int, err error) { err = baderror.WrapGRPC(c.Send(&Hunk{Data: b})) if err != nil { + c.cancel(err) return } return len(b), nil } func (c *GRPCConn) Close() error { - c.cancel() + c.cancel(net.ErrClosed) return nil } diff --git a/transport/v2raygrpc/server.go b/transport/v2raygrpc/server.go index 1b6a34c1..883357d7 100644 --- a/transport/v2raygrpc/server.go +++ b/transport/v2raygrpc/server.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -45,7 +46,7 @@ func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig t } func (s *Server) Tun(server GunService_TunServer) error { - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := common.ContextWithCancelCause(s.ctx) conn := NewGRPCConn(server, cancel) var metadata M.Metadata if remotePeer, loaded := peer.FromContext(server.Context()); loaded {