package trackerconn import ( "io" "net" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" N "github.com/sagernet/sing/common/network" "go.uber.org/atomic" ) func New(conn net.Conn, readCounter *atomic.Int64, writeCounter *atomic.Int64, direct bool) N.ExtendedConn { trackerConn := &Conn{bufio.NewExtendedConn(conn), readCounter, writeCounter} if direct { return (*DirectConn)(trackerConn) } else { return trackerConn } } func NewHook(conn net.Conn, readCounter func(n int64), writeCounter func(n int64), direct bool) N.ExtendedConn { trackerConn := &HookConn{bufio.NewExtendedConn(conn), readCounter, writeCounter} if direct { return (*DirectHookConn)(trackerConn) } else { return trackerConn } } type Conn struct { N.ExtendedConn readCounter *atomic.Int64 writeCounter *atomic.Int64 } func (c *Conn) Read(p []byte) (n int, err error) { n, err = c.ExtendedConn.Read(p) c.readCounter.Add(int64(n)) return n, err } func (c *Conn) ReadBuffer(buffer *buf.Buffer) error { err := c.ExtendedConn.ReadBuffer(buffer) if err != nil { return err } c.readCounter.Add(int64(buffer.Len())) return nil } func (c *Conn) Write(p []byte) (n int, err error) { n, err = c.ExtendedConn.Write(p) c.writeCounter.Add(int64(n)) return n, err } func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { dataLen := int64(buffer.Len()) err := c.ExtendedConn.WriteBuffer(buffer) if err != nil { return err } c.writeCounter.Add(dataLen) return nil } func (c *Conn) Upstream() any { return c.ExtendedConn } type HookConn struct { N.ExtendedConn readCounter func(n int64) writeCounter func(n int64) } func (c *HookConn) Read(p []byte) (n int, err error) { n, err = c.ExtendedConn.Read(p) c.readCounter(int64(n)) return n, err } func (c *HookConn) ReadBuffer(buffer *buf.Buffer) error { err := c.ExtendedConn.ReadBuffer(buffer) if err != nil { return err } c.readCounter(int64(buffer.Len())) return nil } func (c *HookConn) Write(p []byte) (n int, err error) { n, err = c.ExtendedConn.Write(p) c.writeCounter(int64(n)) return n, err } func (c *HookConn) WriteBuffer(buffer *buf.Buffer) error { dataLen := int64(buffer.Len()) err := c.ExtendedConn.WriteBuffer(buffer) if err != nil { return err } c.writeCounter(dataLen) return nil } func (c *HookConn) Upstream() any { return c.ExtendedConn } type DirectConn Conn func (c *DirectConn) WriteTo(w io.Writer) (n int64, err error) { reader := N.UnwrapReader(c.ExtendedConn) if wt, ok := reader.(io.WriterTo); ok { n, err = wt.WriteTo(w) c.readCounter.Add(n) return } else { return bufio.Copy(w, (*Conn)(c)) } } func (c *DirectConn) ReadFrom(r io.Reader) (n int64, err error) { writer := N.UnwrapWriter(c.ExtendedConn) if rt, ok := writer.(io.ReaderFrom); ok { n, err = rt.ReadFrom(r) c.writeCounter.Add(n) return } else { return bufio.Copy((*Conn)(c), r) } } type DirectHookConn HookConn func (c *DirectHookConn) WriteTo(w io.Writer) (n int64, err error) { reader := N.UnwrapReader(c.ExtendedConn) if wt, ok := reader.(io.WriterTo); ok { n, err = wt.WriteTo(w) c.readCounter(n) return } else { return bufio.Copy(w, (*HookConn)(c)) } } func (c *DirectHookConn) ReadFrom(r io.Reader) (n int64, err error) { writer := N.UnwrapWriter(c.ExtendedConn) if rt, ok := writer.(io.ReaderFrom); ok { n, err = rt.ReadFrom(r) c.writeCounter(n) return } else { return bufio.Copy((*HookConn)(c), r) } }