mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-22 16:41:30 +00:00
Update badtls
This commit is contained in:
parent
e8dad1afeb
commit
5e1499d67b
|
@ -1,4 +1,4 @@
|
||||||
//go:build go1.19 && !go1.20
|
//go:build go1.20 && !go1.21
|
||||||
|
|
||||||
package badtls
|
package badtls
|
||||||
|
|
||||||
|
@ -14,39 +14,60 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
aTLS "github.com/sagernet/sing/common/tls"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
*tls.Conn
|
*tls.Conn
|
||||||
writer N.ExtendedWriter
|
writer N.ExtendedWriter
|
||||||
activeCall *int32
|
isHandshakeComplete *atomic.Bool
|
||||||
closeNotifySent *bool
|
activeCall *atomic.Int32
|
||||||
version *uint16
|
closeNotifySent *bool
|
||||||
rand io.Reader
|
version *uint16
|
||||||
halfAccess *sync.Mutex
|
rand io.Reader
|
||||||
halfError *error
|
halfAccess *sync.Mutex
|
||||||
cipher cipher.AEAD
|
halfError *error
|
||||||
explicitNonceLen int
|
cipher cipher.AEAD
|
||||||
halfPtr uintptr
|
explicitNonceLen int
|
||||||
halfSeq []byte
|
halfPtr uintptr
|
||||||
halfScratchBuf []byte
|
halfSeq []byte
|
||||||
|
halfScratchBuf []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(conn *tls.Conn) (TLSConn, error) {
|
func TryCreate(conn aTLS.Conn) aTLS.Conn {
|
||||||
if !handshakeComplete(conn) {
|
tlsConn, ok := conn.(*tls.Conn)
|
||||||
|
if !ok {
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
badConn, err := Create(tlsConn)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("initialize badtls: ", err)
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
return badConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func Create(conn *tls.Conn) (aTLS.Conn, error) {
|
||||||
|
rawConn := reflect.Indirect(reflect.ValueOf(conn))
|
||||||
|
rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
|
||||||
|
if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
|
||||||
|
return nil, E.New("badtls: invalid isHandshakeComplete")
|
||||||
|
}
|
||||||
|
isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
|
||||||
|
if !isHandshakeComplete.Load() {
|
||||||
return nil, E.New("handshake not finished")
|
return nil, E.New("handshake not finished")
|
||||||
}
|
}
|
||||||
rawConn := reflect.Indirect(reflect.ValueOf(conn))
|
|
||||||
rawActiveCall := rawConn.FieldByName("activeCall")
|
rawActiveCall := rawConn.FieldByName("activeCall")
|
||||||
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Int32 {
|
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
|
||||||
return nil, E.New("badtls: invalid active call")
|
return nil, E.New("badtls: invalid active call")
|
||||||
}
|
}
|
||||||
activeCall := (*int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
|
activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
|
||||||
rawHalfConn := rawConn.FieldByName("out")
|
rawHalfConn := rawConn.FieldByName("out")
|
||||||
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
||||||
return nil, E.New("badtls: invalid half conn")
|
return nil, E.New("badtls: invalid half conn")
|
||||||
|
@ -108,19 +129,20 @@ func Create(conn *tls.Conn) (TLSConn, error) {
|
||||||
}
|
}
|
||||||
halfScratchBuf := rawHalfScratchBuf.Bytes()
|
halfScratchBuf := rawHalfScratchBuf.Bytes()
|
||||||
return &Conn{
|
return &Conn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
writer: bufio.NewExtendedWriter(conn.NetConn()),
|
writer: bufio.NewExtendedWriter(conn.NetConn()),
|
||||||
activeCall: activeCall,
|
isHandshakeComplete: isHandshakeComplete,
|
||||||
closeNotifySent: closeNotifySent,
|
activeCall: activeCall,
|
||||||
version: version,
|
closeNotifySent: closeNotifySent,
|
||||||
halfAccess: halfAccess,
|
version: version,
|
||||||
halfError: halfError,
|
halfAccess: halfAccess,
|
||||||
cipher: aeadCipher,
|
halfError: halfError,
|
||||||
explicitNonceLen: explicitNonceLen,
|
cipher: aeadCipher,
|
||||||
rand: randReader,
|
explicitNonceLen: explicitNonceLen,
|
||||||
halfPtr: rawHalfConn.UnsafeAddr(),
|
rand: randReader,
|
||||||
halfSeq: halfSeq,
|
halfPtr: rawHalfConn.UnsafeAddr(),
|
||||||
halfScratchBuf: halfScratchBuf,
|
halfSeq: halfSeq,
|
||||||
|
halfScratchBuf: halfScratchBuf,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,15 +152,15 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
||||||
return common.Error(c.Write(buffer.Bytes()))
|
return common.Error(c.Write(buffer.Bytes()))
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
x := atomic.LoadInt32(c.activeCall)
|
x := c.activeCall.Load()
|
||||||
if x&1 != 0 {
|
if x&1 != 0 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
if atomic.CompareAndSwapInt32(c.activeCall, x, x+2) {
|
if c.activeCall.CompareAndSwap(x, x+2) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
defer atomic.AddInt32(c.activeCall, -2)
|
defer c.activeCall.Add(-2)
|
||||||
c.halfAccess.Lock()
|
c.halfAccess.Lock()
|
||||||
defer c.halfAccess.Unlock()
|
defer c.halfAccess.Unlock()
|
||||||
if err := *c.halfError; err != nil {
|
if err := *c.halfError; err != nil {
|
||||||
|
@ -186,6 +208,7 @@ func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
|
||||||
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
|
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
|
||||||
}
|
}
|
||||||
incSeq(c.halfPtr)
|
incSeq(c.halfPtr)
|
||||||
|
log.Trace("badtls write ", buffer.Len())
|
||||||
return c.writer.WriteBuffer(buffer)
|
return c.writer.WriteBuffer(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build !go1.19 || go1.20
|
//go:build !go1.19 || go1.21
|
||||||
|
|
||||||
package badtls
|
package badtls
|
||||||
|
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
package badtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TLSConn interface {
|
|
||||||
net.Conn
|
|
||||||
HandshakeContext(ctx context.Context) error
|
|
||||||
ConnectionState() tls.ConnectionState
|
|
||||||
}
|
|
|
@ -1,9 +1,8 @@
|
||||||
//go:build go1.19 && !go.1.20
|
//go:build go1.20 && !go.1.21
|
||||||
|
|
||||||
package badtls
|
package badtls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
_ "unsafe"
|
_ "unsafe"
|
||||||
)
|
)
|
||||||
|
@ -16,9 +15,6 @@ const (
|
||||||
//go:linkname errShutdown crypto/tls.errShutdown
|
//go:linkname errShutdown crypto/tls.errShutdown
|
||||||
var errShutdown error
|
var errShutdown error
|
||||||
|
|
||||||
//go:linkname handshakeComplete crypto/tls.(*Conn).handshakeComplete
|
|
||||||
func handshakeComplete(conn *tls.Conn) bool
|
|
||||||
|
|
||||||
//go:linkname incSeq crypto/tls.(*halfConn).incSeq
|
//go:linkname incSeq crypto/tls.(*halfConn).incSeq
|
||||||
func incSeq(conn uintptr)
|
func incSeq(conn uintptr)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue