diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index d189ec6f..207facf6 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -31,9 +31,7 @@ import ( ) func init() { - experimental.RegisterClashServerConstructor(func(router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { - return NewServer(router, logFactory, options) - }) + experimental.RegisterClashServerConstructor(NewServer) } var _ adapter.ClashServer = (*Server)(nil) @@ -51,7 +49,7 @@ type Server struct { cacheFile adapter.ClashCacheFile } -func NewServer(router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (*Server, error) { +func NewServer(router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { trafficManager := trafficontrol.NewManager() chiRouter := chi.NewRouter() server := &Server{ diff --git a/experimental/clashapi/trafficontrol/tracker.go b/experimental/clashapi/trafficontrol/tracker.go index 000667b5..d12f508d 100644 --- a/experimental/clashapi/trafficontrol/tracker.go +++ b/experimental/clashapi/trafficontrol/tracker.go @@ -101,8 +101,14 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad download := atomic.NewInt64(0) t := &tcpTracker{ - ExtendedConn: trackerconn.New(conn, upload, download, directIO), - manager: manager, + ExtendedConn: trackerconn.NewHook(conn, func(n int64) { + upload.Add(n) + manager.PushUploaded(n) + }, func(n int64) { + download.Add(n) + manager.PushDownloaded(n) + }, directIO), + manager: manager, trackerInfo: &trackerInfo{ UUID: uuid, Start: time.Now(), @@ -182,8 +188,14 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route download := atomic.NewInt64(0) ut := &udpTracker{ - PacketConn: trackerconn.NewPacket(conn, upload, download), - manager: manager, + PacketConn: trackerconn.NewHookPacket(conn, func(n int64) { + upload.Add(n) + manager.PushUploaded(n) + }, func(n int64) { + download.Add(n) + manager.PushDownloaded(n) + }), + manager: manager, trackerInfo: &trackerInfo{ UUID: uuid, Start: time.Now(), diff --git a/experimental/trackerconn/conn.go b/experimental/trackerconn/conn.go index dd54b6e7..8c04b2d4 100644 --- a/experimental/trackerconn/conn.go +++ b/experimental/trackerconn/conn.go @@ -20,6 +20,15 @@ func New(conn net.Conn, readCounter *atomic.Int64, writeCounter *atomic.Int64, d } } +func NewHook(conn net.Conn, readCounter func(n int64), writeCounter func(n int64), direct bool) N.ExtendedConn { + trackerConn := &HookConn{bufio.NewExtendedConn(conn), readCounter, writeCounter} + if direct { + return (*DirectHookConn)(trackerConn) + } else { + return trackerConn + } +} + type Conn struct { N.ExtendedConn readCounter *atomic.Int64 @@ -61,6 +70,47 @@ func (c *Conn) Upstream() any { return c.ExtendedConn } +type HookConn struct { + N.ExtendedConn + readCounter func(n int64) + writeCounter func(n int64) +} + +func (c *HookConn) Read(p []byte) (n int, err error) { + n, err = c.ExtendedConn.Read(p) + c.readCounter(int64(n)) + return n, err +} + +func (c *HookConn) ReadBuffer(buffer *buf.Buffer) error { + err := c.ExtendedConn.ReadBuffer(buffer) + if err != nil { + return err + } + c.readCounter(int64(buffer.Len())) + return nil +} + +func (c *HookConn) Write(p []byte) (n int, err error) { + n, err = c.ExtendedConn.Write(p) + c.writeCounter(int64(n)) + return n, err +} + +func (c *HookConn) WriteBuffer(buffer *buf.Buffer) error { + dataLen := int64(buffer.Len()) + err := c.ExtendedConn.WriteBuffer(buffer) + if err != nil { + return err + } + c.writeCounter(dataLen) + return nil +} + +func (c *HookConn) Upstream() any { + return c.ExtendedConn +} + type DirectConn Conn func (c *DirectConn) WriteTo(w io.Writer) (n int64, err error) { @@ -84,3 +134,27 @@ func (c *DirectConn) ReadFrom(r io.Reader) (n int64, err error) { return bufio.Copy((*Conn)(c), r) } } + +type DirectHookConn HookConn + +func (c *DirectHookConn) WriteTo(w io.Writer) (n int64, err error) { + reader := N.UnwrapReader(c.ExtendedConn) + if wt, ok := reader.(io.WriterTo); ok { + n, err = wt.WriteTo(w) + c.readCounter(n) + return + } else { + return bufio.Copy(w, (*HookConn)(c)) + } +} + +func (c *DirectHookConn) ReadFrom(r io.Reader) (n int64, err error) { + writer := N.UnwrapWriter(c.ExtendedConn) + if rt, ok := writer.(io.ReaderFrom); ok { + n, err = rt.ReadFrom(r) + c.writeCounter(n) + return + } else { + return bufio.Copy((*HookConn)(c), r) + } +} diff --git a/experimental/trackerconn/packet_conn.go b/experimental/trackerconn/packet_conn.go index e1da2413..3aaccb84 100644 --- a/experimental/trackerconn/packet_conn.go +++ b/experimental/trackerconn/packet_conn.go @@ -8,16 +8,20 @@ import ( "go.uber.org/atomic" ) +func NewPacket(conn N.PacketConn, readCounter *atomic.Int64, writeCounter *atomic.Int64) *PacketConn { + return &PacketConn{conn, readCounter, writeCounter} +} + +func NewHookPacket(conn N.PacketConn, readCounter func(n int64), writeCounter func(n int64)) *HookPacketConn { + return &HookPacketConn{conn, readCounter, writeCounter} +} + type PacketConn struct { N.PacketConn readCounter *atomic.Int64 writeCounter *atomic.Int64 } -func NewPacket(conn N.PacketConn, readCounter *atomic.Int64, writeCounter *atomic.Int64) *PacketConn { - return &PacketConn{conn, readCounter, writeCounter} -} - func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { destination, err = c.PacketConn.ReadPacket(buffer) if err == nil { @@ -39,3 +43,31 @@ func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) er func (c *PacketConn) Upstream() any { return c.PacketConn } + +type HookPacketConn struct { + N.PacketConn + readCounter func(n int64) + writeCounter func(n int64) +} + +func (c *HookPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.PacketConn.ReadPacket(buffer) + if err == nil { + c.readCounter(int64(buffer.Len())) + } + return +} + +func (c *HookPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + dataLen := int64(buffer.Len()) + err := c.PacketConn.WritePacket(buffer, destination) + if err != nil { + return err + } + c.writeCounter(dataLen) + return nil +} + +func (c *HookPacketConn) Upstream() any { + return c.PacketConn +}