sing-box/transport/wireguard/client_bind.go

202 lines
4.2 KiB
Go
Raw Normal View History

2022-09-05 16:15:09 +00:00
package wireguard
import (
"context"
"net"
"net/netip"
2022-09-05 16:15:09 +00:00
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
2023-04-13 08:02:28 +00:00
E "github.com/sagernet/sing/common/exceptions"
2022-09-05 16:15:09 +00:00
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/conn"
2022-09-05 16:15:09 +00:00
)
var _ conn.Bind = (*ClientBind)(nil)
type ClientBind struct {
ctx context.Context
2023-04-13 08:02:28 +00:00
errorHandler E.Handler
dialer N.Dialer
reservedForEndpoint map[M.Socksaddr][3]uint8
connAccess sync.Mutex
conn *wireConn
done chan struct{}
isConnect bool
connectAddr M.Socksaddr
reserved [3]uint8
2022-09-05 16:15:09 +00:00
}
2023-04-13 08:02:28 +00:00
func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
2022-09-05 16:15:09 +00:00
return &ClientBind{
ctx: ctx,
2023-04-13 08:02:28 +00:00
errorHandler: errorHandler,
dialer: dialer,
reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
isConnect: isConnect,
connectAddr: connectAddr,
reserved: reserved,
2022-09-05 16:15:09 +00:00
}
}
func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
c.reservedForEndpoint[destination] = reserved
}
2022-09-05 16:15:09 +00:00
func (c *ClientBind) connect() (*wireConn, error) {
serverConn := c.conn
if serverConn != nil {
select {
case <-serverConn.done:
serverConn = nil
default:
return serverConn, nil
}
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
serverConn = c.conn
if serverConn != nil {
select {
case <-serverConn.done:
serverConn = nil
default:
return serverConn, nil
}
}
if c.isConnect {
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
if err != nil {
2023-04-13 08:02:28 +00:00
return nil, err
}
c.conn = &wireConn{
2023-04-13 08:02:28 +00:00
PacketConn: &bufio.UnbindPacketConn{
ExtendedConn: bufio.NewExtendedConn(udpConn),
Addr: c.connectAddr,
},
done: make(chan struct{}),
}
} else {
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
if err != nil {
2023-04-13 08:02:28 +00:00
return nil, err
}
c.conn = &wireConn{
2023-04-13 08:02:28 +00:00
PacketConn: bufio.NewPacketConn(udpConn),
done: make(chan struct{}),
}
2022-09-05 16:15:09 +00:00
}
return c.conn, nil
}
func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
2022-09-23 05:14:31 +00:00
select {
case <-c.done:
err = net.ErrClosed
return
default:
}
2022-09-05 16:15:09 +00:00
return []conn.ReceiveFunc{c.receive}, 0, nil
}
func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
udpConn, err := c.connect()
if err != nil {
2023-04-13 08:02:28 +00:00
select {
case <-c.done:
return
default:
}
c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
err = nil
2022-09-05 16:15:09 +00:00
return
}
2023-04-13 08:02:28 +00:00
n, addr, err := udpConn.ReadFrom(b)
2022-09-05 16:15:09 +00:00
if err != nil {
udpConn.Close()
2022-09-23 05:14:31 +00:00
select {
case <-c.done:
default:
2023-04-13 08:02:28 +00:00
c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
2022-09-23 05:14:31 +00:00
}
return
2022-09-05 16:15:09 +00:00
}
if n > 3 {
b[1] = 0
b[2] = 0
b[3] = 0
}
2023-04-13 08:02:28 +00:00
ep = Endpoint(M.SocksaddrFromNet(addr))
2022-09-05 16:15:09 +00:00
return
}
2022-11-06 02:36:19 +00:00
func (c *ClientBind) Reset() {
common.Close(common.PtrOrNil(c.conn))
}
2022-09-05 16:15:09 +00:00
func (c *ClientBind) Close() error {
common.Close(common.PtrOrNil(c.conn))
2022-09-23 05:14:31 +00:00
if c.done == nil {
c.done = make(chan struct{})
return nil
}
select {
case <-c.done:
return net.ErrClosed
default:
close(c.done)
}
2022-09-05 16:15:09 +00:00
return nil
}
func (c *ClientBind) SetMark(mark uint32) error {
return nil
}
func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
udpConn, err := c.connect()
if err != nil {
return err
}
destination := M.Socksaddr(ep.(Endpoint))
if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination]
if !loaded {
reserved = c.reserved
}
b[1] = reserved[0]
b[2] = reserved[1]
b[3] = reserved[2]
}
2023-04-13 08:02:28 +00:00
_, err = udpConn.WriteTo(b, destination)
2022-09-05 16:15:09 +00:00
if err != nil {
udpConn.Close()
}
return err
}
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return Endpoint(M.ParseSocksaddr(s)), nil
2022-09-05 16:15:09 +00:00
}
type wireConn struct {
2023-04-13 08:02:28 +00:00
net.PacketConn
2022-09-05 16:15:09 +00:00
access sync.Mutex
done chan struct{}
}
func (w *wireConn) Close() error {
w.access.Lock()
defer w.access.Unlock()
select {
case <-w.done:
return net.ErrClosed
default:
}
2023-04-13 08:02:28 +00:00
w.PacketConn.Close()
2022-09-05 16:15:09 +00:00
close(w.done)
return nil
}