Fix shadowtls server detection

This commit is contained in:
世界 2022-11-22 22:11:09 +08:00
parent ffd54eef6c
commit a401828ed5
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 30 additions and 8 deletions

View file

@ -91,7 +91,7 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a
hashConn := shadowtls.NewHashWriteConn(conn, s.password) hashConn := shadowtls.NewHashWriteConn(conn, s.password)
go bufio.Copy(hashConn, handshakeConn) go bufio.Copy(hashConn, handshakeConn)
var request *buf.Buffer var request *buf.Buffer
request, err = s.copyUntilHandshakeFinishedV2(handshakeConn, conn, hashConn, s.fallbackAfter) request, err = s.copyUntilHandshakeFinishedV2(ctx, handshakeConn, conn, hashConn, s.fallbackAfter)
if err == nil { if err == nil {
handshakeConn.Close() handshakeConn.Close()
return s.newConnection(ctx, bufio.NewCachedConn(shadowtls.NewConn(conn), request), metadata) return s.newConnection(ctx, bufio.NewCachedConn(shadowtls.NewConn(conn), request), metadata)
@ -135,7 +135,7 @@ func (s *ShadowTLS) copyUntilHandshakeFinished(dst io.Writer, src io.Reader) err
} }
} }
func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) { func (s *ShadowTLS) copyUntilHandshakeFinishedV2(ctx context.Context, dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) {
const applicationData = 0x17 const applicationData = 0x17
var tlsHdr [5]byte var tlsHdr [5]byte
var applicationDataCount int var applicationDataCount int
@ -152,9 +152,17 @@ func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, ha
data.Release() data.Release()
return nil, err return nil, err
} }
if length >= 8 && bytes.Equal(data.To(8), hash.Sum()) { if hash.HasContent() && length >= 8 {
data.Advance(8) checksum := hash.Sum()
return data, nil if bytes.Equal(data.To(8), checksum) {
s.logger.TraceContext(ctx, "match current hashcode")
data.Advance(8)
return data, nil
} else if hash.LastSum() != nil && bytes.Equal(data.To(8), hash.LastSum()) {
s.logger.TraceContext(ctx, "match last hashcode")
data.Advance(8)
return data, nil
}
} }
_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data)) _, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data))
data.Release() data.Release()

View file

@ -34,19 +34,25 @@ func (c *HashReadConn) Sum() []byte {
type HashWriteConn struct { type HashWriteConn struct {
net.Conn net.Conn
hmac hash.Hash hmac hash.Hash
hasContent bool
lastSum []byte
} }
func NewHashWriteConn(conn net.Conn, password string) *HashWriteConn { func NewHashWriteConn(conn net.Conn, password string) *HashWriteConn {
return &HashWriteConn{ return &HashWriteConn{
conn, Conn: conn,
hmac.New(sha1.New, []byte(password)), hmac: hmac.New(sha1.New, []byte(password)),
} }
} }
func (c *HashWriteConn) Write(p []byte) (n int, err error) { func (c *HashWriteConn) Write(p []byte) (n int, err error) {
if c.hmac != nil { if c.hmac != nil {
if c.hasContent {
c.lastSum = c.Sum()
}
c.hmac.Write(p) c.hmac.Write(p)
c.hasContent = true
} }
return c.Conn.Write(p) return c.Conn.Write(p)
} }
@ -55,6 +61,14 @@ func (c *HashWriteConn) Sum() []byte {
return c.hmac.Sum(nil)[:8] return c.hmac.Sum(nil)[:8]
} }
func (c *HashWriteConn) LastSum() []byte {
return c.lastSum
}
func (c *HashWriteConn) Fallback() { func (c *HashWriteConn) Fallback() {
c.hmac = nil c.hmac = nil
} }
func (c *HashWriteConn) HasContent() bool {
return c.hasContent
}