From 862e3c430c3d32cdaf9163d9a67a121fe062eab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu> Date: Mon, 11 Jul 2022 20:37:57 +0800 Subject: [PATCH] Fix tcp sniffing --- common/sniff/sniff.go | 21 ++++++++++++++++++--- route/router.go | 15 +++++++-------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go index b3a91aac..6ed3d16e 100644 --- a/common/sniff/sniff.go +++ b/common/sniff/sniff.go @@ -1,11 +1,16 @@ package sniff import ( + "bytes" "context" "io" + "net" "os" + "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" ) type ( @@ -13,13 +18,23 @@ type ( PacketSniffer = func(ctx context.Context, packet []byte) (*adapter.InboundContext, error) ) -func PeekStream(ctx context.Context, reader io.Reader, sniffers ...StreamSniffer) (*adapter.InboundContext, error) { +func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, sniffers ...StreamSniffer) (*adapter.InboundContext, error) { + err := conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond)) + if err != nil { + return nil, err + } + _, err = buffer.ReadFrom(conn) + err = E.Errors(err, conn.SetReadDeadline(time.Time{})) + if err != nil { + return nil, err + } + var metadata *adapter.InboundContext for _, sniffer := range sniffers { - sniffMetadata, err := sniffer(ctx, reader) + metadata, err = sniffer(ctx, bytes.NewReader(buffer.Bytes())) if err != nil { continue } - return sniffMetadata, nil + return metadata, nil } return nil, os.ErrInvalid } diff --git a/route/router.go b/route/router.go index d34a4ea3..e54aa752 100644 --- a/route/router.go +++ b/route/router.go @@ -389,8 +389,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() - reader := io.TeeReader(conn, buffer) - sniffMetadata, err := sniff.PeekStream(ctx, reader, sniff.TLSClientHello, sniff.HTTPHost) + sniffMetadata, err := sniff.PeekStream(ctx, conn, buffer, sniff.TLSClientHello, sniff.HTTPHost) if err == nil { metadata.Protocol = sniffMetadata.Protocol metadata.Domain = sniffMetadata.Domain @@ -398,9 +397,9 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad metadata.Destination.Fqdn = metadata.Domain } if metadata.Domain != "" { - r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) + r.logger.WithContext(ctx).Debug("sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) } else { - r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol) + r.logger.WithContext(ctx).Debug("sniffed protocol: ", metadata.Protocol) } } if !buffer.IsEmpty() { @@ -413,7 +412,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad return err } metadata.DestinationAddresses = addresses - r.dnsLogger.WithContext(ctx).Info("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") + r.dnsLogger.WithContext(ctx).Debug("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") } detour := r.match(ctx, metadata, r.defaultOutboundForConnection) if !common.Contains(detour.Network(), C.NetworkTCP) { @@ -442,9 +441,9 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m metadata.Destination.Fqdn = metadata.Domain } if metadata.Domain != "" { - r.logger.WithContext(ctx).Info("sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) + r.logger.WithContext(ctx).Debug("sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain) } else { - r.logger.WithContext(ctx).Info("sniffed packet protocol: ", metadata.Protocol) + r.logger.WithContext(ctx).Debug("sniffed packet protocol: ", metadata.Protocol) } } conn = bufio.NewCachedPacketConn(conn, buffer, originDestination) @@ -455,7 +454,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m return err } metadata.DestinationAddresses = addresses - r.dnsLogger.WithContext(ctx).Info("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") + r.dnsLogger.WithContext(ctx).Debug("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]") } detour := r.match(ctx, metadata, r.defaultOutboundForPacketConnection) if !common.Contains(detour.Network(), C.NetworkUDP) {