diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index 060b8840..d34760ed 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -112,7 +112,7 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati } switch N.NetworkName(network) { case N.NetworkTCP: - tcpConn, err := gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol) + tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol) if err != nil { return nil, err } diff --git a/transport/wireguard/gonet.go b/transport/wireguard/gonet.go new file mode 100644 index 00000000..96d5b601 --- /dev/null +++ b/transport/wireguard/gonet.go @@ -0,0 +1,78 @@ +//go:build with_gvisor + +package wireguard + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + M "github.com/sagernet/sing/common/metadata" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*gonet.TCPConn, error) { + // Create TCP endpoint, then connect. + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) + if err != nil { + return nil, errors.New(err.String()) + } + + // Create wait queue entry that notifies a channel. + // + // We do this unconditionally as Connect will always return an error. + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) + wq.EventRegister(&waitEntry) + defer wq.EventUnregister(&waitEntry) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Bind before connect if requested. + if localAddr != (tcpip.FullAddress{}) { + if err = ep.Bind(localAddr); err != nil { + return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err) + } + } + + err = ep.Connect(remoteAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); ok { + select { + case <-ctx.Done(): + ep.Close() + return nil, ctx.Err() + case <-notifyCh: + } + + err = ep.LastError() + } + if err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "tcp", + Addr: M.SocksaddrFrom(M.AddrFromIP(net.IP(remoteAddr.Addr)), remoteAddr.Port).TCPAddr(), + Err: errors.New(err.String()), + } + } + + // sing-box added: set keepalive + ep.SocketOptions().SetKeepAlive(true) + keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second) + ep.SetSockOpt(&keepAliveIdle) + keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second) + ep.SetSockOpt(&keepAliveInterval) + + return gonet.NewTCPConn(&wq, ep), nil +}