package mux import ( "context" "encoding/binary" "net" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/hashicorp/yamux" ) func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error { session, err := yamux.Server(conn, newMuxConfig()) if err != nil { return err } for { stream, err := session.Accept() if err != nil { return err } stream = &wrapStream{stream} request, err := ReadRequest(stream) if err != nil { return err } metadata.Destination = request.Destination if request.Network == N.NetworkTCP { go func() { logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) hErr := router.RouteConnection(ctx, &ServerConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata) if hErr != nil { stream.Close() errorHandler.NewError(ctx, hErr) } }() } else { go func() { var packetConn N.PacketConn if !request.PacketAddr { logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) packetConn = &ServerPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} } else { logger.InfoContext(ctx, "inbound multiplex packet connection") packetConn = &ServerPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} } hErr := router.RoutePacketConnection(ctx, packetConn, metadata) if hErr != nil { stream.Close() errorHandler.NewError(ctx, hErr) } }() } } } var _ N.HandshakeConn = (*ServerConn)(nil) type ServerConn struct { N.ExtendedConn responseWrite bool } func (c *ServerConn) HandshakeFailure(err error) error { errMessage := err.Error() _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() common.Must( buffer.WriteByte(statusError), rw.WriteVString(_buffer, errMessage), ) return c.ExtendedConn.WriteBuffer(buffer) } func (c *ServerConn) Write(b []byte) (n int, err error) { if c.responseWrite { return c.ExtendedConn.Write(b) } _buffer := buf.StackNewSize(1 + len(b)) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() common.Must( buffer.WriteByte(statusSuccess), common.Error(buffer.Write(b)), ) _, err = c.ExtendedConn.Write(buffer.Bytes()) if err != nil { return } c.responseWrite = true return len(b), nil } func (c *ServerConn) WriteBuffer(buffer *buf.Buffer) error { if c.responseWrite { return c.ExtendedConn.WriteBuffer(buffer) } buffer.ExtendHeader(1)[0] = statusSuccess c.responseWrite = true return c.ExtendedConn.WriteBuffer(buffer) } func (c *ServerConn) Headroom() int { if !c.responseWrite { return 1 } return 0 } func (c *ServerConn) Upstream() any { return c.ExtendedConn } var ( _ N.HandshakeConn = (*ServerPacketConn)(nil) _ N.PacketConn = (*ServerPacketConn)(nil) ) type ServerPacketConn struct { N.ExtendedConn destination M.Socksaddr responseWrite bool } func (c *ServerPacketConn) HandshakeFailure(err error) error { errMessage := err.Error() _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() common.Must( buffer.WriteByte(statusError), rw.WriteVString(_buffer, errMessage), ) return c.ExtendedConn.WriteBuffer(buffer) } func (c *ServerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { var length uint16 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) if err != nil { return } _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) if err != nil { return } destination = c.destination return } func (c *ServerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { pLen := buffer.Len() common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) if !c.responseWrite { buffer.ExtendHeader(1)[0] = statusSuccess c.responseWrite = true } return c.ExtendedConn.WriteBuffer(buffer) } func (c *ServerPacketConn) Upstream() any { return c.ExtendedConn } func (c *ServerPacketConn) Headroom() int { if !c.responseWrite { return 3 } return 2 } var ( _ N.HandshakeConn = (*ServerPacketAddrConn)(nil) _ N.PacketConn = (*ServerPacketAddrConn)(nil) ) type ServerPacketAddrConn struct { N.ExtendedConn responseWrite bool } func (c *ServerPacketAddrConn) HandshakeFailure(err error) error { errMessage := err.Error() _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() common.Must( buffer.WriteByte(statusError), rw.WriteVString(_buffer, errMessage), ) return c.ExtendedConn.WriteBuffer(buffer) } func (c *ServerPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) if err != nil { return } var length uint16 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) if err != nil { return } _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) if err != nil { return } return } func (c *ServerPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { pLen := buffer.Len() common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination)) if !c.responseWrite { buffer.ExtendHeader(1)[0] = statusSuccess c.responseWrite = true } return c.ExtendedConn.WriteBuffer(buffer) } func (c *ServerPacketAddrConn) Upstream() any { return c.ExtendedConn } func (c *ServerPacketAddrConn) Headroom() int { if !c.responseWrite { return 3 + M.MaxSocksaddrLength } return 2 + M.MaxSocksaddrLength }