From 27d6b63e71cae71f8b8b5c8ba27f2aceb8ebbc3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 28 Aug 2024 11:40:43 +0800 Subject: [PATCH] Fix stream sniffer --- common/sniff/sniff.go | 49 ++++++++++++++++++++++++++----------------- route/router.go | 2 +- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go index 424311a8..3e21b69f 100644 --- a/common/sniff/sniff.go +++ b/common/sniff/sniff.go @@ -18,33 +18,44 @@ type ( PacketSniffer = func(ctx context.Context, packet []byte) (*adapter.InboundContext, error) ) +func Skip(metadata adapter.InboundContext) bool { + // skip server first protocols + switch metadata.Destination.Port { + case 25, 465, 587: + // SMTP + return true + case 143, 993: + // IMAP + return true + case 110, 995: + // POP3 + return true + } + return false +} + func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) (*adapter.InboundContext, error) { if timeout == 0 { timeout = C.ReadPayloadTimeout } deadline := time.Now().Add(timeout) var errors []error - - for i := 0; i < 3; i++ { - err := conn.SetReadDeadline(deadline) - if err != nil { - return nil, E.Cause(err, "set read deadline") + err := conn.SetReadDeadline(deadline) + if err != nil { + return nil, E.Cause(err, "set read deadline") + } + defer conn.SetReadDeadline(time.Time{}) + var metadata *adapter.InboundContext + for _, sniffer := range sniffers { + if buffer.IsEmpty() { + metadata, err = sniffer(ctx, io.TeeReader(conn, buffer)) + } else { + metadata, err = sniffer(ctx, io.MultiReader(bytes.NewReader(buffer.Bytes()), io.TeeReader(conn, buffer))) } - _, err = buffer.ReadOnceFrom(conn) - err = E.Errors(err, conn.SetReadDeadline(time.Time{})) - if err != nil { - if i > 0 { - break - } - return nil, E.Cause(err, "read payload") - } - for _, sniffer := range sniffers { - metadata, err := sniffer(ctx, bytes.NewReader(buffer.Bytes())) - if metadata != nil { - return metadata, nil - } - errors = append(errors, err) + if metadata != nil { + return metadata, nil } + errors = append(errors, err) } return nil, E.Errors(errors...) } diff --git a/route/router.go b/route/router.go index 0c13b198..5d89b118 100644 --- a/route/router.go +++ b/route/router.go @@ -832,7 +832,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad conn = deadline.NewConn(conn) } - if metadata.InboundOptions.SniffEnabled { + if metadata.InboundOptions.SniffEnabled && !sniff.Skip(metadata) { buffer := buf.NewPacket() sniffMetadata, err := sniff.PeekStream(ctx, conn, buffer, time.Duration(metadata.InboundOptions.SniffTimeout), sniff.StreamDomainNameQuery, sniff.TLSClientHello, sniff.HTTPHost) if sniffMetadata != nil {