diff --git a/common/dialer/tfo.go b/common/dialer/tfo.go index 2e3eb9b3..0b0c9fcc 100644 --- a/common/dialer/tfo.go +++ b/common/dialer/tfo.go @@ -7,6 +7,7 @@ import ( "io" "net" "os" + "sync" "time" "github.com/sagernet/sing/common" @@ -24,6 +25,7 @@ type slowOpenConn struct { destination M.Socksaddr conn net.Conn create chan struct{} + access sync.Mutex err error } @@ -60,16 +62,26 @@ func (c *slowOpenConn) Read(b []byte) (n int, err error) { } func (c *slowOpenConn) Write(b []byte) (n int, err error) { - if c.conn == nil { - c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) - if err != nil { - c.conn = nil - c.err = E.Cause(err, "dial tcp fast open") - } - close(c.create) - return + if c.conn != nil { + return c.conn.Write(b) } - return c.conn.Write(b) + c.access.Lock() + defer c.access.Unlock() + select { + case <-c.create: + if c.err != nil { + return 0, c.err + } + return c.conn.Write(b) + default: + } + c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) + if err != nil { + c.conn = nil + c.err = E.Cause(err, "dial tcp fast open") + } + close(c.create) + return } func (c *slowOpenConn) Close() error {