From fe8d46cce527964b397352ef267f1bbb0d5d8be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 9 Sep 2023 19:51:10 +0800 Subject: [PATCH] Fix TFO async write --- common/dialer/tfo.go | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/common/dialer/tfo.go b/common/dialer/tfo.go index 2e3eb9b3..0b0c9fcc 100644 --- a/common/dialer/tfo.go +++ b/common/dialer/tfo.go @@ -7,6 +7,7 @@ import ( "io" "net" "os" + "sync" "time" "github.com/sagernet/sing/common" @@ -24,6 +25,7 @@ type slowOpenConn struct { destination M.Socksaddr conn net.Conn create chan struct{} + access sync.Mutex err error } @@ -60,16 +62,26 @@ func (c *slowOpenConn) Read(b []byte) (n int, err error) { } func (c *slowOpenConn) Write(b []byte) (n int, err error) { - if c.conn == nil { - c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) - if err != nil { - c.conn = nil - c.err = E.Cause(err, "dial tcp fast open") - } - close(c.create) - return + if c.conn != nil { + return c.conn.Write(b) } - return c.conn.Write(b) + c.access.Lock() + defer c.access.Unlock() + select { + case <-c.create: + if c.err != nil { + return 0, c.err + } + return c.conn.Write(b) + default: + } + c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) + if err != nil { + c.conn = nil + c.err = E.Cause(err, "dial tcp fast open") + } + close(c.create) + return } func (c *slowOpenConn) Close() error {