mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-22 08:31:30 +00:00
Fix wireguard reconnect
This commit is contained in:
parent
9d32fc9bd1
commit
1fbe7c54bf
|
@ -39,6 +39,10 @@ func (a *myOutboundAdapter) Network() []string {
|
||||||
return a.network
|
return a.network
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *myOutboundAdapter) NewError(ctx context.Context, err error) {
|
||||||
|
NewError(a.logger, ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
|
func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
|
||||||
ctx = adapter.WithContext(ctx, &metadata)
|
ctx = adapter.WithContext(ctx, &metadata)
|
||||||
var outConn net.Conn
|
var outConn net.Conn
|
||||||
|
@ -121,3 +125,12 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro
|
||||||
}
|
}
|
||||||
return bufio.CopyConn(ctx, conn, serverConn)
|
return bufio.CopyConn(ctx, conn, serverConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewError(logger log.ContextLogger, ctx context.Context, err error) {
|
||||||
|
common.Close(err)
|
||||||
|
if E.IsClosedOrCanceled(err) {
|
||||||
|
logger.DebugContext(ctx, "connection closed: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.ErrorContext(ctx, err)
|
||||||
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
|
||||||
connectAddr = options.ServerOptions.Build()
|
connectAddr = options.ServerOptions.Build()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved)
|
outbound.bind = wireguard.NewClientBind(ctx, outbound, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved)
|
||||||
localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
|
localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
|
||||||
if len(localPrefixes) == 0 {
|
if len(localPrefixes) == 0 {
|
||||||
return nil, E.New("missing local address")
|
return nil, E.New("missing local address")
|
||||||
|
|
|
@ -7,8 +7,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/wireguard-go/conn"
|
"github.com/sagernet/wireguard-go/conn"
|
||||||
|
@ -18,6 +18,7 @@ var _ conn.Bind = (*ClientBind)(nil)
|
||||||
|
|
||||||
type ClientBind struct {
|
type ClientBind struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
errorHandler E.Handler
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
reservedForEndpoint map[M.Socksaddr][3]uint8
|
reservedForEndpoint map[M.Socksaddr][3]uint8
|
||||||
connAccess sync.Mutex
|
connAccess sync.Mutex
|
||||||
|
@ -28,9 +29,10 @@ type ClientBind struct {
|
||||||
reserved [3]uint8
|
reserved [3]uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
|
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
|
||||||
return &ClientBind{
|
return &ClientBind{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
errorHandler: errorHandler,
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
|
reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
|
||||||
isConnect: isConnect,
|
isConnect: isConnect,
|
||||||
|
@ -67,10 +69,10 @@ func (c *ClientBind) connect() (*wireConn, error) {
|
||||||
if c.isConnect {
|
if c.isConnect {
|
||||||
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
|
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &wireError{err}
|
return nil, err
|
||||||
}
|
}
|
||||||
c.conn = &wireConn{
|
c.conn = &wireConn{
|
||||||
NetPacketConn: &bufio.UnbindPacketConn{
|
PacketConn: &bufio.UnbindPacketConn{
|
||||||
ExtendedConn: bufio.NewExtendedConn(udpConn),
|
ExtendedConn: bufio.NewExtendedConn(udpConn),
|
||||||
Addr: c.connectAddr,
|
Addr: c.connectAddr,
|
||||||
},
|
},
|
||||||
|
@ -79,11 +81,11 @@ func (c *ClientBind) connect() (*wireConn, error) {
|
||||||
} else {
|
} else {
|
||||||
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
|
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &wireError{err}
|
return nil, err
|
||||||
}
|
}
|
||||||
c.conn = &wireConn{
|
c.conn = &wireConn{
|
||||||
NetPacketConn: bufio.NewPacketConn(udpConn),
|
PacketConn: bufio.NewPacketConn(udpConn),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return c.conn, nil
|
return c.conn, nil
|
||||||
|
@ -102,30 +104,31 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
|
||||||
func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
|
func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||||
udpConn, err := c.connect()
|
udpConn, err := c.connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = &wireError{err}
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
|
||||||
|
err = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buffer := buf.With(b)
|
n, addr, err := udpConn.ReadFrom(b)
|
||||||
destination, err := udpConn.ReadPacket(buffer)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
udpConn.Close()
|
udpConn.Close()
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
default:
|
default:
|
||||||
err = &wireError{err}
|
c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n = buffer.Len()
|
|
||||||
if buffer.Start() > 0 {
|
|
||||||
copy(b, buffer.Bytes())
|
|
||||||
}
|
|
||||||
if n > 3 {
|
if n > 3 {
|
||||||
b[1] = 0
|
b[1] = 0
|
||||||
b[2] = 0
|
b[2] = 0
|
||||||
b[3] = 0
|
b[3] = 0
|
||||||
}
|
}
|
||||||
ep = Endpoint(destination)
|
ep = Endpoint(M.SocksaddrFromNet(addr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,7 +170,7 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
|
||||||
b[2] = reserved[1]
|
b[2] = reserved[1]
|
||||||
b[3] = reserved[2]
|
b[3] = reserved[2]
|
||||||
}
|
}
|
||||||
err = udpConn.WritePacket(buf.As(b), destination)
|
_, err = udpConn.WriteTo(b, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
udpConn.Close()
|
udpConn.Close()
|
||||||
}
|
}
|
||||||
|
@ -179,7 +182,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type wireConn struct {
|
type wireConn struct {
|
||||||
N.NetPacketConn
|
net.PacketConn
|
||||||
access sync.Mutex
|
access sync.Mutex
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
@ -192,7 +195,7 @@ func (w *wireConn) Close() error {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
w.NetPacketConn.Close()
|
w.PacketConn.Close()
|
||||||
close(w.done)
|
close(w.done)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
package wireguard
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
type wireError struct {
|
|
||||||
cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireError) Error() string {
|
|
||||||
return w.cause.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireError) Timeout() bool {
|
|
||||||
if cause, causeNet := w.cause.(net.Error); causeNet {
|
|
||||||
return cause.Timeout()
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wireError) Temporary() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
Loading…
Reference in a new issue