Transport: Add HTTP3 to HTTP (#3819)

This commit is contained in:
yuhan6665 2024-09-25 21:29:41 -04:00 committed by GitHub
parent 7086d286be
commit 3632e83faa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 316 additions and 128 deletions

View file

@ -650,7 +650,7 @@ func (p TransportProtocol) Build() (string, error) {
return "mkcp", nil return "mkcp", nil
case "ws", "websocket": case "ws", "websocket":
return "websocket", nil return "websocket", nil
case "h2", "http": case "h2", "h3", "http":
return "http", nil return "http", nil
case "grpc", "gun": case "grpc", "gun":
return "grpc", nil return "grpc", nil

View file

@ -9,6 +9,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx" c "github.com/xtls/xray-core/common/ctx"
@ -24,6 +26,13 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
// defines the maximum time an idle TCP session can survive in the tunnel, so
// it should be consistent across HTTP versions and with other transports.
const connIdleTimeout = 300 * time.Second
// consistent with quic-go
const h3KeepalivePeriod = 10 * time.Second
type dialerConf struct { type dialerConf struct {
net.Destination net.Destination
*internet.MemoryStreamConfig *internet.MemoryStreamConfig
@ -48,13 +57,70 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
if tlsConfigs == nil && realityConfigs == nil { if tlsConfigs == nil && realityConfigs == nil {
return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning() return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning()
} }
isH3 := tlsConfigs != nil && (len(tlsConfigs.NextProtocol) == 1 && tlsConfigs.NextProtocol[0] == "h3")
if isH3 {
dest.Network = net.Network_UDP
}
sockopt := streamSettings.SocketSettings sockopt := streamSettings.SocketSettings
if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found { if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
return client, nil return client, nil
} }
transport := &http2.Transport{ var transport http.RoundTripper
if isH3 {
quicConfig := &quic.Config{
MaxIdleTimeout: connIdleTimeout,
// these two are defaults of quic-go/http3. the default of quic-go (no
// http3) is different, so it is hardcoded here for clarity.
// https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
MaxIncomingStreams: -1,
KeepAlivePeriod: h3KeepalivePeriod,
}
roundTripper := &http3.RoundTripper{
QUICConfig: quicConfig,
TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)),
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
return nil, err
}
var udpConn net.PacketConn
var udpAddr *net.UDPAddr
switch c := conn.(type) {
case *internet.PacketConnWrapper:
var ok bool
udpConn, ok = c.Conn.(*net.UDPConn)
if !ok {
return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
}
udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
if err != nil {
return nil, err
}
case *net.UDPConn:
udpConn = c
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
default:
udpConn = &internet.FakePacketConn{c}
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
}
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
},
}
transport = roundTripper
} else {
transportH2 := &http2.Transport{
DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
rawHost, rawPort, err := net.SplitHostPort(addr) rawHost, rawPort, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
@ -106,14 +172,14 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
return cn, nil return cn, nil
}, },
} }
if tlsConfigs != nil { if tlsConfigs != nil {
transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest)) transportH2.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
} }
if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 { if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout) transportH2.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout) transportH2.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
}
transport = transportH2
} }
client := &http.Client{ client := &http.Client{
@ -158,9 +224,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
Host: dest.NetAddr(), Host: dest.NetAddr(),
Path: httpSettings.getNormalizedPath(), Path: httpSettings.getNormalizedPath(),
}, },
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders, Header: httpHeaders,
} }
// Disable any compression method from server. // Disable any compression method from server.

View file

@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol/tls/cert" "github.com/xtls/xray-core/common/protocol/tls/cert"
"github.com/xtls/xray-core/testing/servers/tcp" "github.com/xtls/xray-core/testing/servers/tcp"
"github.com/xtls/xray-core/testing/servers/udp"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
. "github.com/xtls/xray-core/transport/internet/http" . "github.com/xtls/xray-core/transport/internet/http"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
@ -92,3 +93,80 @@ func TestHTTPConnection(t *testing.T) {
t.Error(r) t.Error(r)
} }
} }
func TestH3Connection(t *testing.T) {
port := udp.PickPort()
listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.example.com")))},
},
}, func(conn stat.Connection) {
go func() {
defer conn.Close()
b := buf.New()
defer b.Release()
for {
if _, err := b.ReadFrom(conn); err != nil {
return
}
_, err := conn.Write(b.Bytes())
common.Must(err)
}
}()
})
common.Must(err)
defer listener.Close()
time.Sleep(time.Second)
dctx := context.Background()
conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
ServerName: "www.example.com",
AllowInsecure: true,
},
})
common.Must(err)
defer conn.Close()
const N = 1024
b1 := make([]byte, N)
common.Must2(rand.Read(b1))
b2 := buf.New()
nBytes, err := conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}
b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
nBytes, err = conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}
b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
}

View file

@ -2,11 +2,14 @@ package http
import ( import (
"context" "context"
gotls "crypto/tls"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
goreality "github.com/xtls/reality" goreality "github.com/xtls/reality"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
@ -24,9 +27,11 @@ import (
type Listener struct { type Listener struct {
server *http.Server server *http.Server
h3server *http3.Server
handler internet.ConnHandler handler internet.ConnHandler
local net.Addr local net.Addr
config *Config config *Config
isH3 bool
} }
func (l *Listener) Addr() net.Addr { func (l *Listener) Addr() net.Addr {
@ -34,7 +39,14 @@ func (l *Listener) Addr() net.Addr {
} }
func (l *Listener) Close() error { func (l *Listener) Close() error {
if l.h3server != nil {
if err := l.h3server.Close(); err != nil {
return err
}
} else if l.server != nil {
return l.server.Close() return l.server.Close()
}
return errors.New("listener does not have an HTTP/3 server or h2 server")
} }
type flushWriter struct { type flushWriter struct {
@ -119,29 +131,61 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request)
func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
httpSettings := streamSettings.ProtocolSettings.(*Config) httpSettings := streamSettings.ProtocolSettings.(*Config)
var listener *Listener config := tls.ConfigFromStreamSettings(streamSettings)
if port == net.Port(0) { // unix var tlsConfig *gotls.Config
listener = &Listener{ if config == nil {
tlsConfig = &gotls.Config{}
} else {
tlsConfig = config.GetTLSConfig()
}
isH3 := len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
listener := &Listener{
handler: handler, handler: handler,
local: &net.UnixAddr{ config: httpSettings,
isH3: isH3,
}
if port == net.Port(0) { // unix
listener.local = &net.UnixAddr{
Name: address.Domain(), Name: address.Domain(),
Net: "unix", Net: "unix",
},
config: httpSettings,
} }
} else { // tcp } else if isH3 { // udp
listener = &Listener{ listener.local = &net.UDPAddr{
handler: handler, IP: address.IP(),
local: &net.TCPAddr{ Port: int(port),
}
} else {
listener.local = &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
},
config: httpSettings,
} }
} }
if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
errors.LogWarning(ctx, "accepting PROXY protocol")
}
if isH3 {
Conn, err := internet.ListenSystemPacket(context.Background(), listener.local, streamSettings.SocketSettings)
if err != nil {
return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
}
h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
if err != nil {
return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
}
errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
listener.h3server = &http3.Server{
Handler: listener,
}
go func() {
if err := listener.h3server.ServeListener(h3listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
}
}()
} else {
var server *http.Server var server *http.Server
config := tls.ConfigFromStreamSettings(streamSettings)
if config == nil { if config == nil {
h2s := &http2.Server{} h2s := &http2.Server{}
@ -159,10 +203,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
} }
} }
if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
errors.LogWarning(ctx, "accepting PROXY protocol")
}
listener.server = server listener.server = server
go func() { go func() {
var streamListener net.Listener var streamListener net.Listener
@ -202,6 +242,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
} }
} }
}() }()
}
return listener, nil return listener, nil
} }

View file

@ -365,7 +365,13 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
// Addr implements net.Listener.Addr(). // Addr implements net.Listener.Addr().
func (ln *Listener) Addr() net.Addr { func (ln *Listener) Addr() net.Addr {
if ln.h3listener != nil {
return ln.h3listener.Addr()
}
if ln.listener != nil {
return ln.listener.Addr() return ln.listener.Addr()
}
return nil
} }
// Close implements net.Listener.Close(). // Close implements net.Listener.Close().