package inbound

import (
	"bytes"
	"context"
	"encoding/binary"
	"io"
	"net"

	"github.com/sagernet/sing-box/adapter"
	"github.com/sagernet/sing-box/common/dialer"
	C "github.com/sagernet/sing-box/constant"
	"github.com/sagernet/sing-box/log"
	"github.com/sagernet/sing-box/option"
	E "github.com/sagernet/sing/common/exceptions"
	M "github.com/sagernet/sing/common/metadata"
	N "github.com/sagernet/sing/common/network"
	"github.com/sagernet/sing/common/task"
)

type ShadowTLS struct {
	myInboundAdapter
	handshakeDialer N.Dialer
	handshakeAddr   M.Socksaddr
}

func NewShadowTLS(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSInboundOptions) (*ShadowTLS, error) {
	inbound := &ShadowTLS{
		myInboundAdapter: myInboundAdapter{
			protocol:      C.TypeShadowTLS,
			network:       []string{N.NetworkTCP},
			ctx:           ctx,
			router:        router,
			logger:        logger,
			tag:           tag,
			listenOptions: options.ListenOptions,
		},
		handshakeDialer: dialer.New(router, options.Handshake.DialerOptions),
		handshakeAddr:   options.Handshake.ServerOptions.Build(),
	}
	inbound.connHandler = inbound
	return inbound, nil
}

func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
	handshakeConn, err := s.handshakeDialer.DialContext(ctx, N.NetworkTCP, s.handshakeAddr)
	if err != nil {
		return err
	}
	var handshake task.Group
	handshake.Append("client handshake", func(ctx context.Context) error {
		return s.copyUntilHandshakeFinished(handshakeConn, conn)
	})
	handshake.Append("server handshake", func(ctx context.Context) error {
		return s.copyUntilHandshakeFinished(conn, handshakeConn)
	})
	handshake.FastFail()
	err = handshake.Run(ctx)
	if err != nil {
		return err
	}
	return s.newConnection(ctx, conn, metadata)
}

func (s *ShadowTLS) copyUntilHandshakeFinished(dst io.Writer, src io.Reader) error {
	const handshake = 0x16
	const changeCipherSpec = 0x14
	var hasSeenChangeCipherSpec bool
	var tlsHdr [5]byte
	for {
		_, err := io.ReadFull(src, tlsHdr[:])
		if err != nil {
			return err
		}
		length := binary.BigEndian.Uint16(tlsHdr[3:])
		_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), io.LimitReader(src, int64(length))))
		if err != nil {
			return err
		}
		if tlsHdr[0] != handshake {
			if tlsHdr[0] != changeCipherSpec {
				return E.New("unexpected tls frame type: ", tlsHdr[0])
			}
			if !hasSeenChangeCipherSpec {
				hasSeenChangeCipherSpec = true
				continue
			}
		}
		if hasSeenChangeCipherSpec {
			return nil
		}
	}
}