badtls: Support uTLS and TLS ECH for read waiter

This commit is contained in:
世界 2023-12-27 18:05:52 +08:00
parent d530c724c0
commit daee0b154e
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 117 additions and 23 deletions

View file

@ -4,6 +4,8 @@ package badtls
import ( import (
"bytes" "bytes"
"context"
"net"
"os" "os"
"reflect" "reflect"
"sync" "sync"
@ -18,20 +20,32 @@ import (
var _ N.ReadWaiter = (*ReadWaitConn)(nil) var _ N.ReadWaiter = (*ReadWaitConn)(nil)
type ReadWaitConn struct { type ReadWaitConn struct {
*tls.STDConn tls.Conn
halfAccess *sync.Mutex halfAccess *sync.Mutex
rawInput *bytes.Buffer rawInput *bytes.Buffer
input *bytes.Reader input *bytes.Reader
hand *bytes.Buffer hand *bytes.Buffer
readWaitOptions N.ReadWaitOptions readWaitOptions N.ReadWaitOptions
tlsReadRecord func() error
tlsHandlePostHandshakeMessage func() error
} }
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
stdConn, isSTDConn := conn.(*tls.STDConn) var (
if !isSTDConn { loaded bool
tlsReadRecord func() error
tlsHandlePostHandshakeMessage func() error
)
for _, tlsCreator := range tlsRegistry {
loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
if loaded {
break
}
}
if !loaded {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
rawConn := reflect.Indirect(reflect.ValueOf(stdConn)) rawConn := reflect.Indirect(reflect.ValueOf(conn))
rawHalfConn := rawConn.FieldByName("in") rawHalfConn := rawConn.FieldByName("in")
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half conn") return nil, E.New("badtls: invalid half conn")
@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
} }
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr())) hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
return &ReadWaitConn{ return &ReadWaitConn{
STDConn: stdConn, Conn: conn,
halfAccess: halfAccess, halfAccess: halfAccess,
rawInput: rawInput, rawInput: rawInput,
input: input, input: input,
hand: hand, hand: hand,
tlsReadRecord: tlsReadRecord,
tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
}, nil }, nil
} }
@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
} }
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
err = c.Handshake() err = c.HandshakeContext(context.Background())
if err != nil { if err != nil {
return return
} }
c.halfAccess.Lock() c.halfAccess.Lock()
defer c.halfAccess.Unlock() defer c.halfAccess.Unlock()
for c.input.Len() == 0 { for c.input.Len() == 0 {
err = tlsReadRecord(c.STDConn) err = c.tlsReadRecord()
if err != nil { if err != nil {
return return
} }
for c.hand.Len() > 0 { for c.hand.Len() > 0 {
err = tlsHandlePostHandshakeMessage(c.STDConn) err = c.tlsHandlePostHandshakeMessage()
if err != nil { if err != nil {
return return
} }
@ -100,7 +116,7 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { // recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
c.rawInput.Bytes()[0] == 21 { c.rawInput.Bytes()[0] == 21 {
_ = tlsReadRecord(c.STDConn) _ = c.tlsReadRecord()
// return n, err // will be io.EOF on closeNotify // return n, err // will be io.EOF on closeNotify
} }
@ -109,11 +125,27 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
} }
func (c *ReadWaitConn) Upstream() any { func (c *ReadWaitConn) Upstream() any {
return c.STDConn return c.Conn
} }
//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
func tlsReadRecord(c *tls.STDConn) error
//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage func init() {
func tlsHandlePostHandshakeMessage(c *tls.STDConn) error tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
tlsConn, loaded := conn.(*tls.STDConn)
if !loaded {
return
}
return true, func() error {
return stdTLSReadRecord(tlsConn)
}, func() error {
return stdTLSHandlePostHandshakeMessage(tlsConn)
}
})
}
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
func stdTLSReadRecord(c *tls.STDConn) error
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error

View file

@ -0,0 +1,31 @@
//go:build go1.21 && !without_badtls && with_ech
package badtls
import (
"net"
_ "unsafe"
"github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing/common"
)
func init() {
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
tlsConn, loaded := common.Cast[*tls.Conn](conn)
if !loaded {
return
}
return true, func() error {
return echReadRecord(tlsConn)
}, func() error {
return echHandlePostHandshakeMessage(tlsConn)
}
})
}
//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord
func echReadRecord(c *tls.Conn) error
//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage
func echHandlePostHandshakeMessage(c *tls.Conn) error

View file

@ -0,0 +1,31 @@
//go:build go1.21 && !without_badtls && with_utls
package badtls
import (
"net"
_ "unsafe"
"github.com/sagernet/sing/common"
"github.com/sagernet/utls"
)
func init() {
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
tlsConn, loaded := common.Cast[*tls.UConn](conn)
if !loaded {
return
}
return true, func() error {
return utlsReadRecord(tlsConn.Conn)
}, func() error {
return utlsHandlePostHandshakeMessage(tlsConn.Conn)
}
})
}
//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord
func utlsReadRecord(c *tls.Conn) error
//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage
func utlsHandlePostHandshakeMessage(c *tls.Conn) error