Improve connection timeout

This commit is contained in:
世界 2022-07-18 20:40:14 +08:00
parent 3fb011712b
commit c7fabe40ed
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
14 changed files with 152 additions and 111 deletions

5
box.go
View file

@ -143,9 +143,12 @@ func (s *Box) Start() error {
if err != nil { if err != nil {
return err return err
} }
for _, in := range s.inbounds { for i, in := range s.inbounds {
err = in.Start() err = in.Start()
if err != nil { if err != nil {
for g := 0; g < i; g++ {
s.inbounds[g].Close()
}
return err return err
} }
} }

View file

@ -59,6 +59,8 @@ func NewDefault(router adapter.Router, options option.DialerOptions) *DefaultDia
} }
if options.ConnectTimeout != 0 { if options.ConnectTimeout != 0 {
dialer.Timeout = time.Duration(options.ConnectTimeout) dialer.Timeout = time.Duration(options.ConnectTimeout)
} else {
dialer.Timeout = C.DefaultTCPTimeout
} }
return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener} return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener}
} }

View file

@ -6,7 +6,6 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -24,8 +23,5 @@ func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.
if domainStrategy != dns.DomainStrategyAsIS || options.Detour == "" { if domainStrategy != dns.DomainStrategyAsIS || options.Detour == "" {
dialer = NewResolveDialer(router, dialer, domainStrategy, time.Duration(options.FallbackDelay)) dialer = NewResolveDialer(router, dialer, domainStrategy, time.Duration(options.FallbackDelay))
} }
if options.OverrideOptions.IsValid() {
dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions))
}
return dialer return dialer
} }

View file

@ -1,69 +0,0 @@
package dialer
import (
"context"
"crypto/tls"
"net"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
)
var _ N.Dialer = (*OverrideDialer)(nil)
type OverrideDialer struct {
upstream N.Dialer
tlsEnabled bool
tlsConfig tls.Config
uotEnabled bool
}
func NewOverride(upstream N.Dialer, options option.OverrideStreamOptions) N.Dialer {
return &OverrideDialer{
upstream,
options.TLS,
tls.Config{
ServerName: options.TLSServerName,
InsecureSkipVerify: options.TLSInsecure,
},
options.UDPOverTCP,
}
}
func (d *OverrideDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network {
case C.NetworkTCP:
conn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination)
if err != nil {
return nil, err
}
return tls.Client(conn, &d.tlsConfig), nil
case C.NetworkUDP:
if d.uotEnabled {
tcpConn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination)
if err != nil {
return nil, err
}
return uot.NewClientConn(tcpConn), nil
}
}
return d.upstream.DialContext(ctx, network, destination)
}
func (d *OverrideDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if d.uotEnabled {
tcpConn, err := d.upstream.DialContext(ctx, C.NetworkTCP, destination)
if err != nil {
return nil, err
}
return uot.NewClientConn(tcpConn), nil
}
return d.upstream.ListenPacket(ctx, destination)
}
func (d *OverrideDialer) Upstream() any {
return d.upstream
}

86
common/dialer/tls.go Normal file
View file

@ -0,0 +1,86 @@
package dialer
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/netip"
"os"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TLSDialer struct {
dialer N.Dialer
config *tls.Config
}
func NewTLS(dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
if !options.Enabled {
return dialer, nil
}
var serverName string
if options.ServerName != "" {
serverName = options.ServerName
} else if serverAddress != "" {
if _, err := netip.ParseAddr(serverName); err != nil {
serverName = serverAddress
}
}
if serverName == "" && options.Insecure {
return nil, E.New("missing server_name or insecure=true")
}
var tlsConfig tls.Config
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {
tlsConfig.ServerName = serverName
}
if options.Insecure {
tlsConfig.InsecureSkipVerify = options.Insecure
} else if options.DisableSNI {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyConnection = func(state tls.ConnectionState) error {
verifyOptions := x509.VerifyOptions{
DNSName: serverName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range state.PeerCertificates[1:] {
verifyOptions.Intermediates.AddCert(cert)
}
_, err := state.PeerCertificates[0].Verify(verifyOptions)
return err
}
}
return &TLSDialer{
dialer: dialer,
config: &tlsConfig,
}, nil
}
func (d *TLSDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if network != C.NetworkTCP {
return nil, os.ErrInvalid
}
conn, err := d.dialer.DialContext(ctx, network, destination)
if err != nil {
return nil, err
}
tlsConn := tls.Client(conn, d.config)
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
err = tlsConn.HandshakeContext(ctx)
return tlsConn, err
}
func (d *TLSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return nil, os.ErrInvalid
}

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
@ -19,7 +20,7 @@ type (
) )
func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, sniffers ...StreamSniffer) (*adapter.InboundContext, error) { func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, sniffers ...StreamSniffer) (*adapter.InboundContext, error) {
err := conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond)) err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
if err != nil { if err != nil {
return nil, err return nil, err
} }

8
constant/timeout.go Normal file
View file

@ -0,0 +1,8 @@
package constant
import "time"
const (
DefaultTCPTimeout = 5 * time.Second
ReadPayloadTimeout = 300 * time.Millisecond
)

View file

@ -82,22 +82,10 @@ type DialerOptions struct {
type OutboundDialerOptions struct { type OutboundDialerOptions struct {
DialerOptions DialerOptions
OverrideOptions *OverrideStreamOptions `json:"override,omitempty"`
DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"`
FallbackDelay Duration `json:"fallback_delay,omitempty"` FallbackDelay Duration `json:"fallback_delay,omitempty"`
} }
type OverrideStreamOptions struct {
TLS bool `json:"tls,omitempty"`
TLSServerName string `json:"tls_servername,omitempty"`
TLSInsecure bool `json:"tls_insecure,omitempty"`
UDPOverTCP bool `json:"udp_over_tcp,omitempty"`
}
func (o *OverrideStreamOptions) IsValid() bool {
return o != nil && (o.TLS || o.UDPOverTCP)
}
type ServerOptions struct { type ServerOptions struct {
Server string `json:"server"` Server string `json:"server"`
ServerPort uint16 `json:"server_port"` ServerPort uint16 `json:"server_port"`
@ -127,6 +115,14 @@ type HTTPOutboundOptions struct {
ServerOptions ServerOptions
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`
TLSOptions *OutboundTLSOptions `json:"tls,omitempty"`
}
type OutboundTLSOptions struct {
Enabled bool `json:"enabled,omitempty"`
DisableSNI bool `json:"disable_sni,omitempty"`
ServerName string `json:"server_name,omitempty"`
Insecure bool `json:"insecure,omitempty"`
} }
type ShadowsocksOutboundOptions struct { type ShadowsocksOutboundOptions struct {
@ -146,4 +142,5 @@ type VMessOutboundOptions struct {
GlobalPadding bool `json:"global_padding,omitempty"` GlobalPadding bool `json:"global_padding,omitempty"`
AuthenticatedLength bool `json:"authenticated_length,omitempty"` AuthenticatedLength bool `json:"authenticated_length,omitempty"`
Network NetworkList `json:"network,omitempty"` Network NetworkList `json:"network,omitempty"`
TLSOptions *OutboundTLSOptions `json:"tls,omitempty"`
} }

View file

@ -21,7 +21,7 @@ func New(router adapter.Router, logger log.ContextLogger, options option.Outboun
case C.TypeSocks: case C.TypeSocks:
return NewSocks(router, logger, options.Tag, options.SocksOptions) return NewSocks(router, logger, options.Tag, options.SocksOptions)
case C.TypeHTTP: case C.TypeHTTP:
return NewHTTP(router, logger, options.Tag, options.HTTPOptions), nil return NewHTTP(router, logger, options.Tag, options.HTTPOptions)
case C.TypeShadowsocks: case C.TypeShadowsocks:
return NewShadowsocks(router, logger, options.Tag, options.ShadowsocksOptions) return NewShadowsocks(router, logger, options.Tag, options.ShadowsocksOptions)
case C.TypeVMess: case C.TypeVMess:

View file

@ -93,7 +93,7 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro
} }
_payload := buf.StackNew() _payload := buf.StackNew()
payload := common.Dup(_payload) payload := common.Dup(_payload)
err := conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond)) err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
if err != nil { if err != nil {
return err return err
} }

View file

@ -10,6 +10,7 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/protocol/http" "github.com/sagernet/sing/protocol/http"
@ -22,7 +23,11 @@ type HTTP struct {
client *http.Client client *http.Client
} }
func NewHTTP(router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) *HTTP { func NewHTTP(router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (*HTTP, error) {
detour, err := dialer.NewTLS(dialer.NewOutbound(router, options.OutboundDialerOptions), options.Server, common.PtrValueOrDefault(options.TLSOptions))
if err != nil {
return nil, err
}
return &HTTP{ return &HTTP{
myOutboundAdapter{ myOutboundAdapter{
protocol: C.TypeHTTP, protocol: C.TypeHTTP,
@ -30,8 +35,8 @@ func NewHTTP(router adapter.Router, logger log.ContextLogger, tag string, option
tag: tag, tag: tag,
network: []string{C.NetworkTCP}, network: []string{C.NetworkTCP},
}, },
http.NewClient(dialer.NewOutbound(router, options.OutboundDialerOptions), options.ServerOptions.Build(), options.Username, options.Password), http.NewClient(detour, options.ServerOptions.Build(), options.Username, options.Password),
} }, nil
} }
func (h *HTTP) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (h *HTTP) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {

View file

@ -10,6 +10,7 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-vmess" "github.com/sagernet/sing-vmess"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
@ -35,6 +36,10 @@ func NewVMess(router adapter.Router, logger log.ContextLogger, tag string, optio
if err != nil { if err != nil {
return nil, err return nil, err
} }
detour, err := dialer.NewTLS(dialer.NewOutbound(router, options.OutboundDialerOptions), options.Server, common.PtrValueOrDefault(options.TLSOptions))
if err != nil {
return nil, err
}
return &VMess{ return &VMess{
myOutboundAdapter{ myOutboundAdapter{
protocol: C.TypeDirect, protocol: C.TypeDirect,
@ -42,7 +47,7 @@ func NewVMess(router adapter.Router, logger log.ContextLogger, tag string, optio
tag: tag, tag: tag,
network: options.Network.Build(), network: options.Network.Build(),
}, },
dialer.NewOutbound(router, options.OutboundDialerOptions), detour,
client, client,
options.ServerOptions.Build(), options.ServerOptions.Build(),
}, nil }, nil

View file

@ -35,13 +35,20 @@ func mkPort(t *testing.T) uint16 {
} }
func startInstance(t *testing.T, options option.Options) { func startInstance(t *testing.T, options option.Options) {
var err error
for retry := 0; retry < 3; retry++ {
instance, err := box.New(context.Background(), options) instance, err := box.New(context.Background(), options)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, instance.Start()) err = instance.Start()
if err != nil {
time.Sleep(5 * time.Millisecond)
continue
}
t.Cleanup(func() { t.Cleanup(func() {
instance.Close() instance.Close()
}) })
time.Sleep(time.Second) }
require.NoError(t, err)
} }
func testSuit(t *testing.T, clientPort uint16, testPort uint16) { func testSuit(t *testing.T, clientPort uint16, testPort uint16) {

View file

@ -484,7 +484,7 @@ func listen(network, address string) (net.Listener, error) {
} }
lastErr = err lastErr = err
time.Sleep(time.Millisecond * 200) time.Sleep(5 * time.Millisecond)
} }
return nil, lastErr return nil, lastErr
} }
@ -500,7 +500,7 @@ func listenPacket(network, address string) (net.PacketConn, error) {
} }
lastErr = err lastErr = err
time.Sleep(time.Millisecond * 200) time.Sleep(5 * time.Millisecond)
} }
return nil, lastErr return nil, lastErr
} }