diff --git a/common/badtls/badtls.go b/common/badtls/badtls.go index 3964defa..c5c55e3c 100644 --- a/common/badtls/badtls.go +++ b/common/badtls/badtls.go @@ -1,4 +1,4 @@ -//go:build go1.19 && !go1.20 +//go:build go1.20 && !go1.21 package badtls @@ -14,39 +14,60 @@ import ( "sync/atomic" "unsafe" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" ) type Conn struct { *tls.Conn - writer N.ExtendedWriter - activeCall *int32 - closeNotifySent *bool - version *uint16 - rand io.Reader - halfAccess *sync.Mutex - halfError *error - cipher cipher.AEAD - explicitNonceLen int - halfPtr uintptr - halfSeq []byte - halfScratchBuf []byte + writer N.ExtendedWriter + isHandshakeComplete *atomic.Bool + activeCall *atomic.Int32 + closeNotifySent *bool + version *uint16 + rand io.Reader + halfAccess *sync.Mutex + halfError *error + cipher cipher.AEAD + explicitNonceLen int + halfPtr uintptr + halfSeq []byte + halfScratchBuf []byte } -func Create(conn *tls.Conn) (TLSConn, error) { - if !handshakeComplete(conn) { +func TryCreate(conn aTLS.Conn) aTLS.Conn { + tlsConn, ok := conn.(*tls.Conn) + if !ok { + return conn + } + badConn, err := Create(tlsConn) + if err != nil { + log.Warn("initialize badtls: ", err) + return conn + } + return badConn +} + +func Create(conn *tls.Conn) (aTLS.Conn, error) { + rawConn := reflect.Indirect(reflect.ValueOf(conn)) + rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete") + if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct { + return nil, E.New("badtls: invalid isHandshakeComplete") + } + isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr())) + if !isHandshakeComplete.Load() { return nil, E.New("handshake not finished") } - rawConn := reflect.Indirect(reflect.ValueOf(conn)) rawActiveCall := rawConn.FieldByName("activeCall") - if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 { + if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct { return nil, E.New("badtls: invalid active call") } - activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr())) + activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr())) rawHalfConn := rawConn.FieldByName("out") if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { return nil, E.New("badtls: invalid half conn") @@ -108,19 +129,20 @@ func Create(conn *tls.Conn) (TLSConn, error) { } halfScratchBuf := rawHalfScratchBuf.Bytes() return &Conn{ - Conn: conn, - writer: bufio.NewExtendedWriter(conn.NetConn()), - activeCall: activeCall, - closeNotifySent: closeNotifySent, - version: version, - halfAccess: halfAccess, - halfError: halfError, - cipher: aeadCipher, - explicitNonceLen: explicitNonceLen, - rand: randReader, - halfPtr: rawHalfConn.UnsafeAddr(), - halfSeq: halfSeq, - halfScratchBuf: halfScratchBuf, + Conn: conn, + writer: bufio.NewExtendedWriter(conn.NetConn()), + isHandshakeComplete: isHandshakeComplete, + activeCall: activeCall, + closeNotifySent: closeNotifySent, + version: version, + halfAccess: halfAccess, + halfError: halfError, + cipher: aeadCipher, + explicitNonceLen: explicitNonceLen, + rand: randReader, + halfPtr: rawHalfConn.UnsafeAddr(), + halfSeq: halfSeq, + halfScratchBuf: halfScratchBuf, }, nil } @@ -130,15 +152,15 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { return common.Error(c.Write(buffer.Bytes())) } for { - x := atomic.LoadInt32(c.activeCall) + x := c.activeCall.Load() if x&1 != 0 { return net.ErrClosed } - if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) { + if c.activeCall.CompareAndSwap(x, x+2) { break } } - defer atomic.AddInt32(c.activeCall, -2) + defer c.activeCall.Add(-2) c.halfAccess.Lock() defer c.halfAccess.Unlock() if err := *c.halfError; err != nil { @@ -186,6 +208,7 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead())) } incSeq(c.halfPtr) + log.Trace("badtls write ", buffer.Len()) return c.writer.WriteBuffer(buffer) } diff --git a/common/badtls/badtls_stub.go b/common/badtls/badtls_stub.go index c44d8792..7810bb1c 100644 --- a/common/badtls/badtls_stub.go +++ b/common/badtls/badtls_stub.go @@ -1,4 +1,4 @@ -//go:build !go1.19 || go1.20 +//go:build !go1.19 || go1.21 package badtls diff --git a/common/badtls/conn.go b/common/badtls/conn.go deleted file mode 100644 index 235763dc..00000000 --- a/common/badtls/conn.go +++ /dev/null @@ -1,13 +0,0 @@ -package badtls - -import ( - "context" - "crypto/tls" - "net" -) - -type TLSConn interface { - net.Conn - HandshakeContext(ctx context.Context) error - ConnectionState() tls.ConnectionState -} diff --git a/common/badtls/link.go b/common/badtls/link.go index c86c7b49..b8d5f4bd 100644 --- a/common/badtls/link.go +++ b/common/badtls/link.go @@ -1,9 +1,8 @@ -//go:build go1.19 && !go.1.20 +//go:build go1.20 && !go.1.21 package badtls import ( - "crypto/tls" "reflect" _ "unsafe" ) @@ -16,9 +15,6 @@ const ( //go:linkname errShutdown crypto/tls.errShutdown var errShutdown error -//go:linkname handshakeComplete crypto/tls.(*Conn).handshakeComplete -func handshakeComplete(conn *tls.Conn) bool - //go:linkname incSeq crypto/tls.(*halfConn).incSeq func incSeq(conn uintptr)