mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-23 17:11:29 +00:00
badtls: Support uTLS and TLS ECH for read waiter
This commit is contained in:
parent
11bec79a06
commit
1ce8ae6cd5
|
@ -4,6 +4,8 @@ package badtls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
@ -18,20 +20,32 @@ import (
|
|||
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
|
||||
|
||||
type ReadWaitConn struct {
|
||||
*tls.STDConn
|
||||
halfAccess *sync.Mutex
|
||||
rawInput *bytes.Buffer
|
||||
input *bytes.Reader
|
||||
hand *bytes.Buffer
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
tls.Conn
|
||||
halfAccess *sync.Mutex
|
||||
rawInput *bytes.Buffer
|
||||
input *bytes.Reader
|
||||
hand *bytes.Buffer
|
||||
readWaitOptions N.ReadWaitOptions
|
||||
tlsReadRecord func() error
|
||||
tlsHandlePostHandshakeMessage func() error
|
||||
}
|
||||
|
||||
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
||||
stdConn, isSTDConn := conn.(*tls.STDConn)
|
||||
if !isSTDConn {
|
||||
var (
|
||||
loaded bool
|
||||
tlsReadRecord func() error
|
||||
tlsHandlePostHandshakeMessage func() error
|
||||
)
|
||||
for _, tlsCreator := range tlsRegistry {
|
||||
loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
|
||||
if loaded {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !loaded {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
|
||||
rawConn := reflect.Indirect(reflect.ValueOf(conn))
|
||||
rawHalfConn := rawConn.FieldByName("in")
|
||||
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
||||
return nil, E.New("badtls: invalid half conn")
|
||||
|
@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
|||
}
|
||||
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
||||
return &ReadWaitConn{
|
||||
STDConn: stdConn,
|
||||
halfAccess: halfAccess,
|
||||
rawInput: rawInput,
|
||||
input: input,
|
||||
hand: hand,
|
||||
Conn: conn,
|
||||
halfAccess: halfAccess,
|
||||
rawInput: rawInput,
|
||||
input: input,
|
||||
hand: hand,
|
||||
tlsReadRecord: tlsReadRecord,
|
||||
tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
|
|||
}
|
||||
|
||||
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||||
err = c.Handshake()
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.halfAccess.Lock()
|
||||
defer c.halfAccess.Unlock()
|
||||
for c.input.Len() == 0 {
|
||||
err = tlsReadRecord(c.STDConn)
|
||||
err = c.tlsReadRecord()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for c.hand.Len() > 0 {
|
||||
err = tlsHandlePostHandshakeMessage(c.STDConn)
|
||||
err = c.tlsHandlePostHandshakeMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -100,7 +116,7 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
|||
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
|
||||
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
|
||||
c.rawInput.Bytes()[0] == 21 {
|
||||
_ = tlsReadRecord(c.STDConn)
|
||||
_ = c.tlsReadRecord()
|
||||
// return n, err // will be io.EOF on closeNotify
|
||||
}
|
||||
|
||||
|
@ -108,8 +124,24 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
|
||||
func tlsReadRecord(c *tls.STDConn) error
|
||||
var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
|
||||
|
||||
//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
||||
func tlsHandlePostHandshakeMessage(c *tls.STDConn) error
|
||||
func init() {
|
||||
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||
tlsConn, loaded := conn.(*tls.STDConn)
|
||||
if !loaded {
|
||||
return
|
||||
}
|
||||
return true, func() error {
|
||||
return stdTLSReadRecord(tlsConn)
|
||||
}, func() error {
|
||||
return stdTLSHandlePostHandshakeMessage(tlsConn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
|
||||
func stdTLSReadRecord(c *tls.STDConn) error
|
||||
|
||||
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
||||
func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error
|
||||
|
|
31
common/badtls/read_wait_ech.go
Normal file
31
common/badtls/read_wait_ech.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
//go:build go1.21 && !without_badtls && with_ech
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"net"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/cloudflare-tls"
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
func init() {
|
||||
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||
tlsConn, loaded := common.Cast[*tls.Conn](conn)
|
||||
if !loaded {
|
||||
return
|
||||
}
|
||||
return true, func() error {
|
||||
return echReadRecord(tlsConn)
|
||||
}, func() error {
|
||||
return echHandlePostHandshakeMessage(tlsConn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord
|
||||
func echReadRecord(c *tls.Conn) error
|
||||
|
||||
//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage
|
||||
func echHandlePostHandshakeMessage(c *tls.Conn) error
|
31
common/badtls/read_wait_utls.go
Normal file
31
common/badtls/read_wait_utls.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
//go:build go1.21 && !without_badtls && with_utls
|
||||
|
||||
package badtls
|
||||
|
||||
import (
|
||||
"net"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/utls"
|
||||
)
|
||||
|
||||
func init() {
|
||||
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
|
||||
tlsConn, loaded := common.Cast[*tls.UConn](conn)
|
||||
if !loaded {
|
||||
return
|
||||
}
|
||||
return true, func() error {
|
||||
return utlsReadRecord(tlsConn.Conn)
|
||||
}, func() error {
|
||||
return utlsHandlePostHandshakeMessage(tlsConn.Conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord
|
||||
func utlsReadRecord(c *tls.Conn) error
|
||||
|
||||
//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage
|
||||
func utlsHandlePostHandshakeMessage(c *tls.Conn) error
|
Loading…
Reference in a new issue