diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index 4f8fa88d..2d330919 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -38,6 +38,7 @@ type Server struct { trafficManager *trafficontrol.Manager urlTestHistory *urltest.HistoryStorage tcpListener net.Listener + directIO bool mode string storeSelected bool cacheFile adapter.ClashCacheFile @@ -55,6 +56,7 @@ func NewServer(router adapter.Router, logFactory log.ObservableFactory, options }, trafficManager: trafficManager, urlTestHistory: urltest.NewHistoryStorage(), + directIO: options.DirectIO, mode: strings.ToLower(options.DefaultMode), } if server.mode == "" { @@ -149,7 +151,7 @@ func (s *Server) HistoryStorage() *urltest.HistoryStorage { } func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) { - tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule) + tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule, s.directIO) return tracker, tracker } diff --git a/experimental/clashapi/trafficontrol/tracker.go b/experimental/clashapi/trafficontrol/tracker.go index 3155f340..000667b5 100644 --- a/experimental/clashapi/trafficontrol/tracker.go +++ b/experimental/clashapi/trafficontrol/tracker.go @@ -6,9 +6,8 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/experimental/trackerconn" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/gofrs/uuid" @@ -45,7 +44,7 @@ type trackerInfo struct { } type tcpTracker struct { - net.Conn `json:"-"` + N.ExtendedConn `json:"-"` *trackerInfo manager *Manager } @@ -54,25 +53,9 @@ func (tt *tcpTracker) ID() string { return tt.UUID.String() } -func (tt *tcpTracker) Read(b []byte) (int, error) { - n, err := tt.Conn.Read(b) - upload := int64(n) - tt.manager.PushUploaded(upload) - tt.UploadTotal.Add(upload) - return n, err -} - -func (tt *tcpTracker) Write(b []byte) (int, error) { - n, err := tt.Conn.Write(b) - download := int64(n) - tt.manager.PushDownloaded(download) - tt.DownloadTotal.Add(download) - return n, err -} - func (tt *tcpTracker) Close() error { tt.manager.Leave(tt) - return tt.Conn.Close() + return tt.ExtendedConn.Close() } func (tt *tcpTracker) Leave() { @@ -80,10 +63,18 @@ func (tt *tcpTracker) Leave() { } func (tt *tcpTracker) Upstream() any { - return tt.Conn + return tt.ExtendedConn } -func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router adapter.Router, rule adapter.Rule) *tcpTracker { +func (tt *tcpTracker) ReaderReplaceable() bool { + return true +} + +func (tt *tcpTracker) WriterReplaceable() bool { + return true +} + +func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router adapter.Router, rule adapter.Rule, directIO bool) *tcpTracker { uuid, _ := uuid.NewV4() var chain []string @@ -106,17 +97,20 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad next = group.Now() } + upload := atomic.NewInt64(0) + download := atomic.NewInt64(0) + t := &tcpTracker{ - Conn: conn, - manager: manager, + ExtendedConn: trackerconn.New(conn, upload, download, directIO), + manager: manager, trackerInfo: &trackerInfo{ UUID: uuid, Start: time.Now(), Metadata: metadata, Chain: common.Reverse(chain), Rule: "", - UploadTotal: atomic.NewInt64(0), - DownloadTotal: atomic.NewInt64(0), + UploadTotal: upload, + DownloadTotal: download, }, } @@ -140,27 +134,6 @@ func (ut *udpTracker) ID() string { return ut.UUID.String() } -func (ut *udpTracker) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - destination, err = ut.PacketConn.ReadPacket(buffer) - if err == nil { - upload := int64(buffer.Len()) - ut.manager.PushUploaded(upload) - ut.UploadTotal.Add(upload) - } - return -} - -func (ut *udpTracker) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - download := int64(buffer.Len()) - err := ut.PacketConn.WritePacket(buffer, destination) - if err != nil { - return err - } - ut.manager.PushDownloaded(download) - ut.DownloadTotal.Add(download) - return nil -} - func (ut *udpTracker) Close() error { ut.manager.Leave(ut) return ut.PacketConn.Close() @@ -174,6 +147,14 @@ func (ut *udpTracker) Upstream() any { return ut.PacketConn } +func (ut *udpTracker) ReaderReplaceable() bool { + return true +} + +func (ut *udpTracker) WriterReplaceable() bool { + return true +} + func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, router adapter.Router, rule adapter.Rule) *udpTracker { uuid, _ := uuid.NewV4() @@ -197,8 +178,11 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route next = group.Now() } + upload := atomic.NewInt64(0) + download := atomic.NewInt64(0) + ut := &udpTracker{ - PacketConn: conn, + PacketConn: trackerconn.NewPacket(conn, upload, download), manager: manager, trackerInfo: &trackerInfo{ UUID: uuid, @@ -206,8 +190,8 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route Metadata: metadata, Chain: common.Reverse(chain), Rule: "", - UploadTotal: atomic.NewInt64(0), - DownloadTotal: atomic.NewInt64(0), + UploadTotal: upload, + DownloadTotal: download, }, } diff --git a/experimental/trackerconn/conn.go b/experimental/trackerconn/conn.go new file mode 100644 index 00000000..56c029ca --- /dev/null +++ b/experimental/trackerconn/conn.go @@ -0,0 +1,82 @@ +package trackerconn + +import ( + "io" + "net" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + N "github.com/sagernet/sing/common/network" + + "go.uber.org/atomic" +) + +func New(conn net.Conn, readCounter *atomic.Int64, writeCounter *atomic.Int64, direct bool) N.ExtendedConn { + trackerConn := &Conn{bufio.NewExtendedConn(conn), readCounter, writeCounter} + if direct { + return (*DirectConn)(trackerConn) + } else { + return trackerConn + } +} + +type Conn struct { + N.ExtendedConn + readCounter *atomic.Int64 + writeCounter *atomic.Int64 +} + +func (c *Conn) Read(p []byte) (n int, err error) { + n, err = c.ExtendedConn.Read(p) + c.readCounter.Add(int64(n)) + return n, err +} + +func (c *Conn) ReadBuffer(buffer *buf.Buffer) error { + err := c.ExtendedConn.ReadBuffer(buffer) + if err != nil { + return err + } + c.readCounter.Add(int64(buffer.Len())) + return nil +} + +func (c *Conn) Write(p []byte) (n int, err error) { + n, err = c.ExtendedConn.Write(p) + c.writeCounter.Add(int64(n)) + return n, err +} + +func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { + dataLen := int64(buffer.Len()) + err := c.ExtendedConn.WriteBuffer(buffer) + if err != nil { + return err + } + c.writeCounter.Add(dataLen) + return nil +} + +type DirectConn Conn + +func (c *DirectConn) WriteTo(w io.Writer) (n int64, err error) { + reader := N.UnwrapReader(c.ExtendedConn) + if wt, ok := reader.(io.WriterTo); ok { + n, err = wt.WriteTo(w) + c.readCounter.Add(n) + return + } else { + return bufio.Copy(w, (*Conn)(c)) + } +} + +func (c *DirectConn) ReadFrom(r io.Reader) (n int64, err error) { + writer := N.UnwrapWriter(c.ExtendedConn) + if rt, ok := writer.(io.ReaderFrom); ok { + n, err = rt.ReadFrom(r) + c.writeCounter.Add(n) + return + } else { + return bufio.Copy((*Conn)(c), r) + } +} diff --git a/experimental/trackerconn/packet_conn.go b/experimental/trackerconn/packet_conn.go new file mode 100644 index 00000000..5d9e4164 --- /dev/null +++ b/experimental/trackerconn/packet_conn.go @@ -0,0 +1,37 @@ +package trackerconn + +import ( + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "go.uber.org/atomic" +) + +type PacketConn struct { + N.PacketConn + readCounter *atomic.Int64 + writeCounter *atomic.Int64 +} + +func NewPacket(conn N.PacketConn, readCounter *atomic.Int64, writeCounter *atomic.Int64) *PacketConn { + return &PacketConn{conn, readCounter, writeCounter} +} + +func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.PacketConn.ReadPacket(buffer) + if err == nil { + c.readCounter.Add(int64(buffer.Len())) + } + return +} + +func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + dataLen := int64(buffer.Len()) + err := c.PacketConn.WritePacket(buffer, destination) + if err != nil { + return err + } + c.writeCounter.Add(dataLen) + return nil +} diff --git a/option/clash.go b/option/clash.go index df4f4978..a473e21a 100644 --- a/option/clash.go +++ b/option/clash.go @@ -5,6 +5,7 @@ type ClashAPIOptions struct { ExternalUI string `json:"external_ui,omitempty"` Secret string `json:"secret,omitempty"` + DirectIO bool `json:"direct_io,omitempty"` DefaultMode string `json:"default_mode,omitempty"` StoreSelected bool `json:"store_selected,omitempty"` CacheFile string `json:"cache_file,omitempty"`