package mux

import (
	"encoding/binary"
	"io"
	"net"

	C "github.com/sagernet/sing-box/constant"
	"github.com/sagernet/sing/common"
	"github.com/sagernet/sing/common/buf"
	E "github.com/sagernet/sing/common/exceptions"
	M "github.com/sagernet/sing/common/metadata"
	N "github.com/sagernet/sing/common/network"
	"github.com/sagernet/sing/common/rw"
	"github.com/sagernet/smux"

	"github.com/hashicorp/yamux"
)

var Destination = M.Socksaddr{
	Fqdn: "sp.mux.sing-box.arpa",
	Port: 444,
}

const (
	ProtocolSMux Protocol = iota
	ProtocolYAMux
)

type Protocol byte

func ParseProtocol(name string) (Protocol, error) {
	switch name {
	case "", "smux":
		return ProtocolSMux, nil
	case "yamux":
		return ProtocolYAMux, nil
	default:
		return ProtocolYAMux, E.New("unknown multiplex protocol: ", name)
	}
}

func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
	switch p {
	case ProtocolSMux:
		session, err := smux.Server(conn, nil)
		if err != nil {
			return nil, err
		}
		return &smuxSession{session}, nil
	case ProtocolYAMux:
		return yamux.Server(conn, yaMuxConfig())
	default:
		panic("unknown protocol")
	}
}

func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
	switch p {
	case ProtocolSMux:
		session, err := smux.Client(conn, nil)
		if err != nil {
			return nil, err
		}
		return &smuxSession{session}, nil
	case ProtocolYAMux:
		return yamux.Client(conn, yaMuxConfig())
	default:
		panic("unknown protocol")
	}
}

func yaMuxConfig() *yamux.Config {
	config := yamux.DefaultConfig()
	config.LogOutput = io.Discard
	config.StreamCloseTimeout = C.TCPTimeout
	config.StreamOpenTimeout = C.TCPTimeout
	return config
}

func (p Protocol) String() string {
	switch p {
	case ProtocolSMux:
		return "smux"
	case ProtocolYAMux:
		return "yamux"
	default:
		return "unknown"
	}
}

const (
	version0 = 0
)

type Request struct {
	Protocol Protocol
}

func ReadRequest(reader io.Reader) (*Request, error) {
	version, err := rw.ReadByte(reader)
	if err != nil {
		return nil, err
	}
	if version != version0 {
		return nil, E.New("unsupported version: ", version)
	}
	protocol, err := rw.ReadByte(reader)
	if err != nil {
		return nil, err
	}
	if protocol > byte(ProtocolYAMux) {
		return nil, E.New("unsupported protocol: ", protocol)
	}
	return &Request{Protocol: Protocol(protocol)}, nil
}

func EncodeRequest(buffer *buf.Buffer, request Request) {
	buffer.WriteByte(version0)
	buffer.WriteByte(byte(request.Protocol))
}

const (
	flagUDP       = 1
	flagAddr      = 2
	statusSuccess = 0
	statusError   = 1
)

type StreamRequest struct {
	Network     string
	Destination M.Socksaddr
	PacketAddr  bool
}

func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
	var flags uint16
	err := binary.Read(reader, binary.BigEndian, &flags)
	if err != nil {
		return nil, err
	}
	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
	if err != nil {
		return nil, err
	}
	var network string
	var udpAddr bool
	if flags&flagUDP == 0 {
		network = N.NetworkTCP
	} else {
		network = N.NetworkUDP
		udpAddr = flags&flagAddr != 0
	}
	return &StreamRequest{network, destination, udpAddr}, nil
}

func requestLen(request StreamRequest) int {
	var rLen int
	rLen += 1 // version
	rLen += 2 // flags
	rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination)
	return rLen
}

func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
	destination := request.Destination
	var flags uint16
	if request.Network == N.NetworkUDP {
		flags |= flagUDP
	}
	if request.PacketAddr {
		flags |= flagAddr
		if !destination.IsValid() {
			destination = Destination
		}
	}
	common.Must(
		binary.Write(buffer, binary.BigEndian, flags),
		M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
	)
}

type StreamResponse struct {
	Status  uint8
	Message string
}

func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
	var response StreamResponse
	status, err := rw.ReadByte(reader)
	if err != nil {
		return nil, err
	}
	response.Status = status
	if status == statusError {
		response.Message, err = rw.ReadVString(reader)
		if err != nil {
			return nil, err
		}
	}
	return &response, nil
}

type wrapStream struct {
	net.Conn
}

func (w *wrapStream) Read(p []byte) (n int, err error) {
	n, err = w.Conn.Read(p)
	err = wrapError(err)
	return
}

func (w *wrapStream) Write(p []byte) (n int, err error) {
	n, err = w.Conn.Write(p)
	err = wrapError(err)
	return
}

func (w *wrapStream) WriteIsThreadUnsafe() {
}

func (w *wrapStream) Upstream() any {
	return w.Conn
}

func wrapError(err error) error {
	switch err {
	case yamux.ErrStreamClosed:
		return io.EOF
	default:
		return err
	}
}