diff --git a/adapter/v2ray.go b/adapter/v2ray.go index d1b420ee..5a98d2e5 100644 --- a/adapter/v2ray.go +++ b/adapter/v2ray.go @@ -22,4 +22,5 @@ type V2RayServerTransportHandler interface { type V2RayClientTransport interface { DialContext(ctx context.Context) (net.Conn, error) + Close() error } diff --git a/outbound/hysteria.go b/outbound/hysteria.go index 8c130e33..cf7f7fed 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -130,8 +130,8 @@ func (h *Hysteria) NewPacketConnection(ctx context.Context, conn N.PacketConn, m return NewPacketConnection(ctx, h, conn, metadata) } -func (h *Hysteria) InterfaceUpdated() error { - return h.client.CloseWithError(E.New("network changed")) +func (h *Hysteria) InterfaceUpdated() { + h.client.CloseWithError(E.New("network changed")) } func (h *Hysteria) Close() error { diff --git a/outbound/hysteria2.go b/outbound/hysteria2.go index f2ffe2fd..9079e403 100644 --- a/outbound/hysteria2.go +++ b/outbound/hysteria2.go @@ -116,8 +116,8 @@ func (h *Hysteria2) NewPacketConnection(ctx context.Context, conn N.PacketConn, return NewPacketConnection(ctx, h, conn, metadata) } -func (h *Hysteria2) InterfaceUpdated() error { - return h.client.CloseWithError(E.New("network changed")) +func (h *Hysteria2) InterfaceUpdated() { + h.client.CloseWithError(E.New("network changed")) } func (h *Hysteria2) Close() error { diff --git a/outbound/trojan.go b/outbound/trojan.go index 14369613..52d72757 100644 --- a/outbound/trojan.go +++ b/outbound/trojan.go @@ -108,6 +108,9 @@ func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, met } func (h *Trojan) InterfaceUpdated() { + if h.transport != nil { + h.transport.Close() + } if h.multiplexDialer != nil { h.multiplexDialer.Reset() } diff --git a/outbound/vless.go b/outbound/vless.go index b3e94661..66080eaf 100644 --- a/outbound/vless.go +++ b/outbound/vless.go @@ -127,6 +127,9 @@ func (h *VLESS) NewPacketConnection(ctx context.Context, conn N.PacketConn, meta } func (h *VLESS) InterfaceUpdated() { + if h.transport != nil { + h.transport.Close() + } if h.multiplexDialer != nil { h.multiplexDialer.Reset() } diff --git a/outbound/vmess.go b/outbound/vmess.go index 6add5414..c7d88b90 100644 --- a/outbound/vmess.go +++ b/outbound/vmess.go @@ -103,6 +103,9 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg } func (h *VMess) InterfaceUpdated() { + if h.transport != nil { + h.transport.Close() + } if h.multiplexDialer != nil { h.multiplexDialer.Reset() } diff --git a/transport/v2raygrpc/client.go b/transport/v2raygrpc/client.go index 162002ad..1e72040a 100644 --- a/transport/v2raygrpc/client.go +++ b/transport/v2raygrpc/client.go @@ -72,12 +72,6 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt }, nil } -func (c *Client) Close() error { - return common.Close( - common.PtrOrNil(c.conn), - ) -} - func (c *Client) connect() (*grpc.ClientConn, error) { conn := c.conn if conn != nil && conn.GetState() != connectivity.Shutdown { @@ -113,3 +107,13 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } return NewGRPCConn(stream, cancel), nil } + +func (c *Client) Close() error { + c.connAccess.Lock() + defer c.connAccess.Unlock() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + return nil +} diff --git a/transport/v2raygrpclite/client.go b/transport/v2raygrpclite/client.go index bd52c1d9..de8915a1 100644 --- a/transport/v2raygrpclite/client.go +++ b/transport/v2raygrpclite/client.go @@ -109,8 +109,6 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } func (c *Client) Close() error { - if c.transport != nil { - v2rayhttp.CloseIdleConnections(c.transport) - } + v2rayhttp.ResetTransport(c.transport) return nil } diff --git a/transport/v2rayhttp/client.go b/transport/v2rayhttp/client.go index d817d37d..a105a4f3 100644 --- a/transport/v2rayhttp/client.go +++ b/transport/v2rayhttp/client.go @@ -155,6 +155,6 @@ func (c *Client) dialHTTP2(ctx context.Context) (net.Conn, error) { } func (c *Client) Close() error { - CloseIdleConnections(c.transport) + c.transport = ResetTransport(c.transport) return nil } diff --git a/transport/v2rayhttp/force_close.go b/transport/v2rayhttp/force_close.go new file mode 100644 index 00000000..d574a510 --- /dev/null +++ b/transport/v2rayhttp/force_close.go @@ -0,0 +1,47 @@ +package v2rayhttp + +import ( + "net/http" + "reflect" + "sync" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/net/http2" +) + +type clientConnPool struct { + t *http2.Transport + mu sync.Mutex + conns map[string][]*http2.ClientConn // key is host:port +} + +type efaceWords struct { + typ unsafe.Pointer + data unsafe.Pointer +} + +func ResetTransport(rawTransport http.RoundTripper) http.RoundTripper { + switch transport := rawTransport.(type) { + case *http.Transport: + transport.CloseIdleConnections() + return transport.Clone() + case *http2.Transport: + connPool := transportConnPool(transport) + p := (*clientConnPool)((*efaceWords)(unsafe.Pointer(&connPool)).data) + p.mu.Lock() + defer p.mu.Unlock() + for _, vv := range p.conns { + for _, cc := range vv { + cc.Close() + } + } + return transport + default: + panic(E.New("unknown transport type: ", reflect.TypeOf(transport))) + } +} + +//go:linkname transportConnPool golang.org/x/net/http2.(*Transport).connPool +func transportConnPool(t *http2.Transport) http2.ClientConnPool diff --git a/transport/v2rayhttpupgrade/client.go b/transport/v2rayhttpupgrade/client.go index 1b72cdfc..e2b86b1f 100644 --- a/transport/v2rayhttpupgrade/client.go +++ b/transport/v2rayhttpupgrade/client.go @@ -116,3 +116,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } return conn, nil } + +func (c *Client) Close() error { + return nil +} diff --git a/transport/v2rayquic/client.go b/transport/v2rayquic/client.go index 44455e2f..a1c3e3a6 100644 --- a/transport/v2rayquic/client.go +++ b/transport/v2rayquic/client.go @@ -97,5 +97,15 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } func (c *Client) Close() error { - return common.Close(c.conn, c.rawConn) + c.connAccess.Lock() + defer c.connAccess.Unlock() + if c.conn != nil { + c.conn.CloseWithError(0, "") + } + if c.rawConn != nil { + c.rawConn.Close() + } + c.conn = nil + c.rawConn = nil + return nil } diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go index 5de610c2..9c495076 100644 --- a/transport/v2raywebsocket/client.go +++ b/transport/v2raywebsocket/client.go @@ -127,3 +127,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil } } + +func (c *Client) Close() error { + return nil +}