Fix wireguard reconnect

This commit is contained in:
世界 2023-04-13 16:02:28 +08:00
parent 9d32fc9bd1
commit 1fbe7c54bf
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 36 additions and 42 deletions

View file

@ -39,6 +39,10 @@ func (a *myOutboundAdapter) Network() []string {
return a.network return a.network
} }
func (a *myOutboundAdapter) NewError(ctx context.Context, err error) {
NewError(a.logger, ctx, err)
}
func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn var outConn net.Conn
@ -121,3 +125,12 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro
} }
return bufio.CopyConn(ctx, conn, serverConn) return bufio.CopyConn(ctx, conn, serverConn)
} }
func NewError(logger log.ContextLogger, ctx context.Context, err error) {
common.Close(err)
if E.IsClosedOrCanceled(err) {
logger.DebugContext(ctx, "connection closed: ", err)
return
}
logger.ErrorContext(ctx, err)
}

View file

@ -64,7 +64,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
connectAddr = options.ServerOptions.Build() connectAddr = options.ServerOptions.Build()
} }
} }
outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved) outbound.bind = wireguard.NewClientBind(ctx, outbound, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved)
localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build) localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
if len(localPrefixes) == 0 { if len(localPrefixes) == 0 {
return nil, E.New("missing local address") return nil, E.New("missing local address")

View file

@ -7,8 +7,8 @@ import (
"sync" "sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/conn" "github.com/sagernet/wireguard-go/conn"
@ -18,6 +18,7 @@ var _ conn.Bind = (*ClientBind)(nil)
type ClientBind struct { type ClientBind struct {
ctx context.Context ctx context.Context
errorHandler E.Handler
dialer N.Dialer dialer N.Dialer
reservedForEndpoint map[M.Socksaddr][3]uint8 reservedForEndpoint map[M.Socksaddr][3]uint8
connAccess sync.Mutex connAccess sync.Mutex
@ -28,9 +29,10 @@ type ClientBind struct {
reserved [3]uint8 reserved [3]uint8
} }
func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind { func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
return &ClientBind{ return &ClientBind{
ctx: ctx, ctx: ctx,
errorHandler: errorHandler,
dialer: dialer, dialer: dialer,
reservedForEndpoint: make(map[M.Socksaddr][3]uint8), reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
isConnect: isConnect, isConnect: isConnect,
@ -67,10 +69,10 @@ func (c *ClientBind) connect() (*wireConn, error) {
if c.isConnect { if c.isConnect {
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr) udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
if err != nil { if err != nil {
return nil, &wireError{err} return nil, err
} }
c.conn = &wireConn{ c.conn = &wireConn{
NetPacketConn: &bufio.UnbindPacketConn{ PacketConn: &bufio.UnbindPacketConn{
ExtendedConn: bufio.NewExtendedConn(udpConn), ExtendedConn: bufio.NewExtendedConn(udpConn),
Addr: c.connectAddr, Addr: c.connectAddr,
}, },
@ -79,11 +81,11 @@ func (c *ClientBind) connect() (*wireConn, error) {
} else { } else {
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
if err != nil { if err != nil {
return nil, &wireError{err} return nil, err
} }
c.conn = &wireConn{ c.conn = &wireConn{
NetPacketConn: bufio.NewPacketConn(udpConn), PacketConn: bufio.NewPacketConn(udpConn),
done: make(chan struct{}), done: make(chan struct{}),
} }
} }
return c.conn, nil return c.conn, nil
@ -102,30 +104,31 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
udpConn, err := c.connect() udpConn, err := c.connect()
if err != nil { if err != nil {
err = &wireError{err} select {
case <-c.done:
return
default:
}
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
err = nil
return return
} }
buffer := buf.With(b) n, addr, err := udpConn.ReadFrom(b)
destination, err := udpConn.ReadPacket(buffer)
if err != nil { if err != nil {
udpConn.Close() udpConn.Close()
select { select {
case <-c.done: case <-c.done:
default: default:
err = &wireError{err} c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
} }
return return
} }
n = buffer.Len()
if buffer.Start() > 0 {
copy(b, buffer.Bytes())
}
if n > 3 { if n > 3 {
b[1] = 0 b[1] = 0
b[2] = 0 b[2] = 0
b[3] = 0 b[3] = 0
} }
ep = Endpoint(destination) ep = Endpoint(M.SocksaddrFromNet(addr))
return return
} }
@ -167,7 +170,7 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
b[2] = reserved[1] b[2] = reserved[1]
b[3] = reserved[2] b[3] = reserved[2]
} }
err = udpConn.WritePacket(buf.As(b), destination) _, err = udpConn.WriteTo(b, destination)
if err != nil { if err != nil {
udpConn.Close() udpConn.Close()
} }
@ -179,7 +182,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
} }
type wireConn struct { type wireConn struct {
N.NetPacketConn net.PacketConn
access sync.Mutex access sync.Mutex
done chan struct{} done chan struct{}
} }
@ -192,7 +195,7 @@ func (w *wireConn) Close() error {
return net.ErrClosed return net.ErrClosed
default: default:
} }
w.NetPacketConn.Close() w.PacketConn.Close()
close(w.done) close(w.done)
return nil return nil
} }

View file

@ -1,22 +0,0 @@
package wireguard
import "net"
type wireError struct {
cause error
}
func (w *wireError) Error() string {
return w.cause.Error()
}
func (w *wireError) Timeout() bool {
if cause, causeNet := w.cause.(net.Error); causeNet {
return cause.Timeout()
}
return false
}
func (w *wireError) Temporary() bool {
return true
}