Update badtls

This commit is contained in:
世界 2023-04-18 19:56:01 +08:00
parent e8dad1afeb
commit 5e1499d67b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 59 additions and 53 deletions

View file

@ -1,4 +1,4 @@
//go:build go1.19 && !go1.20
//go:build go1.20 && !go1.21
package badtls
@ -14,39 +14,60 @@ import (
"sync/atomic"
"unsafe"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
)
type Conn struct {
*tls.Conn
writer N.ExtendedWriter
activeCall *int32
closeNotifySent *bool
version *uint16
rand io.Reader
halfAccess *sync.Mutex
halfError *error
cipher cipher.AEAD
explicitNonceLen int
halfPtr uintptr
halfSeq []byte
halfScratchBuf []byte
writer N.ExtendedWriter
isHandshakeComplete *atomic.Bool
activeCall *atomic.Int32
closeNotifySent *bool
version *uint16
rand io.Reader
halfAccess *sync.Mutex
halfError *error
cipher cipher.AEAD
explicitNonceLen int
halfPtr uintptr
halfSeq []byte
halfScratchBuf []byte
}
func Create(conn *tls.Conn) (TLSConn, error) {
if !handshakeComplete(conn) {
func TryCreate(conn aTLS.Conn) aTLS.Conn {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return conn
}
badConn, err := Create(tlsConn)
if err != nil {
log.Warn("initialize badtls: ", err)
return conn
}
return badConn
}
func Create(conn *tls.Conn) (aTLS.Conn, error) {
rawConn := reflect.Indirect(reflect.ValueOf(conn))
rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid isHandshakeComplete")
}
isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
if !isHandshakeComplete.Load() {
return nil, E.New("handshake not finished")
}
rawConn := reflect.Indirect(reflect.ValueOf(conn))
rawActiveCall := rawConn.FieldByName("activeCall")
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 {
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid active call")
}
activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
rawHalfConn := rawConn.FieldByName("out")
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half conn")
@ -108,19 +129,20 @@ func Create(conn *tls.Conn) (TLSConn, error) {
}
halfScratchBuf := rawHalfScratchBuf.Bytes()
return &Conn{
Conn: conn,
writer: bufio.NewExtendedWriter(conn.NetConn()),
activeCall: activeCall,
closeNotifySent: closeNotifySent,
version: version,
halfAccess: halfAccess,
halfError: halfError,
cipher: aeadCipher,
explicitNonceLen: explicitNonceLen,
rand: randReader,
halfPtr: rawHalfConn.UnsafeAddr(),
halfSeq: halfSeq,
halfScratchBuf: halfScratchBuf,
Conn: conn,
writer: bufio.NewExtendedWriter(conn.NetConn()),
isHandshakeComplete: isHandshakeComplete,
activeCall: activeCall,
closeNotifySent: closeNotifySent,
version: version,
halfAccess: halfAccess,
halfError: halfError,
cipher: aeadCipher,
explicitNonceLen: explicitNonceLen,
rand: randReader,
halfPtr: rawHalfConn.UnsafeAddr(),
halfSeq: halfSeq,
halfScratchBuf: halfScratchBuf,
}, nil
}
@ -130,15 +152,15 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
return common.Error(c.Write(buffer.Bytes()))
}
for {
x := atomic.LoadInt32(c.activeCall)
x := c.activeCall.Load()
if x&1 != 0 {
return net.ErrClosed
}
if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) {
if c.activeCall.CompareAndSwap(x, x+2) {
break
}
}
defer atomic.AddInt32(c.activeCall, -2)
defer c.activeCall.Add(-2)
c.halfAccess.Lock()
defer c.halfAccess.Unlock()
if err := *c.halfError; err != nil {
@ -186,6 +208,7 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
}
incSeq(c.halfPtr)
log.Trace("badtls write ", buffer.Len())
return c.writer.WriteBuffer(buffer)
}

View file

@ -1,4 +1,4 @@
//go:build !go1.19 || go1.20
//go:build !go1.19 || go1.21
package badtls

View file

@ -1,13 +0,0 @@
package badtls
import (
"context"
"crypto/tls"
"net"
)
type TLSConn interface {
net.Conn
HandshakeContext(ctx context.Context) error
ConnectionState() tls.ConnectionState
}

View file

@ -1,9 +1,8 @@
//go:build go1.19 && !go.1.20
//go:build go1.20 && !go.1.21
package badtls
import (
"crypto/tls"
"reflect"
_ "unsafe"
)
@ -16,9 +15,6 @@ const (
//go:linkname errShutdown crypto/tls.errShutdown
var errShutdown error
//go:linkname handshakeComplete crypto/tls.(*Conn).handshakeComplete
func handshakeComplete(conn *tls.Conn) bool
//go:linkname incSeq crypto/tls.(*halfConn).incSeq
func incSeq(conn uintptr)