From 62a7cbe47444bb139e39605aaf4a01852f7aa9b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 27 Nov 2024 18:08:19 +0800 Subject: [PATCH] Improve timeouts --- protocol/wireguard/endpoint.go | 8 +- route/conn.go | 272 ++++++++++++--------------------- route/conn_monitor.go | 128 ---------------- route/conn_monitor_test.go | 43 ------ 4 files changed, 105 insertions(+), 346 deletions(-) delete mode 100644 route/conn_monitor.go delete mode 100644 route/conn_monitor_test.go diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index dc40b613..937f84dd 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -56,12 +56,18 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } + var udpTimeout time.Duration + if options.UDPTimeout != 0 { + udpTimeout = time.Duration(options.UDPTimeout) + } else { + udpTimeout = C.UDPTimeout + } wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{ Context: ctx, Logger: logger, System: options.System, Handler: ep, - UDPTimeout: time.Duration(options.UDPTimeout), + UDPTimeout: udpTimeout, Dialer: outboundDialer, CreateDialer: func(interfaceName string) N.Dialer { return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{ diff --git a/route/conn.go b/route/conn.go index 4a2192e0..93ac33e3 100644 --- a/route/conn.go +++ b/route/conn.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/netip" + "sync" "sync/atomic" "time" @@ -18,31 +19,35 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" ) var _ adapter.ConnectionManager = (*ConnectionManager)(nil) type ConnectionManager struct { - logger logger.ContextLogger - monitor *ConnectionMonitor + logger logger.ContextLogger + access sync.Mutex + connections list.List[io.Closer] } func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { return &ConnectionManager{ - logger: logger, - monitor: NewConnectionMonitor(), + logger: logger, } } func (m *ConnectionManager) Start(stage adapter.StartStage) error { - if stage != adapter.StartStateInitialize { - return nil - } - return m.monitor.Start() + return nil } func (m *ConnectionManager) Close() error { - return m.monitor.Close() + m.access.Lock() + defer m.access.Unlock() + for element := m.connections.Front(); element != nil; element = element.Next() { + common.Close(element.Value) + } + m.connections.Init() + return nil } func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { @@ -57,95 +62,32 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } if err != nil { + err = E.Cause(err, "open outbound connection") N.CloseOnHandshakeFailure(conn, onClose, err) - m.logger.ErrorContext(ctx, "open outbound connection: ", err) + m.logger.ErrorContext(ctx, err) return } err = N.ReportConnHandshakeSuccess(conn, remoteConn) if err != nil { + err = E.Cause(err, "report handshake success") remoteConn.Close() N.CloseOnHandshakeFailure(conn, onClose, err) - m.logger.ErrorContext(ctx, "report handshake success: ", err) + m.logger.ErrorContext(ctx, err) return } + m.access.Lock() + element := m.connections.PushBack(conn) + m.access.Unlock() + onClose = N.AppendClose(onClose, func(it error) { + m.access.Lock() + defer m.access.Unlock() + m.connections.Remove(element) + }) var done atomic.Bool - if ctx.Done() != nil { - onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) - } go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) } -func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { - originSource := source - originDestination := destination - var readCounters, writeCounters []N.CountFunc - for { - source, readCounters = N.UnwrapCountReader(source, readCounters) - destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) - if cachedSrc, isCached := source.(N.CachedReader); isCached { - cachedBuffer := cachedSrc.ReadCached() - if cachedBuffer != nil { - dataLen := cachedBuffer.Len() - _, err := destination.Write(cachedBuffer.Bytes()) - cachedBuffer.Release() - if err != nil { - m.logger.ErrorContext(ctx, "connection upload payload: ", err) - if done.Swap(true) { - if onClose != nil { - onClose(err) - } - } - common.Close(originSource, originDestination) - return - } - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - } - continue - } - break - } - _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) - if err != nil { - common.Close(originSource, originDestination) - } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { - err = duplexDst.CloseWrite() - if err != nil { - common.Close(originSource, originDestination) - } - } else { - common.Close(originDestination) - } - if done.Swap(true) { - if onClose != nil { - onClose(err) - } - common.Close(originSource, originDestination) - } - if !direction { - if err == nil { - m.logger.DebugContext(ctx, "connection upload finished") - } else if !E.IsClosedOrCanceled(err) { - m.logger.ErrorContext(ctx, "connection upload closed: ", err) - } else { - m.logger.TraceContext(ctx, "connection upload closed") - } - } else { - if err == nil { - m.logger.DebugContext(ctx, "connection download finished") - } else if !E.IsClosedOrCanceled(err) { - m.logger.ErrorContext(ctx, "connection download closed: ", err) - } else { - m.logger.TraceContext(ctx, "connection download closed") - } - } -} - func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = adapter.WithContext(ctx, &metadata) var ( @@ -227,58 +169,91 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout) } destination := bufio.NewPacketConn(remotePacketConn) + m.access.Lock() + element := m.connections.PushBack(conn) + m.access.Unlock() + onClose = N.AppendClose(onClose, func(it error) { + m.access.Lock() + defer m.access.Unlock() + m.connections.Remove(element) + }) var done atomic.Bool - if ctx.Done() != nil { - onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) - } go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose) go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) } -func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { - _, err := bufio.CopyPacket(destination, source) - /*var readCounters, writeCounters []N.CountFunc - var cachedPackets []*N.PacketBuffer +func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { originSource := source + originDestination := destination + var readCounters, writeCounters []N.CountFunc for { - source, readCounters = N.UnwrapCountPacketReader(source, readCounters) - destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters) - if cachedReader, isCached := source.(N.CachedPacketReader); isCached { - packet := cachedReader.ReadCachedPacket() - if packet != nil { - cachedPackets = append(cachedPackets, packet) - continue + source, readCounters = N.UnwrapCountReader(source, readCounters) + destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) + if cachedSrc, isCached := source.(N.CachedReader); isCached { + cachedBuffer := cachedSrc.ReadCached() + if cachedBuffer != nil { + dataLen := cachedBuffer.Len() + _, err := destination.Write(cachedBuffer.Bytes()) + cachedBuffer.Release() + if err != nil { + if done.Swap(true) { + onClose(err) + } + common.Close(originSource, originDestination) + if !direction { + m.logger.ErrorContext(ctx, "connection upload payload: ", err) + } else { + m.logger.ErrorContext(ctx, "connection download payload: ", err) + } + return + } + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } } + continue } break } - var handled bool - if natConn, isNatConn := source.(udpnat.Conn); isNatConn { - natConn.SetHandler(&udpHijacker{ - ctx: ctx, - logger: m.logger, - source: natConn, - destination: destination, - direction: direction, - readCounters: readCounters, - writeCounters: writeCounters, - done: done, - onClose: onClose, - }) - handled = true - } - if cachedPackets != nil { - _, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters) + _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) + if err != nil { + common.Close(originDestination) + } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { + err = duplexDst.CloseWrite() if err != nil { - common.Close(source, destination) - m.logger.ErrorContext(ctx, "packet upload payload: ", err) - return + common.Close(originSource, originDestination) + } + } else { + common.Close(originDestination) + } + if done.Swap(true) { + onClose(err) + common.Close(originSource, originDestination) + } + if !direction { + if err == nil { + m.logger.DebugContext(ctx, "connection upload finished") + } else if !E.IsClosedOrCanceled(err) { + m.logger.ErrorContext(ctx, "connection upload closed: ", err) + } else { + m.logger.TraceContext(ctx, "connection upload closed") + } + } else { + if err == nil { + m.logger.DebugContext(ctx, "connection download finished") + } else if !E.IsClosedOrCanceled(err) { + m.logger.ErrorContext(ctx, "connection download closed: ", err) + } else { + m.logger.TraceContext(ctx, "connection download closed") } } - if handled { - return - } - _, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/ +} + +func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { + _, err := bufio.CopyPacket(destination, source) if !direction { if E.IsClosedOrCanceled(err) { m.logger.TraceContext(ctx, "packet upload closed") @@ -293,58 +268,7 @@ func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.P } } if !done.Swap(true) { - if onClose != nil { - onClose(err) - } + onClose(err) } common.Close(source, destination) } - -/*type udpHijacker struct { - ctx context.Context - logger logger.ContextLogger - source io.Closer - destination N.PacketWriter - direction bool - readCounters []N.CountFunc - writeCounters []N.CountFunc - done *atomic.Bool - onClose N.CloseHandlerFunc -} - -func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { - dataLen := buffer.Len() - for _, counter := range u.readCounters { - counter(int64(dataLen)) - } - err := u.destination.WritePacket(buffer, source) - if err != nil { - common.Close(u.source, u.destination) - u.logger.DebugContext(u.ctx, "packet upload closed: ", err) - return - } - for _, counter := range u.writeCounters { - counter(int64(dataLen)) - } -} - -func (u *udpHijacker) Close() error { - var err error - if !u.done.Swap(true) { - err = common.Close(u.source, u.destination) - if u.onClose != nil { - u.onClose(net.ErrClosed) - } - } - if u.direction { - u.logger.TraceContext(u.ctx, "packet download closed") - } else { - u.logger.TraceContext(u.ctx, "packet upload closed") - } - return err -} - -func (u *udpHijacker) Upstream() any { - return u.destination -} -*/ diff --git a/route/conn_monitor.go b/route/conn_monitor.go deleted file mode 100644 index 9e271b82..00000000 --- a/route/conn_monitor.go +++ /dev/null @@ -1,128 +0,0 @@ -package route - -import ( - "context" - "io" - "reflect" - "sync" - "time" - - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/x/list" -) - -type ConnectionMonitor struct { - access sync.RWMutex - reloadChan chan struct{} - connections list.List[*monitorEntry] -} - -type monitorEntry struct { - ctx context.Context - closer io.Closer -} - -func NewConnectionMonitor() *ConnectionMonitor { - return &ConnectionMonitor{ - reloadChan: make(chan struct{}, 1), - } -} - -func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc { - m.access.Lock() - defer m.access.Unlock() - element := m.connections.PushBack(&monitorEntry{ - ctx: ctx, - closer: closer, - }) - select { - case <-m.reloadChan: - return nil - default: - select { - case m.reloadChan <- struct{}{}: - default: - } - } - return func(it error) { - m.access.Lock() - defer m.access.Unlock() - m.connections.Remove(element) - select { - case <-m.reloadChan: - default: - select { - case m.reloadChan <- struct{}{}: - default: - } - } - } -} - -func (m *ConnectionMonitor) Start() error { - go m.monitor() - return nil -} - -func (m *ConnectionMonitor) Close() error { - m.access.Lock() - defer m.access.Unlock() - close(m.reloadChan) - for element := m.connections.Front(); element != nil; element = element.Next() { - element.Value.closer.Close() - } - return nil -} - -func (m *ConnectionMonitor) monitor() { - var ( - selectCases []reflect.SelectCase - elements []*list.Element[*monitorEntry] - ) - rootCase := reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(m.reloadChan), - } - for { - m.access.RLock() - if m.connections.Len() == 0 { - m.access.RUnlock() - if _, loaded := <-m.reloadChan; !loaded { - return - } else { - continue - } - } - if len(elements) < m.connections.Len() { - elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len()) - } - if len(selectCases) < m.connections.Len()+1 { - selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1) - } - elements = elements[:0] - selectCases = selectCases[:1] - selectCases[0] = rootCase - for element := m.connections.Front(); element != nil; element = element.Next() { - elements = append(elements, element) - selectCases = append(selectCases, reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(element.Value.ctx.Done()), - }) - } - m.access.RUnlock() - selected, _, loaded := reflect.Select(selectCases) - if selected == 0 { - if !loaded { - return - } else { - time.Sleep(time.Second) - continue - } - } - element := elements[selected-1] - m.access.Lock() - m.connections.Remove(element) - m.access.Unlock() - element.Value.closer.Close() // maybe go close - } -} diff --git a/route/conn_monitor_test.go b/route/conn_monitor_test.go deleted file mode 100644 index a712bddc..00000000 --- a/route/conn_monitor_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package route_test - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/sagernet/sing-box/route" - - "github.com/stretchr/testify/require" -) - -func TestMonitor(t *testing.T) { - t.Parallel() - var closer myCloser - closer.Add(1) - monitor := route.NewConnectionMonitor() - require.NoError(t, monitor.Start()) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - monitor.Add(ctx, &closer) - done := make(chan struct{}) - go func() { - closer.Wait() - close(done) - }() - select { - case <-done: - case <-time.After(time.Second + 100*time.Millisecond): - t.Fatal("timeout") - } - cancel() - require.NoError(t, monitor.Close()) -} - -type myCloser struct { - sync.WaitGroup -} - -func (c *myCloser) Close() error { - c.Done() - return nil -}