mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-09 18:43:14 +00:00
Fix wireguard client bind
This commit is contained in:
parent
53927d8bbd
commit
e08c052fc9
|
@ -12,6 +12,8 @@ import (
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
"github.com/sagernet/sing/service/pause"
|
||||||
"github.com/sagernet/wireguard-go/conn"
|
"github.com/sagernet/wireguard-go/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,6 +21,9 @@ var _ conn.Bind = (*ClientBind)(nil)
|
||||||
|
|
||||||
type ClientBind struct {
|
type ClientBind struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
pauseManager pause.Manager
|
||||||
|
bindCtx context.Context
|
||||||
|
bindDone context.CancelFunc
|
||||||
errorHandler E.Handler
|
errorHandler E.Handler
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
reservedForEndpoint map[netip.AddrPort][3]uint8
|
reservedForEndpoint map[netip.AddrPort][3]uint8
|
||||||
|
@ -33,6 +38,7 @@ type ClientBind struct {
|
||||||
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
|
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
|
||||||
return &ClientBind{
|
return &ClientBind{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
pauseManager: service.FromContext[pause.Manager](ctx),
|
||||||
errorHandler: errorHandler,
|
errorHandler: errorHandler,
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
|
reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
|
||||||
|
@ -55,6 +61,11 @@ func (c *ClientBind) connect() (*wireConn, error) {
|
||||||
}
|
}
|
||||||
c.connAccess.Lock()
|
c.connAccess.Lock()
|
||||||
defer c.connAccess.Unlock()
|
defer c.connAccess.Unlock()
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
serverConn = c.conn
|
serverConn = c.conn
|
||||||
if serverConn != nil {
|
if serverConn != nil {
|
||||||
select {
|
select {
|
||||||
|
@ -65,7 +76,7 @@ func (c *ClientBind) connect() (*wireConn, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c.isConnect {
|
if c.isConnect {
|
||||||
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
|
udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -74,7 +85,7 @@ func (c *ClientBind) connect() (*wireConn, error) {
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
|
udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -92,6 +103,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
|
||||||
c.done = make(chan struct{})
|
c.done = make(chan struct{})
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
|
||||||
return []conn.ReceiveFunc{c.receive}, 0, nil
|
return []conn.ReceiveFunc{c.receive}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,6 +117,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
||||||
}
|
}
|
||||||
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
|
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
|
||||||
err = nil
|
err = nil
|
||||||
|
c.pauseManager.WaitActive()
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -130,12 +143,17 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientBind) Close() error {
|
func (c *ClientBind) Close() error {
|
||||||
common.Close(common.PtrOrNil(c.conn))
|
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
default:
|
default:
|
||||||
close(c.done)
|
close(c.done)
|
||||||
}
|
}
|
||||||
|
if c.bindDone != nil {
|
||||||
|
c.bindDone()
|
||||||
|
}
|
||||||
|
c.connAccess.Lock()
|
||||||
|
defer c.connAccess.Unlock()
|
||||||
|
common.Close(common.PtrOrNil(c.conn))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,6 +164,8 @@ func (c *ClientBind) SetMark(mark uint32) error {
|
||||||
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||||
udpConn, err := c.connect()
|
udpConn, err := c.connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.pauseManager.WaitActive()
|
||||||
|
time.Sleep(time.Second)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
destination := netip.AddrPort(ep.(Endpoint))
|
destination := netip.AddrPort(ep.(Endpoint))
|
||||||
|
|
Loading…
Reference in a new issue