Add interrupt support for outbound groups

This commit is contained in:
世界 2023-09-15 00:07:07 +08:00
parent bd7adcbb7e
commit c320be75a7
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 282 additions and 55 deletions

75
common/interrupt/conn.go Normal file
View file

@ -0,0 +1,75 @@
package interrupt
import (
"net"
"github.com/sagernet/sing/common/x/list"
)
/*type GroupedConn interface {
MarkAsInternal()
}
func MarkAsInternal(conn any) {
if groupedConn, isGroupConn := common.Cast[GroupedConn](conn); isGroupConn {
groupedConn.MarkAsInternal()
}
}*/
type Conn struct {
net.Conn
group *Group
element *list.Element[*groupConnItem]
}
/*func (c *Conn) MarkAsInternal() {
c.element.Value.internal = true
}*/
func (c *Conn) Close() error {
c.group.access.Lock()
defer c.group.access.Unlock()
c.group.connections.Remove(c.element)
return c.Conn.Close()
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return true
}
func (c *Conn) Upstream() any {
return c.Conn
}
type PacketConn struct {
net.PacketConn
group *Group
element *list.Element[*groupConnItem]
}
/*func (c *PacketConn) MarkAsInternal() {
c.element.Value.internal = true
}*/
func (c *PacketConn) Close() error {
c.group.access.Lock()
defer c.group.access.Unlock()
c.group.connections.Remove(c.element)
return c.PacketConn.Close()
}
func (c *PacketConn) ReaderReplaceable() bool {
return true
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}
func (c *PacketConn) Upstream() any {
return c.PacketConn
}

View file

@ -0,0 +1,13 @@
package interrupt
import "context"
type contextKeyIsExternalConnection struct{}
func ContextWithIsExternalConnection(ctx context.Context) context.Context {
return context.WithValue(ctx, contextKeyIsExternalConnection{}, true)
}
func IsExternalConnectionFromContext(ctx context.Context) bool {
return ctx.Value(contextKeyIsExternalConnection{}) != nil
}

52
common/interrupt/group.go Normal file
View file

@ -0,0 +1,52 @@
package interrupt
import (
"io"
"net"
"sync"
"github.com/sagernet/sing/common/x/list"
)
type Group struct {
access sync.Mutex
connections list.List[*groupConnItem]
}
type groupConnItem struct {
conn io.Closer
isExternal bool
}
func NewGroup() *Group {
return &Group{}
}
func (g *Group) NewConn(conn net.Conn, isExternal bool) net.Conn {
g.access.Lock()
defer g.access.Unlock()
item := g.connections.PushBack(&groupConnItem{conn, isExternal})
return &Conn{Conn: conn, group: g, element: item}
}
func (g *Group) NewPacketConn(conn net.PacketConn, isExternal bool) net.PacketConn {
g.access.Lock()
defer g.access.Unlock()
item := g.connections.PushBack(&groupConnItem{conn, isExternal})
return &PacketConn{PacketConn: conn, group: g, element: item}
}
func (g *Group) Interrupt(interruptExternalConnections bool) {
g.access.Lock()
defer g.access.Unlock()
var toDelete []*list.Element[*groupConnItem]
for element := g.connections.Front(); element != nil; element = element.Next() {
if !element.Value.isExternal || interruptExternalConnections {
element.Value.conn.Close()
toDelete = append(toDelete, element)
}
}
for _, element := range toDelete {
g.connections.Remove(element)
}
}

View file

@ -10,7 +10,8 @@
"proxy-b", "proxy-b",
"proxy-c" "proxy-c"
], ],
"default": "proxy-c" "default": "proxy-c",
"interrupt_exist_connections": false
} }
``` ```
@ -29,3 +30,9 @@ List of outbound tags to select.
#### default #### default
The default outbound tag. The first outbound will be used if empty. The default outbound tag. The first outbound will be used if empty.
#### interrupt_exist_connections
Interrupt existing connections when the selected outbound has changed.
Only inbound connections are affected by this setting, internal connections will always be interrupted.

View file

@ -10,7 +10,8 @@
"proxy-b", "proxy-b",
"proxy-c" "proxy-c"
], ],
"default": "proxy-c" "default": "proxy-c",
"interrupt_exist_connections": false
} }
``` ```
@ -29,3 +30,9 @@
#### default #### default
默认的出站标签。默认使用第一个出站。 默认的出站标签。默认使用第一个出站。
#### interrupt_exist_connections
当选定的出站发生更改时,中断现有连接。
仅入站连接受此设置影响,内部连接将始终被中断。

View file

@ -12,7 +12,8 @@
], ],
"url": "https://www.gstatic.com/generate_204", "url": "https://www.gstatic.com/generate_204",
"interval": "1m", "interval": "1m",
"tolerance": 50 "tolerance": 50,
"interrupt_exist_connections": false
} }
``` ```
@ -35,3 +36,9 @@ The test interval. `1m` will be used if empty.
#### tolerance #### tolerance
The test tolerance in milliseconds. `50` will be used if empty. The test tolerance in milliseconds. `50` will be used if empty.
#### interrupt_exist_connections
Interrupt existing connections when the selected outbound has changed.
Only inbound connections are affected by this setting, internal connections will always be interrupted.

View file

@ -12,7 +12,8 @@
], ],
"url": "https://www.gstatic.com/generate_204", "url": "https://www.gstatic.com/generate_204",
"interval": "1m", "interval": "1m",
"tolerance": 50 "tolerance": 50,
"interrupt_exist_connections": false
} }
``` ```
@ -35,3 +36,9 @@
#### tolerance #### tolerance
以毫秒为单位的测试容差。 默认使用 `50` 以毫秒为单位的测试容差。 默认使用 `50`
#### interrupt_exist_connections
当选定的出站发生更改时,中断现有连接。
仅入站连接受此设置影响,内部连接将始终被中断。

View file

@ -17,13 +17,15 @@ type ClashAPIOptions struct {
} }
type SelectorOutboundOptions struct { type SelectorOutboundOptions struct {
Outbounds []string `json:"outbounds"` Outbounds []string `json:"outbounds"`
Default string `json:"default,omitempty"` Default string `json:"default,omitempty"`
InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"`
} }
type URLTestOutboundOptions struct { type URLTestOutboundOptions struct {
Outbounds []string `json:"outbounds"` Outbounds []string `json:"outbounds"`
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
Interval Duration `json:"interval,omitempty"` Interval Duration `json:"interval,omitempty"`
Tolerance uint16 `json:"tolerance,omitempty"` Tolerance uint16 `json:"tolerance,omitempty"`
InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"`
} }

View file

@ -5,6 +5,7 @@ import (
"net" "net"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/interrupt"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
@ -20,10 +21,12 @@ var (
type Selector struct { type Selector struct {
myOutboundAdapter myOutboundAdapter
tags []string tags []string
defaultTag string defaultTag string
outbounds map[string]adapter.Outbound outbounds map[string]adapter.Outbound
selected adapter.Outbound selected adapter.Outbound
interruptGroup *interrupt.Group
interruptExternalConnections bool
} }
func NewSelector(router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (*Selector, error) { func NewSelector(router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (*Selector, error) {
@ -35,9 +38,11 @@ func NewSelector(router adapter.Router, logger log.ContextLogger, tag string, op
tag: tag, tag: tag,
dependencies: options.Outbounds, dependencies: options.Outbounds,
}, },
tags: options.Outbounds, tags: options.Outbounds,
defaultTag: options.Default, defaultTag: options.Default,
outbounds: make(map[string]adapter.Outbound), outbounds: make(map[string]adapter.Outbound),
interruptGroup: interrupt.NewGroup(),
interruptExternalConnections: options.InterruptExistConnections,
} }
if len(outbound.tags) == 0 { if len(outbound.tags) == 0 {
return nil, E.New("missing tags") return nil, E.New("missing tags")
@ -100,6 +105,9 @@ func (s *Selector) SelectOutbound(tag string) bool {
if !loaded { if !loaded {
return false return false
} }
if s.selected == detour {
return true
}
s.selected = detour s.selected = detour
if s.tag != "" { if s.tag != "" {
if clashServer := s.router.ClashServer(); clashServer != nil && clashServer.StoreSelected() { if clashServer := s.router.ClashServer(); clashServer != nil && clashServer.StoreSelected() {
@ -109,22 +117,33 @@ func (s *Selector) SelectOutbound(tag string) bool {
} }
} }
} }
s.interruptGroup.Interrupt(s.interruptExternalConnections)
return true return true
} }
func (s *Selector) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (s *Selector) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return s.selected.DialContext(ctx, network, destination) conn, err := s.selected.DialContext(ctx, network, destination)
if err != nil {
return nil, err
}
return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
} }
func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return s.selected.ListenPacket(ctx, destination) conn, err := s.selected.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
} }
func (s *Selector) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (s *Selector) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
return s.selected.NewConnection(ctx, conn, metadata) return s.selected.NewConnection(ctx, conn, metadata)
} }
func (s *Selector) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (s *Selector) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
return s.selected.NewPacketConnection(ctx, conn, metadata) return s.selected.NewPacketConnection(ctx, conn, metadata)
} }

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/interrupt"
"github.com/sagernet/sing-box/common/urltest" "github.com/sagernet/sing-box/common/urltest"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@ -30,12 +31,13 @@ var (
type URLTest struct { type URLTest struct {
myOutboundAdapter myOutboundAdapter
ctx context.Context ctx context.Context
tags []string tags []string
link string link string
interval time.Duration interval time.Duration
tolerance uint16 tolerance uint16
group *URLTestGroup group *URLTestGroup
interruptExternalConnections bool
} }
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (*URLTest, error) { func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (*URLTest, error) {
@ -47,11 +49,12 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
tag: tag, tag: tag,
dependencies: options.Outbounds, dependencies: options.Outbounds,
}, },
ctx: ctx, ctx: ctx,
tags: options.Outbounds, tags: options.Outbounds,
link: options.URL, link: options.URL,
interval: time.Duration(options.Interval), interval: time.Duration(options.Interval),
tolerance: options.Tolerance, tolerance: options.Tolerance,
interruptExternalConnections: options.InterruptExistConnections,
} }
if len(outbound.tags) == 0 { if len(outbound.tags) == 0 {
return nil, E.New("missing tags") return nil, E.New("missing tags")
@ -75,7 +78,7 @@ func (s *URLTest) Start() error {
} }
outbounds = append(outbounds, detour) outbounds = append(outbounds, detour)
} }
s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance) s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance, s.interruptExternalConnections)
return nil return nil
} }
@ -111,7 +114,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M
outbound := s.group.Select(network) outbound := s.group.Select(network)
conn, err := outbound.DialContext(ctx, network, destination) conn, err := outbound.DialContext(ctx, network, destination)
if err == nil { if err == nil {
return conn, nil return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
} }
s.logger.ErrorContext(ctx, err) s.logger.ErrorContext(ctx, err)
s.group.history.DeleteURLTestHistory(outbound.Tag()) s.group.history.DeleteURLTestHistory(outbound.Tag())
@ -123,7 +126,7 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
outbound := s.group.Select(N.NetworkUDP) outbound := s.group.Select(N.NetworkUDP)
conn, err := outbound.ListenPacket(ctx, destination) conn, err := outbound.ListenPacket(ctx, destination)
if err == nil { if err == nil {
return conn, nil return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
} }
s.logger.ErrorContext(ctx, err) s.logger.ErrorContext(ctx, err)
s.group.history.DeleteURLTestHistory(outbound.Tag()) s.group.history.DeleteURLTestHistory(outbound.Tag())
@ -131,10 +134,12 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
} }
func (s *URLTest) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (s *URLTest) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
return NewConnection(ctx, s, conn, metadata) return NewConnection(ctx, s, conn, metadata)
} }
func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
ctx = interrupt.ContextWithIsExternalConnection(ctx)
return NewPacketConnection(ctx, s, conn, metadata) return NewPacketConnection(ctx, s, conn, metadata)
} }
@ -144,23 +149,36 @@ func (s *URLTest) InterfaceUpdated() {
} }
type URLTestGroup struct { type URLTestGroup struct {
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
logger log.Logger logger log.Logger
outbounds []adapter.Outbound outbounds []adapter.Outbound
link string link string
interval time.Duration interval time.Duration
tolerance uint16 tolerance uint16
history *urltest.HistoryStorage history *urltest.HistoryStorage
checking atomic.Bool checking atomic.Bool
pauseManager pause.Manager pauseManager pause.Manager
selectedOutboundTCP adapter.Outbound
selectedOutboundUDP adapter.Outbound
interruptGroup *interrupt.Group
interruptExternalConnections bool
access sync.Mutex access sync.Mutex
ticker *time.Ticker ticker *time.Ticker
close chan struct{} close chan struct{}
} }
func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logger, outbounds []adapter.Outbound, link string, interval time.Duration, tolerance uint16) *URLTestGroup { func NewURLTestGroup(
ctx context.Context,
router adapter.Router,
logger log.Logger,
outbounds []adapter.Outbound,
link string,
interval time.Duration,
tolerance uint16,
interruptExternalConnections bool,
) *URLTestGroup {
if interval == 0 { if interval == 0 {
interval = C.DefaultURLTestInterval interval = C.DefaultURLTestInterval
} }
@ -175,16 +193,18 @@ func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logg
history = urltest.NewHistoryStorage() history = urltest.NewHistoryStorage()
} }
return &URLTestGroup{ return &URLTestGroup{
ctx: ctx, ctx: ctx,
router: router, router: router,
logger: logger, logger: logger,
outbounds: outbounds, outbounds: outbounds,
link: link, link: link,
interval: interval, interval: interval,
tolerance: tolerance, tolerance: tolerance,
history: history, history: history,
close: make(chan struct{}), close: make(chan struct{}),
pauseManager: pause.ManagerFromContext(ctx), pauseManager: pause.ManagerFromContext(ctx),
interruptGroup: interrupt.NewGroup(),
interruptExternalConnections: interruptExternalConnections,
} }
} }
@ -329,5 +349,23 @@ func (g *URLTestGroup) urlTest(ctx context.Context, link string, force bool) (ma
}) })
} }
b.Wait() b.Wait()
g.performUpdateCheck()
return result, nil return result, nil
} }
func (g *URLTestGroup) performUpdateCheck() {
outbound := g.Select(N.NetworkTCP)
var updated bool
if outbound != g.selectedOutboundTCP {
g.selectedOutboundTCP = outbound
updated = true
}
outbound = g.Select(N.NetworkUDP)
if outbound != g.selectedOutboundUDP {
g.selectedOutboundUDP = outbound
updated = true
}
if updated {
g.interruptGroup.Interrupt(g.interruptExternalConnections)
}
}