Fix wireguard client bind

This commit is contained in:
世界 2024-06-03 16:59:13 +08:00
parent 53927d8bbd
commit e08c052fc9
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -12,6 +12,8 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn" "github.com/sagernet/wireguard-go/conn"
) )
@ -19,6 +21,9 @@ var _ conn.Bind = (*ClientBind)(nil)
type ClientBind struct { type ClientBind struct {
ctx context.Context ctx context.Context
pauseManager pause.Manager
bindCtx context.Context
bindDone context.CancelFunc
errorHandler E.Handler errorHandler E.Handler
dialer N.Dialer dialer N.Dialer
reservedForEndpoint map[netip.AddrPort][3]uint8 reservedForEndpoint map[netip.AddrPort][3]uint8
@ -33,6 +38,7 @@ type ClientBind struct {
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind { func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
return &ClientBind{ return &ClientBind{
ctx: ctx, ctx: ctx,
pauseManager: service.FromContext[pause.Manager](ctx),
errorHandler: errorHandler, errorHandler: errorHandler,
dialer: dialer, dialer: dialer,
reservedForEndpoint: make(map[netip.AddrPort][3]uint8), reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
@ -55,6 +61,11 @@ func (c *ClientBind) connect() (*wireConn, error) {
} }
c.connAccess.Lock() c.connAccess.Lock()
defer c.connAccess.Unlock() defer c.connAccess.Unlock()
select {
case <-c.done:
return nil, net.ErrClosed
default:
}
serverConn = c.conn serverConn = c.conn
if serverConn != nil { if serverConn != nil {
select { select {
@ -65,7 +76,7 @@ func (c *ClientBind) connect() (*wireConn, error) {
} }
} }
if c.isConnect { if c.isConnect {
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr)) udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -74,7 +85,7 @@ func (c *ClientBind) connect() (*wireConn, error) {
done: make(chan struct{}), done: make(chan struct{}),
} }
} else { } else {
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -92,6 +103,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
c.done = make(chan struct{}) c.done = make(chan struct{})
default: default:
} }
c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
return []conn.ReceiveFunc{c.receive}, 0, nil return []conn.ReceiveFunc{c.receive}, 0, nil
} }
@ -105,6 +117,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
} }
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server")) c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
err = nil err = nil
c.pauseManager.WaitActive()
time.Sleep(time.Second) time.Sleep(time.Second)
return return
} }
@ -130,12 +143,17 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
} }
func (c *ClientBind) Close() error { func (c *ClientBind) Close() error {
common.Close(common.PtrOrNil(c.conn))
select { select {
case <-c.done: case <-c.done:
default: default:
close(c.done) close(c.done)
} }
if c.bindDone != nil {
c.bindDone()
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
common.Close(common.PtrOrNil(c.conn))
return nil return nil
} }
@ -146,6 +164,8 @@ func (c *ClientBind) SetMark(mark uint32) error {
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error { func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
udpConn, err := c.connect() udpConn, err := c.connect()
if err != nil { if err != nil {
c.pauseManager.WaitActive()
time.Sleep(time.Second)
return err return err
} }
destination := netip.AddrPort(ep.(Endpoint)) destination := netip.AddrPort(ep.(Endpoint))