//go:build go1.19 && !go1.20 package badtls import ( "crypto/cipher" "crypto/rand" "crypto/tls" "encoding/binary" "io" "net" "reflect" "sync" "sync/atomic" "unsafe" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" ) type Conn struct { *tls.Conn writer N.ExtendedWriter activeCall *int32 closeNotifySent *bool version *uint16 rand io.Reader halfAccess *sync.Mutex halfError *error cipher cipher.AEAD explicitNonceLen int halfPtr uintptr halfSeq []byte halfScratchBuf []byte } func Create(conn *tls.Conn) (TLSConn, error) { if !handshakeComplete(conn) { return nil, E.New("handshake not finished") } rawConn := reflect.Indirect(reflect.ValueOf(conn)) rawActiveCall := rawConn.FieldByName("activeCall") if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 { return nil, E.New("badtls: invalid active call") } activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr())) rawHalfConn := rawConn.FieldByName("out") if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { return nil, E.New("badtls: invalid half conn") } rawVersion := rawConn.FieldByName("vers") if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 { return nil, E.New("badtls: invalid version") } version := (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr())) rawCloseNotifySent := rawConn.FieldByName("closeNotifySent") if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool { return nil, E.New("badtls: invalid notify") } closeNotifySent := (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr())) rawConfig := reflect.Indirect(rawConn.FieldByName("config")) if !rawConfig.IsValid() || rawConfig.Kind() != reflect.Struct { return nil, E.New("badtls: bad config") } config := (*tls.Config)(unsafe.Pointer(rawConfig.UnsafeAddr())) randReader := config.Rand if randReader == nil { randReader = rand.Reader } rawHalfMutex := rawHalfConn.FieldByName("Mutex") if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct { return nil, E.New("badtls: invalid half mutex") } halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr())) rawHalfError := rawHalfConn.FieldByName("err") if !rawHalfError.IsValid() || rawHalfError.Kind() != reflect.Interface { return nil, E.New("badtls: invalid half error") } halfError := (*error)(unsafe.Pointer(rawHalfError.UnsafeAddr())) rawHalfCipherInterface := rawHalfConn.FieldByName("cipher") if !rawHalfCipherInterface.IsValid() || rawHalfCipherInterface.Kind() != reflect.Interface { return nil, E.New("badtls: invalid cipher interface") } rawHalfCipher := rawHalfCipherInterface.Elem() aeadCipher, loaded := valueInterface(rawHalfCipher, false).(cipher.AEAD) if !loaded { return nil, E.New("badtls: invalid AEAD cipher") } var explicitNonceLen int switch cipherName := reflect.Indirect(rawHalfCipher).Type().String(); cipherName { case "tls.prefixNonceAEAD": explicitNonceLen = aeadCipher.NonceSize() case "tls.xorNonceAEAD": default: return nil, E.New("badtls: unknown cipher type: ", cipherName) } rawHalfSeq := rawHalfConn.FieldByName("seq") if !rawHalfSeq.IsValid() || rawHalfSeq.Kind() != reflect.Array { return nil, E.New("badtls: invalid seq") } halfSeq := rawHalfSeq.Bytes() rawHalfScratchBuf := rawHalfConn.FieldByName("scratchBuf") if !rawHalfScratchBuf.IsValid() || rawHalfScratchBuf.Kind() != reflect.Array { return nil, E.New("badtls: invalid scratchBuf") } halfScratchBuf := rawHalfScratchBuf.Bytes() return &Conn{ Conn: conn, writer: bufio.NewExtendedWriter(conn.NetConn()), activeCall: activeCall, closeNotifySent: closeNotifySent, version: version, halfAccess: halfAccess, halfError: halfError, cipher: aeadCipher, explicitNonceLen: explicitNonceLen, rand: randReader, halfPtr: rawHalfConn.UnsafeAddr(), halfSeq: halfSeq, halfScratchBuf: halfScratchBuf, }, nil } func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { if buffer.Len() > maxPlaintext { defer buffer.Release() return common.Error(c.Write(buffer.Bytes())) } for { x := atomic.LoadInt32(c.activeCall) if x&1 != 0 { return net.ErrClosed } if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) { break } } defer atomic.AddInt32(c.activeCall, -2) c.halfAccess.Lock() defer c.halfAccess.Unlock() if err := *c.halfError; err != nil { return err } if *c.closeNotifySent { return errShutdown } dataLen := buffer.Len() dataBytes := buffer.Bytes() outBuf := buffer.ExtendHeader(recordHeaderLen + c.explicitNonceLen) outBuf[0] = 23 version := *c.version if version == 0 { version = tls.VersionTLS10 } else if version == tls.VersionTLS13 { version = tls.VersionTLS12 } binary.BigEndian.PutUint16(outBuf[1:], version) var nonce []byte if c.explicitNonceLen > 0 { nonce = outBuf[5 : 5+c.explicitNonceLen] if c.explicitNonceLen < 16 { copy(nonce, c.halfSeq) } else { if _, err := io.ReadFull(c.rand, nonce); err != nil { return err } } } if len(nonce) == 0 { nonce = c.halfSeq } if *c.version == tls.VersionTLS13 { buffer.FreeBytes()[0] = 23 binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+1+c.cipher.Overhead())) c.cipher.Seal(outBuf, nonce, outBuf[recordHeaderLen:recordHeaderLen+c.explicitNonceLen+dataLen+1], outBuf[:recordHeaderLen]) buffer.Extend(1 + c.cipher.Overhead()) } else { binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen)) additionalData := append(c.halfScratchBuf[:0], c.halfSeq...) additionalData = append(additionalData, outBuf[:recordHeaderLen]...) c.cipher.Seal(outBuf, nonce, dataBytes, additionalData) buffer.Extend(c.cipher.Overhead()) binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead())) } incSeq(c.halfPtr) return c.writer.WriteBuffer(buffer) } func (c *Conn) FrontHeadroom() int { return recordHeaderLen + c.explicitNonceLen } func (c *Conn) RearHeadroom() int { return 1 + c.cipher.Overhead() } func (c *Conn) WriterMTU() int { return maxPlaintext } func (c *Conn) Upstream() any { return c.NetConn() }