mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-24 09:31:30 +00:00
116 lines
3.1 KiB
Go
116 lines
3.1 KiB
Go
|
//go:build go1.21 && !without_badtls
|
||
|
|
||
|
package badtls
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"os"
|
||
|
"reflect"
|
||
|
"sync"
|
||
|
"unsafe"
|
||
|
|
||
|
"github.com/sagernet/sing/common/buf"
|
||
|
E "github.com/sagernet/sing/common/exceptions"
|
||
|
N "github.com/sagernet/sing/common/network"
|
||
|
"github.com/sagernet/sing/common/tls"
|
||
|
)
|
||
|
|
||
|
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
|
||
|
|
||
|
type ReadWaitConn struct {
|
||
|
*tls.STDConn
|
||
|
halfAccess *sync.Mutex
|
||
|
rawInput *bytes.Buffer
|
||
|
input *bytes.Reader
|
||
|
hand *bytes.Buffer
|
||
|
readWaitOptions N.ReadWaitOptions
|
||
|
}
|
||
|
|
||
|
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
|
||
|
stdConn, isSTDConn := conn.(*tls.STDConn)
|
||
|
if !isSTDConn {
|
||
|
return nil, os.ErrInvalid
|
||
|
}
|
||
|
rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
|
||
|
rawHalfConn := rawConn.FieldByName("in")
|
||
|
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
|
||
|
return nil, E.New("badtls: invalid half conn")
|
||
|
}
|
||
|
rawHalfMutex := rawHalfConn.FieldByName("Mutex")
|
||
|
if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
|
||
|
return nil, E.New("badtls: invalid half mutex")
|
||
|
}
|
||
|
halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
|
||
|
rawRawInput := rawConn.FieldByName("rawInput")
|
||
|
if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
|
||
|
return nil, E.New("badtls: invalid raw input")
|
||
|
}
|
||
|
rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
|
||
|
rawInput0 := rawConn.FieldByName("input")
|
||
|
if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct {
|
||
|
return nil, E.New("badtls: invalid input")
|
||
|
}
|
||
|
input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr()))
|
||
|
rawHand := rawConn.FieldByName("hand")
|
||
|
if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
|
||
|
return nil, E.New("badtls: invalid hand")
|
||
|
}
|
||
|
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
|
||
|
return &ReadWaitConn{
|
||
|
STDConn: stdConn,
|
||
|
halfAccess: halfAccess,
|
||
|
rawInput: rawInput,
|
||
|
input: input,
|
||
|
hand: hand,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
||
|
c.readWaitOptions = options
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
|
||
|
err = c.Handshake()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
c.halfAccess.Lock()
|
||
|
defer c.halfAccess.Unlock()
|
||
|
for c.input.Len() == 0 {
|
||
|
err = tlsReadRecord(c.STDConn)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
for c.hand.Len() > 0 {
|
||
|
err = tlsHandlePostHandshakeMessage(c.STDConn)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
buffer = c.readWaitOptions.NewBuffer()
|
||
|
n, err := c.input.Read(buffer.FreeBytes())
|
||
|
if err != nil {
|
||
|
buffer.Release()
|
||
|
return
|
||
|
}
|
||
|
buffer.Truncate(n)
|
||
|
|
||
|
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
|
||
|
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
|
||
|
c.rawInput.Bytes()[0] == 21 {
|
||
|
_ = tlsReadRecord(c.STDConn)
|
||
|
// return n, err // will be io.EOF on closeNotify
|
||
|
}
|
||
|
|
||
|
c.readWaitOptions.PostReturn(buffer)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
|
||
|
func tlsReadRecord(c *tls.STDConn) error
|
||
|
|
||
|
//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
|
||
|
func tlsHandlePostHandshakeMessage(c *tls.STDConn) error
|