mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-22 16:41:30 +00:00
319 lines
9.6 KiB
Go
319 lines
9.6 KiB
Go
package congestion
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/sagernet/quic-go/congestion"
|
|
"github.com/sagernet/quic-go/logging"
|
|
)
|
|
|
|
const (
|
|
maxBurstPackets = 3
|
|
renoBeta = 0.7 // Reno backoff factor.
|
|
minCongestionWindowPackets = 2
|
|
initialCongestionWindow = 32
|
|
)
|
|
|
|
const (
|
|
InvalidPacketNumber congestion.PacketNumber = -1
|
|
MaxCongestionWindowPackets = 20000
|
|
MaxByteCount = congestion.ByteCount(1<<62 - 1)
|
|
)
|
|
|
|
type cubicSender struct {
|
|
hybridSlowStart HybridSlowStart
|
|
rttStats congestion.RTTStatsProvider
|
|
cubic *Cubic
|
|
pacer *pacer
|
|
clock Clock
|
|
|
|
reno bool
|
|
|
|
// Track the largest packet that has been sent.
|
|
largestSentPacketNumber congestion.PacketNumber
|
|
|
|
// Track the largest packet that has been acked.
|
|
largestAckedPacketNumber congestion.PacketNumber
|
|
|
|
// Track the largest packet number outstanding when a CWND cutback occurs.
|
|
largestSentAtLastCutback congestion.PacketNumber
|
|
|
|
// Whether the last loss event caused us to exit slowstart.
|
|
// Used for stats collection of slowstartPacketsLost
|
|
lastCutbackExitedSlowstart bool
|
|
|
|
// Congestion window in bytes.
|
|
congestionWindow congestion.ByteCount
|
|
|
|
// Slow start congestion window in bytes, aka ssthresh.
|
|
slowStartThreshold congestion.ByteCount
|
|
|
|
// ACK counter for the Reno implementation.
|
|
numAckedPackets uint64
|
|
|
|
initialCongestionWindow congestion.ByteCount
|
|
initialMaxCongestionWindow congestion.ByteCount
|
|
|
|
maxDatagramSize congestion.ByteCount
|
|
|
|
lastState logging.CongestionState
|
|
tracer logging.ConnectionTracer
|
|
}
|
|
|
|
var _ congestion.CongestionControl = &cubicSender{}
|
|
|
|
// NewCubicSender makes a new cubic sender
|
|
func NewCubicSender(
|
|
clock Clock,
|
|
initialMaxDatagramSize congestion.ByteCount,
|
|
reno bool,
|
|
tracer logging.ConnectionTracer,
|
|
) *cubicSender {
|
|
return newCubicSender(
|
|
clock,
|
|
reno,
|
|
initialMaxDatagramSize,
|
|
initialCongestionWindow*initialMaxDatagramSize,
|
|
MaxCongestionWindowPackets*initialMaxDatagramSize,
|
|
tracer,
|
|
)
|
|
}
|
|
|
|
func newCubicSender(
|
|
clock Clock,
|
|
reno bool,
|
|
initialMaxDatagramSize,
|
|
initialCongestionWindow,
|
|
initialMaxCongestionWindow congestion.ByteCount,
|
|
tracer logging.ConnectionTracer,
|
|
) *cubicSender {
|
|
c := &cubicSender{
|
|
largestSentPacketNumber: InvalidPacketNumber,
|
|
largestAckedPacketNumber: InvalidPacketNumber,
|
|
largestSentAtLastCutback: InvalidPacketNumber,
|
|
initialCongestionWindow: initialCongestionWindow,
|
|
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
|
congestionWindow: initialCongestionWindow,
|
|
slowStartThreshold: MaxByteCount,
|
|
cubic: NewCubic(clock),
|
|
clock: clock,
|
|
reno: reno,
|
|
tracer: tracer,
|
|
maxDatagramSize: initialMaxDatagramSize,
|
|
}
|
|
c.pacer = newPacer(c.BandwidthEstimate)
|
|
if c.tracer != nil {
|
|
c.lastState = logging.CongestionStateSlowStart
|
|
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) {
|
|
c.rttStats = provider
|
|
}
|
|
|
|
// TimeUntilSend returns when the next packet should be sent.
|
|
func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time {
|
|
return c.pacer.TimeUntilSend()
|
|
}
|
|
|
|
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
|
|
return c.pacer.Budget(now) >= c.maxDatagramSize
|
|
}
|
|
|
|
func (c *cubicSender) maxCongestionWindow() congestion.ByteCount {
|
|
return c.maxDatagramSize * MaxCongestionWindowPackets
|
|
}
|
|
|
|
func (c *cubicSender) minCongestionWindow() congestion.ByteCount {
|
|
return c.maxDatagramSize * minCongestionWindowPackets
|
|
}
|
|
|
|
func (c *cubicSender) OnPacketSent(
|
|
sentTime time.Time,
|
|
_ congestion.ByteCount,
|
|
packetNumber congestion.PacketNumber,
|
|
bytes congestion.ByteCount,
|
|
isRetransmittable bool,
|
|
) {
|
|
c.pacer.SentPacket(sentTime, bytes)
|
|
if !isRetransmittable {
|
|
return
|
|
}
|
|
c.largestSentPacketNumber = packetNumber
|
|
c.hybridSlowStart.OnPacketSent(packetNumber)
|
|
}
|
|
|
|
func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
|
return bytesInFlight < c.GetCongestionWindow()
|
|
}
|
|
|
|
func (c *cubicSender) InRecovery() bool {
|
|
return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
|
|
}
|
|
|
|
func (c *cubicSender) InSlowStart() bool {
|
|
return c.GetCongestionWindow() < c.slowStartThreshold
|
|
}
|
|
|
|
func (c *cubicSender) GetCongestionWindow() congestion.ByteCount {
|
|
return c.congestionWindow
|
|
}
|
|
|
|
func (c *cubicSender) MaybeExitSlowStart() {
|
|
if c.InSlowStart() &&
|
|
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
|
|
// exit slow start
|
|
c.slowStartThreshold = c.congestionWindow
|
|
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
|
}
|
|
}
|
|
|
|
func (c *cubicSender) OnPacketAcked(
|
|
ackedPacketNumber congestion.PacketNumber,
|
|
ackedBytes congestion.ByteCount,
|
|
priorInFlight congestion.ByteCount,
|
|
eventTime time.Time,
|
|
) {
|
|
c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber)
|
|
if c.InRecovery() {
|
|
return
|
|
}
|
|
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
|
if c.InSlowStart() {
|
|
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
|
}
|
|
}
|
|
|
|
func (c *cubicSender) OnPacketLost(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) {
|
|
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
|
// already sent should be treated as a single loss event, since it's expected.
|
|
if packetNumber <= c.largestSentAtLastCutback {
|
|
return
|
|
}
|
|
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
|
c.maybeTraceStateChange(logging.CongestionStateRecovery)
|
|
|
|
if c.reno {
|
|
c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta)
|
|
} else {
|
|
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
|
}
|
|
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
|
|
c.congestionWindow = minCwnd
|
|
}
|
|
c.slowStartThreshold = c.congestionWindow
|
|
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
|
// reset packet count from congestion avoidance mode. We start
|
|
// counting again when we're out of recovery.
|
|
c.numAckedPackets = 0
|
|
}
|
|
|
|
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
|
// represents, but quic has a separate ack for each packet.
|
|
func (c *cubicSender) maybeIncreaseCwnd(
|
|
_ congestion.PacketNumber,
|
|
ackedBytes congestion.ByteCount,
|
|
priorInFlight congestion.ByteCount,
|
|
eventTime time.Time,
|
|
) {
|
|
// Do not increase the congestion window unless the sender is close to using
|
|
// the current window.
|
|
if !c.isCwndLimited(priorInFlight) {
|
|
c.cubic.OnApplicationLimited()
|
|
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
|
|
return
|
|
}
|
|
if c.congestionWindow >= c.maxCongestionWindow() {
|
|
return
|
|
}
|
|
if c.InSlowStart() {
|
|
// TCP slow start, exponential growth, increase by one for each ACK.
|
|
c.congestionWindow += c.maxDatagramSize
|
|
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
|
|
return
|
|
}
|
|
// Congestion avoidance
|
|
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
|
if c.reno {
|
|
// Classic Reno congestion avoidance.
|
|
c.numAckedPackets++
|
|
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
|
|
c.congestionWindow += c.maxDatagramSize
|
|
c.numAckedPackets = 0
|
|
}
|
|
} else {
|
|
c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
|
}
|
|
}
|
|
|
|
func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool {
|
|
congestionWindow := c.GetCongestionWindow()
|
|
if bytesInFlight >= congestionWindow {
|
|
return true
|
|
}
|
|
availableBytes := congestionWindow - bytesInFlight
|
|
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
|
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
|
|
}
|
|
|
|
// BandwidthEstimate returns the current bandwidth estimate
|
|
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
|
if c.rttStats == nil {
|
|
return infBandwidth
|
|
}
|
|
srtt := c.rttStats.SmoothedRTT()
|
|
if srtt == 0 {
|
|
// If we haven't measured an rtt, the bandwidth estimate is unknown.
|
|
return infBandwidth
|
|
}
|
|
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
|
}
|
|
|
|
// OnRetransmissionTimeout is called on an retransmission timeout
|
|
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
|
c.largestSentAtLastCutback = InvalidPacketNumber
|
|
if !packetsRetransmitted {
|
|
return
|
|
}
|
|
c.hybridSlowStart.Restart()
|
|
c.cubic.Reset()
|
|
c.slowStartThreshold = c.congestionWindow / 2
|
|
c.congestionWindow = c.minCongestionWindow()
|
|
}
|
|
|
|
// OnConnectionMigration is called when the connection is migrated (?)
|
|
func (c *cubicSender) OnConnectionMigration() {
|
|
c.hybridSlowStart.Restart()
|
|
c.largestSentPacketNumber = InvalidPacketNumber
|
|
c.largestAckedPacketNumber = InvalidPacketNumber
|
|
c.largestSentAtLastCutback = InvalidPacketNumber
|
|
c.lastCutbackExitedSlowstart = false
|
|
c.cubic.Reset()
|
|
c.numAckedPackets = 0
|
|
c.congestionWindow = c.initialCongestionWindow
|
|
c.slowStartThreshold = c.initialMaxCongestionWindow
|
|
}
|
|
|
|
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
|
|
if c.tracer == nil || new == c.lastState {
|
|
return
|
|
}
|
|
c.tracer.UpdatedCongestionState(new)
|
|
c.lastState = new
|
|
}
|
|
|
|
func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) {
|
|
if s < c.maxDatagramSize {
|
|
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
|
|
}
|
|
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
|
|
c.maxDatagramSize = s
|
|
if cwndIsMinCwnd {
|
|
c.congestionWindow = c.minCongestionWindow()
|
|
}
|
|
c.pacer.SetMaxDatagramSize(s)
|
|
}
|