mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-29 03:51:28 +00:00
Fixing tcp connestions leak
- always use HandshakeContext instead of Handshake - pickup dailer dropped ctx - rename HandshakeContextAddress to HandshakeAddressContext
This commit is contained in:
parent
5ea1315b85
commit
cae94570df
|
@ -71,8 +71,8 @@ func (d *DokodemoDoor) policy() policy.Session {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
type hasHandshakeAddress interface {
|
type hasHandshakeAddressContext interface {
|
||||||
HandshakeAddress() net.Address
|
HandshakeAddressContext(ctx context.Context) net.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process implements proxy.Inbound.
|
// Process implements proxy.Inbound.
|
||||||
|
@ -89,8 +89,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
|
||||||
if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
|
if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
|
||||||
dest = outbound.Target
|
dest = outbound.Target
|
||||||
destinationOverridden = true
|
destinationOverridden = true
|
||||||
} else if handshake, ok := conn.(hasHandshakeAddress); ok {
|
} else if handshake, ok := conn.(hasHandshakeAddressContext); ok {
|
||||||
addr := handshake.HandshakeAddress()
|
addr := handshake.HandshakeAddressContext(ctx)
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
dest.Address = addr
|
dest.Address = addr
|
||||||
destinationOverridden = true
|
destinationOverridden = true
|
||||||
|
|
|
@ -308,7 +308,7 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
|
||||||
|
|
||||||
nextProto := ""
|
nextProto := ""
|
||||||
if tlsConn, ok := iConn.(*tls.Conn); ok {
|
if tlsConn, ok := iConn.(*tls.Conn); ok {
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||||
rawConn.Close()
|
rawConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
|
||||||
} else {
|
} else {
|
||||||
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
|
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
|
||||||
}
|
}
|
||||||
if err := cn.Handshake(); err != nil {
|
if err := cn.HandshakeContext(ctx); err != nil {
|
||||||
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
|
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
|
||||||
tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
|
tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
|
||||||
if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
|
if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
|
||||||
conn = tls.UClient(conn, tlsConfig, fingerprint)
|
conn = tls.UClient(conn, tlsConfig, fingerprint)
|
||||||
if err := conn.(*tls.UConn).Handshake(); err != nil {
|
if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -65,7 +65,7 @@ func (c *grpcUtls) ClientHandshake(ctx context.Context, authority string, rawCon
|
||||||
conn := UClient(rawConn, cfg, c.fingerprint).(*UConn)
|
conn := UClient(rawConn, cfg, c.fingerprint).(*UConn)
|
||||||
errChannel := make(chan error, 1)
|
errChannel := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errChannel <- conn.Handshake()
|
errChannel <- conn.HandshakeContext(ctx)
|
||||||
close(errChannel)
|
close(errChannel)
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package tls
|
package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"time"
|
||||||
|
|
||||||
utls "github.com/refraction-networking/utls"
|
utls "github.com/refraction-networking/utls"
|
||||||
"github.com/xtls/xray-core/common/buf"
|
"github.com/xtls/xray-core/common/buf"
|
||||||
|
@ -14,7 +16,7 @@ import (
|
||||||
|
|
||||||
type Interface interface {
|
type Interface interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
Handshake() error
|
HandshakeContext(ctx context.Context) error
|
||||||
VerifyHostname(host string) error
|
VerifyHostname(host string) error
|
||||||
NegotiatedProtocol() (name string, mutual bool)
|
NegotiatedProtocol() (name string, mutual bool)
|
||||||
}
|
}
|
||||||
|
@ -25,6 +27,16 @@ type Conn struct {
|
||||||
*tls.Conn
|
*tls.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tlsCloseTimeout = 250 * time.Millisecond
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
timer := time.AfterFunc(tlsCloseTimeout, func() {
|
||||||
|
c.Conn.NetConn().Close()
|
||||||
|
})
|
||||||
|
defer timer.Stop()
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||||
mb = buf.Compact(mb)
|
mb = buf.Compact(mb)
|
||||||
mb, err := buf.WriteMultiBuffer(c, mb)
|
mb, err := buf.WriteMultiBuffer(c, mb)
|
||||||
|
@ -32,8 +44,8 @@ func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) HandshakeAddress() net.Address {
|
func (c *Conn) HandshakeAddressContext(ctx context.Context) net.Address {
|
||||||
if err := c.Handshake(); err != nil {
|
if err := c.HandshakeContext(ctx); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
state := c.ConnectionState()
|
state := c.ConnectionState()
|
||||||
|
@ -64,8 +76,16 @@ type UConn struct {
|
||||||
*utls.UConn
|
*utls.UConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UConn) HandshakeAddress() net.Address {
|
func (c *UConn) Close() error {
|
||||||
if err := c.Handshake(); err != nil {
|
timer := time.AfterFunc(tlsCloseTimeout, func() {
|
||||||
|
c.Conn.NetConn().Close()
|
||||||
|
})
|
||||||
|
defer timer.Stop()
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UConn) HandshakeAddressContext(ctx context.Context) net.Address {
|
||||||
|
if err := c.HandshakeContext(ctx); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
state := c.ConnectionState()
|
state := c.ConnectionState()
|
||||||
|
@ -77,7 +97,7 @@ func (c *UConn) HandshakeAddress() net.Address {
|
||||||
|
|
||||||
// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
|
// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
|
||||||
// http/1.1 in its ALPN.
|
// http/1.1 in its ALPN.
|
||||||
func (c *UConn) WebsocketHandshake() error {
|
func (c *UConn) WebsocketHandshakeContext(ctx context.Context) error {
|
||||||
// Build the handshake state. This will apply every variable of the TLS of the
|
// Build the handshake state. This will apply every variable of the TLS of the
|
||||||
// fingerprint in the UConn
|
// fingerprint in the UConn
|
||||||
if err := c.BuildHandshakeState(); err != nil {
|
if err := c.BuildHandshakeState(); err != nil {
|
||||||
|
@ -99,7 +119,7 @@ func (c *UConn) WebsocketHandshake() error {
|
||||||
if err := c.BuildHandshakeState(); err != nil {
|
if err := c.BuildHandshakeState(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.Handshake()
|
return c.HandshakeContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
|
func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
|
||||||
|
|
|
@ -96,7 +96,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
|
||||||
}
|
}
|
||||||
// TLS and apply the handshake
|
// TLS and apply the handshake
|
||||||
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
|
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
|
||||||
if err := cn.WebsocketHandshake(); err != nil {
|
if err := cn.WebsocketHandshakeContext(ctx); err != nil {
|
||||||
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
|
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -147,7 +147,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
|
||||||
header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
|
header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, resp, err := dialer.Dial(uri, header)
|
conn, resp, err := dialer.DialContext(ctx, uri, header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var reason string
|
var reason string
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
|
|
Loading…
Reference in a new issue