Fixing tcp connestions leak

- always use HandshakeContext instead of Handshake

- pickup dailer dropped ctx

- rename HandshakeContextAddress to HandshakeAddressContext
This commit is contained in:
deorth-kku 2024-02-03 19:45:37 +08:00 committed by yuhan6665
parent 5ea1315b85
commit cae94570df
7 changed files with 38 additions and 18 deletions

View file

@ -71,8 +71,8 @@ func (d *DokodemoDoor) policy() policy.Session {
return p return p
} }
type hasHandshakeAddress interface { type hasHandshakeAddressContext interface {
HandshakeAddress() net.Address HandshakeAddressContext(ctx context.Context) net.Address
} }
// Process implements proxy.Inbound. // Process implements proxy.Inbound.
@ -89,8 +89,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() { if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
dest = outbound.Target dest = outbound.Target
destinationOverridden = true destinationOverridden = true
} else if handshake, ok := conn.(hasHandshakeAddress); ok { } else if handshake, ok := conn.(hasHandshakeAddressContext); ok {
addr := handshake.HandshakeAddress() addr := handshake.HandshakeAddressContext(ctx)
if addr != nil { if addr != nil {
dest.Address = addr dest.Address = addr
destinationOverridden = true destinationOverridden = true

View file

@ -308,7 +308,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
nextProto := "" nextProto := ""
if tlsConn, ok := iConn.(*tls.Conn); ok { if tlsConn, ok := iConn.(*tls.Conn); ok {
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.HandshakeContext(ctx); err != nil {
rawConn.Close() rawConn.Close()
return nil, err return nil, err
} }

View file

@ -87,7 +87,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
} else { } else {
cn = tls.Client(pconn, tlsConfig).(*tls.Conn) cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
} }
if err := cn.Handshake(); err != nil { if err := cn.HandshakeContext(ctx); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err return nil, err
} }

View file

@ -24,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
tlsConfig := config.GetTLSConfig(tls.WithDestination(dest)) tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil { if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
conn = tls.UClient(conn, tlsConfig, fingerprint) conn = tls.UClient(conn, tlsConfig, fingerprint)
if err := conn.(*tls.UConn).Handshake(); err != nil { if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil {
return nil, err return nil, err
} }
} else { } else {

View file

@ -65,7 +65,7 @@ func (c *grpcUtls) ClientHandshake(ctx context.Context, authority string, rawCon
conn := UClient(rawConn, cfg, c.fingerprint).(*UConn) conn := UClient(rawConn, cfg, c.fingerprint).(*UConn)
errChannel := make(chan error, 1) errChannel := make(chan error, 1)
go func() { go func() {
errChannel <- conn.Handshake() errChannel <- conn.HandshakeContext(ctx)
close(errChannel) close(errChannel)
}() }()
select { select {

View file

@ -1,9 +1,11 @@
package tls package tls
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"math/big" "math/big"
"time"
utls "github.com/refraction-networking/utls" utls "github.com/refraction-networking/utls"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
@ -14,7 +16,7 @@ import (
type Interface interface { type Interface interface {
net.Conn net.Conn
Handshake() error HandshakeContext(ctx context.Context) error
VerifyHostname(host string) error VerifyHostname(host string) error
NegotiatedProtocol() (name string, mutual bool) NegotiatedProtocol() (name string, mutual bool)
} }
@ -25,6 +27,16 @@ type Conn struct {
*tls.Conn *tls.Conn
} }
const tlsCloseTimeout = 250 * time.Millisecond
func (c *Conn) Close() error {
timer := time.AfterFunc(tlsCloseTimeout, func() {
c.Conn.NetConn().Close()
})
defer timer.Stop()
return c.Conn.Close()
}
func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error { func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
mb = buf.Compact(mb) mb = buf.Compact(mb)
mb, err := buf.WriteMultiBuffer(c, mb) mb, err := buf.WriteMultiBuffer(c, mb)
@ -32,8 +44,8 @@ func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
return err return err
} }
func (c *Conn) HandshakeAddress() net.Address { func (c *Conn) HandshakeAddressContext(ctx context.Context) net.Address {
if err := c.Handshake(); err != nil { if err := c.HandshakeContext(ctx); err != nil {
return nil return nil
} }
state := c.ConnectionState() state := c.ConnectionState()
@ -64,8 +76,16 @@ type UConn struct {
*utls.UConn *utls.UConn
} }
func (c *UConn) HandshakeAddress() net.Address { func (c *UConn) Close() error {
if err := c.Handshake(); err != nil { timer := time.AfterFunc(tlsCloseTimeout, func() {
c.Conn.NetConn().Close()
})
defer timer.Stop()
return c.Conn.Close()
}
func (c *UConn) HandshakeAddressContext(ctx context.Context) net.Address {
if err := c.HandshakeContext(ctx); err != nil {
return nil return nil
} }
state := c.ConnectionState() state := c.ConnectionState()
@ -77,7 +97,7 @@ func (c *UConn) HandshakeAddress() net.Address {
// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send // WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
// http/1.1 in its ALPN. // http/1.1 in its ALPN.
func (c *UConn) WebsocketHandshake() error { func (c *UConn) WebsocketHandshakeContext(ctx context.Context) error {
// Build the handshake state. This will apply every variable of the TLS of the // Build the handshake state. This will apply every variable of the TLS of the
// fingerprint in the UConn // fingerprint in the UConn
if err := c.BuildHandshakeState(); err != nil { if err := c.BuildHandshakeState(); err != nil {
@ -99,7 +119,7 @@ func (c *UConn) WebsocketHandshake() error {
if err := c.BuildHandshakeState(); err != nil { if err := c.BuildHandshakeState(); err != nil {
return err return err
} }
return c.Handshake() return c.HandshakeContext(ctx)
} }
func (c *UConn) NegotiatedProtocol() (name string, mutual bool) { func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
@ -118,7 +138,7 @@ func copyConfig(c *tls.Config) *utls.Config {
ServerName: c.ServerName, ServerName: c.ServerName,
InsecureSkipVerify: c.InsecureSkipVerify, InsecureSkipVerify: c.InsecureSkipVerify,
VerifyPeerCertificate: c.VerifyPeerCertificate, VerifyPeerCertificate: c.VerifyPeerCertificate,
KeyLogWriter: c.KeyLogWriter, KeyLogWriter: c.KeyLogWriter,
} }
} }

View file

@ -96,7 +96,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
} }
// TLS and apply the handshake // TLS and apply the handshake
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
if err := cn.WebsocketHandshake(); err != nil { if err := cn.WebsocketHandshakeContext(ctx); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err return nil, err
} }
@ -147,7 +147,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed)) header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
} }
conn, resp, err := dialer.Dial(uri, header) conn, resp, err := dialer.DialContext(ctx, uri, header)
if err != nil { if err != nil {
var reason string var reason string
if resp != nil { if resp != nil {