diff --git a/common/badtls/read_wait.go b/common/badtls/read_wait.go index 4657bc5b..1a0998bf 100644 --- a/common/badtls/read_wait.go +++ b/common/badtls/read_wait.go @@ -4,6 +4,8 @@ package badtls import ( "bytes" + "context" + "net" "os" "reflect" "sync" @@ -18,20 +20,32 @@ import ( var _ N.ReadWaiter = (*ReadWaitConn)(nil) type ReadWaitConn struct { - *tls.STDConn - halfAccess *sync.Mutex - rawInput *bytes.Buffer - input *bytes.Reader - hand *bytes.Buffer - readWaitOptions N.ReadWaitOptions + tls.Conn + halfAccess *sync.Mutex + rawInput *bytes.Buffer + input *bytes.Reader + hand *bytes.Buffer + readWaitOptions N.ReadWaitOptions + tlsReadRecord func() error + tlsHandlePostHandshakeMessage func() error } func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { - stdConn, isSTDConn := conn.(*tls.STDConn) - if !isSTDConn { + var ( + loaded bool + tlsReadRecord func() error + tlsHandlePostHandshakeMessage func() error + ) + for _, tlsCreator := range tlsRegistry { + loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn) + if loaded { + break + } + } + if !loaded { return nil, os.ErrInvalid } - rawConn := reflect.Indirect(reflect.ValueOf(stdConn)) + rawConn := reflect.Indirect(reflect.ValueOf(conn)) rawHalfConn := rawConn.FieldByName("in") if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { return nil, E.New("badtls: invalid half conn") @@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { } hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr())) return &ReadWaitConn{ - STDConn: stdConn, - halfAccess: halfAccess, - rawInput: rawInput, - input: input, - hand: hand, + Conn: conn, + halfAccess: halfAccess, + rawInput: rawInput, + input: input, + hand: hand, + tlsReadRecord: tlsReadRecord, + tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage, }, nil } @@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy } func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { - err = c.Handshake() + err = c.HandshakeContext(context.Background()) if err != nil { return } c.halfAccess.Lock() defer c.halfAccess.Unlock() for c.input.Len() == 0 { - err = tlsReadRecord(c.STDConn) + err = c.tlsReadRecord() if err != nil { return } for c.hand.Len() > 0 { - err = tlsHandlePostHandshakeMessage(c.STDConn) + err = c.tlsHandlePostHandshakeMessage() if err != nil { return } @@ -100,7 +116,7 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && // recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { c.rawInput.Bytes()[0] == 21 { - _ = tlsReadRecord(c.STDConn) + _ = c.tlsReadRecord() // return n, err // will be io.EOF on closeNotify } @@ -108,8 +124,24 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { return } -//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord -func tlsReadRecord(c *tls.STDConn) error +var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) -//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage -func tlsHandlePostHandshakeMessage(c *tls.STDConn) error +func init() { + tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) { + tlsConn, loaded := conn.(*tls.STDConn) + if !loaded { + return + } + return true, func() error { + return stdTLSReadRecord(tlsConn) + }, func() error { + return stdTLSHandlePostHandshakeMessage(tlsConn) + } + }) +} + +//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord +func stdTLSReadRecord(c *tls.STDConn) error + +//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage +func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error diff --git a/common/badtls/read_wait_ech.go b/common/badtls/read_wait_ech.go new file mode 100644 index 00000000..6a0d5b5f --- /dev/null +++ b/common/badtls/read_wait_ech.go @@ -0,0 +1,31 @@ +//go:build go1.21 && !without_badtls && with_ech + +package badtls + +import ( + "net" + _ "unsafe" + + "github.com/sagernet/cloudflare-tls" + "github.com/sagernet/sing/common" +) + +func init() { + tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) { + tlsConn, loaded := common.Cast[*tls.Conn](conn) + if !loaded { + return + } + return true, func() error { + return echReadRecord(tlsConn) + }, func() error { + return echHandlePostHandshakeMessage(tlsConn) + } + }) +} + +//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord +func echReadRecord(c *tls.Conn) error + +//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage +func echHandlePostHandshakeMessage(c *tls.Conn) error diff --git a/common/badtls/read_wait_utls.go b/common/badtls/read_wait_utls.go new file mode 100644 index 00000000..ebdb2251 --- /dev/null +++ b/common/badtls/read_wait_utls.go @@ -0,0 +1,31 @@ +//go:build go1.21 && !without_badtls && with_utls + +package badtls + +import ( + "net" + _ "unsafe" + + "github.com/sagernet/sing/common" + "github.com/sagernet/utls" +) + +func init() { + tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) { + tlsConn, loaded := common.Cast[*tls.UConn](conn) + if !loaded { + return + } + return true, func() error { + return utlsReadRecord(tlsConn.Conn) + }, func() error { + return utlsHandlePostHandshakeMessage(tlsConn.Conn) + } + }) +} + +//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord +func utlsReadRecord(c *tls.Conn) error + +//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage +func utlsHandlePostHandshakeMessage(c *tls.Conn) error