Improve direct copy

This commit is contained in:
世界 2023-04-24 19:01:10 +08:00
parent bc32c78d03
commit c287731df9
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 114 additions and 197 deletions

View file

@ -7,9 +7,9 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/experimental/trackerconn"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
@ -115,13 +115,13 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad
download := new(atomic.Int64) download := new(atomic.Int64)
t := &tcpTracker{ t := &tcpTracker{
ExtendedConn: trackerconn.NewHook(conn, func(n int64) { ExtendedConn: bufio.NewCounterConn(conn, []N.CountFunc{func(n int64) {
upload.Add(n) upload.Add(n)
manager.PushUploaded(n) manager.PushUploaded(n)
}, func(n int64) { }}, []N.CountFunc{func(n int64) {
download.Add(n) download.Add(n)
manager.PushDownloaded(n) manager.PushDownloaded(n)
}), }}),
manager: manager, manager: manager,
trackerInfo: &trackerInfo{ trackerInfo: &trackerInfo{
UUID: uuid, UUID: uuid,
@ -202,13 +202,13 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route
download := new(atomic.Int64) download := new(atomic.Int64)
ut := &udpTracker{ ut := &udpTracker{
PacketConn: trackerconn.NewHookPacket(conn, func(n int64) { PacketConn: bufio.NewCounterPacketConn(conn, []N.CountFunc{func(n int64) {
upload.Add(n) upload.Add(n)
manager.PushUploaded(n) manager.PushUploaded(n)
}, func(n int64) { }}, []N.CountFunc{func(n int64) {
download.Add(n) download.Add(n)
manager.PushDownloaded(n) manager.PushDownloaded(n)
}), }}),
manager: manager, manager: manager,
trackerInfo: &trackerInfo{ trackerInfo: &trackerInfo{
UUID: uuid, UUID: uuid,

View file

@ -1,108 +0,0 @@
package trackerconn
import (
"net"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
func New(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *Conn {
return &Conn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
}
func NewHook(conn net.Conn, readCounter func(n int64), writeCounter func(n int64)) *HookConn {
return &HookConn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
}
type Conn struct {
N.ExtendedConn
readCounter []*atomic.Int64
writeCounter []*atomic.Int64
}
func (c *Conn) Read(p []byte) (n int, err error) {
n, err = c.ExtendedConn.Read(p)
for _, counter := range c.readCounter {
counter.Add(int64(n))
}
return n, err
}
func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
err := c.ExtendedConn.ReadBuffer(buffer)
if err != nil {
return err
}
for _, counter := range c.readCounter {
counter.Add(int64(buffer.Len()))
}
return nil
}
func (c *Conn) Write(p []byte) (n int, err error) {
n, err = c.ExtendedConn.Write(p)
for _, counter := range c.writeCounter {
counter.Add(int64(n))
}
return n, err
}
func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
dataLen := int64(buffer.Len())
err := c.ExtendedConn.WriteBuffer(buffer)
if err != nil {
return err
}
for _, counter := range c.writeCounter {
counter.Add(dataLen)
}
return nil
}
func (c *Conn) Upstream() any {
return c.ExtendedConn
}
type HookConn struct {
N.ExtendedConn
readCounter func(n int64)
writeCounter func(n int64)
}
func (c *HookConn) Read(p []byte) (n int, err error) {
n, err = c.ExtendedConn.Read(p)
c.readCounter(int64(n))
return n, err
}
func (c *HookConn) ReadBuffer(buffer *buf.Buffer) error {
err := c.ExtendedConn.ReadBuffer(buffer)
if err != nil {
return err
}
c.readCounter(int64(buffer.Len()))
return nil
}
func (c *HookConn) Write(p []byte) (n int, err error) {
n, err = c.ExtendedConn.Write(p)
c.writeCounter(int64(n))
return n, err
}
func (c *HookConn) WriteBuffer(buffer *buf.Buffer) error {
dataLen := int64(buffer.Len())
err := c.ExtendedConn.WriteBuffer(buffer)
if err != nil {
return err
}
c.writeCounter(dataLen)
return nil
}
func (c *HookConn) Upstream() any {
return c.ExtendedConn
}

View file

@ -1,76 +0,0 @@
package trackerconn
import (
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func NewPacket(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *PacketConn {
return &PacketConn{conn, readCounter, writeCounter}
}
func NewHookPacket(conn N.PacketConn, readCounter func(n int64), writeCounter func(n int64)) *HookPacketConn {
return &HookPacketConn{conn, readCounter, writeCounter}
}
type PacketConn struct {
N.PacketConn
readCounter []*atomic.Int64
writeCounter []*atomic.Int64
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
for _, counter := range c.readCounter {
counter.Add(int64(buffer.Len()))
}
}
return
}
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
dataLen := int64(buffer.Len())
err := c.PacketConn.WritePacket(buffer, destination)
if err != nil {
return err
}
for _, counter := range c.writeCounter {
counter.Add(dataLen)
}
return nil
}
func (c *PacketConn) Upstream() any {
return c.PacketConn
}
type HookPacketConn struct {
N.PacketConn
readCounter func(n int64)
writeCounter func(n int64)
}
func (c *HookPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.PacketConn.ReadPacket(buffer)
if err == nil {
c.readCounter(int64(buffer.Len()))
}
return
}
func (c *HookPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
dataLen := int64(buffer.Len())
err := c.PacketConn.WritePacket(buffer, destination)
if err != nil {
return err
}
c.writeCounter(dataLen)
return nil
}
func (c *HookPacketConn) Upstream() any {
return c.PacketConn
}

View file

@ -10,9 +10,9 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/experimental/trackerconn"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -83,7 +83,7 @@ func (s *StatsService) RoutedConnection(inbound string, outbound string, user st
writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink")) writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink"))
} }
s.access.Unlock() s.access.Unlock()
return trackerconn.New(conn, readCounter, writeCounter) return bufio.NewInt64CounterConn(conn, readCounter, writeCounter)
} }
func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn { func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn {
@ -109,7 +109,7 @@ func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, u
writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink")) writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink"))
} }
s.access.Unlock() s.access.Unlock()
return trackerconn.NewPacket(conn, readCounter, writeCounter) return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter)
} }
func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) { func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) {

2
go.mod
View file

@ -24,7 +24,7 @@ require (
github.com/sagernet/gomobile v0.0.0-20230413023804-244d7ff07035 github.com/sagernet/gomobile v0.0.0-20230413023804-244d7ff07035
github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32 github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754
github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc
github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646 github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646
github.com/sagernet/sing-shadowsocks v0.2.2-0.20230417102954-f77257340507 github.com/sagernet/sing-shadowsocks v0.2.2-0.20230417102954-f77257340507

4
go.sum
View file

@ -111,8 +111,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss= github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754 h1:y89Ntm1rrZPQVb1f+TKd4DH6NwX5XCyMIwoseTQd/5U=
github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc h1:hmbuqKv48SAjiKPoqtJGvS5pEHVPZjTHq9CPwQY2cZ4= github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc h1:hmbuqKv48SAjiKPoqtJGvS5pEHVPZjTHq9CPwQY2cZ4=
github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc/go.mod h1:ZKuuqgsHRxDahYrzgSgy4vIAGGuKPlIf4hLcNzYzLkY= github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc/go.mod h1:ZKuuqgsHRxDahYrzgSgy4vIAGGuKPlIf4hLcNzYzLkY=
github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646 h1:X3ADfMqeGns1Q1FlXc9kaL9FwW1UM6D6tEQo8jFstpc= github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646 h1:X3ADfMqeGns1Q1FlXc9kaL9FwW1UM6D6tEQo8jFstpc=

View file

@ -11,6 +11,7 @@ import (
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler" "github.com/sagernet/sing/common/canceler"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -101,6 +102,24 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
} }
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
var reader N.PacketReader = conn
var counters []N.CountFunc
var cachedBuffer []*N.PacketBuffer
for {
reader, counters = N.UnwrapCountPacketReader(reader, counters)
if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
packet := cachedReader.ReadCachedPacket()
if packet != nil {
cachedBuffer = append([]*N.PacketBuffer{packet}, cachedBuffer...)
continue
}
}
if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedBuffer, metadata)
}
break
}
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
fastClose, cancel := common.ContextWithCancelCause(ctx) fastClose, cancel := common.ContextWithCancelCause(ctx)
timeout := canceler.New(fastClose, cancel, C.DNSTimeout) timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
@ -153,3 +172,85 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
}) })
return group.Run(fastClose) return group.Run(fastClose)
} }
func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
ctx = adapter.WithContext(ctx, &metadata)
fastClose, cancel := common.ContextWithCancelCause(ctx)
timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
var group task.Group
group.Append0(func(ctx context.Context) error {
var buffer *buf.Buffer
newBuffer := func() *buf.Buffer {
if buffer != nil {
buffer.Release()
}
buffer = buf.NewSize(dns.FixedPacketSize)
buffer.FullReset()
return buffer
}
for {
var message mDNS.Msg
var destination M.Socksaddr
var err error
if len(cached) > 0 {
packet := cached[0]
cached = cached[1:]
for _, counter := range readCounters {
counter(int64(packet.Buffer.Len()))
}
err = message.Unpack(packet.Buffer.Bytes())
packet.Buffer.Release()
if err != nil {
cancel(err)
return err
}
destination = packet.Destination
} else {
destination, err = readWaiter.WaitReadPacket(newBuffer)
if err != nil {
if buffer != nil {
buffer.Release()
}
cancel(err)
return err
}
for _, counter := range readCounters {
counter(int64(buffer.Len()))
}
err = message.Unpack(buffer.Bytes())
buffer.Release()
if err != nil {
cancel(err)
return err
}
timeout.Update()
}
metadataInQuery := metadata
go func() error {
response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
if err != nil {
cancel(err)
return err
}
timeout.Update()
responseBuffer := buf.NewPacket()
n, err := response.PackBuffer(responseBuffer.FreeBytes())
if err != nil {
cancel(err)
responseBuffer.Release()
return err
}
responseBuffer.Truncate(len(n))
err = conn.WritePacket(responseBuffer, destination)
if err != nil {
cancel(err)
}
return err
}()
}
})
group.Cleanup(func() {
conn.Close()
})
return group.Run(fastClose)
}