diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 39adce25..6c534532 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -12,6 +12,8 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" + "github.com/sagernet/sing/service/pause" "github.com/sagernet/wireguard-go/conn" ) @@ -19,6 +21,9 @@ var _ conn.Bind = (*ClientBind)(nil) type ClientBind struct { ctx context.Context + pauseManager pause.Manager + bindCtx context.Context + bindDone context.CancelFunc errorHandler E.Handler dialer N.Dialer reservedForEndpoint map[netip.AddrPort][3]uint8 @@ -33,6 +38,7 @@ type ClientBind struct { func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind { return &ClientBind{ ctx: ctx, + pauseManager: service.FromContext[pause.Manager](ctx), errorHandler: errorHandler, dialer: dialer, reservedForEndpoint: make(map[netip.AddrPort][3]uint8), @@ -55,6 +61,11 @@ func (c *ClientBind) connect() (*wireConn, error) { } c.connAccess.Lock() defer c.connAccess.Unlock() + select { + case <-c.done: + return nil, net.ErrClosed + default: + } serverConn = c.conn if serverConn != nil { select { @@ -65,7 +76,7 @@ func (c *ClientBind) connect() (*wireConn, error) { } } if c.isConnect { - udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr)) + udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr)) if err != nil { return nil, err } @@ -74,7 +85,7 @@ func (c *ClientBind) connect() (*wireConn, error) { done: make(chan struct{}), } } else { - udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) + udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) if err != nil { return nil, err } @@ -92,6 +103,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1 c.done = make(chan struct{}) default: } + c.bindCtx, c.bindDone = context.WithCancel(c.ctx) return []conn.ReceiveFunc{c.receive}, 0, nil } @@ -105,6 +117,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) } c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server")) err = nil + c.pauseManager.WaitActive() time.Sleep(time.Second) return } @@ -130,12 +143,17 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) } func (c *ClientBind) Close() error { - common.Close(common.PtrOrNil(c.conn)) select { case <-c.done: default: close(c.done) } + if c.bindDone != nil { + c.bindDone() + } + c.connAccess.Lock() + defer c.connAccess.Unlock() + common.Close(common.PtrOrNil(c.conn)) return nil } @@ -146,6 +164,8 @@ func (c *ClientBind) SetMark(mark uint32) error { func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error { udpConn, err := c.connect() if err != nil { + c.pauseManager.WaitActive() + time.Sleep(time.Second) return err } destination := netip.AddrPort(ep.(Endpoint))