From 862e3c430c3d32cdaf9163d9a67a121fe062eab6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Mon, 11 Jul 2022 20:37:57 +0800
Subject: [PATCH] Fix tcp sniffing

---
 common/sniff/sniff.go | 21 ++++++++++++++++++---
 route/router.go       | 15 +++++++--------
 2 files changed, 25 insertions(+), 11 deletions(-)

diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go
index b3a91aac..6ed3d16e 100644
--- a/common/sniff/sniff.go
+++ b/common/sniff/sniff.go
@@ -1,11 +1,16 @@
 package sniff
 
 import (
+	"bytes"
 	"context"
 	"io"
+	"net"
 	"os"
+	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
 )
 
 type (
@@ -13,13 +18,23 @@ type (
 	PacketSniffer = func(ctx context.Context, packet []byte) (*adapter.InboundContext, error)
 )
 
-func PeekStream(ctx context.Context, reader io.Reader, sniffers ...StreamSniffer) (*adapter.InboundContext, error) {
+func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, sniffers ...StreamSniffer) (*adapter.InboundContext, error) {
+	err := conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond))
+	if err != nil {
+		return nil, err
+	}
+	_, err = buffer.ReadFrom(conn)
+	err = E.Errors(err, conn.SetReadDeadline(time.Time{}))
+	if err != nil {
+		return nil, err
+	}
+	var metadata *adapter.InboundContext
 	for _, sniffer := range sniffers {
-		sniffMetadata, err := sniffer(ctx, reader)
+		metadata, err = sniffer(ctx, bytes.NewReader(buffer.Bytes()))
 		if err != nil {
 			continue
 		}
-		return sniffMetadata, nil
+		return metadata, nil
 	}
 	return nil, os.ErrInvalid
 }
diff --git a/route/router.go b/route/router.go
index d34a4ea3..e54aa752 100644
--- a/route/router.go
+++ b/route/router.go
@@ -389,8 +389,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
 		defer common.KeepAlive(_buffer)
 		buffer := common.Dup(_buffer)
 		defer buffer.Release()
-		reader := io.TeeReader(conn, buffer)
-		sniffMetadata, err := sniff.PeekStream(ctx, reader, sniff.TLSClientHello, sniff.HTTPHost)
+		sniffMetadata, err := sniff.PeekStream(ctx, conn, buffer, sniff.TLSClientHello, sniff.HTTPHost)
 		if err == nil {
 			metadata.Protocol = sniffMetadata.Protocol
 			metadata.Domain = sniffMetadata.Domain
@@ -398,9 +397,9 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
 				metadata.Destination.Fqdn = metadata.Domain
 			}
 			if metadata.Domain != "" {
-				r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
+				r.logger.WithContext(ctx).Debug("sniffed protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
 			} else {
-				r.logger.WithContext(ctx).Info("sniffed protocol: ", metadata.Protocol)
+				r.logger.WithContext(ctx).Debug("sniffed protocol: ", metadata.Protocol)
 			}
 		}
 		if !buffer.IsEmpty() {
@@ -413,7 +412,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
 			return err
 		}
 		metadata.DestinationAddresses = addresses
-		r.dnsLogger.WithContext(ctx).Info("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
+		r.dnsLogger.WithContext(ctx).Debug("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
 	}
 	detour := r.match(ctx, metadata, r.defaultOutboundForConnection)
 	if !common.Contains(detour.Network(), C.NetworkTCP) {
@@ -442,9 +441,9 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 				metadata.Destination.Fqdn = metadata.Domain
 			}
 			if metadata.Domain != "" {
-				r.logger.WithContext(ctx).Info("sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
+				r.logger.WithContext(ctx).Debug("sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
 			} else {
-				r.logger.WithContext(ctx).Info("sniffed packet protocol: ", metadata.Protocol)
+				r.logger.WithContext(ctx).Debug("sniffed packet protocol: ", metadata.Protocol)
 			}
 		}
 		conn = bufio.NewCachedPacketConn(conn, buffer, originDestination)
@@ -455,7 +454,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 			return err
 		}
 		metadata.DestinationAddresses = addresses
-		r.dnsLogger.WithContext(ctx).Info("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
+		r.dnsLogger.WithContext(ctx).Debug("resolved [", strings.Join(F.MapToString(metadata.DestinationAddresses), " "), "]")
 	}
 	detour := r.match(ctx, metadata, r.defaultOutboundForPacketConnection)
 	if !common.Contains(detour.Network(), C.NetworkUDP) {