From 3699a57847666d66646798598e15a4a5081fb39e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 8 Jul 2022 12:58:43 +0800 Subject: [PATCH] Add dial parallel for outbound dialer --- common/dialer/default.go | 2 +- common/dialer/dialer.go | 8 ++- common/dialer/parallel.go | 91 +++++++++++++++++++++++++++++++++++ common/dialer/resolve.go | 19 +++++--- common/dialer/resolve_conn.go | 83 ++++++++++++++++++++++++++++++++ common/dialer/serial.go | 1 + dns/client.go | 8 ++- dns/transport.go | 11 ++++- dns/transport_tcp.go | 6 +-- dns/transport_tls.go | 6 +-- dns/transport_udp.go | 6 +-- go.mod | 4 +- go.sum | 8 +-- inbound/default.go | 1 + option/outbound.go | 15 +++--- option/types.go | 21 ++++++++ route/router.go | 2 - 17 files changed, 253 insertions(+), 39 deletions(-) create mode 100644 common/dialer/parallel.go create mode 100644 common/dialer/resolve_conn.go diff --git a/common/dialer/default.go b/common/dialer/default.go index d7e34bd2..430ad9bc 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -38,7 +38,7 @@ func NewDefault(options option.DialerOptions) *DefaultDialer { listener.Control = control.Append(listener.Control, ProtectPath(options.ProtectPath)) } if options.ConnectTimeout != 0 { - dialer.Timeout = time.Duration(options.ConnectTimeout) * time.Second + dialer.Timeout = time.Duration(options.ConnectTimeout) } return &DefaultDialer{tfo.Dialer{Dialer: dialer, DisableTFO: !options.TCPFastOpen}, listener} } diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index f2f7fec1..3dcf06a0 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -1,6 +1,8 @@ package dialer import ( + "time" + "github.com/sagernet/sing/common" N "github.com/sagernet/sing/common/network" @@ -21,7 +23,11 @@ func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N. dialer := New(router, options.DialerOptions) domainStrategy := C.DomainStrategy(options.DomainStrategy) if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" && !C.CGO_ENABLED { - dialer = NewResolveDialer(router, dialer, domainStrategy) + fallbackDelay := time.Duration(options.FallbackDelay) + if fallbackDelay == 0 { + fallbackDelay = time.Millisecond * 300 + } + dialer = NewResolveDialer(router, dialer, domainStrategy, fallbackDelay) } if options.OverrideOptions.IsValid() { dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions)) diff --git a/common/dialer/parallel.go b/common/dialer/parallel.go new file mode 100644 index 00000000..227f5228 --- /dev/null +++ b/common/dialer/parallel.go @@ -0,0 +1,91 @@ +package dialer + +import ( + "context" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + C "github.com/sagernet/sing-box/constant" +) + +func DialParallel(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.DomainStrategy, fallbackDelay time.Duration) (net.Conn, error) { + // kanged form net.Dial + + returned := make(chan struct{}) + defer close(returned) + + addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool { + return address.Is4() || address.Is4In6() + }) + addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool { + return address.Is6() && !address.Is4In6() + }) + if len(addresses4) == 0 || len(addresses6) == 0 { + return DialSerial(ctx, dialer, network, destination, destinationAddresses) + } + var primaries, fallbacks []netip.Addr + switch strategy { + case C.DomainStrategyPreferIPv6: + primaries = addresses6 + fallbacks = addresses4 + default: + primaries = addresses4 + fallbacks = addresses6 + } + type dialResult struct { + net.Conn + error + primary bool + done bool + } + results := make(chan dialResult) // unbuffered + startRacer := func(ctx context.Context, primary bool) { + ras := primaries + if !primary { + ras = fallbacks + } + c, err := DialSerial(ctx, dialer, network, destination, ras) + select { + case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: + case <-returned: + if c != nil { + c.Close() + } + } + } + var primary, fallback dialResult + primaryCtx, primaryCancel := context.WithCancel(ctx) + defer primaryCancel() + go startRacer(primaryCtx, true) + fallbackTimer := time.NewTimer(fallbackDelay) + defer fallbackTimer.Stop() + for { + select { + case <-fallbackTimer.C: + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + go startRacer(fallbackCtx, false) + + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + if res.primary { + primary = res + } else { + fallback = res + } + if primary.done && fallback.done { + return nil, primary.error + } + if res.primary && fallbackTimer.Stop() { + fallbackTimer.Reset(0) + } + } + } +} diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go index e598be79..51b707e7 100644 --- a/common/dialer/resolve.go +++ b/common/dialer/resolve.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "time" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -13,16 +14,18 @@ import ( ) type ResolveDialer struct { - dialer N.Dialer - router adapter.Router - strategy C.DomainStrategy + dialer N.Dialer + router adapter.Router + strategy C.DomainStrategy + fallbackDelay time.Duration } -func NewResolveDialer(router adapter.Router, dialer N.Dialer, strategy C.DomainStrategy) *ResolveDialer { +func NewResolveDialer(router adapter.Router, dialer N.Dialer, strategy C.DomainStrategy, fallbackDelay time.Duration) *ResolveDialer { return &ResolveDialer{ dialer, router, strategy, + fallbackDelay, } } @@ -40,7 +43,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina if err != nil { return nil, err } - return DialSerial(ctx, d.dialer, network, destination, addresses) + return DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy, d.fallbackDelay) } func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { @@ -57,7 +60,11 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd if err != nil { return nil, err } - return ListenSerial(ctx, d.dialer, destination, addresses) + conn, err := ListenSerial(ctx, d.dialer, destination, addresses) + if err != nil { + return nil, err + } + return NewResolvePacketConn(d.router, d.strategy, conn), nil } func (d *ResolveDialer) Upstream() any { diff --git a/common/dialer/resolve_conn.go b/common/dialer/resolve_conn.go new file mode 100644 index 00000000..9def8498 --- /dev/null +++ b/common/dialer/resolve_conn.go @@ -0,0 +1,83 @@ +package dialer + +import ( + "context" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" +) + +func NewResolvePacketConn(router adapter.Router, strategy C.DomainStrategy, conn net.PacketConn) N.NetPacketConn { + if udpConn, ok := conn.(*net.UDPConn); ok { + return &ResolveUDPConn{udpConn, router, strategy} + } else { + return &ResolvePacketConn{conn, router, strategy} + } +} + +type ResolveUDPConn struct { + *net.UDPConn + router adapter.Router + strategy C.DomainStrategy +} + +func (w *ResolveUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, addr, err := w.ReadFromUDPAddrPort(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + return M.SocksaddrFromNetIP(addr), nil +} + +func (w *ResolveUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + if destination.Family().IsFqdn() { + addresses, err := w.router.Lookup(context.Background(), destination.Fqdn, w.strategy) + if err != nil { + return err + } + return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), M.SocksaddrFromAddrPort(addresses[0], destination.Port).UDPAddr())) + } + return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr())) +} + +func (w *ResolveUDPConn) Upstream() any { + return w.UDPConn +} + +type ResolvePacketConn struct { + net.PacketConn + router adapter.Router + strategy C.DomainStrategy +} + +func (w *ResolvePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + _, addr, err := buffer.ReadPacketFrom(w) + if err != nil { + return M.Socksaddr{}, err + } + return M.SocksaddrFromNet(addr), err +} + +func (w *ResolvePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + if destination.Family().IsFqdn() { + addresses, err := w.router.Lookup(context.Background(), destination.Fqdn, w.strategy) + if err != nil { + return err + } + return common.Error(w.WriteTo(buffer.Bytes(), M.SocksaddrFromAddrPort(addresses[0], destination.Port).UDPAddr())) + } + return common.Error(w.WriteTo(buffer.Bytes(), destination.UDPAddr())) +} + +func (w *ResolvePacketConn) Upstream() any { + return w.PacketConn +} diff --git a/common/dialer/serial.go b/common/dialer/serial.go index b5508e94..023443fe 100644 --- a/common/dialer/serial.go +++ b/common/dialer/serial.go @@ -18,6 +18,7 @@ func DialSerial(ctx context.Context, dialer N.Dialer, network string, destinatio conn, err = dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port)) if err != nil { connErrors = append(connErrors, err) + continue } return conn, nil } diff --git a/dns/client.go b/dns/client.go index 2f7c21d5..1f596861 100644 --- a/dns/client.go +++ b/dns/client.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "strings" "time" "github.com/sagernet/sing/common" @@ -71,11 +72,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m if !c.disableCache { c.storeCache(question, response) } - return message, err + return response, err } func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) { - dnsName, err := dnsmessage.NewName(domain) + if strings.HasPrefix(domain, ".") { + domain = domain[:len(domain)-1] + } + dnsName, err := dnsmessage.NewName(domain + ".") if err != nil { return nil, wrapError(err) } diff --git a/dns/transport.go b/dns/transport.go index e6a2c883..14894196 100644 --- a/dns/transport.go +++ b/dns/transport.go @@ -22,8 +22,15 @@ func NewTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, addre } host := serverURL.Hostname() port := serverURL.Port() - if port == "" { - port = "53" + switch serverURL.Scheme { + case "tls": + if port == "" { + port = "853" + } + default: + if port == "" { + port = "53" + } } destination := M.ParseSocksaddrHostPortStr(host, port) switch serverURL.Scheme { diff --git a/dns/transport_tcp.go b/dns/transport_tcp.go index 945928b9..5e0f502a 100644 --- a/dns/transport_tcp.go +++ b/dns/transport_tcp.go @@ -77,10 +77,9 @@ func (t *TCPTransport) offer() (*dnsConnection, error) { func (t *TCPTransport) newConnection(conn *dnsConnection) { defer close(conn.done) defer conn.Close() - ctx, cancel := context.WithCancel(t.ctx) - err := task.Any(t.ctx, func() error { + err := task.Any(t.ctx, func(ctx context.Context) error { return t.loopIn(conn) - }, func() error { + }, func(ctx context.Context) error { select { case <-ctx.Done(): return nil @@ -88,7 +87,6 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) { return os.ErrClosed } }) - cancel() conn.err = err if err != nil { t.logger.Debug("connection closed: ", err) diff --git a/dns/transport_tls.go b/dns/transport_tls.go index b28dbfd7..152bc3fa 100644 --- a/dns/transport_tls.go +++ b/dns/transport_tls.go @@ -85,10 +85,9 @@ func (t *TLSTransport) offer(ctx context.Context) (*dnsConnection, error) { func (t *TLSTransport) newConnection(conn *dnsConnection) { defer close(conn.done) defer conn.Close() - ctx, cancel := context.WithCancel(t.ctx) - err := task.Any(t.ctx, func() error { + err := task.Any(t.ctx, func(ctx context.Context) error { return t.loopIn(conn) - }, func() error { + }, func(ctx context.Context) error { select { case <-ctx.Done(): return nil @@ -96,7 +95,6 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) { return os.ErrClosed } }) - cancel() conn.err = err if err != nil { t.logger.Debug("connection closed: ", err) diff --git a/dns/transport_udp.go b/dns/transport_udp.go index a0358f80..d60c171b 100644 --- a/dns/transport_udp.go +++ b/dns/transport_udp.go @@ -73,10 +73,9 @@ func (t *UDPTransport) offer() (*dnsConnection, error) { func (t *UDPTransport) newConnection(conn *dnsConnection) { defer close(conn.done) defer conn.Close() - ctx, cancel := context.WithCancel(t.ctx) - err := task.Any(t.ctx, func() error { + err := task.Any(t.ctx, func(ctx context.Context) error { return t.loopIn(conn) - }, func() error { + }, func(ctx context.Context) error { select { case <-ctx.Done(): return nil @@ -84,7 +83,6 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) { return os.ErrClosed } }) - cancel() conn.err = err if err != nil { t.logger.Debug("connection closed: ", err) diff --git a/go.mod b/go.mod index 228f4d00..d282bcb2 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,13 @@ require ( github.com/goccy/go-json v0.9.8 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/oschwald/maxminddb-golang v1.9.0 - github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4 + github.com/sagernet/sing v0.0.0-20220708041648-04e100e91a92 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 github.com/sirupsen/logrus v1.8.1 github.com/spf13/cobra v1.5.0 github.com/stretchr/testify v1.8.0 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d - golang.org/x/net v0.0.0-20220630215102-69896b714898 + golang.org/x/net v0.0.0-20220706163947-c90051bbdb60 ) require ( diff --git a/go.sum b/go.sum index a78d1274..cf1b5d41 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4 h1:nV/DyNi+O1VxNoChD5E9M6Y0VoFdVr0UEW9h9JnqxNs= -github.com/sagernet/sing v0.0.0-20220707133944-6a0987c52ae4/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= +github.com/sagernet/sing v0.0.0-20220708041648-04e100e91a92 h1:c+Jg/o4UBZ+7CFdKWy8XhPN5X1rtulYdMqdgjx6PNUo= +github.com/sagernet/sing v0.0.0-20220708041648-04e100e91a92/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 h1:whNDUGOAX5GPZkSy4G3Gv9QyIgk5SXRyjkRuP7ohF8k= github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649/go.mod h1:MuyT+9fEPjvauAv0fSE0a6Q+l0Tv2ZrAafTkYfnxBFw= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= @@ -41,8 +41,8 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw= -golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220706163947-c90051bbdb60 h1:8NSylCMxLW4JvserAndSgFL7aPli6A68yf0bYFTcWCM= +golang.org/x/net v0.0.0-20220706163947-c90051bbdb60/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b h1:2n253B2r0pYSmEV+UNCQoPfU/FiaizQEK5Gu4Bq4JE8= golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/inbound/default.go b/inbound/default.go index 5c7d5f66..5bc7f9af 100644 --- a/inbound/default.go +++ b/inbound/default.go @@ -142,6 +142,7 @@ func (a *myInboundAdapter) loopTCPIn() { a.logger.WithContext(ctx).Info("inbound connection from ", metadata.Source) hErr := a.connHandler.NewConnection(ctx, conn, metadata) if hErr != nil { + conn.Close() a.NewError(ctx, E.Cause(hErr, "process connection from ", metadata.Source)) } }() diff --git a/option/outbound.go b/option/outbound.go index c2376b72..344f2adc 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -67,19 +67,20 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error { } type DialerOptions struct { - Detour string `json:"detour,omitempty"` - BindInterface string `json:"bind_interface,omitempty"` - ProtectPath string `json:"protect_path,omitempty"` - RoutingMark int `json:"routing_mark,omitempty"` - ReuseAddr bool `json:"reuse_addr,omitempty"` - ConnectTimeout int `json:"connect_timeout,omitempty"` - TCPFastOpen bool `json:"tcp_fast_open,omitempty"` + Detour string `json:"detour,omitempty"` + BindInterface string `json:"bind_interface,omitempty"` + ProtectPath string `json:"protect_path,omitempty"` + RoutingMark int `json:"routing_mark,omitempty"` + ReuseAddr bool `json:"reuse_addr,omitempty"` + ConnectTimeout Duration `json:"connect_timeout,omitempty"` + TCPFastOpen bool `json:"tcp_fast_open,omitempty"` } type OutboundDialerOptions struct { DialerOptions OverrideOptions *OverrideStreamOptions `json:"override,omitempty"` DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` + FallbackDelay Duration `json:"fallback_delay,omitempty"` } type OverrideStreamOptions struct { diff --git a/option/types.go b/option/types.go index dca2ddd4..dfc5aef4 100644 --- a/option/types.go +++ b/option/types.go @@ -3,6 +3,7 @@ package option import ( "net/netip" "strings" + "time" E "github.com/sagernet/sing/common/exceptions" @@ -135,3 +136,23 @@ func (s *DomainStrategy) UnmarshalJSON(bytes []byte) error { } return nil } + +type Duration time.Duration + +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal((time.Duration)(d).String()) +} + +func (d *Duration) UnmarshalJSON(bytes []byte) error { + var value string + err := json.Unmarshal(bytes, &value) + if err != nil { + return err + } + duration, err := time.ParseDuration(value) + if err != nil { + return err + } + *d = Duration(duration) + return nil +} diff --git a/route/router.go b/route/router.go index 83bf4e69..0f34cc62 100644 --- a/route/router.go +++ b/route/router.go @@ -450,7 +450,6 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def r.logger.WithContext(ctx).Error("outbound not found: ", detour) } } - r.logger.WithContext(ctx).Info("no match") return defaultOutbound } @@ -470,7 +469,6 @@ func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport { r.dnsLogger.WithContext(ctx).Error("transport not found: ", detour) } } - r.dnsLogger.WithContext(ctx).Info("no match") return r.defaultTransport }