sing-box/experimental/ssmapi/traffic.go
2023-02-18 19:25:02 +08:00

228 lines
7 KiB
Go

package ssmapi
import (
"net"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/experimental/trackerconn"
N "github.com/sagernet/sing/common/network"
"go.uber.org/atomic"
)
type TrafficManager struct {
nodeTags map[string]bool
nodeUsers map[string]bool
globalUplink *atomic.Int64
globalDownlink *atomic.Int64
globalUplinkPackets *atomic.Int64
globalDownlinkPackets *atomic.Int64
globalTCPSessions *atomic.Int64
globalUDPSessions *atomic.Int64
userAccess sync.Mutex
userUplink map[string]*atomic.Int64
userDownlink map[string]*atomic.Int64
userUplinkPackets map[string]*atomic.Int64
userDownlinkPackets map[string]*atomic.Int64
userTCPSessions map[string]*atomic.Int64
userUDPSessions map[string]*atomic.Int64
}
func NewTrafficManager(nodes []Node) *TrafficManager {
manager := &TrafficManager{
nodeTags: make(map[string]bool),
globalUplink: atomic.NewInt64(0),
globalDownlink: atomic.NewInt64(0),
globalUplinkPackets: atomic.NewInt64(0),
globalDownlinkPackets: atomic.NewInt64(0),
globalTCPSessions: atomic.NewInt64(0),
globalUDPSessions: atomic.NewInt64(0),
userUplink: make(map[string]*atomic.Int64),
userDownlink: make(map[string]*atomic.Int64),
userUplinkPackets: make(map[string]*atomic.Int64),
userDownlinkPackets: make(map[string]*atomic.Int64),
userTCPSessions: make(map[string]*atomic.Int64),
userUDPSessions: make(map[string]*atomic.Int64),
}
for _, node := range nodes {
manager.nodeTags[node.Tag()] = true
}
return manager
}
func (s *TrafficManager) UpdateUsers(users []string) {
nodeUsers := make(map[string]bool)
for _, user := range users {
nodeUsers[user] = true
}
s.nodeUsers = nodeUsers
}
func (s *TrafficManager) userCounter(user string) (*atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
upCounter, loaded := s.userUplink[user]
if !loaded {
upCounter = atomic.NewInt64(0)
s.userUplink[user] = upCounter
}
downCounter, loaded := s.userDownlink[user]
if !loaded {
downCounter = atomic.NewInt64(0)
s.userDownlink[user] = downCounter
}
upPacketsCounter, loaded := s.userUplinkPackets[user]
if !loaded {
upPacketsCounter = atomic.NewInt64(0)
s.userUplinkPackets[user] = upPacketsCounter
}
downPacketsCounter, loaded := s.userDownlinkPackets[user]
if !loaded {
downPacketsCounter = atomic.NewInt64(0)
s.userDownlinkPackets[user] = downPacketsCounter
}
tcpSessionsCounter, loaded := s.userTCPSessions[user]
if !loaded {
tcpSessionsCounter = atomic.NewInt64(0)
s.userTCPSessions[user] = tcpSessionsCounter
}
udpSessionsCounter, loaded := s.userUDPSessions[user]
if !loaded {
udpSessionsCounter = atomic.NewInt64(0)
s.userUDPSessions[user] = udpSessionsCounter
}
return upCounter, downCounter, upPacketsCounter, downPacketsCounter, tcpSessionsCounter, udpSessionsCounter
}
func createCounter(counterList []*atomic.Int64, packetCounterList []*atomic.Int64) func(n int64) {
return func(n int64) {
for _, counter := range counterList {
counter.Add(n)
}
for _, counter := range packetCounterList {
counter.Inc()
}
}
}
func (s *TrafficManager) RoutedConnection(metadata adapter.InboundContext, conn net.Conn) net.Conn {
s.globalTCPSessions.Inc()
var readCounter []*atomic.Int64
var writeCounter []*atomic.Int64
if s.nodeTags[metadata.Inbound] {
readCounter = append(readCounter, s.globalUplink)
writeCounter = append(writeCounter, s.globalDownlink)
}
if s.nodeUsers[metadata.User] {
upCounter, downCounter, _, _, tcpSessionCounter, _ := s.userCounter(metadata.User)
readCounter = append(readCounter, upCounter)
writeCounter = append(writeCounter, downCounter)
tcpSessionCounter.Inc()
}
if len(readCounter) > 0 {
return trackerconn.New(conn, readCounter, writeCounter)
}
return conn
}
func (s *TrafficManager) RoutedPacketConnection(metadata adapter.InboundContext, conn N.PacketConn) N.PacketConn {
s.globalUDPSessions.Inc()
var readCounter []*atomic.Int64
var readPacketCounter []*atomic.Int64
var writeCounter []*atomic.Int64
var writePacketCounter []*atomic.Int64
if s.nodeTags[metadata.Inbound] {
readCounter = append(readCounter, s.globalUplink)
writeCounter = append(writeCounter, s.globalDownlink)
readPacketCounter = append(readPacketCounter, s.globalUplinkPackets)
writePacketCounter = append(writePacketCounter, s.globalDownlinkPackets)
}
if s.nodeUsers[metadata.User] {
upCounter, downCounter, upPacketsCounter, downPacketsCounter, _, udpSessionCounter := s.userCounter(metadata.User)
readCounter = append(readCounter, upCounter)
writeCounter = append(writeCounter, downCounter)
readPacketCounter = append(readPacketCounter, upPacketsCounter)
writePacketCounter = append(writePacketCounter, downPacketsCounter)
udpSessionCounter.Inc()
}
if len(readCounter) > 0 {
return trackerconn.NewHookPacket(conn, createCounter(readCounter, readPacketCounter), createCounter(writeCounter, writePacketCounter))
}
return conn
}
func (s *TrafficManager) ReadUser(user *SSMUserObject) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
s.readUser(user)
}
func (s *TrafficManager) readUser(user *SSMUserObject) {
if counter, loaded := s.userUplink[user.UserName]; loaded {
user.UplinkBytes = counter.Load()
}
if counter, loaded := s.userDownlink[user.UserName]; loaded {
user.DownlinkBytes = counter.Load()
}
if counter, loaded := s.userUplinkPackets[user.UserName]; loaded {
user.UplinkPackets = counter.Load()
}
if counter, loaded := s.userDownlinkPackets[user.UserName]; loaded {
user.DownlinkPackets = counter.Load()
}
if counter, loaded := s.userTCPSessions[user.UserName]; loaded {
user.TCPSessions = counter.Load()
}
if counter, loaded := s.userUDPSessions[user.UserName]; loaded {
user.UDPSessions = counter.Load()
}
}
func (s *TrafficManager) ReadUsers(users []*SSMUserObject) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
for _, user := range users {
s.readUser(user)
}
return
}
func (s *TrafficManager) ReadGlobal() (
uplinkBytes int64,
downlinkBytes int64,
uplinkPackets int64,
downlinkPackets int64,
tcpSessions int64,
udpSessions int64,
) {
return s.globalUplink.Load(),
s.globalDownlink.Load(),
s.globalUplinkPackets.Load(),
s.globalDownlinkPackets.Load(),
s.globalTCPSessions.Load(),
s.globalUDPSessions.Load()
}
func (s *TrafficManager) Clear() {
s.globalUplink.Store(0)
s.globalDownlink.Store(0)
s.globalUplinkPackets.Store(0)
s.globalDownlinkPackets.Store(0)
s.globalTCPSessions.Store(0)
s.globalUDPSessions.Store(0)
s.userAccess.Lock()
defer s.userAccess.Unlock()
s.userUplink = make(map[string]*atomic.Int64)
s.userDownlink = make(map[string]*atomic.Int64)
s.userUplinkPackets = make(map[string]*atomic.Int64)
s.userDownlinkPackets = make(map[string]*atomic.Int64)
s.userTCPSessions = make(map[string]*atomic.Int64)
s.userUDPSessions = make(map[string]*atomic.Int64)
}