Improve timeouts

This commit is contained in:
世界 2024-11-27 18:08:19 +08:00
parent ec310170cc
commit 705c23866a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 105 additions and 346 deletions

View file

@ -56,12 +56,18 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil { if err != nil {
return nil, err return nil, err
} }
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{ wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx, Context: ctx,
Logger: logger, Logger: logger,
System: options.System, System: options.System,
Handler: ep, Handler: ep,
UDPTimeout: time.Duration(options.UDPTimeout), UDPTimeout: udpTimeout,
Dialer: outboundDialer, Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer { CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{ return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{

View file

@ -5,6 +5,7 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -18,31 +19,35 @@ import (
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
) )
var _ adapter.ConnectionManager = (*ConnectionManager)(nil) var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
type ConnectionManager struct { type ConnectionManager struct {
logger logger.ContextLogger logger logger.ContextLogger
monitor *ConnectionMonitor access sync.Mutex
connections list.List[io.Closer]
} }
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
return &ConnectionManager{ return &ConnectionManager{
logger: logger, logger: logger,
monitor: NewConnectionMonitor(),
} }
} }
func (m *ConnectionManager) Start(stage adapter.StartStage) error { func (m *ConnectionManager) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateInitialize { return nil
return nil
}
return m.monitor.Start()
} }
func (m *ConnectionManager) Close() error { func (m *ConnectionManager) Close() error {
return m.monitor.Close() m.access.Lock()
defer m.access.Unlock()
for element := m.connections.Front(); element != nil; element = element.Next() {
common.Close(element.Value)
}
m.connections.Init()
return nil
} }
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
@ -57,95 +62,32 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
} }
if err != nil { if err != nil {
err = E.Cause(err, "open outbound connection")
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "open outbound connection: ", err) m.logger.ErrorContext(ctx, err)
return return
} }
err = N.ReportConnHandshakeSuccess(conn, remoteConn) err = N.ReportConnHandshakeSuccess(conn, remoteConn)
if err != nil { if err != nil {
err = E.Cause(err, "report handshake success")
remoteConn.Close() remoteConn.Close()
N.CloseOnHandshakeFailure(conn, onClose, err) N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "report handshake success: ", err) m.logger.ErrorContext(ctx, err)
return return
} }
m.access.Lock()
element := m.connections.PushBack(conn)
m.access.Unlock()
onClose = N.AppendClose(onClose, func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
})
var done atomic.Bool var done atomic.Bool
if ctx.Done() != nil {
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
}
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
} }
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
originSource := source
originDestination := destination
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
_, err := destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
if done.Swap(true) {
if onClose != nil {
onClose(err)
}
}
common.Close(originSource, originDestination)
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
continue
}
break
}
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
if err != nil {
common.Close(originSource, originDestination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
err = duplexDst.CloseWrite()
if err != nil {
common.Close(originSource, originDestination)
}
} else {
common.Close(originDestination)
}
if done.Swap(true) {
if onClose != nil {
onClose(err)
}
common.Close(originSource, originDestination)
}
if !direction {
if err == nil {
m.logger.DebugContext(ctx, "connection upload finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection upload closed")
}
} else {
if err == nil {
m.logger.DebugContext(ctx, "connection download finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection download closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection download closed")
}
}
}
func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var ( var (
@ -227,58 +169,91 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout) ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
} }
destination := bufio.NewPacketConn(remotePacketConn) destination := bufio.NewPacketConn(remotePacketConn)
m.access.Lock()
element := m.connections.PushBack(conn)
m.access.Unlock()
onClose = N.AppendClose(onClose, func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
})
var done atomic.Bool var done atomic.Bool
if ctx.Done() != nil {
onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
}
go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose) go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
} }
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacket(destination, source)
/*var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
originSource := source originSource := source
originDestination := destination
var readCounters, writeCounters []N.CountFunc
for { for {
source, readCounters = N.UnwrapCountPacketReader(source, readCounters) source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters) destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedReader, isCached := source.(N.CachedPacketReader); isCached { if cachedSrc, isCached := source.(N.CachedReader); isCached {
packet := cachedReader.ReadCachedPacket() cachedBuffer := cachedSrc.ReadCached()
if packet != nil { if cachedBuffer != nil {
cachedPackets = append(cachedPackets, packet) dataLen := cachedBuffer.Len()
continue _, err := destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
if done.Swap(true) {
onClose(err)
}
common.Close(originSource, originDestination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
} else {
m.logger.ErrorContext(ctx, "connection download payload: ", err)
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
} }
continue
} }
break break
} }
var handled bool _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
if natConn, isNatConn := source.(udpnat.Conn); isNatConn { if err != nil {
natConn.SetHandler(&udpHijacker{ common.Close(originDestination)
ctx: ctx, } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
logger: m.logger, err = duplexDst.CloseWrite()
source: natConn,
destination: destination,
direction: direction,
readCounters: readCounters,
writeCounters: writeCounters,
done: done,
onClose: onClose,
})
handled = true
}
if cachedPackets != nil {
_, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
if err != nil { if err != nil {
common.Close(source, destination) common.Close(originSource, originDestination)
m.logger.ErrorContext(ctx, "packet upload payload: ", err) }
return } else {
common.Close(originDestination)
}
if done.Swap(true) {
onClose(err)
common.Close(originSource, originDestination)
}
if !direction {
if err == nil {
m.logger.DebugContext(ctx, "connection upload finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection upload closed")
}
} else {
if err == nil {
m.logger.DebugContext(ctx, "connection download finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection download closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection download closed")
} }
} }
if handled { }
return
} func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/ _, err := bufio.CopyPacket(destination, source)
if !direction { if !direction {
if E.IsClosedOrCanceled(err) { if E.IsClosedOrCanceled(err) {
m.logger.TraceContext(ctx, "packet upload closed") m.logger.TraceContext(ctx, "packet upload closed")
@ -293,58 +268,7 @@ func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.P
} }
} }
if !done.Swap(true) { if !done.Swap(true) {
if onClose != nil { onClose(err)
onClose(err)
}
} }
common.Close(source, destination) common.Close(source, destination)
} }
/*type udpHijacker struct {
ctx context.Context
logger logger.ContextLogger
source io.Closer
destination N.PacketWriter
direction bool
readCounters []N.CountFunc
writeCounters []N.CountFunc
done *atomic.Bool
onClose N.CloseHandlerFunc
}
func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
dataLen := buffer.Len()
for _, counter := range u.readCounters {
counter(int64(dataLen))
}
err := u.destination.WritePacket(buffer, source)
if err != nil {
common.Close(u.source, u.destination)
u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
return
}
for _, counter := range u.writeCounters {
counter(int64(dataLen))
}
}
func (u *udpHijacker) Close() error {
var err error
if !u.done.Swap(true) {
err = common.Close(u.source, u.destination)
if u.onClose != nil {
u.onClose(net.ErrClosed)
}
}
if u.direction {
u.logger.TraceContext(u.ctx, "packet download closed")
} else {
u.logger.TraceContext(u.ctx, "packet upload closed")
}
return err
}
func (u *udpHijacker) Upstream() any {
return u.destination
}
*/

View file

@ -1,128 +0,0 @@
package route
import (
"context"
"io"
"reflect"
"sync"
"time"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
type ConnectionMonitor struct {
access sync.RWMutex
reloadChan chan struct{}
connections list.List[*monitorEntry]
}
type monitorEntry struct {
ctx context.Context
closer io.Closer
}
func NewConnectionMonitor() *ConnectionMonitor {
return &ConnectionMonitor{
reloadChan: make(chan struct{}, 1),
}
}
func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc {
m.access.Lock()
defer m.access.Unlock()
element := m.connections.PushBack(&monitorEntry{
ctx: ctx,
closer: closer,
})
select {
case <-m.reloadChan:
return nil
default:
select {
case m.reloadChan <- struct{}{}:
default:
}
}
return func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
select {
case <-m.reloadChan:
default:
select {
case m.reloadChan <- struct{}{}:
default:
}
}
}
}
func (m *ConnectionMonitor) Start() error {
go m.monitor()
return nil
}
func (m *ConnectionMonitor) Close() error {
m.access.Lock()
defer m.access.Unlock()
close(m.reloadChan)
for element := m.connections.Front(); element != nil; element = element.Next() {
element.Value.closer.Close()
}
return nil
}
func (m *ConnectionMonitor) monitor() {
var (
selectCases []reflect.SelectCase
elements []*list.Element[*monitorEntry]
)
rootCase := reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(m.reloadChan),
}
for {
m.access.RLock()
if m.connections.Len() == 0 {
m.access.RUnlock()
if _, loaded := <-m.reloadChan; !loaded {
return
} else {
continue
}
}
if len(elements) < m.connections.Len() {
elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len())
}
if len(selectCases) < m.connections.Len()+1 {
selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1)
}
elements = elements[:0]
selectCases = selectCases[:1]
selectCases[0] = rootCase
for element := m.connections.Front(); element != nil; element = element.Next() {
elements = append(elements, element)
selectCases = append(selectCases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(element.Value.ctx.Done()),
})
}
m.access.RUnlock()
selected, _, loaded := reflect.Select(selectCases)
if selected == 0 {
if !loaded {
return
} else {
time.Sleep(time.Second)
continue
}
}
element := elements[selected-1]
m.access.Lock()
m.connections.Remove(element)
m.access.Unlock()
element.Value.closer.Close() // maybe go close
}
}

View file

@ -1,43 +0,0 @@
package route_test
import (
"context"
"sync"
"testing"
"time"
"github.com/sagernet/sing-box/route"
"github.com/stretchr/testify/require"
)
func TestMonitor(t *testing.T) {
t.Parallel()
var closer myCloser
closer.Add(1)
monitor := route.NewConnectionMonitor()
require.NoError(t, monitor.Start())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
monitor.Add(ctx, &closer)
done := make(chan struct{})
go func() {
closer.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(time.Second + 100*time.Millisecond):
t.Fatal("timeout")
}
cancel()
require.NoError(t, monitor.Close())
}
type myCloser struct {
sync.WaitGroup
}
func (c *myCloser) Close() error {
c.Done()
return nil
}