Improve multiplexer

This commit is contained in:
世界 2022-08-03 21:51:34 +08:00
parent 8e4de29409
commit 03890151d7
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 247 additions and 56 deletions

View file

@ -15,40 +15,44 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/hashicorp/yamux"
)
var _ N.Dialer = (*Client)(nil)
type Client struct {
access sync.Mutex
connections list.List[*yamux.Session]
connections list.List[abstractSession]
ctx context.Context
dialer N.Dialer
protocol Protocol
maxConnections int
minStreams int
maxStreams int
}
func NewClient(ctx context.Context, dialer N.Dialer, maxConnections int, minStreams int, maxStreams int) *Client {
func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client {
return &Client{
ctx: ctx,
dialer: dialer,
protocol: protocol,
maxConnections: maxConnections,
minStreams: minStreams,
maxStreams: maxStreams,
}
}
func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) N.Dialer {
func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) {
if !options.Enabled {
return dialer
return dialer, nil
}
if options.MaxConnections == 0 && options.MaxStreams == 0 {
options.MinStreams = 8
}
return NewClient(ctx, dialer, options.MaxConnections, options.MinStreams, options.MaxStreams)
protocol, err := ParseProtocol(options.Protocol)
if err != nil {
return nil, err
}
return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil
}
func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
@ -80,8 +84,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
func (c *Client) openStream() (net.Conn, error) {
var (
session *yamux.Session
stream *yamux.Stream
session abstractSession
stream net.Conn
err error
)
for attempts := 0; attempts < 2; attempts++ {
@ -89,7 +93,7 @@ func (c *Client) openStream() (net.Conn, error) {
if err != nil {
continue
}
stream, err = session.OpenStream()
stream, err = session.Open()
if err != nil {
continue
}
@ -101,11 +105,11 @@ func (c *Client) openStream() (net.Conn, error) {
return &wrapStream{stream}, nil
}
func (c *Client) offer() (*yamux.Session, error) {
func (c *Client) offer() (abstractSession, error) {
c.access.Lock()
defer c.access.Unlock()
sessions := make([]*yamux.Session, 0, c.maxConnections)
sessions := make([]abstractSession, 0, c.maxConnections)
for element := c.connections.Front(); element != nil; {
if element.Value.IsClosed() {
nextElement := element.Next()
@ -120,10 +124,7 @@ func (c *Client) offer() (*yamux.Session, error) {
if sLen == 0 {
return c.offerNew()
}
// session := common.MinBy(sessions, yamux.Session.NumStreams)
session := common.MinBy(sessions, func(it *yamux.Session) int {
return it.NumStreams()
})
session := common.MinBy(sessions, abstractSession.NumStreams)
numStreams := session.NumStreams()
if numStreams == 0 {
return session, nil
@ -140,12 +141,12 @@ func (c *Client) offer() (*yamux.Session, error) {
return c.offerNew()
}
func (c *Client) offerNew() (*yamux.Session, error) {
func (c *Client) offerNew() (abstractSession, error) {
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
if err != nil {
return nil, err
}
session, err := yamux.Client(conn, newMuxConfig())
session, err := c.protocol.newClient(&protocolConn{Conn: conn, protocol: c.protocol})
if err != nil {
return nil, err
}
@ -170,7 +171,7 @@ type ClientConn struct {
}
func (c *ClientConn) readResponse() error {
response, err := ReadResponse(c.Conn)
response, err := ReadStreamResponse(c.Conn)
if err != nil {
return err
}
@ -195,7 +196,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
if c.requestWrite {
return c.Conn.Write(b)
}
request := Request{
request := StreamRequest{
Network: N.NetworkTCP,
Destination: c.destination,
}
@ -203,7 +204,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) {
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeRequest(request, buffer)
EncodeStreamRequest(request, buffer)
buffer.Write(b)
_, err = c.Conn.Write(buffer.Bytes())
if err != nil {
@ -255,7 +256,7 @@ type ClientPacketConn struct {
}
func (c *ClientPacketConn) readResponse() error {
response, err := ReadResponse(c.ExtendedConn)
response, err := ReadStreamResponse(c.ExtendedConn)
if err != nil {
return err
}
@ -285,7 +286,7 @@ func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
}
func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
request := Request{
request := StreamRequest{
Network: N.NetworkUDP,
Destination: c.destination,
}
@ -297,7 +298,7 @@ func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeRequest(request, buffer)
EncodeStreamRequest(request, buffer)
if len(payload) > 0 {
common.Must(
binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
@ -363,7 +364,7 @@ type ClientPacketAddrConn struct {
}
func (c *ClientPacketAddrConn) readResponse() error {
response, err := ReadResponse(c.ExtendedConn)
response, err := ReadStreamResponse(c.ExtendedConn)
if err != nil {
return err
}
@ -399,7 +400,7 @@ func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
}
func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
request := Request{
request := StreamRequest{
Network: N.NetworkUDP,
Destination: c.destination,
PacketAddr: true,
@ -412,7 +413,7 @@ func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeRequest(request, buffer)
EncodeStreamRequest(request, buffer)
if len(payload) > 0 {
common.Must(
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),

View file

@ -14,6 +14,7 @@ import (
"github.com/sagernet/sing/common/rw"
"github.com/hashicorp/yamux"
"github.com/xtaci/smux"
)
var Destination = M.Socksaddr{
@ -21,7 +22,55 @@ var Destination = M.Socksaddr{
Port: 444,
}
func newMuxConfig() *yamux.Config {
const (
ProtocolYAMux Protocol = 0
ProtocolSMux Protocol = 1
)
type Protocol byte
func ParseProtocol(name string) (Protocol, error) {
switch name {
case "", "yamux":
return ProtocolYAMux, nil
case "smux":
return ProtocolSMux, nil
default:
return ProtocolYAMux, E.New("unknown multiplex protocol: ", name)
}
}
func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
switch p {
case ProtocolYAMux:
return yamux.Server(conn, yaMuxConfig())
case ProtocolSMux:
session, err := smux.Server(conn, nil)
if err != nil {
return nil, err
}
return &smuxSession{session}, nil
default:
panic("unknown protocol")
}
}
func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
switch p {
case ProtocolYAMux:
return yamux.Client(conn, yaMuxConfig())
case ProtocolSMux:
session, err := smux.Client(conn, nil)
if err != nil {
return nil, err
}
return &smuxSession{session}, nil
default:
panic("unknown protocol")
}
}
func yaMuxConfig() *yamux.Config {
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
config.StreamCloseTimeout = C.TCPTimeout
@ -29,18 +78,23 @@ func newMuxConfig() *yamux.Config {
return config
}
func (p Protocol) String() string {
switch p {
case ProtocolYAMux:
return "yamux"
case ProtocolSMux:
return "smux"
default:
return "unknown"
}
}
const (
version0 = 0
flagUDP = 1
flagAddr = 2
statusSuccess = 0
statusError = 1
version0 = 0
)
type Request struct {
Network string
Destination M.Socksaddr
PacketAddr bool
Protocol Protocol
}
func ReadRequest(reader io.Reader) (*Request, error) {
@ -51,8 +105,37 @@ func ReadRequest(reader io.Reader) (*Request, error) {
if version != version0 {
return nil, E.New("unsupported version: ", version)
}
protocol, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if protocol > byte(ProtocolSMux) {
return nil, E.New("unsupported protocol: ", protocol)
}
return &Request{Protocol: Protocol(protocol)}, nil
}
func EncodeRequest(buffer *buf.Buffer, request Request) {
buffer.WriteByte(version0)
buffer.WriteByte(byte(request.Protocol))
}
const (
flagUDP = 1
flagAddr = 2
statusSuccess = 0
statusError = 1
)
type StreamRequest struct {
Network string
Destination M.Socksaddr
PacketAddr bool
}
func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
var flags uint16
err = binary.Read(reader, binary.BigEndian, &flags)
err := binary.Read(reader, binary.BigEndian, &flags)
if err != nil {
return nil, err
}
@ -68,10 +151,10 @@ func ReadRequest(reader io.Reader) (*Request, error) {
network = N.NetworkUDP
udpAddr = flags&flagAddr != 0
}
return &Request{network, destination, udpAddr}, nil
return &StreamRequest{network, destination, udpAddr}, nil
}
func requestLen(request Request) int {
func requestLen(request StreamRequest) int {
var rLen int
rLen += 1 // version
rLen += 2 // flags
@ -79,7 +162,7 @@ func requestLen(request Request) int {
return rLen
}
func EncodeRequest(request Request, buffer *buf.Buffer) {
func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
destination := request.Destination
var flags uint16
if request.Network == N.NetworkUDP {
@ -92,19 +175,18 @@ func EncodeRequest(request Request, buffer *buf.Buffer) {
}
}
common.Must(
buffer.WriteByte(version0),
binary.Write(buffer, binary.BigEndian, flags),
M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
)
}
type Response struct {
type StreamResponse struct {
Status uint8
Message string
}
func ReadResponse(reader io.Reader) (*Response, error) {
var response Response
func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
var response StreamResponse
status, err := rw.ReadByte(reader)
if err != nil {
return nil, err

View file

@ -14,12 +14,14 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/hashicorp/yamux"
)
func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error {
session, err := yamux.Server(conn, newMuxConfig())
request, err := ReadRequest(conn)
if err != nil {
return err
}
session, err := request.Protocol.newServer(conn)
if err != nil {
return err
}
@ -34,7 +36,7 @@ func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Ha
func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) {
stream = &wrapStream{stream}
request, err := ReadRequest(stream)
request, err := ReadStreamRequest(stream)
if err != nil {
logger.ErrorContext(ctx, err)
return

71
common/mux/session.go Normal file
View file

@ -0,0 +1,71 @@
package mux
import (
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/xtaci/smux"
)
type abstractSession interface {
Open() (net.Conn, error)
Accept() (net.Conn, error)
NumStreams() int
Close() error
IsClosed() bool
}
var _ abstractSession = (*smuxSession)(nil)
type smuxSession struct {
*smux.Session
}
func (s *smuxSession) Open() (net.Conn, error) {
return s.OpenStream()
}
func (s *smuxSession) Accept() (net.Conn, error) {
return s.AcceptStream()
}
type protocolConn struct {
net.Conn
protocol Protocol
protocolWritten bool
}
func (c *protocolConn) Write(p []byte) (n int, err error) {
if c.protocolWritten {
return c.Conn.Write(p)
}
_buffer := buf.StackNewSize(2 + len(p))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
EncodeRequest(buffer, Request{
Protocol: c.protocol,
})
common.Must(common.Error(buffer.Write(p)))
n, err = c.Conn.Write(buffer.Bytes())
if err == nil {
n--
}
c.protocolWritten = true
return n, err
}
func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.protocolWritten {
return bufio.ReadFrom0(c, r)
}
return bufio.Copy(c.Conn, r)
}
func (c *protocolConn) Upstream() any {
return c.Conn
}

View file

@ -7,6 +7,7 @@
```json
{
"enabled": true,
"protocol": "yamux",
"max_connections": 4,
"min_streams": 4,
"max_streams": 0
@ -19,6 +20,17 @@
Enable multiplex.
#### protocol
Multiplex protocol.
| Protocol | Description |
|----------|------------------------------------|
| yamux | https://github.com/hashicorp/yamux |
| smux | https://github.com/xtaci/smux |
YAMux is used by default.
#### max_connections
Maximum connections.

1
go.mod
View file

@ -20,6 +20,7 @@ require (
github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9
github.com/spf13/cobra v1.5.0
github.com/stretchr/testify v1.8.0
github.com/xtaci/smux v1.5.16
go.uber.org/atomic v1.9.0
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b

2
go.sum
View file

@ -201,6 +201,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk=
github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=

View file

@ -100,8 +100,9 @@ func (o ServerOptions) Build() M.Socksaddr {
}
type MultiplexOptions struct {
Enabled bool `json:"enabled,omitempty"`
MaxConnections int `json:"max_connections,omitempty"`
MinStreams int `json:"min_streams,omitempty"`
MaxStreams int `json:"max_streams,omitempty"`
Enabled bool `json:"enabled,omitempty"`
Protocol string `json:"protocol,omitempty"`
MaxConnections int `json:"max_connections,omitempty"`
MinStreams int `json:"min_streams,omitempty"`
MaxStreams int `json:"max_streams,omitempty"`
}

View file

@ -46,7 +46,10 @@ func NewShadowsocks(ctx context.Context, router adapter.Router, logger log.Conte
method: method,
serverAddr: options.ServerOptions.Build(),
}
outbound.multiplexDialer = mux.NewClientWithOptions(ctx, (*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.Multiplex))
outbound.multiplexDialer, err = mux.NewClientWithOptions(ctx, (*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.Multiplex))
if err != nil {
return nil, err
}
return outbound, nil
}

View file

@ -32,13 +32,13 @@ func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dn
}
func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
r.dnsLogger.Debug(ctx, "lookup domain ", domain)
r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
ctx, transport := r.matchDNS(ctx)
ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout)
defer cancel()
addrs, err := r.dnsClient.Lookup(ctx, transport, domain, strategy)
if len(addrs) > 0 {
r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", F.MapToString(addrs))
r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(addrs), " "))
} else {
r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
}

View file

@ -58,6 +58,7 @@ require (
github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
github.com/xtaci/smux v1.5.16 // indirect
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
golang.org/x/mod v0.5.1 // indirect

View file

@ -228,6 +228,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk=
github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

View file

@ -4,12 +4,24 @@ import (
"net/netip"
"testing"
"github.com/sagernet/sing-box/common/mux"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
)
func TestShadowsocksMux(t *testing.T) {
for _, protocol := range []mux.Protocol{
mux.ProtocolYAMux,
mux.ProtocolSMux,
} {
t.Run(protocol.String(), func(t *testing.T) {
testShadowsocksMux(t, protocol.String())
})
}
}
func testShadowsocksMux(t *testing.T, protocol string) {
method := shadowaead_2022.List[0]
password := mkBase64(t, 16)
startInstance(t, option.Options{
@ -54,7 +66,8 @@ func TestShadowsocksMux(t *testing.T) {
Method: method,
Password: password,
Multiplex: &option.MultiplexOptions{
Enabled: true,
Enabled: true,
Protocol: protocol,
},
},
},