package trafficcontrol

import (
	"io"
	"net"
	"sync"
	"sync/atomic"

	"github.com/sagernet/sing/common/buf"
	"github.com/sagernet/sing/common/bufio"
	M "github.com/sagernet/sing/common/metadata"
	N "github.com/sagernet/sing/common/network"
)

type Manager[U comparable] struct {
	access sync.Mutex
	users  map[U]*Traffic
}

type Traffic struct {
	Upload   uint64
	Download uint64
}

func NewManager[U comparable]() *Manager[U] {
	return &Manager[U]{
		users: make(map[U]*Traffic),
	}
}

func (m *Manager[U]) Reset() {
	m.users = make(map[U]*Traffic)
}

func (m *Manager[U]) TrackConnection(user U, conn net.Conn) net.Conn {
	m.access.Lock()
	defer m.access.Unlock()
	var traffic *Traffic
	if t, loaded := m.users[user]; loaded {
		traffic = t
	} else {
		traffic = new(Traffic)
		m.users[user] = traffic
	}
	return &TrackConn{conn, traffic}
}

func (m *Manager[U]) TrackPacketConnection(user U, conn N.PacketConn) N.PacketConn {
	m.access.Lock()
	defer m.access.Unlock()
	var traffic *Traffic
	if t, loaded := m.users[user]; loaded {
		traffic = t
	} else {
		traffic = new(Traffic)
		m.users[user] = traffic
	}
	return &TrackPacketConn{conn, traffic}
}

func (m *Manager[U]) ReadTraffics() map[U]Traffic {
	m.access.Lock()
	defer m.access.Unlock()

	trafficMap := make(map[U]Traffic)
	for user, traffic := range m.users {
		upload := atomic.SwapUint64(&traffic.Upload, 0)
		download := atomic.SwapUint64(&traffic.Download, 0)
		if upload == 0 && download == 0 {
			continue
		}
		trafficMap[user] = Traffic{
			Upload:   upload,
			Download: download,
		}
	}
	return trafficMap
}

type TrackConn struct {
	net.Conn
	*Traffic
}

func (c *TrackConn) Read(p []byte) (n int, err error) {
	n, err = c.Conn.Read(p)
	if n > 0 {
		atomic.AddUint64(&c.Upload, uint64(n))
	}
	return
}

func (c *TrackConn) Write(p []byte) (n int, err error) {
	n, err = c.Conn.Write(p)
	if n > 0 {
		atomic.AddUint64(&c.Download, uint64(n))
	}
	return
}

func (c *TrackConn) WriteTo(w io.Writer) (n int64, err error) {
	n, err = bufio.Copy(w, c.Conn)
	if n > 0 {
		atomic.AddUint64(&c.Upload, uint64(n))
	}
	return
}

func (c *TrackConn) ReadFrom(r io.Reader) (n int64, err error) {
	n, err = bufio.Copy(c.Conn, r)
	if n > 0 {
		atomic.AddUint64(&c.Download, uint64(n))
	}
	return
}

func (c *TrackConn) Upstream() any {
	return c.Conn
}

type TrackPacketConn struct {
	N.PacketConn
	*Traffic
}

func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
	destination, err := c.PacketConn.ReadPacket(buffer)
	if err == nil {
		atomic.AddUint64(&c.Upload, uint64(buffer.Len()))
	}
	return destination, err
}

func (c *TrackPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
	n := buffer.Len()
	err := c.PacketConn.WritePacket(buffer, destination)
	if err == nil {
		atomic.AddUint64(&c.Download, uint64(n))
	}
	return err
}

func (c *TrackPacketConn) Upstream() any {
	return c.PacketConn
}