Update badtls

This commit is contained in:
世界 2023-04-18 19:56:01 +08:00
parent e8dad1afeb
commit 5e1499d67b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 59 additions and 53 deletions

View file

@ -1,4 +1,4 @@
//go:build go1.19 && !go1.20 //go:build go1.20 && !go1.21
package badtls package badtls
@ -14,17 +14,20 @@ import (
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
) )
type Conn struct { type Conn struct {
*tls.Conn *tls.Conn
writer N.ExtendedWriter writer N.ExtendedWriter
activeCall *int32 isHandshakeComplete *atomic.Bool
activeCall *atomic.Int32
closeNotifySent *bool closeNotifySent *bool
version *uint16 version *uint16
rand io.Reader rand io.Reader
@ -37,16 +40,34 @@ type Conn struct {
halfScratchBuf []byte halfScratchBuf []byte
} }
func Create(conn *tls.Conn) (TLSConn, error) { func TryCreate(conn aTLS.Conn) aTLS.Conn {
if !handshakeComplete(conn) { tlsConn, ok := conn.(*tls.Conn)
if !ok {
return conn
}
badConn, err := Create(tlsConn)
if err != nil {
log.Warn("initialize badtls: ", err)
return conn
}
return badConn
}
func Create(conn *tls.Conn) (aTLS.Conn, error) {
rawConn := reflect.Indirect(reflect.ValueOf(conn))
rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid isHandshakeComplete")
}
isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
if !isHandshakeComplete.Load() {
return nil, E.New("handshake not finished") return nil, E.New("handshake not finished")
} }
rawConn := reflect.Indirect(reflect.ValueOf(conn))
rawActiveCall := rawConn.FieldByName("activeCall") rawActiveCall := rawConn.FieldByName("activeCall")
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 { if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid active call") return nil, E.New("badtls: invalid active call")
} }
activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr())) activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
rawHalfConn := rawConn.FieldByName("out") rawHalfConn := rawConn.FieldByName("out")
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half conn") return nil, E.New("badtls: invalid half conn")
@ -110,6 +131,7 @@ func Create(conn *tls.Conn) (TLSConn, error) {
return &Conn{ return &Conn{
Conn: conn, Conn: conn,
writer: bufio.NewExtendedWriter(conn.NetConn()), writer: bufio.NewExtendedWriter(conn.NetConn()),
isHandshakeComplete: isHandshakeComplete,
activeCall: activeCall, activeCall: activeCall,
closeNotifySent: closeNotifySent, closeNotifySent: closeNotifySent,
version: version, version: version,
@ -130,15 +152,15 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
return common.Error(c.Write(buffer.Bytes())) return common.Error(c.Write(buffer.Bytes()))
} }
for { for {
x := atomic.LoadInt32(c.activeCall) x := c.activeCall.Load()
if x&1 != 0 { if x&1 != 0 {
return net.ErrClosed return net.ErrClosed
} }
if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) { if c.activeCall.CompareAndSwap(x, x+2) {
break break
} }
} }
defer atomic.AddInt32(c.activeCall, -2) defer c.activeCall.Add(-2)
c.halfAccess.Lock() c.halfAccess.Lock()
defer c.halfAccess.Unlock() defer c.halfAccess.Unlock()
if err := *c.halfError; err != nil { if err := *c.halfError; err != nil {
@ -186,6 +208,7 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead())) binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
} }
incSeq(c.halfPtr) incSeq(c.halfPtr)
log.Trace("badtls write ", buffer.Len())
return c.writer.WriteBuffer(buffer) return c.writer.WriteBuffer(buffer)
} }

View file

@ -1,4 +1,4 @@
//go:build !go1.19 || go1.20 //go:build !go1.19 || go1.21
package badtls package badtls

View file

@ -1,13 +0,0 @@
package badtls
import (
"context"
"crypto/tls"
"net"
)
type TLSConn interface {
net.Conn
HandshakeContext(ctx context.Context) error
ConnectionState() tls.ConnectionState
}

View file

@ -1,9 +1,8 @@
//go:build go1.19 && !go.1.20 //go:build go1.20 && !go.1.21
package badtls package badtls
import ( import (
"crypto/tls"
"reflect" "reflect"
_ "unsafe" _ "unsafe"
) )
@ -16,9 +15,6 @@ const (
//go:linkname errShutdown crypto/tls.errShutdown //go:linkname errShutdown crypto/tls.errShutdown
var errShutdown error var errShutdown error
//go:linkname handshakeComplete crypto/tls.(*Conn).handshakeComplete
func handshakeComplete(conn *tls.Conn) bool
//go:linkname incSeq crypto/tls.(*halfConn).incSeq //go:linkname incSeq crypto/tls.(*halfConn).incSeq
func incSeq(conn uintptr) func incSeq(conn uintptr)