package wireguard

import (
	"io"

	"github.com/sagernet/sing/common/buf"
	E "github.com/sagernet/sing/common/exceptions"
	M "github.com/sagernet/sing/common/metadata"
	N "github.com/sagernet/sing/common/network"

	"golang.zx2c4.com/wireguard/conn"
)

var _ conn.Bind = (*ServerBind)(nil)

type ServerBind struct {
	inbound   chan serverPacket
	done      chan struct{}
	writeBack N.PacketWriter
}

func NewServerBind(writeBack N.PacketWriter) *ServerBind {
	return &ServerBind{
		inbound:   make(chan serverPacket, 256),
		done:      make(chan struct{}),
		writeBack: writeBack,
	}
}

func (s *ServerBind) Abort() error {
	select {
	case <-s.done:
		return io.ErrClosedPipe
	default:
		close(s.done)
	}
	return nil
}

type serverPacket struct {
	buffer *buf.Buffer
	source M.Socksaddr
}

func (s *ServerBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
	fns = []conn.ReceiveFunc{s.receive}
	return
}

func (s *ServerBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
	select {
	case packet := <-s.inbound:
		defer packet.buffer.Release()
		n = copy(b, packet.buffer.Bytes())
		ep = Endpoint(packet.source)
		return
	case <-s.done:
		err = io.ErrClosedPipe
		return
	}
}

func (s *ServerBind) WriteIsThreadUnsafe() {
}

func (s *ServerBind) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
	select {
	case s.inbound <- serverPacket{
		buffer: buffer,
		source: destination,
	}:
		return nil
	case <-s.done:
		return io.ErrClosedPipe
	}
}

func (s *ServerBind) Close() error {
	return nil
}

func (s *ServerBind) SetMark(mark uint32) error {
	return nil
}

func (s *ServerBind) Send(b []byte, ep conn.Endpoint) error {
	return s.writeBack.WritePacket(buf.As(b), M.Socksaddr(ep.(Endpoint)))
}

func (s *ServerBind) ParseEndpoint(addr string) (conn.Endpoint, error) {
	destination := M.ParseSocksaddr(addr)
	if !destination.IsValid() || destination.Port == 0 {
		return nil, E.New("invalid endpoint: ", addr)
	}
	return Endpoint(destination), nil
}