diff --git a/go.mod b/go.mod index c4442319..ae7d2b1c 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/sagernet/sing-shadowsocks2 v0.2.0 github.com/sagernet/sing-shadowtls v0.1.4 github.com/sagernet/sing-tun v0.3.2 - github.com/sagernet/sing-vmess v0.1.10 + github.com/sagernet/sing-vmess v0.1.11 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 github.com/sagernet/tfo-go v0.0.0-20231209031829-7b5343ac1dc6 github.com/sagernet/utls v1.5.4 diff --git a/go.sum b/go.sum index 7e74fe27..9195582e 100644 --- a/go.sum +++ b/go.sum @@ -122,8 +122,8 @@ github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnV github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4= github.com/sagernet/sing-tun v0.3.2 h1:z0bLUT/YXH9RrJS9DsIpB0Bb9afl2hVJOmHd0zA3HJY= github.com/sagernet/sing-tun v0.3.2/go.mod h1:DxLIyhjWU/HwGYoX0vNGg2c5QgTQIakphU1MuERR5tQ= -github.com/sagernet/sing-vmess v0.1.10 h1:3f/ZUEYK35CI/jdquxyyU/4FY70b+E3gORG5fG3ZBnk= -github.com/sagernet/sing-vmess v0.1.10/go.mod h1:luTSsfyBGAc9VhtCqwjR+dt1QgqBhuYBCONB/POhF8I= +github.com/sagernet/sing-vmess v0.1.11 h1:Kq20MJOBrZzxyHko+/fPHuTszYxe41ClbiNt0TTaqP4= +github.com/sagernet/sing-vmess v0.1.11/go.mod h1:luTSsfyBGAc9VhtCqwjR+dt1QgqBhuYBCONB/POhF8I= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo= github.com/sagernet/tfo-go v0.0.0-20231209031829-7b5343ac1dc6 h1:z3SJQhVyU63FT26Wn/UByW6b7q8QKB0ZkPqsyqcz2PI= diff --git a/inbound/vless.go b/inbound/vless.go index 029fdaf7..69ed042b 100644 --- a/inbound/vless.go +++ b/inbound/vless.go @@ -13,9 +13,9 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/v2ray" - "github.com/sagernet/sing-box/transport/vless" "github.com/sagernet/sing-vmess" "github.com/sagernet/sing-vmess/packetaddr" + "github.com/sagernet/sing-vmess/vless" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" E "github.com/sagernet/sing/common/exceptions" diff --git a/outbound/vless.go b/outbound/vless.go index 506521f7..b3e94661 100644 --- a/outbound/vless.go +++ b/outbound/vless.go @@ -12,8 +12,8 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/v2ray" - "github.com/sagernet/sing-box/transport/vless" "github.com/sagernet/sing-vmess/packetaddr" + "github.com/sagernet/sing-vmess/vless" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" diff --git a/transport/vless/client.go b/transport/vless/client.go deleted file mode 100644 index 09150f6d..00000000 --- a/transport/vless/client.go +++ /dev/null @@ -1,294 +0,0 @@ -package vless - -import ( - "encoding/binary" - "io" - "net" - "sync" - - "github.com/sagernet/sing-vmess" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - - "github.com/gofrs/uuid/v5" -) - -type Client struct { - key [16]byte - flow string - logger logger.Logger -} - -func NewClient(userId string, flow string, logger logger.Logger) (*Client, error) { - user := uuid.FromStringOrNil(userId) - if user == uuid.Nil { - user = uuid.NewV5(user, userId) - } - switch flow { - case "", "xtls-rprx-vision": - default: - return nil, E.New("unsupported flow: " + flow) - } - return &Client{user, flow, logger}, nil -} - -func (c *Client) prepareConn(conn net.Conn, tlsConn net.Conn) (net.Conn, error) { - if c.flow == FlowVision { - protocolConn, err := NewVisionConn(conn, tlsConn, c.key, c.logger) - if err != nil { - return nil, E.Cause(err, "initialize vision") - } - conn = protocolConn - } - return conn, nil -} - -func (c *Client) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { - remoteConn := NewConn(conn, c.key, vmess.CommandTCP, destination, c.flow) - protocolConn, err := c.prepareConn(remoteConn, conn) - if err != nil { - return nil, err - } - return protocolConn, common.Error(remoteConn.Write(nil)) -} - -func (c *Client) DialEarlyConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { - return c.prepareConn(NewConn(conn, c.key, vmess.CommandTCP, destination, c.flow), conn) -} - -func (c *Client) DialPacketConn(conn net.Conn, destination M.Socksaddr) (*PacketConn, error) { - serverConn := &PacketConn{Conn: conn, key: c.key, destination: destination, flow: c.flow} - return serverConn, common.Error(serverConn.Write(nil)) -} - -func (c *Client) DialEarlyPacketConn(conn net.Conn, destination M.Socksaddr) (*PacketConn, error) { - return &PacketConn{Conn: conn, key: c.key, destination: destination, flow: c.flow}, nil -} - -func (c *Client) DialXUDPPacketConn(conn net.Conn, destination M.Socksaddr) (vmess.PacketConn, error) { - remoteConn := NewConn(conn, c.key, vmess.CommandTCP, destination, c.flow) - protocolConn, err := c.prepareConn(remoteConn, conn) - if err != nil { - return nil, err - } - return vmess.NewXUDPConn(protocolConn, destination), common.Error(remoteConn.Write(nil)) -} - -func (c *Client) DialEarlyXUDPPacketConn(conn net.Conn, destination M.Socksaddr) (vmess.PacketConn, error) { - remoteConn := NewConn(conn, c.key, vmess.CommandMux, destination, c.flow) - protocolConn, err := c.prepareConn(remoteConn, conn) - if err != nil { - return nil, err - } - return vmess.NewXUDPConn(protocolConn, destination), common.Error(remoteConn.Write(nil)) -} - -var ( - _ N.EarlyConn = (*Conn)(nil) - _ N.VectorisedWriter = (*Conn)(nil) -) - -type Conn struct { - N.ExtendedConn - writer N.VectorisedWriter - request Request - requestWritten bool - responseRead bool -} - -func NewConn(conn net.Conn, uuid [16]byte, command byte, destination M.Socksaddr, flow string) *Conn { - return &Conn{ - ExtendedConn: bufio.NewExtendedConn(conn), - writer: bufio.NewVectorisedWriter(conn), - request: Request{ - UUID: uuid, - Command: command, - Destination: destination, - Flow: flow, - }, - } -} - -func (c *Conn) Read(b []byte) (n int, err error) { - if !c.responseRead { - err = ReadResponse(c.ExtendedConn) - if err != nil { - return - } - c.responseRead = true - } - return c.ExtendedConn.Read(b) -} - -func (c *Conn) ReadBuffer(buffer *buf.Buffer) error { - if !c.responseRead { - err := ReadResponse(c.ExtendedConn) - if err != nil { - return err - } - c.responseRead = true - } - return c.ExtendedConn.ReadBuffer(buffer) -} - -func (c *Conn) Write(b []byte) (n int, err error) { - if !c.requestWritten { - err = WriteRequest(c.ExtendedConn, c.request, b) - if err == nil { - n = len(b) - } - c.requestWritten = true - return - } - return c.ExtendedConn.Write(b) -} - -func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { - if !c.requestWritten { - err := EncodeRequest(c.request, buf.With(buffer.ExtendHeader(RequestLen(c.request)))) - if err != nil { - return err - } - c.requestWritten = true - } - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *Conn) WriteVectorised(buffers []*buf.Buffer) error { - if !c.requestWritten { - buffer := buf.NewSize(RequestLen(c.request)) - err := EncodeRequest(c.request, buffer) - if err != nil { - buffer.Release() - return err - } - c.requestWritten = true - return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...)) - } - return c.writer.WriteVectorised(buffers) -} - -func (c *Conn) ReaderReplaceable() bool { - return c.responseRead -} - -func (c *Conn) WriterReplaceable() bool { - return c.requestWritten -} - -func (c *Conn) NeedHandshake() bool { - return !c.requestWritten -} - -func (c *Conn) FrontHeadroom() int { - if c.requestWritten { - return 0 - } - return RequestLen(c.request) -} - -func (c *Conn) Upstream() any { - return c.ExtendedConn -} - -type PacketConn struct { - net.Conn - access sync.Mutex - key [16]byte - destination M.Socksaddr - flow string - requestWritten bool - responseRead bool -} - -func (c *PacketConn) Read(b []byte) (n int, err error) { - if !c.responseRead { - err = ReadResponse(c.Conn) - if err != nil { - return - } - c.responseRead = true - } - var length uint16 - err = binary.Read(c.Conn, binary.BigEndian, &length) - if err != nil { - return - } - if cap(b) < int(length) { - return 0, io.ErrShortBuffer - } - return io.ReadFull(c.Conn, b[:length]) -} - -func (c *PacketConn) Write(b []byte) (n int, err error) { - if !c.requestWritten { - c.access.Lock() - if c.requestWritten { - c.access.Unlock() - } else { - err = WritePacketRequest(c.Conn, Request{c.key, vmess.CommandUDP, c.destination, c.flow}, nil) - if err == nil { - n = len(b) - } - c.requestWritten = true - c.access.Unlock() - } - } - err = binary.Write(c.Conn, binary.BigEndian, uint16(len(b))) - if err != nil { - return - } - return c.Conn.Write(b) -} - -func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() - dataLen := buffer.Len() - binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(dataLen)) - if !c.requestWritten { - c.access.Lock() - if c.requestWritten { - c.access.Unlock() - } else { - err := WritePacketRequest(c.Conn, Request{c.key, vmess.CommandUDP, c.destination, c.flow}, buffer.Bytes()) - c.requestWritten = true - c.access.Unlock() - return err - } - } - return common.Error(c.Conn.Write(buffer.Bytes())) -} - -func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(p) - if err != nil { - return - } - if c.destination.IsFqdn() { - addr = c.destination - } else { - addr = c.destination.UDPAddr() - } - return -} - -func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return c.Write(p) -} - -func (c *PacketConn) FrontHeadroom() int { - return 2 -} - -func (c *PacketConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *PacketConn) Upstream() any { - return c.Conn -} diff --git a/transport/vless/constant.go b/transport/vless/constant.go deleted file mode 100644 index fb27af56..00000000 --- a/transport/vless/constant.go +++ /dev/null @@ -1,40 +0,0 @@ -package vless - -import ( - "bytes" - - "github.com/sagernet/sing/common/buf" -) - -var ( - tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} - tlsClientHandShakeStart = []byte{0x16, 0x03} - tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03} - tlsApplicationDataStart = []byte{0x17, 0x03, 0x03} -) - -const ( - commandPaddingContinue byte = iota - commandPaddingEnd - commandPaddingDirect -) - -var tls13CipherSuiteDic = map[uint16]string{ - 0x1301: "TLS_AES_128_GCM_SHA256", - 0x1302: "TLS_AES_256_GCM_SHA384", - 0x1303: "TLS_CHACHA20_POLY1305_SHA256", - 0x1304: "TLS_AES_128_CCM_SHA256", - 0x1305: "TLS_AES_128_CCM_8_SHA256", -} - -func reshapeBuffer(b []byte) []*buf.Buffer { - const bufferLimit = 8192 - 21 - if len(b) < bufferLimit { - return []*buf.Buffer{buf.As(b)} - } - index := int32(bytes.LastIndex(b, tlsApplicationDataStart)) - if index <= 0 { - index = 8192 / 2 - } - return []*buf.Buffer{buf.As(b[:index]), buf.As(b[index:])} -} diff --git a/transport/vless/protocol.go b/transport/vless/protocol.go deleted file mode 100644 index 5cda06e1..00000000 --- a/transport/vless/protocol.go +++ /dev/null @@ -1,297 +0,0 @@ -package vless - -import ( - "bytes" - "encoding/binary" - "io" - - "github.com/sagernet/sing-vmess" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - "github.com/sagernet/sing/common/rw" -) - -const ( - Version = 0 - FlowVision = "xtls-rprx-vision" -) - -type Request struct { - UUID [16]byte - Command byte - Destination M.Socksaddr - Flow string -} - -func ReadRequest(reader io.Reader) (*Request, error) { - var request Request - - var version uint8 - err := binary.Read(reader, binary.BigEndian, &version) - if err != nil { - return nil, err - } - if version != Version { - return nil, E.New("unknown version: ", version) - } - - _, err = io.ReadFull(reader, request.UUID[:]) - if err != nil { - return nil, err - } - - var addonsLen uint8 - err = binary.Read(reader, binary.BigEndian, &addonsLen) - if err != nil { - return nil, err - } - - if addonsLen > 0 { - addonsBytes, err := rw.ReadBytes(reader, int(addonsLen)) - if err != nil { - return nil, err - } - - addons, err := readAddons(bytes.NewReader(addonsBytes)) - if err != nil { - return nil, err - } - request.Flow = addons.Flow - } - - err = binary.Read(reader, binary.BigEndian, &request.Command) - if err != nil { - return nil, err - } - - if request.Command != vmess.CommandMux { - request.Destination, err = vmess.AddressSerializer.ReadAddrPort(reader) - if err != nil { - return nil, err - } - } - - return &request, nil -} - -type Addons struct { - Flow string - Seed string -} - -func readAddons(reader io.Reader) (*Addons, error) { - protoHeader, err := rw.ReadByte(reader) - if err != nil { - return nil, err - } - if protoHeader != 10 { - return nil, E.New("unknown protobuf message header: ", protoHeader) - } - - var addons Addons - - flowLen, err := rw.ReadUVariant(reader) - if err != nil { - if err == io.EOF { - return &addons, nil - } - return nil, err - } - flowBytes, err := rw.ReadBytes(reader, int(flowLen)) - if err != nil { - return nil, err - } - addons.Flow = string(flowBytes) - - seedLen, err := rw.ReadUVariant(reader) - if err != nil { - if err == io.EOF { - return &addons, nil - } - return nil, err - } - seedBytes, err := rw.ReadBytes(reader, int(seedLen)) - if err != nil { - return nil, err - } - addons.Seed = string(seedBytes) - - return &addons, nil -} - -func WriteRequest(writer io.Writer, request Request, payload []byte) error { - var requestLen int - requestLen += 1 // version - requestLen += 16 // uuid - requestLen += 1 // protobuf length - - var addonsLen int - if request.Flow != "" { - addonsLen += 1 // protobuf header - addonsLen += rw.UVariantLen(uint64(len(request.Flow))) - addonsLen += len(request.Flow) - requestLen += addonsLen - } - requestLen += 1 // command - if request.Command != vmess.CommandMux { - requestLen += vmess.AddressSerializer.AddrPortLen(request.Destination) - } - requestLen += len(payload) - buffer := buf.NewSize(requestLen) - defer buffer.Release() - common.Must( - buffer.WriteByte(Version), - common.Error(buffer.Write(request.UUID[:])), - buffer.WriteByte(byte(addonsLen)), - ) - if addonsLen > 0 { - common.Must(buffer.WriteByte(10)) - binary.PutUvarint(buffer.Extend(rw.UVariantLen(uint64(len(request.Flow)))), uint64(len(request.Flow))) - common.Must(common.Error(buffer.WriteString(request.Flow))) - } - common.Must( - buffer.WriteByte(request.Command), - ) - - if request.Command != vmess.CommandMux { - err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination) - if err != nil { - return err - } - } - - common.Must1(buffer.Write(payload)) - return common.Error(writer.Write(buffer.Bytes())) -} - -func EncodeRequest(request Request, buffer *buf.Buffer) error { - var requestLen int - requestLen += 1 // version - requestLen += 16 // uuid - requestLen += 1 // protobuf length - - var addonsLen int - if request.Flow != "" { - addonsLen += 1 // protobuf header - addonsLen += rw.UVariantLen(uint64(len(request.Flow))) - addonsLen += len(request.Flow) - requestLen += addonsLen - } - requestLen += 1 // command - if request.Command != vmess.CommandMux { - requestLen += vmess.AddressSerializer.AddrPortLen(request.Destination) - } - common.Must( - buffer.WriteByte(Version), - common.Error(buffer.Write(request.UUID[:])), - buffer.WriteByte(byte(addonsLen)), - ) - if addonsLen > 0 { - common.Must(buffer.WriteByte(10)) - binary.PutUvarint(buffer.Extend(rw.UVariantLen(uint64(len(request.Flow)))), uint64(len(request.Flow))) - common.Must(common.Error(buffer.WriteString(request.Flow))) - } - common.Must( - buffer.WriteByte(request.Command), - ) - - if request.Command != vmess.CommandMux { - err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination) - if err != nil { - return err - } - } - return nil -} - -func RequestLen(request Request) int { - var requestLen int - requestLen += 1 // version - requestLen += 16 // uuid - requestLen += 1 // protobuf length - - var addonsLen int - if request.Flow != "" { - addonsLen += 1 // protobuf header - addonsLen += rw.UVariantLen(uint64(len(request.Flow))) - addonsLen += len(request.Flow) - requestLen += addonsLen - } - requestLen += 1 // command - if request.Command != vmess.CommandMux { - requestLen += vmess.AddressSerializer.AddrPortLen(request.Destination) - } - return requestLen -} - -func WritePacketRequest(writer io.Writer, request Request, payload []byte) error { - var requestLen int - requestLen += 1 // version - requestLen += 16 // uuid - requestLen += 1 // protobuf length - var addonsLen int - /*if request.Flow != "" { - addonsLen += 1 // protobuf header - addonsLen += rw.UVariantLen(uint64(len(request.Flow))) - addonsLen += len(request.Flow) - requestLen += addonsLen - }*/ - requestLen += 1 // command - requestLen += vmess.AddressSerializer.AddrPortLen(request.Destination) - if len(payload) > 0 { - requestLen += 2 - requestLen += len(payload) - } - buffer := buf.NewSize(requestLen) - defer buffer.Release() - common.Must( - buffer.WriteByte(Version), - common.Error(buffer.Write(request.UUID[:])), - buffer.WriteByte(byte(addonsLen)), - ) - - if addonsLen > 0 { - common.Must(buffer.WriteByte(10)) - binary.PutUvarint(buffer.Extend(rw.UVariantLen(uint64(len(request.Flow)))), uint64(len(request.Flow))) - common.Must(common.Error(buffer.WriteString(request.Flow))) - } - - common.Must(buffer.WriteByte(vmess.CommandUDP)) - - err := vmess.AddressSerializer.WriteAddrPort(buffer, request.Destination) - if err != nil { - return err - } - - if len(payload) > 0 { - common.Must( - binary.Write(buffer, binary.BigEndian, uint16(len(payload))), - common.Error(buffer.Write(payload)), - ) - } - - return common.Error(writer.Write(buffer.Bytes())) -} - -func ReadResponse(reader io.Reader) error { - version, err := rw.ReadByte(reader) - if err != nil { - return err - } - if version != Version { - return E.New("unknown version: ", version) - } - protobufLength, err := rw.ReadByte(reader) - if err != nil { - return err - } - if protobufLength > 0 { - err = rw.SkipN(reader, int(protobufLength)) - if err != nil { - return err - } - } - return nil -} diff --git a/transport/vless/service.go b/transport/vless/service.go deleted file mode 100644 index 7b690202..00000000 --- a/transport/vless/service.go +++ /dev/null @@ -1,260 +0,0 @@ -package vless - -import ( - "context" - "encoding/binary" - "io" - "net" - - "github.com/sagernet/sing-vmess" - "github.com/sagernet/sing/common/auth" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - - "github.com/gofrs/uuid/v5" -) - -type Service[T comparable] struct { - userMap map[[16]byte]T - userFlow map[T]string - logger logger.Logger - handler Handler -} - -type Handler interface { - N.TCPConnectionHandler - N.UDPConnectionHandler - E.Handler -} - -func NewService[T comparable](logger logger.Logger, handler Handler) *Service[T] { - return &Service[T]{ - logger: logger, - handler: handler, - } -} - -func (s *Service[T]) UpdateUsers(userList []T, userUUIDList []string, userFlowList []string) { - userMap := make(map[[16]byte]T) - userFlowMap := make(map[T]string) - for i, userName := range userList { - userID := uuid.FromStringOrNil(userUUIDList[i]) - if userID == uuid.Nil { - userID = uuid.NewV5(uuid.Nil, userUUIDList[i]) - } - userMap[userID] = userName - userFlowMap[userName] = userFlowList[i] - } - s.userMap = userMap - s.userFlow = userFlowMap -} - -var _ N.TCPConnectionHandler = (*Service[int])(nil) - -func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - request, err := ReadRequest(conn) - if err != nil { - return err - } - user, loaded := s.userMap[request.UUID] - if !loaded { - return E.New("unknown UUID: ", uuid.FromBytesOrNil(request.UUID[:])) - } - ctx = auth.ContextWithUser(ctx, user) - metadata.Destination = request.Destination - - userFlow := s.userFlow[user] - if request.Flow == FlowVision && request.Command == vmess.NetworkUDP { - return E.New(FlowVision, " flow does not support UDP") - } else if request.Flow != userFlow { - return E.New("flow mismatch: expected ", flowName(userFlow), ", but got ", flowName(request.Flow)) - } - - if request.Command == vmess.CommandUDP { - return s.handler.NewPacketConnection(ctx, &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(conn), destination: request.Destination}, metadata) - } - responseConn := &serverConn{ExtendedConn: bufio.NewExtendedConn(conn), writer: bufio.NewVectorisedWriter(conn)} - switch userFlow { - case FlowVision: - conn, err = NewVisionConn(responseConn, conn, request.UUID, s.logger) - if err != nil { - return E.Cause(err, "initialize vision") - } - case "": - conn = responseConn - default: - return E.New("unknown flow: ", userFlow) - } - switch request.Command { - case vmess.CommandTCP: - return s.handler.NewConnection(ctx, conn, metadata) - case vmess.CommandMux: - return vmess.HandleMuxConnection(ctx, conn, s.handler) - default: - return E.New("unknown command: ", request.Command) - } -} - -func flowName(value string) string { - if value == "" { - return "none" - } - return value -} - -var _ N.VectorisedWriter = (*serverConn)(nil) - -type serverConn struct { - N.ExtendedConn - writer N.VectorisedWriter - responseWritten bool -} - -func (c *serverConn) Read(b []byte) (n int, err error) { - return c.ExtendedConn.Read(b) -} - -func (c *serverConn) Write(b []byte) (n int, err error) { - if !c.responseWritten { - _, err = bufio.WriteVectorised(c.writer, [][]byte{{Version, 0}, b}) - if err == nil { - n = len(b) - } - c.responseWritten = true - return - } - return c.ExtendedConn.Write(b) -} - -func (c *serverConn) WriteBuffer(buffer *buf.Buffer) error { - if !c.responseWritten { - header := buffer.ExtendHeader(2) - header[0] = Version - header[1] = 0 - c.responseWritten = true - } - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error { - if !c.responseWritten { - err := c.writer.WriteVectorised(append([]*buf.Buffer{buf.As([]byte{Version, 0})}, buffers...)) - c.responseWritten = true - return err - } - return c.writer.WriteVectorised(buffers) -} - -func (c *serverConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *serverConn) FrontHeadroom() int { - if c.responseWritten { - return 0 - } - return 2 -} - -func (c *serverConn) ReaderReplaceable() bool { - return true -} - -func (c *serverConn) WriterReplaceable() bool { - return c.responseWritten -} - -func (c *serverConn) Upstream() any { - return c.ExtendedConn -} - -type serverPacketConn struct { - N.ExtendedConn - responseWriter io.Writer - responseWritten bool - destination M.Socksaddr -} - -func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.ExtendedConn.Read(p) - if err != nil { - return - } - if c.destination.IsFqdn() { - addr = c.destination - } else { - addr = c.destination.UDPAddr() - } - return -} - -func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if !c.responseWritten { - if c.responseWriter == nil { - var packetLen [2]byte - binary.BigEndian.PutUint16(packetLen[:], uint16(len(p))) - _, err = bufio.WriteVectorised(bufio.NewVectorisedWriter(c.ExtendedConn), [][]byte{{Version, 0}, packetLen[:], p}) - if err == nil { - n = len(p) - } - c.responseWritten = true - return - } else { - _, err = c.responseWriter.Write([]byte{Version, 0}) - if err != nil { - return - } - c.responseWritten = true - } - } - return c.ExtendedConn.Write(p) -} - -func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - var packetLen uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &packetLen) - if err != nil { - return - } - - _, err = buffer.ReadFullFrom(c.ExtendedConn, int(packetLen)) - if err != nil { - return - } - - destination = c.destination - return -} - -func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - if !c.responseWritten { - if c.responseWriter == nil { - var packetLen [2]byte - binary.BigEndian.PutUint16(packetLen[:], uint16(buffer.Len())) - err := bufio.NewVectorisedWriter(c.ExtendedConn).WriteVectorised([]*buf.Buffer{buf.As([]byte{Version, 0}), buf.As(packetLen[:]), buffer}) - c.responseWritten = true - return err - } else { - _, err := c.responseWriter.Write([]byte{Version, 0}) - if err != nil { - return err - } - c.responseWritten = true - } - } - packetLen := buffer.Len() - binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(packetLen)) - return c.ExtendedConn.WriteBuffer(buffer) -} - -func (c *serverPacketConn) FrontHeadroom() int { - return 2 -} - -func (c *serverPacketConn) Upstream() any { - return c.ExtendedConn -} diff --git a/transport/vless/vision.go b/transport/vless/vision.go deleted file mode 100644 index 5ee2d0df..00000000 --- a/transport/vless/vision.go +++ /dev/null @@ -1,380 +0,0 @@ -package vless - -import ( - "bytes" - "crypto/rand" - "crypto/tls" - "io" - "math/big" - "net" - "reflect" - "time" - "unsafe" - - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - N "github.com/sagernet/sing/common/network" -) - -var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) - -func init() { - tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) { - tlsConn, loaded := common.Cast[*tls.Conn](conn) - if !loaded { - return - } - return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn)) - }) -} - -const xrayChunkSize = 8192 - -type VisionConn struct { - net.Conn - reader *bufio.ChunkReader - writer N.VectorisedWriter - input *bytes.Reader - rawInput *bytes.Buffer - netConn net.Conn - logger logger.Logger - - userUUID [16]byte - isTLS bool - numberOfPacketToFilter int - isTLS12orAbove bool - remainingServerHello int32 - cipher uint16 - enableXTLS bool - isPadding bool - directWrite bool - writeUUID bool - withinPaddingBuffers bool - remainingContent int - remainingPadding int - currentCommand byte - directRead bool - remainingReader io.Reader -} - -func NewVisionConn(conn net.Conn, tlsConn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) { - var ( - loaded bool - reflectType reflect.Type - reflectPointer uintptr - netConn net.Conn - ) - for _, tlsCreator := range tlsRegistry { - loaded, netConn, reflectType, reflectPointer = tlsCreator(tlsConn) - if loaded { - break - } - } - if !loaded { - return nil, C.ErrTLSRequired - } - input, _ := reflectType.FieldByName("input") - rawInput, _ := reflectType.FieldByName("rawInput") - return &VisionConn{ - Conn: conn, - reader: bufio.NewChunkReader(conn, xrayChunkSize), - writer: bufio.NewVectorisedWriter(conn), - input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)), - rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)), - netConn: netConn, - logger: logger, - - userUUID: userUUID, - numberOfPacketToFilter: 8, - remainingServerHello: -1, - isPadding: true, - writeUUID: true, - withinPaddingBuffers: true, - remainingContent: -1, - remainingPadding: -1, - }, nil -} - -func (c *VisionConn) Read(p []byte) (n int, err error) { - if c.remainingReader != nil { - n, err = c.remainingReader.Read(p) - if err == io.EOF { - err = nil - c.remainingReader = nil - } - if n > 0 { - return - } - } - if c.directRead { - return c.netConn.Read(p) - } - var bufferBytes []byte - var chunkBuffer *buf.Buffer - if len(p) > xrayChunkSize { - n, err = c.Conn.Read(p) - if err != nil { - return - } - bufferBytes = p[:n] - } else { - chunkBuffer, err = c.reader.ReadChunk() - if err != nil { - return 0, err - } - bufferBytes = chunkBuffer.Bytes() - } - if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 { - buffers := c.unPadding(bufferBytes) - if chunkBuffer != nil { - buffers = common.Map(buffers, func(it *buf.Buffer) *buf.Buffer { - return it.ToOwned() - }) - chunkBuffer.Reset() - } - if c.remainingContent == 0 && c.remainingPadding == 0 { - if c.currentCommand == commandPaddingEnd { - c.withinPaddingBuffers = false - c.remainingContent = -1 - c.remainingPadding = -1 - } else if c.currentCommand == commandPaddingDirect { - c.withinPaddingBuffers = false - c.directRead = true - - inputBuffer, err := io.ReadAll(c.input) - if err != nil { - return 0, err - } - buffers = append(buffers, buf.As(inputBuffer)) - - rawInputBuffer, err := io.ReadAll(c.rawInput) - if err != nil { - return 0, err - } - - buffers = append(buffers, buf.As(rawInputBuffer)) - - c.logger.Trace("XtlsRead readV") - } else if c.currentCommand == commandPaddingContinue { - c.withinPaddingBuffers = true - } else { - return 0, E.New("unknown command ", c.currentCommand) - } - } else if c.remainingContent > 0 || c.remainingPadding > 0 { - c.withinPaddingBuffers = true - } else { - c.withinPaddingBuffers = false - } - if c.numberOfPacketToFilter > 0 { - c.filterTLS(buf.ToSliceMulti(buffers)) - } - c.remainingReader = io.MultiReader(common.Map(buffers, func(it *buf.Buffer) io.Reader { return it })...) - return c.Read(p) - } else { - if c.numberOfPacketToFilter > 0 { - c.filterTLS([][]byte{bufferBytes}) - } - if chunkBuffer != nil { - n = copy(p, bufferBytes) - chunkBuffer.Advance(n) - } - return - } -} - -func (c *VisionConn) Write(p []byte) (n int, err error) { - if c.numberOfPacketToFilter > 0 { - c.filterTLS([][]byte{p}) - } - if c.isPadding { - inputLen := len(p) - buffers := reshapeBuffer(p) - var specIndex int - for i, buffer := range buffers { - if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) { - var command byte = commandPaddingEnd - if c.enableXTLS { - c.directWrite = true - specIndex = i - command = commandPaddingDirect - } - c.isPadding = false - buffers[i] = c.padding(buffer, command) - break - } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 { - c.isPadding = false - buffers[i] = c.padding(buffer, commandPaddingEnd) - break - } - buffers[i] = c.padding(buffer, commandPaddingContinue) - } - if c.directWrite { - encryptedBuffer := buffers[:specIndex+1] - err = c.writer.WriteVectorised(encryptedBuffer) - if err != nil { - return - } - buffers = buffers[specIndex+1:] - c.writer = bufio.NewVectorisedWriter(c.netConn) - c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers)) - time.Sleep(5 * time.Millisecond) // wtf - } - err = c.writer.WriteVectorised(buffers) - if err == nil { - n = inputLen - } - return - } - if c.directWrite { - return c.netConn.Write(p) - } else { - return c.Conn.Write(p) - } -} - -func (c *VisionConn) filterTLS(buffers [][]byte) { - for _, buffer := range buffers { - c.numberOfPacketToFilter-- - if len(buffer) > 6 { - if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 { - c.isTLS = true - if buffer[5] == 2 { - c.isTLS12orAbove = true - c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5 - if len(buffer) >= 79 && c.remainingServerHello >= 79 { - sessionIdLen := int32(buffer[43]) - cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3] - c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) - } else { - c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello) - } - } - } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 { - c.isTLS = true - c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer)) - } - } - if c.remainingServerHello > 0 { - end := int(c.remainingServerHello) - if end > len(buffer) { - end = len(buffer) - } - c.remainingServerHello -= int32(end) - if bytes.Contains(buffer[:end], tls13SupportedVersions) { - cipher, ok := tls13CipherSuiteDic[c.cipher] - if ok && cipher != "TLS_AES_128_CCM_8_SHA256" { - c.enableXTLS = true - } - c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS) - c.numberOfPacketToFilter = 0 - return - } else if c.remainingServerHello == 0 { - c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer)) - c.numberOfPacketToFilter = 0 - return - } - } - if c.numberOfPacketToFilter == 0 { - c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer)) - } - } -} - -func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer { - contentLen := 0 - paddingLen := 0 - if buffer != nil { - contentLen = buffer.Len() - } - if contentLen < 900 && c.isTLS { - l, _ := rand.Int(rand.Reader, big.NewInt(500)) - paddingLen = int(l.Int64()) + 900 - contentLen - } else { - l, _ := rand.Int(rand.Reader, big.NewInt(256)) - paddingLen = int(l.Int64()) - } - var bufferLen int - if c.writeUUID { - bufferLen += 16 - } - bufferLen += 5 - if buffer != nil { - bufferLen += buffer.Len() - } - bufferLen += paddingLen - newBuffer := buf.NewSize(bufferLen) - if c.writeUUID { - common.Must1(newBuffer.Write(c.userUUID[:])) - c.writeUUID = false - } - common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)})) - if buffer != nil { - common.Must1(newBuffer.Write(buffer.Bytes())) - buffer.Release() - } - newBuffer.Extend(paddingLen) - c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command) - return newBuffer -} - -func (c *VisionConn) unPadding(buffer []byte) []*buf.Buffer { - var bufferIndex int - if c.remainingContent == -1 && c.remainingPadding == -1 { - if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) { - bufferIndex = 16 - c.remainingContent = 0 - c.remainingPadding = 0 - c.currentCommand = 0 - } - } - if c.remainingContent == -1 && c.remainingPadding == -1 { - return []*buf.Buffer{buf.As(buffer)} - } - var buffers []*buf.Buffer - for bufferIndex < len(buffer) { - if c.remainingContent <= 0 && c.remainingPadding <= 0 { - if c.currentCommand == 1 { - buffers = append(buffers, buf.As(buffer[bufferIndex:])) - break - } else { - paddingInfo := buffer[bufferIndex : bufferIndex+5] - c.currentCommand = paddingInfo[0] - c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2]) - c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4]) - bufferIndex += 5 - c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand) - } - } else if c.remainingContent > 0 { - end := c.remainingContent - if end > len(buffer)-bufferIndex { - end = len(buffer) - bufferIndex - } - buffers = append(buffers, buf.As(buffer[bufferIndex:bufferIndex+end])) - c.remainingContent -= end - bufferIndex += end - } else { - end := c.remainingPadding - if end > len(buffer)-bufferIndex { - end = len(buffer) - bufferIndex - } - c.remainingPadding -= end - bufferIndex += end - } - if bufferIndex == len(buffer) { - break - } - } - return buffers -} - -func (c *VisionConn) NeedAdditionalReadDeadline() bool { - return true -} - -func (c *VisionConn) Upstream() any { - return c.Conn -} diff --git a/transport/vless/vision_reality.go b/transport/vless/vision_reality.go deleted file mode 100644 index fa04c422..00000000 --- a/transport/vless/vision_reality.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build with_reality_server - -package vless - -import ( - "net" - "reflect" - "unsafe" - - "github.com/sagernet/reality" - "github.com/sagernet/sing/common" -) - -func init() { - tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) { - tlsConn, loaded := common.Cast[*reality.Conn](conn) - if !loaded { - return - } - return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn)) - }) -} diff --git a/transport/vless/vision_utls.go b/transport/vless/vision_utls.go deleted file mode 100644 index e2469c88..00000000 --- a/transport/vless/vision_utls.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build with_utls - -package vless - -import ( - "net" - "reflect" - "unsafe" - - "github.com/sagernet/sing/common" - utls "github.com/sagernet/utls" -) - -func init() { - tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) { - tlsConn, loaded := common.Cast[*utls.UConn](conn) - if !loaded { - return - } - return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn.Conn).Elem(), uintptr(unsafe.Pointer(tlsConn.Conn)) - }) -}