From c287731df97eea5667a3c30393cbef124ef63295 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 24 Apr 2023 19:01:10 +0800 Subject: [PATCH] Improve direct copy --- .../clashapi/trafficontrol/tracker.go | 14 +-- experimental/trackerconn/conn.go | 108 ------------------ experimental/trackerconn/packet_conn.go | 76 ------------ experimental/v2rayapi/stats.go | 6 +- go.mod | 2 +- go.sum | 4 +- outbound/dns.go | 101 ++++++++++++++++ 7 files changed, 114 insertions(+), 197 deletions(-) delete mode 100644 experimental/trackerconn/conn.go delete mode 100644 experimental/trackerconn/packet_conn.go diff --git a/experimental/clashapi/trafficontrol/tracker.go b/experimental/clashapi/trafficontrol/tracker.go index 97be411f..3dc5a367 100644 --- a/experimental/clashapi/trafficontrol/tracker.go +++ b/experimental/clashapi/trafficontrol/tracker.go @@ -7,9 +7,9 @@ import ( "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/experimental/trackerconn" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/bufio" N "github.com/sagernet/sing/common/network" "github.com/gofrs/uuid/v5" @@ -115,13 +115,13 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad download := new(atomic.Int64) t := &tcpTracker{ - ExtendedConn: trackerconn.NewHook(conn, func(n int64) { + ExtendedConn: bufio.NewCounterConn(conn, []N.CountFunc{func(n int64) { upload.Add(n) manager.PushUploaded(n) - }, func(n int64) { + }}, []N.CountFunc{func(n int64) { download.Add(n) manager.PushDownloaded(n) - }), + }}), manager: manager, trackerInfo: &trackerInfo{ UUID: uuid, @@ -202,13 +202,13 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route download := new(atomic.Int64) ut := &udpTracker{ - PacketConn: trackerconn.NewHookPacket(conn, func(n int64) { + PacketConn: bufio.NewCounterPacketConn(conn, []N.CountFunc{func(n int64) { upload.Add(n) manager.PushUploaded(n) - }, func(n int64) { + }}, []N.CountFunc{func(n int64) { download.Add(n) manager.PushDownloaded(n) - }), + }}), manager: manager, trackerInfo: &trackerInfo{ UUID: uuid, diff --git a/experimental/trackerconn/conn.go b/experimental/trackerconn/conn.go deleted file mode 100644 index f9f70e7c..00000000 --- a/experimental/trackerconn/conn.go +++ /dev/null @@ -1,108 +0,0 @@ -package trackerconn - -import ( - "net" - - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - N "github.com/sagernet/sing/common/network" -) - -func New(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *Conn { - return &Conn{bufio.NewExtendedConn(conn), readCounter, writeCounter} -} - -func NewHook(conn net.Conn, readCounter func(n int64), writeCounter func(n int64)) *HookConn { - return &HookConn{bufio.NewExtendedConn(conn), readCounter, writeCounter} -} - -type Conn struct { - N.ExtendedConn - readCounter []*atomic.Int64 - writeCounter []*atomic.Int64 -} - -func (c *Conn) Read(p []byte) (n int, err error) { - n, err = c.ExtendedConn.Read(p) - for _, counter := range c.readCounter { - counter.Add(int64(n)) - } - return n, err -} - -func (c *Conn) ReadBuffer(buffer *buf.Buffer) error { - err := c.ExtendedConn.ReadBuffer(buffer) - if err != nil { - return err - } - for _, counter := range c.readCounter { - counter.Add(int64(buffer.Len())) - } - return nil -} - -func (c *Conn) Write(p []byte) (n int, err error) { - n, err = c.ExtendedConn.Write(p) - for _, counter := range c.writeCounter { - counter.Add(int64(n)) - } - return n, err -} - -func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { - dataLen := int64(buffer.Len()) - err := c.ExtendedConn.WriteBuffer(buffer) - if err != nil { - return err - } - for _, counter := range c.writeCounter { - counter.Add(dataLen) - } - return nil -} - -func (c *Conn) Upstream() any { - return c.ExtendedConn -} - -type HookConn struct { - N.ExtendedConn - readCounter func(n int64) - writeCounter func(n int64) -} - -func (c *HookConn) Read(p []byte) (n int, err error) { - n, err = c.ExtendedConn.Read(p) - c.readCounter(int64(n)) - return n, err -} - -func (c *HookConn) ReadBuffer(buffer *buf.Buffer) error { - err := c.ExtendedConn.ReadBuffer(buffer) - if err != nil { - return err - } - c.readCounter(int64(buffer.Len())) - return nil -} - -func (c *HookConn) Write(p []byte) (n int, err error) { - n, err = c.ExtendedConn.Write(p) - c.writeCounter(int64(n)) - return n, err -} - -func (c *HookConn) WriteBuffer(buffer *buf.Buffer) error { - dataLen := int64(buffer.Len()) - err := c.ExtendedConn.WriteBuffer(buffer) - if err != nil { - return err - } - c.writeCounter(dataLen) - return nil -} - -func (c *HookConn) Upstream() any { - return c.ExtendedConn -} diff --git a/experimental/trackerconn/packet_conn.go b/experimental/trackerconn/packet_conn.go deleted file mode 100644 index 3a84c565..00000000 --- a/experimental/trackerconn/packet_conn.go +++ /dev/null @@ -1,76 +0,0 @@ -package trackerconn - -import ( - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/buf" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" -) - -func NewPacket(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *PacketConn { - return &PacketConn{conn, readCounter, writeCounter} -} - -func NewHookPacket(conn N.PacketConn, readCounter func(n int64), writeCounter func(n int64)) *HookPacketConn { - return &HookPacketConn{conn, readCounter, writeCounter} -} - -type PacketConn struct { - N.PacketConn - readCounter []*atomic.Int64 - writeCounter []*atomic.Int64 -} - -func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - destination, err = c.PacketConn.ReadPacket(buffer) - if err == nil { - for _, counter := range c.readCounter { - counter.Add(int64(buffer.Len())) - } - } - return -} - -func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - dataLen := int64(buffer.Len()) - err := c.PacketConn.WritePacket(buffer, destination) - if err != nil { - return err - } - for _, counter := range c.writeCounter { - counter.Add(dataLen) - } - return nil -} - -func (c *PacketConn) Upstream() any { - return c.PacketConn -} - -type HookPacketConn struct { - N.PacketConn - readCounter func(n int64) - writeCounter func(n int64) -} - -func (c *HookPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - destination, err = c.PacketConn.ReadPacket(buffer) - if err == nil { - c.readCounter(int64(buffer.Len())) - } - return -} - -func (c *HookPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - dataLen := int64(buffer.Len()) - err := c.PacketConn.WritePacket(buffer, destination) - if err != nil { - return err - } - c.writeCounter(dataLen) - return nil -} - -func (c *HookPacketConn) Upstream() any { - return c.PacketConn -} diff --git a/experimental/v2rayapi/stats.go b/experimental/v2rayapi/stats.go index dfb5d666..38b9a301 100644 --- a/experimental/v2rayapi/stats.go +++ b/experimental/v2rayapi/stats.go @@ -10,9 +10,9 @@ import ( "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/experimental/trackerconn" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" ) @@ -83,7 +83,7 @@ func (s *StatsService) RoutedConnection(inbound string, outbound string, user st writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink")) } s.access.Unlock() - return trackerconn.New(conn, readCounter, writeCounter) + return bufio.NewInt64CounterConn(conn, readCounter, writeCounter) } func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn { @@ -109,7 +109,7 @@ func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, u writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink")) } s.access.Unlock() - return trackerconn.NewPacket(conn, readCounter, writeCounter) + return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter) } func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) { diff --git a/go.mod b/go.mod index 2d1a96f3..744ffa87 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/sagernet/gomobile v0.0.0-20230413023804-244d7ff07035 github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 + github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754 github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646 github.com/sagernet/sing-shadowsocks v0.2.2-0.20230417102954-f77257340507 diff --git a/go.sum b/go.sum index 7f7ebc29..c4343ca1 100644 --- a/go.sum +++ b/go.sum @@ -111,8 +111,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss= -github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= +github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754 h1:y89Ntm1rrZPQVb1f+TKd4DH6NwX5XCyMIwoseTQd/5U= +github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc h1:hmbuqKv48SAjiKPoqtJGvS5pEHVPZjTHq9CPwQY2cZ4= github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc/go.mod h1:ZKuuqgsHRxDahYrzgSgy4vIAGGuKPlIf4hLcNzYzLkY= github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646 h1:X3ADfMqeGns1Q1FlXc9kaL9FwW1UM6D6tEQo8jFstpc= diff --git a/outbound/dns.go b/outbound/dns.go index 5af64173..ad9bec45 100644 --- a/outbound/dns.go +++ b/outbound/dns.go @@ -11,6 +11,7 @@ import ( "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -101,6 +102,24 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap } func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + var reader N.PacketReader = conn + var counters []N.CountFunc + var cachedBuffer []*N.PacketBuffer + for { + reader, counters = N.UnwrapCountPacketReader(reader, counters) + if cachedReader, isCached := reader.(N.CachedPacketReader); isCached { + packet := cachedReader.ReadCachedPacket() + if packet != nil { + cachedBuffer = append([]*N.PacketBuffer{packet}, cachedBuffer...) + continue + } + } + if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { + return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedBuffer, metadata) + } + break + } + ctx = adapter.WithContext(ctx, &metadata) fastClose, cancel := common.ContextWithCancelCause(ctx) timeout := canceler.New(fastClose, cancel, C.DNSTimeout) @@ -153,3 +172,85 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada }) return group.Run(fastClose) } + +func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + fastClose, cancel := common.ContextWithCancelCause(ctx) + timeout := canceler.New(fastClose, cancel, C.DNSTimeout) + var group task.Group + group.Append0(func(ctx context.Context) error { + var buffer *buf.Buffer + newBuffer := func() *buf.Buffer { + if buffer != nil { + buffer.Release() + } + buffer = buf.NewSize(dns.FixedPacketSize) + buffer.FullReset() + return buffer + } + for { + var message mDNS.Msg + var destination M.Socksaddr + var err error + if len(cached) > 0 { + packet := cached[0] + cached = cached[1:] + for _, counter := range readCounters { + counter(int64(packet.Buffer.Len())) + } + err = message.Unpack(packet.Buffer.Bytes()) + packet.Buffer.Release() + if err != nil { + cancel(err) + return err + } + destination = packet.Destination + } else { + destination, err = readWaiter.WaitReadPacket(newBuffer) + if err != nil { + if buffer != nil { + buffer.Release() + } + cancel(err) + return err + } + for _, counter := range readCounters { + counter(int64(buffer.Len())) + } + err = message.Unpack(buffer.Bytes()) + buffer.Release() + if err != nil { + cancel(err) + return err + } + timeout.Update() + } + metadataInQuery := metadata + go func() error { + response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) + if err != nil { + cancel(err) + return err + } + timeout.Update() + responseBuffer := buf.NewPacket() + n, err := response.PackBuffer(responseBuffer.FreeBytes()) + if err != nil { + cancel(err) + responseBuffer.Release() + return err + } + responseBuffer.Truncate(len(n)) + err = conn.WritePacket(responseBuffer, destination) + if err != nil { + cancel(err) + } + return err + }() + } + }) + group.Cleanup(func() { + conn.Close() + }) + return group.Run(fastClose) +}