diff --git a/inbound/trojan.go b/inbound/trojan.go index 026de488..ac639d10 100644 --- a/inbound/trojan.go +++ b/inbound/trojan.go @@ -10,6 +10,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/trojan" "github.com/sagernet/sing-box/transport/v2ray" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" @@ -17,7 +18,6 @@ import ( F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/protocol/trojan" ) var ( @@ -157,7 +157,7 @@ func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adap return err } } - return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata)) + return h.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata)) } func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { diff --git a/outbound/trojan.go b/outbound/trojan.go index 4db92f41..7c11d445 100644 --- a/outbound/trojan.go +++ b/outbound/trojan.go @@ -11,13 +11,13 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/trojan" "github.com/sagernet/sing-box/transport/v2ray" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/protocol/trojan" ) var _ adapter.Outbound = (*Trojan)(nil) diff --git a/transport/trojan/mux.go b/transport/trojan/mux.go new file mode 100644 index 00000000..745cde56 --- /dev/null +++ b/transport/trojan/mux.go @@ -0,0 +1,66 @@ +package trojan + +import ( + "context" + "net" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/rw" + "github.com/sagernet/sing/common/task" + "github.com/sagernet/smux" +) + +func HandleMuxConnection(ctx context.Context, conn net.Conn, metadata M.Metadata, handler Handler) error { + session, err := smux.Server(conn, smuxConfig()) + if err != nil { + return err + } + var group task.Group + group.Append0(func(ctx context.Context) error { + var stream net.Conn + for { + stream, err = session.AcceptStream() + if err != nil { + return err + } + go newMuxConnection(ctx, stream, metadata, handler) + } + }) + group.Cleanup(func() { + session.Close() + }) + return group.Run(ctx) +} + +func newMuxConnection(ctx context.Context, stream net.Conn, metadata M.Metadata, handler Handler) { + err := newMuxConnection0(ctx, stream, metadata, handler) + if err != nil { + handler.NewError(ctx, E.Cause(err, "process trojan-go multiplex connection")) + } +} + +func newMuxConnection0(ctx context.Context, stream net.Conn, metadata M.Metadata, handler Handler) error { + command, err := rw.ReadByte(stream) + if err != nil { + return E.Cause(err, "read command") + } + metadata.Destination, err = M.SocksaddrSerializer.ReadAddrPort(stream) + if err != nil { + return E.Cause(err, "read destination") + } + switch command { + case CommandTCP: + return handler.NewConnection(ctx, stream, metadata) + case CommandUDP: + return handler.NewPacketConnection(ctx, &PacketConn{stream}, metadata) + default: + return E.New("unknown command ", command) + } +} + +func smuxConfig() *smux.Config { + config := smux.DefaultConfig() + config.KeepAliveDisabled = true + return config +} diff --git a/transport/trojan/protocol.go b/transport/trojan/protocol.go new file mode 100644 index 00000000..d05a9d36 --- /dev/null +++ b/transport/trojan/protocol.go @@ -0,0 +1,313 @@ +package trojan + +import ( + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "io" + "net" + "os" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" +) + +const ( + KeyLength = 56 + CommandTCP = 1 + CommandUDP = 3 + CommandMux = 0x7f +) + +var CRLF = []byte{'\r', '\n'} + +type ClientConn struct { + N.ExtendedConn + key [KeyLength]byte + destination M.Socksaddr + headerWritten bool +} + +func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn { + return &ClientConn{ + ExtendedConn: bufio.NewExtendedConn(conn), + key: key, + destination: destination, + } +} + +func (c *ClientConn) Write(p []byte) (n int, err error) { + if c.headerWritten { + return c.ExtendedConn.Write(p) + } + err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p) + if err != nil { + return + } + n = len(p) + c.headerWritten = true + return +} + +func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error { + if c.headerWritten { + return c.ExtendedConn.WriteBuffer(buffer) + } + err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer) + if err != nil { + return err + } + c.headerWritten = true + return nil +} + +func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) { + if !c.headerWritten { + return bufio.ReadFrom0(c, r) + } + return bufio.Copy(c.ExtendedConn, r) +} + +func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) { + return bufio.Copy(w, c.ExtendedConn) +} + +func (c *ClientConn) FrontHeadroom() int { + if !c.headerWritten { + return KeyLength + 5 + M.MaxSocksaddrLength + } + return 0 +} + +func (c *ClientConn) Upstream() any { + return c.ExtendedConn +} + +type ClientPacketConn struct { + net.Conn + key [KeyLength]byte + headerWritten bool +} + +func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { + return &ClientPacketConn{ + Conn: conn, + key: key, + } +} + +func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + return ReadPacket(c.Conn, buffer) +} + +func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if !c.headerWritten { + err := ClientHandshakePacket(c.Conn, c.key, destination, buffer) + c.headerWritten = true + return err + } + return WritePacket(c.Conn, buffer, destination) +} + +func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + buffer := buf.With(p) + destination, err := c.ReadPacket(buffer) + if err != nil { + return + } + n = buffer.Len() + addr = destination.UDPAddr() + return +} + +func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return bufio.WritePacket(c, p, addr) +} + +func (c *ClientPacketConn) Read(p []byte) (n int, err error) { + n, _, err = c.ReadFrom(p) + return +} + +func (c *ClientPacketConn) Write(p []byte) (n int, err error) { + return 0, os.ErrInvalid +} + +func (c *ClientPacketConn) FrontHeadroom() int { + if !c.headerWritten { + return KeyLength + 2*M.MaxSocksaddrLength + 9 + } + return M.MaxSocksaddrLength + 4 +} + +func (c *ClientPacketConn) Upstream() any { + return c.Conn +} + +func Key(password string) [KeyLength]byte { + var key [KeyLength]byte + hash := sha256.New224() + common.Must1(hash.Write([]byte(password))) + hex.Encode(key[:], hash.Sum(nil)) + return key +} + +func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error { + _, err := conn.Write(key[:]) + if err != nil { + return err + } + _, err = conn.Write(CRLF) + if err != nil { + return err + } + _, err = conn.Write([]byte{command}) + if err != nil { + return err + } + err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) + if err != nil { + return err + } + _, err = conn.Write(CRLF) + if err != nil { + return err + } + if len(payload) > 0 { + _, err = conn.Write(payload) + if err != nil { + return err + } + } + return nil +} + +func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error { + headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5 + var header *buf.Buffer + defer header.Release() + var writeHeader bool + if len(payload) > 0 && headerLen+len(payload) < 65535 { + buffer := buf.StackNewSize(headerLen + len(payload)) + defer common.KeepAlive(buffer) + header = common.Dup(buffer) + } else { + buffer := buf.StackNewSize(headerLen) + defer common.KeepAlive(buffer) + header = common.Dup(buffer) + writeHeader = true + } + common.Must1(header.Write(key[:])) + common.Must1(header.Write(CRLF)) + common.Must(header.WriteByte(CommandTCP)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + common.Must1(header.Write(CRLF)) + if !writeHeader { + common.Must1(header.Write(payload)) + } + + _, err := conn.Write(header.Bytes()) + if err != nil { + return E.Cause(err, "write request") + } + + if writeHeader { + _, err = conn.Write(payload) + if err != nil { + return E.Cause(err, "write payload") + } + } + return nil +} + +func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { + header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5)) + common.Must1(header.Write(key[:])) + common.Must1(header.Write(CRLF)) + common.Must(header.WriteByte(CommandTCP)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + common.Must1(header.Write(CRLF)) + + _, err := conn.Write(payload.Bytes()) + if err != nil { + return E.Cause(err, "write request") + } + return nil +} + +func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { + headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9 + payloadLen := payload.Len() + var header *buf.Buffer + defer header.Release() + var writeHeader bool + if payload.Start() >= headerLen { + header = buf.With(payload.ExtendHeader(headerLen)) + } else { + buffer := buf.StackNewSize(headerLen) + defer common.KeepAlive(buffer) + header = common.Dup(buffer) + writeHeader = true + } + common.Must1(header.Write(key[:])) + common.Must1(header.Write(CRLF)) + common.Must(header.WriteByte(CommandUDP)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + common.Must1(header.Write(CRLF)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen))) + common.Must1(header.Write(CRLF)) + + if writeHeader { + _, err := conn.Write(header.Bytes()) + if err != nil { + return E.Cause(err, "write request") + } + } + + _, err := conn.Write(payload.Bytes()) + if err != nil { + return E.Cause(err, "write payload") + } + return nil +} + +func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) { + destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "read destination") + } + + var length uint16 + err = binary.Read(conn, binary.BigEndian, &length) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "read chunk length") + } + + err = rw.SkipN(conn, 2) + if err != nil { + return M.Socksaddr{}, E.Cause(err, "skip crlf") + } + + _, err = buffer.ReadFullFrom(conn, int(length)) + return destination, err +} + +func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + bufferLen := buffer.Len() + header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen))) + common.Must1(header.Write(CRLF)) + _, err := conn.Write(buffer.Bytes()) + if err != nil { + return E.Cause(err, "write packet") + } + return nil +} diff --git a/transport/trojan/service.go b/transport/trojan/service.go new file mode 100644 index 00000000..453423d2 --- /dev/null +++ b/transport/trojan/service.go @@ -0,0 +1,138 @@ +package trojan + +import ( + "context" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" +) + +type Handler interface { + N.TCPConnectionHandler + N.UDPConnectionHandler + E.Handler +} + +type Service[K comparable] struct { + users map[K][56]byte + keys map[[56]byte]K + handler Handler + fallbackHandler N.TCPConnectionHandler +} + +func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandler) *Service[K] { + return &Service[K]{ + users: make(map[K][56]byte), + keys: make(map[[56]byte]K), + handler: handler, + fallbackHandler: fallbackHandler, + } +} + +var ErrUserExists = E.New("user already exists") + +func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error { + users := make(map[K][56]byte) + keys := make(map[[56]byte]K) + for i, user := range userList { + if _, loaded := users[user]; loaded { + return ErrUserExists + } + key := Key(passwordList[i]) + if oldUser, loaded := keys[key]; loaded { + return E.Extend(ErrUserExists, "password used by ", oldUser) + } + users[user] = key + keys[key] = user + } + s.users = users + s.keys = keys + return nil +} + +func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { + var key [KeyLength]byte + n, err := conn.Read(common.Dup(key[:])) + if err != nil { + return err + } else if n != KeyLength { + return s.fallback(ctx, conn, metadata, key[:n], E.New("bad request size")) + } + + if user, loaded := s.keys[key]; loaded { + ctx = auth.ContextWithUser(ctx, user) + } else { + return s.fallback(ctx, conn, metadata, key[:], E.New("bad request")) + } + + err = rw.SkipN(conn, 2) + if err != nil { + return E.Cause(err, "skip crlf") + } + + command, err := rw.ReadByte(conn) + if err != nil { + return E.Cause(err, "read command") + } + + switch command { + case CommandTCP, CommandUDP, CommandMux: + default: + return E.New("unknown command ", command) + } + + // var destination M.Socksaddr + destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) + if err != nil { + return E.Cause(err, "read destination") + } + + err = rw.SkipN(conn, 2) + if err != nil { + return E.Cause(err, "skip crlf") + } + + metadata.Protocol = "trojan" + metadata.Destination = destination + + switch command { + case CommandTCP: + return s.handler.NewConnection(ctx, conn, metadata) + case CommandUDP: + return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata) + // case CommandMux: + default: + return HandleMuxConnection(ctx, conn, metadata, s.handler) + } +} + +func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Metadata, header []byte, err error) error { + if s.fallbackHandler == nil { + return E.Extend(err, "fallback disabled") + } + conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned()) + return s.fallbackHandler.NewConnection(ctx, conn, metadata) +} + +type PacketConn struct { + net.Conn +} + +func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + return ReadPacket(c.Conn, buffer) +} + +func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return WritePacket(c.Conn, buffer, destination) +} + +func (c *PacketConn) FrontHeadroom() int { + return M.MaxSocksaddrLength + 4 +}