mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-09 18:43:14 +00:00
badtls: Support uTLS and TLS ECH for read waiter
This commit is contained in:
parent
d530c724c0
commit
daee0b154e
|
@ -4,6 +4,8 @@ package badtls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -18,20 +20,32 @@ import (
|
||||||
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
|
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
|
||||||
|
|
||||||
type ReadWaitConn struct {
|
type ReadWaitConn struct {
|
||||||
*tls.STDConn
|
tls.Conn
|
||||||
halfAccess *sync.Mutex
|
halfAccess *sync.Mutex
|
||||||
rawInput *bytes.Buffer
|
rawInput *bytes.Buffer
|
||||||
input *bytes.Reader
|
input *bytes.Reader
|
||||||
hand *bytes.Buffer
|
hand *bytes.Buffer
|
||||||
readWaitOptions N.ReadWaitOptions
|
readWaitOptions N.ReadWaitOptions
|
||||||
|
tlsReadRecord func() error
|
||||||
|
tlsHandlePostHandshakeMessage func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
||||||
stdConn, isSTDConn := conn.(*tls.STDConn)
|
var (
|
||||||
if !isSTDConn {
|
loaded bool
|
||||||
|
tlsReadRecord func() error
|
||||||
|
tlsHandlePostHandshakeMessage func() error
|
||||||
|
)
|
||||||
|
for _, tlsCreator := range tlsRegistry {
|
||||||
|
loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
|
||||||
|
if loaded {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !loaded {
|
||||||
return nil, os.ErrInvalid
|
return nil, os.ErrInvalid
|
||||||
}
|
}
|
||||||
rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
|
rawConn := reflect.Indirect(reflect.ValueOf(conn))
|
||||||
rawHalfConn := rawConn.FieldByName("in")
|
rawHalfConn := rawConn.FieldByName("in")
|
||||||
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
||||||
return nil, E.New("badtls: invalid half conn")
|
return nil, E.New("badtls: invalid half conn")
|
||||||
|
@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
||||||
}
|
}
|
||||||
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
||||||
return &ReadWaitConn{
|
return &ReadWaitConn{
|
||||||
STDConn: stdConn,
|
Conn: conn,
|
||||||
halfAccess: halfAccess,
|
halfAccess: halfAccess,
|
||||||
rawInput: rawInput,
|
rawInput: rawInput,
|
||||||
input: input,
|
input: input,
|
||||||
hand: hand,
|
hand: hand,
|
||||||
|
tlsReadRecord: tlsReadRecord,
|
||||||
|
tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||||
err = c.Handshake()
|
err = c.HandshakeContext(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.halfAccess.Lock()
|
c.halfAccess.Lock()
|
||||||
defer c.halfAccess.Unlock()
|
defer c.halfAccess.Unlock()
|
||||||
for c.input.Len() == 0 {
|
for c.input.Len() == 0 {
|
||||||
err = tlsReadRecord(c.STDConn)
|
err = c.tlsReadRecord()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for c.hand.Len() > 0 {
|
for c.hand.Len() > 0 {
|
||||||
err = tlsHandlePostHandshakeMessage(c.STDConn)
|
err = c.tlsHandlePostHandshakeMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -100,7 +116,7 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||||
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
|
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
|
||||||
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
|
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
|
||||||
c.rawInput.Bytes()[0] == 21 {
|
c.rawInput.Bytes()[0] == 21 {
|
||||||
_ = tlsReadRecord(c.STDConn)
|
_ = c.tlsReadRecord()
|
||||||
// return n, err // will be io.EOF on closeNotify
|
// return n, err // will be io.EOF on closeNotify
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,11 +125,27 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ReadWaitConn) Upstream() any {
|
func (c *ReadWaitConn) Upstream() any {
|
||||||
return c.STDConn
|
return c.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
|
var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
|
||||||
func tlsReadRecord(c *tls.STDConn) error
|
|
||||||
|
|
||||||
//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
func init() {
|
||||||
func tlsHandlePostHandshakeMessage(c *tls.STDConn) error
|
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||||
|
tlsConn, loaded := conn.(*tls.STDConn)
|
||||||
|
if !loaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return true, func() error {
|
||||||
|
return stdTLSReadRecord(tlsConn)
|
||||||
|
}, func() error {
|
||||||
|
return stdTLSHandlePostHandshakeMessage(tlsConn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
|
||||||
|
func stdTLSReadRecord(c *tls.STDConn) error
|
||||||
|
|
||||||
|
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
||||||
|
func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error
|
||||||
|
|
31
common/badtls/read_wait_ech.go
Normal file
31
common/badtls/read_wait_ech.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
//go:build go1.21 && !without_badtls && with_ech
|
||||||
|
|
||||||
|
package badtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
_ "unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/cloudflare-tls"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||||
|
tlsConn, loaded := common.Cast[*tls.Conn](conn)
|
||||||
|
if !loaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return true, func() error {
|
||||||
|
return echReadRecord(tlsConn)
|
||||||
|
}, func() error {
|
||||||
|
return echHandlePostHandshakeMessage(tlsConn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord
|
||||||
|
func echReadRecord(c *tls.Conn) error
|
||||||
|
|
||||||
|
//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage
|
||||||
|
func echHandlePostHandshakeMessage(c *tls.Conn) error
|
31
common/badtls/read_wait_utls.go
Normal file
31
common/badtls/read_wait_utls.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
//go:build go1.21 && !without_badtls && with_utls
|
||||||
|
|
||||||
|
package badtls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
_ "unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/utls"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||||
|
tlsConn, loaded := common.Cast[*tls.UConn](conn)
|
||||||
|
if !loaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return true, func() error {
|
||||||
|
return utlsReadRecord(tlsConn.Conn)
|
||||||
|
}, func() error {
|
||||||
|
return utlsHandlePostHandshakeMessage(tlsConn.Conn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord
|
||||||
|
func utlsReadRecord(c *tls.Conn) error
|
||||||
|
|
||||||
|
//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage
|
||||||
|
func utlsHandlePostHandshakeMessage(c *tls.Conn) error
|
Loading…
Reference in a new issue