From 2bd91baad0b297561d962fb4b315f8139eba678f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 13 Feb 2023 05:53:03 +0800 Subject: [PATCH] Add fallback support for v2ray transport --- adapter/v2ray.go | 10 +++++ inbound/trojan.go | 20 +++++++++- inbound/vmess.go | 17 +++++++- include/quic_stub.go | 3 +- transport/v2ray/grpc.go | 5 +-- transport/v2ray/grpc_lite.go | 5 +-- transport/v2ray/quic.go | 5 +-- transport/v2ray/transport.go | 12 +++--- transport/v2raygrpclite/server.go | 33 ++++++++-------- transport/v2rayhttp/conn.go | 2 +- transport/v2rayhttp/server.go | 62 +++++++++++++++--------------- transport/v2rayquic/server.go | 13 +++---- transport/v2raywebsocket/server.go | 30 ++++++++------- 13 files changed, 132 insertions(+), 85 deletions(-) diff --git a/adapter/v2ray.go b/adapter/v2ray.go index 724c4935..df6372a1 100644 --- a/adapter/v2ray.go +++ b/adapter/v2ray.go @@ -3,6 +3,10 @@ package adapter import ( "context" "net" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type V2RayServerTransport interface { @@ -12,6 +16,12 @@ type V2RayServerTransport interface { Close() error } +type V2RayServerTransportHandler interface { + N.TCPConnectionHandler + E.Handler + FallbackConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error +} + type V2RayClientTransport interface { DialContext(ctx context.Context) (net.Conn, error) } diff --git a/inbound/trojan.go b/inbound/trojan.go index ac639d10..84ef9bcf 100644 --- a/inbound/trojan.go +++ b/inbound/trojan.go @@ -89,7 +89,7 @@ func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLog return nil, err } if options.Transport != nil { - inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newTransportConnection, nil, nil), inbound) + inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*trojanTransportHandler)(inbound)) if err != nil { return nil, E.Cause(err, "create server transport: ", options.Transport.Type) } @@ -216,3 +216,21 @@ func (h *Trojan) newPacketConnection(ctx context.Context, conn N.PacketConn, met h.logger.InfoContext(ctx, "[", user, "] inbound packet connection to ", metadata.Destination) return h.router.RoutePacketConnection(ctx, conn, metadata) } + +var _ adapter.V2RayServerTransportHandler = (*trojanTransportHandler)(nil) + +type trojanTransportHandler Trojan + +func (t *trojanTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + return (*Trojan)(t).newTransportConnection(ctx, conn, adapter.InboundContext{ + Source: metadata.Source, + Destination: metadata.Destination, + }) +} + +func (t *trojanTransportHandler) FallbackConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + return (*Trojan)(t).fallbackConnection(ctx, conn, adapter.InboundContext{ + Source: metadata.Source, + Destination: metadata.Destination, + }) +} diff --git a/inbound/vmess.go b/inbound/vmess.go index 6c57c483..3ce73062 100644 --- a/inbound/vmess.go +++ b/inbound/vmess.go @@ -72,7 +72,7 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg } } if options.Transport != nil { - inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, adapter.NewUpstreamHandler(adapter.InboundContext{}, inbound.newTransportConnection, nil, nil), inbound) + inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*vmessTransportHandler)(inbound)) if err != nil { return nil, E.Cause(err, "create server transport: ", options.Transport.Type) } @@ -183,3 +183,18 @@ func (h *VMess) newPacketConnection(ctx context.Context, conn N.PacketConn, meta } return h.router.RoutePacketConnection(ctx, conn, metadata) } + +var _ adapter.V2RayServerTransportHandler = (*vmessTransportHandler)(nil) + +type vmessTransportHandler VMess + +func (t *vmessTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + return (*VMess)(t).newTransportConnection(ctx, conn, adapter.InboundContext{ + Source: metadata.Source, + Destination: metadata.Destination, + }) +} + +func (t *vmessTransportHandler) FallbackConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + return os.ErrInvalid +} diff --git a/include/quic_stub.go b/include/quic_stub.go index a914f45f..18c49b48 100644 --- a/include/quic_stub.go +++ b/include/quic_stub.go @@ -11,7 +11,6 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/v2ray" "github.com/sagernet/sing-dns" - E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -24,7 +23,7 @@ func init() { return nil, C.ErrQUICNotIncluded }) v2ray.RegisterQUICConstructor( - func(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { + func(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) { return nil, C.ErrQUICNotIncluded }, func(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayQUICOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) { diff --git a/transport/v2ray/grpc.go b/transport/v2ray/grpc.go index 64865c2c..05bc5a2a 100644 --- a/transport/v2ray/grpc.go +++ b/transport/v2ray/grpc.go @@ -10,14 +10,13 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/v2raygrpc" "github.com/sagernet/sing-box/transport/v2raygrpclite" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { +func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) { if options.ForceLite { - return v2raygrpclite.NewServer(ctx, options, tlsConfig, handler, errorHandler) + return v2raygrpclite.NewServer(ctx, options, tlsConfig, handler) } return v2raygrpc.NewServer(ctx, options, tlsConfig, handler) } diff --git a/transport/v2ray/grpc_lite.go b/transport/v2ray/grpc_lite.go index 589e9fbe..94f6fad1 100644 --- a/transport/v2ray/grpc_lite.go +++ b/transport/v2ray/grpc_lite.go @@ -9,13 +9,12 @@ import ( "github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/v2raygrpclite" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { - return v2raygrpclite.NewServer(ctx, options, tlsConfig, handler, errorHandler) +func NewGRPCServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) { + return v2raygrpclite.NewServer(ctx, options, tlsConfig, handler) } func NewGRPCClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) { diff --git a/transport/v2ray/quic.go b/transport/v2ray/quic.go index 02fcd974..5471157a 100644 --- a/transport/v2ray/quic.go +++ b/transport/v2ray/quic.go @@ -7,7 +7,6 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/tls" "github.com/sagernet/sing-box/option" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -22,11 +21,11 @@ func RegisterQUICConstructor(server ServerConstructor[option.V2RayQUICOptions], quicClientConstructor = client } -func NewQUICServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { +func NewQUICServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) { if quicServerConstructor == nil { return nil, os.ErrInvalid } - return quicServerConstructor(ctx, options, tlsConfig, handler, errorHandler) + return quicServerConstructor(ctx, options, tlsConfig, handler) } func NewQUICClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayQUICOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) { diff --git a/transport/v2ray/transport.go b/transport/v2ray/transport.go index 28a98d52..649999a8 100644 --- a/transport/v2ray/transport.go +++ b/transport/v2ray/transport.go @@ -15,26 +15,26 @@ import ( ) type ( - ServerConstructor[O any] func(ctx context.Context, options O, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) + ServerConstructor[O any] func(ctx context.Context, options O, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) ClientConstructor[O any] func(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options O, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) ) -func NewServerTransport(ctx context.Context, options option.V2RayTransportOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { +func NewServerTransport(ctx context.Context, options option.V2RayTransportOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) { if options.Type == "" { return nil, nil } switch options.Type { case C.V2RayTransportTypeHTTP: - return v2rayhttp.NewServer(ctx, options.HTTPOptions, tlsConfig, handler, errorHandler) + return v2rayhttp.NewServer(ctx, options.HTTPOptions, tlsConfig, handler) case C.V2RayTransportTypeWebsocket: - return v2raywebsocket.NewServer(ctx, options.WebsocketOptions, tlsConfig, handler, errorHandler) + return v2raywebsocket.NewServer(ctx, options.WebsocketOptions, tlsConfig, handler) case C.V2RayTransportTypeQUIC: if tlsConfig == nil { return nil, C.ErrTLSRequired } - return NewQUICServer(ctx, options.QUICOptions, tlsConfig, handler, errorHandler) + return NewQUICServer(ctx, options.QUICOptions, tlsConfig, handler) case C.V2RayTransportTypeGRPC: - return NewGRPCServer(ctx, options.GRPCOptions, tlsConfig, handler, errorHandler) + return NewGRPCServer(ctx, options.GRPCOptions, tlsConfig, handler) default: return nil, E.New("unknown transport type: " + options.Type) } diff --git a/transport/v2raygrpclite/server.go b/transport/v2raygrpclite/server.go index b45c690b..02ca0ae9 100644 --- a/transport/v2raygrpclite/server.go +++ b/transport/v2raygrpclite/server.go @@ -26,7 +26,7 @@ import ( var _ adapter.V2RayServerTransport = (*Server)(nil) type Server struct { - handler N.TCPConnectionHandler + handler adapter.V2RayServerTransportHandler errorHandler E.Handler httpServer *http.Server h2Server *http2.Server @@ -38,12 +38,11 @@ func (s *Server) Network() []string { return []string{N.NetworkTCP} } -func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (*Server, error) { +func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { server := &Server{ - handler: handler, - errorHandler: errorHandler, - path: fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)), - h2Server: new(http2.Server), + handler: handler, + path: fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)), + h2Server: new(http2.Server), } server.httpServer = &http.Server{ Handler: server, @@ -68,19 +67,15 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { return } if request.URL.Path != s.path { - request.Write(os.Stdout) - writer.WriteHeader(http.StatusNotFound) - s.badRequest(request, E.New("bad path: ", request.URL.Path)) + s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) return } if request.Method != http.MethodPost { - writer.WriteHeader(http.StatusNotFound) - s.badRequest(request, E.New("bad method: ", request.Method)) + s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad method: ", request.Method)) return } if ct := request.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/grpc") { - writer.WriteHeader(http.StatusNotFound) - s.badRequest(request, E.New("bad content type: ", ct)) + s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad content type: ", ct)) return } writer.Header().Set("Content-Type", "application/grpc") @@ -93,8 +88,16 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { conn.CloseWrapper() } -func (s *Server) badRequest(request *http.Request, err error) { - s.errorHandler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) +func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) { + conn := v2rayhttp.NewHTTPConn(request.Body, writer) + fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{}) + if fErr == nil { + return + } else if fErr == os.ErrInvalid { + fErr = nil + } + writer.WriteHeader(statusCode) + s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr)) } func (s *Server) Serve(listener net.Listener) error { diff --git a/transport/v2rayhttp/conn.go b/transport/v2rayhttp/conn.go index 89d8baa8..1a22331b 100644 --- a/transport/v2rayhttp/conn.go +++ b/transport/v2rayhttp/conn.go @@ -22,7 +22,7 @@ type HTTPConn struct { err error } -func newHTTPConn(reader io.Reader, writer io.Writer) HTTPConn { +func NewHTTPConn(reader io.Reader, writer io.Writer) HTTPConn { return HTTPConn{ reader: reader, writer: writer, diff --git a/transport/v2rayhttp/server.go b/transport/v2rayhttp/server.go index 1777ca80..25c4e4be 100644 --- a/transport/v2rayhttp/server.go +++ b/transport/v2rayhttp/server.go @@ -24,32 +24,30 @@ import ( var _ adapter.V2RayServerTransport = (*Server)(nil) type Server struct { - ctx context.Context - handler N.TCPConnectionHandler - errorHandler E.Handler - httpServer *http.Server - h2Server *http2.Server - h2cHandler http.Handler - host []string - path string - method string - headers http.Header + ctx context.Context + handler adapter.V2RayServerTransportHandler + httpServer *http.Server + h2Server *http2.Server + h2cHandler http.Handler + host []string + path string + method string + headers http.Header } func (s *Server) Network() []string { return []string{N.NetworkTCP} } -func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (*Server, error) { +func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { server := &Server{ - ctx: ctx, - handler: handler, - errorHandler: errorHandler, - h2Server: new(http2.Server), - host: options.Host, - path: options.Path, - method: options.Method, - headers: make(http.Header), + ctx: ctx, + handler: handler, + h2Server: new(http2.Server), + host: options.Host, + path: options.Path, + method: options.Method, + headers: make(http.Header), } if server.method == "" { server.method = "PUT" @@ -83,18 +81,15 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } host := request.Host if len(s.host) > 0 && !common.Contains(s.host, host) { - writer.WriteHeader(http.StatusBadRequest) - s.badRequest(request, E.New("bad host: ", host)) + s.fallbackRequest(request.Context(), writer, request, http.StatusBadRequest, E.New("bad host: ", host)) return } if !strings.HasPrefix(request.URL.Path, s.path) { - writer.WriteHeader(http.StatusNotFound) - s.badRequest(request, E.New("bad path: ", request.URL.Path)) + s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) return } if request.Method != s.method { - writer.WriteHeader(http.StatusNotFound) - s.badRequest(request, E.New("bad method: ", request.Method)) + s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad method: ", request.Method)) return } @@ -114,14 +109,13 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if h, ok := writer.(http.Hijacker); ok { conn, _, err := h.Hijack() if err != nil { - writer.WriteHeader(http.StatusInternalServerError) - s.badRequest(request, E.Cause(err, "hijack conn")) + s.fallbackRequest(request.Context(), writer, request, http.StatusInternalServerError, E.Cause(err, "hijack conn")) return } s.handler.NewConnection(request.Context(), conn, metadata) } else { conn := NewHTTP2Wrapper(&ServerHTTPConn{ - newHTTPConn(request.Body, writer), + NewHTTPConn(request.Body, writer), writer.(http.Flusher), }) s.handler.NewConnection(request.Context(), conn, metadata) @@ -129,8 +123,16 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } } -func (s *Server) badRequest(request *http.Request, err error) { - s.errorHandler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) +func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) { + conn := NewHTTPConn(request.Body, writer) + fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{}) + if fErr == nil { + return + } else if fErr == os.ErrInvalid { + fErr = nil + } + writer.WriteHeader(statusCode) + s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr)) } func (s *Server) Serve(listener net.Listener) error { diff --git a/transport/v2rayquic/server.go b/transport/v2rayquic/server.go index e9e635b4..8375cef2 100644 --- a/transport/v2rayquic/server.go +++ b/transport/v2rayquic/server.go @@ -23,13 +23,13 @@ type Server struct { ctx context.Context tlsConfig *tls.STDConfig quicConfig *quic.Config - handler N.TCPConnectionHandler + handler adapter.V2RayServerTransportHandler errorHandler E.Handler udpListener net.PacketConn quicListener quic.Listener } -func NewServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (adapter.V2RayServerTransport, error) { +func NewServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (adapter.V2RayServerTransport, error) { quicConfig := &quic.Config{ DisablePathMTUDiscovery: !C.IsLinux && !C.IsWindows, } @@ -41,11 +41,10 @@ func NewServer(ctx context.Context, options option.V2RayQUICOptions, tlsConfig t stdConfig.NextProtos = []string{"h2", "http/1.1"} } server := &Server{ - ctx: ctx, - tlsConfig: stdConfig, - quicConfig: quicConfig, - handler: handler, - errorHandler: errorHandler, + ctx: ctx, + tlsConfig: stdConfig, + quicConfig: quicConfig, + handler: handler, } return server, nil } diff --git a/transport/v2raywebsocket/server.go b/transport/v2raywebsocket/server.go index 6f1c5146..8fa38f2b 100644 --- a/transport/v2raywebsocket/server.go +++ b/transport/v2raywebsocket/server.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/v2rayhttp" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -26,19 +27,17 @@ var _ adapter.V2RayServerTransport = (*Server)(nil) type Server struct { ctx context.Context - handler N.TCPConnectionHandler - errorHandler E.Handler + handler adapter.V2RayServerTransportHandler httpServer *http.Server path string maxEarlyData uint32 earlyDataHeaderName string } -func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig tls.ServerConfig, handler N.TCPConnectionHandler, errorHandler E.Handler) (*Server, error) { +func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { server := &Server{ ctx: ctx, handler: handler, - errorHandler: errorHandler, path: options.Path, maxEarlyData: options.MaxEarlyData, earlyDataHeaderName: options.EarlyDataHeaderName, @@ -71,8 +70,7 @@ var upgrader = websocket.Upgrader{ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" { if request.URL.Path != s.path { - writer.WriteHeader(http.StatusNotFound) - s.badRequest(request, E.New("bad path: ", request.URL.Path)) + s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) return } } @@ -86,8 +84,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { earlyDataStr := request.URL.RequestURI()[len(s.path):] earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr) } else { - writer.WriteHeader(http.StatusNotFound) - s.badRequest(request, E.New("bad path: ", request.URL.Path)) + s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) return } } else { @@ -97,13 +94,12 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } } if err != nil { - writer.WriteHeader(http.StatusBadRequest) - s.badRequest(request, E.Cause(err, "decode early data")) + s.fallbackRequest(request.Context(), writer, request, http.StatusBadRequest, E.Cause(err, "decode early data")) return } wsConn, err := upgrader.Upgrade(writer, request, nil) if err != nil { - s.badRequest(request, E.Cause(err, "upgrade websocket connection")) + s.fallbackRequest(request.Context(), writer, request, http.StatusBadRequest, E.Cause(err, "upgrade websocket connection")) return } var metadata M.Metadata @@ -115,8 +111,16 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { s.handler.NewConnection(request.Context(), conn, metadata) } -func (s *Server) badRequest(request *http.Request, err error) { - s.errorHandler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) +func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) { + conn := v2rayhttp.NewHTTPConn(request.Body, writer) + fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{}) + if fErr == nil { + return + } else if fErr == os.ErrInvalid { + fErr = nil + } + writer.WriteHeader(statusCode) + s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr)) } func (s *Server) Network() []string {