diff --git a/common/tls/reality_client.go b/common/tls/reality_client.go index 8f3f6b83..1673749c 100644 --- a/common/tls/reality_client.go +++ b/common/tls/reality_client.go @@ -111,6 +111,16 @@ func (e *RealityClientConfig) ClientHandshake(ctx context.Context, conn net.Conn if err != nil { return nil, err } + + if len(uConfig.NextProtos) > 0 { + for _, extension := range uConn.Extensions { + if alpnExtension, isALPN := extension.(*utls.ALPNExtension); isALPN { + alpnExtension.AlpnProtocols = uConfig.NextProtos + break + } + } + } + hello := uConn.HandshakeState.Hello hello.SessionId = make([]byte, 32) copy(hello.Raw[39:], hello.SessionId) diff --git a/common/tls/utls_client.go b/common/tls/utls_client.go index cedc2712..b5f38eab 100644 --- a/common/tls/utls_client.go +++ b/common/tls/utls_client.go @@ -3,6 +3,7 @@ package tls import ( + "context" "crypto/tls" "crypto/x509" "math/rand" @@ -47,7 +48,7 @@ func (e *UTLSClientConfig) Config() (*STDConfig, error) { } func (e *UTLSClientConfig) Client(conn net.Conn) (Conn, error) { - return &utlsConnWrapper{utls.UClient(conn, e.config.Clone(), e.id)}, nil + return &utlsALPNWrapper{utlsConnWrapper{utls.UClient(conn, e.config.Clone(), e.id)}, e.config.NextProtos}, nil } func (e *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) { @@ -87,6 +88,31 @@ func (c *utlsConnWrapper) Upstream() any { return c.UConn } +type utlsALPNWrapper struct { + utlsConnWrapper + nextProtocols []string +} + +func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error { + if len(c.nextProtocols) > 0 { + err := c.BuildHandshakeState() + if err != nil { + return err + } + for _, extension := range c.Extensions { + if alpnExtension, isALPN := extension.(*utls.ALPNExtension); isALPN { + alpnExtension.AlpnProtocols = c.nextProtocols + err = c.BuildHandshakeState() + if err != nil { + return err + } + break + } + } + } + return c.UConn.HandshakeContext(ctx) +} + func NewUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (*UTLSClientConfig, error) { var serverName string if options.ServerName != "" {