mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-12-12 16:18:50 +00:00
260 lines
5.6 KiB
Go
260 lines
5.6 KiB
Go
package wireguard
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/bufio"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/logger"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
"github.com/sagernet/sing/service"
|
|
"github.com/sagernet/sing/service/pause"
|
|
"github.com/sagernet/wireguard-go/conn"
|
|
)
|
|
|
|
var _ conn.Bind = (*ClientBind)(nil)
|
|
|
|
type ClientBind struct {
|
|
ctx context.Context
|
|
logger logger.Logger
|
|
pauseManager pause.Manager
|
|
bindCtx context.Context
|
|
bindDone context.CancelFunc
|
|
dialer N.Dialer
|
|
reservedForEndpoint map[netip.AddrPort][3]uint8
|
|
connAccess sync.Mutex
|
|
conn *wireConn
|
|
done chan struct{}
|
|
isConnect bool
|
|
connectAddr netip.AddrPort
|
|
reserved [3]uint8
|
|
}
|
|
|
|
func NewClientBind(ctx context.Context, logger logger.Logger, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
|
|
return &ClientBind{
|
|
ctx: ctx,
|
|
logger: logger,
|
|
pauseManager: service.FromContext[pause.Manager](ctx),
|
|
dialer: dialer,
|
|
reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
|
|
done: make(chan struct{}),
|
|
isConnect: isConnect,
|
|
connectAddr: connectAddr,
|
|
reserved: reserved,
|
|
}
|
|
}
|
|
|
|
func (c *ClientBind) connect() (*wireConn, error) {
|
|
serverConn := c.conn
|
|
if serverConn != nil {
|
|
select {
|
|
case <-serverConn.done:
|
|
serverConn = nil
|
|
default:
|
|
return serverConn, nil
|
|
}
|
|
}
|
|
c.connAccess.Lock()
|
|
defer c.connAccess.Unlock()
|
|
select {
|
|
case <-c.done:
|
|
return nil, net.ErrClosed
|
|
default:
|
|
}
|
|
serverConn = c.conn
|
|
if serverConn != nil {
|
|
select {
|
|
case <-serverConn.done:
|
|
serverConn = nil
|
|
default:
|
|
return serverConn, nil
|
|
}
|
|
}
|
|
if c.isConnect {
|
|
udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.conn = &wireConn{
|
|
PacketConn: bufio.NewUnbindPacketConn(udpConn),
|
|
done: make(chan struct{}),
|
|
}
|
|
} else {
|
|
udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.conn = &wireConn{
|
|
PacketConn: bufio.NewPacketConn(udpConn),
|
|
done: make(chan struct{}),
|
|
}
|
|
}
|
|
return c.conn, nil
|
|
}
|
|
|
|
func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
|
select {
|
|
case <-c.done:
|
|
c.done = make(chan struct{})
|
|
default:
|
|
}
|
|
c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
|
|
return []conn.ReceiveFunc{c.receive}, 0, nil
|
|
}
|
|
|
|
func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
|
|
udpConn, err := c.connect()
|
|
if err != nil {
|
|
select {
|
|
case <-c.done:
|
|
return
|
|
default:
|
|
}
|
|
c.logger.Error(E.Cause(err, "connect to server"))
|
|
err = nil
|
|
c.pauseManager.WaitActive()
|
|
time.Sleep(time.Second)
|
|
return
|
|
}
|
|
n, addr, err := udpConn.ReadFrom(packets[0])
|
|
if err != nil {
|
|
udpConn.Close()
|
|
select {
|
|
case <-c.done:
|
|
default:
|
|
c.logger.Error(E.Cause(err, "read packet"))
|
|
err = nil
|
|
}
|
|
return
|
|
}
|
|
sizes[0] = n
|
|
if n > 3 {
|
|
b := packets[0]
|
|
common.ClearArray(b[1:4])
|
|
}
|
|
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
|
|
count = 1
|
|
return
|
|
}
|
|
|
|
func (c *ClientBind) Close() error {
|
|
select {
|
|
case <-c.done:
|
|
default:
|
|
close(c.done)
|
|
}
|
|
if c.bindDone != nil {
|
|
c.bindDone()
|
|
}
|
|
c.connAccess.Lock()
|
|
defer c.connAccess.Unlock()
|
|
common.Close(common.PtrOrNil(c.conn))
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientBind) SetMark(mark uint32) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
|
udpConn, err := c.connect()
|
|
if err != nil {
|
|
c.pauseManager.WaitActive()
|
|
time.Sleep(time.Second)
|
|
return err
|
|
}
|
|
destination := netip.AddrPort(ep.(remoteEndpoint))
|
|
for _, b := range bufs {
|
|
if len(b) > 3 {
|
|
reserved, loaded := c.reservedForEndpoint[destination]
|
|
if !loaded {
|
|
reserved = c.reserved
|
|
}
|
|
copy(b[1:4], reserved[:])
|
|
}
|
|
_, err = udpConn.WriteToUDPAddrPort(b, destination)
|
|
if err != nil {
|
|
udpConn.Close()
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
|
ap, err := netip.ParseAddrPort(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return remoteEndpoint(ap), nil
|
|
}
|
|
|
|
func (c *ClientBind) BatchSize() int {
|
|
return 1
|
|
}
|
|
|
|
func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
|
|
c.reservedForEndpoint[destination] = reserved
|
|
}
|
|
|
|
type wireConn struct {
|
|
net.PacketConn
|
|
conn net.Conn
|
|
access sync.Mutex
|
|
done chan struct{}
|
|
}
|
|
|
|
func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
|
|
if w.conn != nil {
|
|
return w.conn.Write(b)
|
|
}
|
|
return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
|
|
}
|
|
|
|
func (w *wireConn) Close() error {
|
|
w.access.Lock()
|
|
defer w.access.Unlock()
|
|
select {
|
|
case <-w.done:
|
|
return net.ErrClosed
|
|
default:
|
|
}
|
|
w.PacketConn.Close()
|
|
close(w.done)
|
|
return nil
|
|
}
|
|
|
|
var _ conn.Endpoint = (*remoteEndpoint)(nil)
|
|
|
|
type remoteEndpoint netip.AddrPort
|
|
|
|
func (e remoteEndpoint) ClearSrc() {
|
|
}
|
|
|
|
func (e remoteEndpoint) SrcToString() string {
|
|
return ""
|
|
}
|
|
|
|
func (e remoteEndpoint) DstToString() string {
|
|
return (netip.AddrPort)(e).String()
|
|
}
|
|
|
|
func (e remoteEndpoint) DstToBytes() []byte {
|
|
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
|
return b
|
|
}
|
|
|
|
func (e remoteEndpoint) DstIP() netip.Addr {
|
|
return (netip.AddrPort)(e).Addr()
|
|
}
|
|
|
|
func (e remoteEndpoint) SrcIP() netip.Addr {
|
|
return netip.Addr{}
|
|
}
|