Fix h2c transport

This commit is contained in:
世界 2022-11-09 11:49:01 +08:00
parent eb2e8a0b40
commit 5510c474c7
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 29 additions and 8 deletions

View file

@ -13,6 +13,7 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"golang.org/x/net/http2"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/backoff" "google.golang.org/grpc/backoff"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
@ -34,6 +35,7 @@ type Client struct {
func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) { func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayGRPCOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
var dialOptions []grpc.DialOption var dialOptions []grpc.DialOption
if tlsConfig != nil { if tlsConfig != nil {
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
dialOptions = append(dialOptions, grpc.WithTransportCredentials(NewTLSTransportCredentials(tlsConfig))) dialOptions = append(dialOptions, grpc.WithTransportCredentials(NewTLSTransportCredentials(tlsConfig)))
} else { } else {
dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))

View file

@ -19,6 +19,7 @@ import (
sHttp "github.com/sagernet/sing/protocol/http" sHttp "github.com/sagernet/sing/protocol/http"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
) )
var _ adapter.V2RayServerTransport = (*Server)(nil) var _ adapter.V2RayServerTransport = (*Server)(nil)
@ -27,6 +28,8 @@ type Server struct {
handler N.TCPConnectionHandler handler N.TCPConnectionHandler
errorHandler E.Handler errorHandler E.Handler
httpServer *http.Server httpServer *http.Server
h2Server *http2.Server
h2cHandler http.Handler
path string path string
} }
@ -39,10 +42,12 @@ func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig t
handler: handler, handler: handler,
errorHandler: errorHandler, errorHandler: errorHandler,
path: fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)), path: fmt.Sprintf("/%s/Tun", url.QueryEscape(options.ServiceName)),
h2Server: new(http2.Server),
} }
server.httpServer = &http.Server{ server.httpServer = &http.Server{
Handler: server, Handler: server,
} }
server.h2cHandler = h2c.NewHandler(server, server.h2Server)
if tlsConfig != nil { if tlsConfig != nil {
stdConfig, err := tlsConfig.Config() stdConfig, err := tlsConfig.Config()
if err != nil { if err != nil {
@ -57,7 +62,12 @@ func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig t
} }
func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" {
s.h2cHandler.ServeHTTP(writer, request)
return
}
if request.URL.Path != s.path { if request.URL.Path != s.path {
request.Write(os.Stdout)
writer.WriteHeader(http.StatusNotFound) writer.WriteHeader(http.StatusNotFound)
s.badRequest(request, E.New("bad path: ", request.URL.Path)) s.badRequest(request, E.New("bad path: ", request.URL.Path))
return return
@ -86,13 +96,13 @@ func (s *Server) badRequest(request *http.Request, err error) {
} }
func (s *Server) Serve(listener net.Listener) error { func (s *Server) Serve(listener net.Listener) error {
err := http2.ConfigureServer(s.httpServer, s.h2Server)
if err != nil {
return err
}
if s.httpServer.TLSConfig == nil { if s.httpServer.TLSConfig == nil {
return s.httpServer.Serve(listener) return s.httpServer.Serve(listener)
} else { } else {
err := http2.ConfigureServer(s.httpServer, &http2.Server{})
if err != nil {
return err
}
return s.httpServer.ServeTLS(listener, "", "") return s.httpServer.ServeTLS(listener, "", "")
} }
} }

View file

@ -18,6 +18,7 @@ import (
sHttp "github.com/sagernet/sing/protocol/http" sHttp "github.com/sagernet/sing/protocol/http"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
) )
var _ adapter.V2RayServerTransport = (*Server)(nil) var _ adapter.V2RayServerTransport = (*Server)(nil)
@ -27,6 +28,8 @@ type Server struct {
handler N.TCPConnectionHandler handler N.TCPConnectionHandler
errorHandler E.Handler errorHandler E.Handler
httpServer *http.Server httpServer *http.Server
h2Server *http2.Server
h2cHandler http.Handler
host []string host []string
path string path string
method string method string
@ -42,6 +45,7 @@ func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig t
ctx: ctx, ctx: ctx,
handler: handler, handler: handler,
errorHandler: errorHandler, errorHandler: errorHandler,
h2Server: new(http2.Server),
host: options.Host, host: options.Host,
path: options.Path, path: options.Path,
method: options.Method, method: options.Method,
@ -61,6 +65,7 @@ func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig t
ReadHeaderTimeout: C.TCPTimeout, ReadHeaderTimeout: C.TCPTimeout,
MaxHeaderBytes: http.DefaultMaxHeaderBytes, MaxHeaderBytes: http.DefaultMaxHeaderBytes,
} }
server.h2cHandler = h2c.NewHandler(server, server.h2Server)
if tlsConfig != nil { if tlsConfig != nil {
stdConfig, err := tlsConfig.Config() stdConfig, err := tlsConfig.Config()
if err != nil { if err != nil {
@ -72,6 +77,10 @@ func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig t
} }
func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" {
s.h2cHandler.ServeHTTP(writer, request)
return
}
host := request.Host host := request.Host
if len(s.host) > 0 && !common.Contains(s.host, host) { if len(s.host) > 0 && !common.Contains(s.host, host) {
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
@ -124,13 +133,13 @@ func (s *Server) badRequest(request *http.Request, err error) {
} }
func (s *Server) Serve(listener net.Listener) error { func (s *Server) Serve(listener net.Listener) error {
err := http2.ConfigureServer(s.httpServer, s.h2Server)
if err != nil {
return err
}
if s.httpServer.TLSConfig == nil { if s.httpServer.TLSConfig == nil {
return s.httpServer.Serve(listener) return s.httpServer.Serve(listener)
} else { } else {
err := http2.ConfigureServer(s.httpServer, &http2.Server{})
if err != nil {
return err
}
return s.httpServer.ServeTLS(listener, "", "") return s.httpServer.ServeTLS(listener, "", "")
} }
} }