From 1f63ce5deeb18e6a66de084778cc5276154ddf85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 6 Nov 2022 10:36:19 +0800 Subject: [PATCH] Fix reset outbound --- adapter/router.go | 4 +++ outbound/hysteria.go | 10 +++++++- outbound/ssh.go | 10 +++++++- outbound/wireguard.go | 10 +++++++- route/router.go | 40 +++++++++++++++++++----------- transport/wireguard/client_bind.go | 6 +++++ 6 files changed, 63 insertions(+), 17 deletions(-) diff --git a/adapter/router.go b/adapter/router.go index 6bf30589..447c272d 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -60,3 +60,7 @@ type DNSRule interface { Rule DisableCache() bool } + +type InterfaceUpdateListener interface { + InterfaceUpdated() error +} diff --git a/outbound/hysteria.go b/outbound/hysteria.go index d3875f0c..b773970a 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -23,7 +23,10 @@ import ( N "github.com/sagernet/sing/common/network" ) -var _ adapter.Outbound = (*Hysteria)(nil) +var ( + _ adapter.Outbound = (*Hysteria)(nil) + _ adapter.InterfaceUpdateListener = (*Hysteria)(nil) +) type Hysteria struct { myOutboundAdapter @@ -236,6 +239,11 @@ func (h *Hysteria) udpRecvLoop(conn quic.Connection) { } } +func (h *Hysteria) InterfaceUpdated() error { + h.Close() + return nil +} + func (h *Hysteria) Close() error { h.connAccess.Lock() defer h.connAccess.Unlock() diff --git a/outbound/ssh.go b/outbound/ssh.go index 48d51a78..33579a7e 100644 --- a/outbound/ssh.go +++ b/outbound/ssh.go @@ -21,7 +21,10 @@ import ( "golang.org/x/crypto/ssh" ) -var _ adapter.Outbound = (*SSH)(nil) +var ( + _ adapter.Outbound = (*SSH)(nil) + _ adapter.InterfaceUpdateListener = (*SSH)(nil) +) type SSH struct { myOutboundAdapter @@ -149,6 +152,11 @@ func (s *SSH) connect() (*ssh.Client, error) { return client, nil } +func (s *SSH) InterfaceUpdated() error { + common.Close(s.clientConn) + return nil +} + func (s *SSH) Close() error { return common.Close(s.clientConn) } diff --git a/outbound/wireguard.go b/outbound/wireguard.go index a0c1e933..c87f4396 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -26,7 +26,10 @@ import ( "golang.zx2c4.com/wireguard/device" ) -var _ adapter.Outbound = (*WireGuard)(nil) +var ( + _ adapter.Outbound = (*WireGuard)(nil) + _ adapter.InterfaceUpdateListener = (*WireGuard)(nil) +) type WireGuard struct { myOutboundAdapter @@ -134,6 +137,11 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context return outbound, nil } +func (w *WireGuard) InterfaceUpdated() error { + w.bind.Reset() + return nil +} + func (w *WireGuard) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch network { case N.NetworkTCP: diff --git a/route/router.go b/route/router.go index cd401706..4487d26f 100644 --- a/route/router.go +++ b/route/router.go @@ -262,20 +262,7 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont if err != nil { return nil, E.New("auto_detect_interface unsupported on current platform") } - interfaceMonitor.RegisterCallback(func(event int) error { - if C.IsAndroid { - var vpnStatus string - if router.interfaceMonitor.AndroidVPNEnabled() { - vpnStatus = "enabled" - } else { - vpnStatus = "disabled" - } - router.logger.Info("updated default interface ", router.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", router.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()), ", vpn ", vpnStatus) - } else { - router.logger.Info("updated default interface ", router.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", router.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified())) - } - return nil - }) + interfaceMonitor.RegisterCallback(router.notifyNetworkUpdate) router.interfaceMonitor = interfaceMonitor } @@ -1014,3 +1001,28 @@ func (r *Router) NewError(ctx context.Context, err error) { } r.logger.ErrorContext(ctx, err) } + +func (r *Router) notifyNetworkUpdate(int) error { + if C.IsAndroid { + var vpnStatus string + if r.interfaceMonitor.AndroidVPNEnabled() { + vpnStatus = "enabled" + } else { + vpnStatus = "disabled" + } + r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()), ", vpn ", vpnStatus) + } else { + r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified())) + } + + for _, outbound := range r.outbounds { + listener, isListener := outbound.(adapter.InterfaceUpdateListener) + if isListener { + err := listener.InterfaceUpdated() + if err != nil { + return err + } + } + } + return nil +} diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 1ecbda23..dc6fd2ac 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -100,6 +100,12 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { return } +func (c *ClientBind) Reset() { + c.connAccess.Lock() + defer c.connAccess.Unlock() + common.Close(common.PtrOrNil(c.conn)) +} + func (c *ClientBind) Close() error { c.connAccess.Lock() defer c.connAccess.Unlock()