sing-box/transport/wireguard/endpoint.go
2024-12-10 21:37:35 +08:00

260 lines
7.2 KiB
Go

package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"os"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
"go4.org/netipx"
)
type Endpoint struct {
options EndpointOptions
peers []peerConfig
ipcConf string
allowedAddress []netip.Prefix
tunDevice Device
device *device.Device
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
}
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
if options.PrivateKey == "" {
return nil, E.New("missing private key")
}
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey := hex.EncodeToString(privateKeyBytes)
ipcConf := "private_key=" + privateKey
if options.ListenPort != 0 {
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
}
var peers []peerConfig
for peerIndex, rawPeer := range options.Peers {
peer := peerConfig{
allowedIPs: rawPeer.AllowedIPs,
keepalive: rawPeer.PersistentKeepaliveInterval,
}
if rawPeer.Endpoint.Addr.IsValid() {
peer.endpoint = rawPeer.Endpoint.AddrPort()
} else if rawPeer.Endpoint.IsFqdn() {
peer.destination = rawPeer.Endpoint
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
if rawPeer.PreSharedKey != "" {
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
}
copy(peer.reserved[:], rawPeer.Reserved[:])
}
peers = append(peers, peer)
}
var allowedPrefixBuilder netipx.IPSetBuilder
for _, peer := range options.Peers {
for _, prefix := range peer.AllowedIPs {
allowedPrefixBuilder.AddPrefix(prefix)
}
}
allowedIPSet, err := allowedPrefixBuilder.IPSet()
if err != nil {
return nil, err
}
allowedAddresses := allowedIPSet.Prefixes()
if options.MTU == 0 {
options.MTU = 1408
}
deviceOptions := DeviceOptions{
Context: options.Context,
Logger: options.Logger,
System: options.System,
Handler: options.Handler,
UDPTimeout: options.UDPTimeout,
CreateDialer: options.CreateDialer,
Name: options.Name,
MTU: options.MTU,
Address: options.Address,
AllowedAddress: allowedAddresses,
}
tunDevice, err := NewDevice(deviceOptions)
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
return &Endpoint{
options: options,
peers: peers,
ipcConf: ipcConf,
allowedAddress: allowedAddresses,
tunDevice: tunDevice,
}, nil
}
func (e *Endpoint) Start(resolve bool) error {
if common.Any(e.peers, func(peer peerConfig) bool {
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
}) {
if !resolve {
return nil
}
for peerIndex, peer := range e.peers {
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
continue
}
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
if err != nil {
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
}
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
}
} else if resolve {
return nil
}
var bind conn.Bind
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
if isWgListener {
bind = conn.NewStdNetBind(wgListener)
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
if len(e.peers) == 1 {
isConnect = true
connectAddr = e.peers[0].endpoint
reserved = e.peers[0].reserved
}
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
}
if isWgListener || len(e.peers) > 1 {
for _, peer := range e.peers {
if peer.reserved != [3]uint8{} {
bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
}
}
}
err := e.tunDevice.Start()
if err != nil {
return err
}
logger := &device.Logger{
Verbosef: func(format string, args ...interface{}) {
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
e.tunDevice.SetDevice(wgDevice)
ipcConf := e.ipcConf
for _, peer := range e.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
e.device = wgDevice
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
if e.pauseManager != nil {
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
}
return nil
}
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.DialContext(ctx, network, destination)
}
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.ListenPacket(ctx, destination)
}
func (e *Endpoint) BindUpdate() error {
return e.device.BindUpdate()
}
func (e *Endpoint) Close() error {
if e.device != nil {
e.device.Close()
}
if e.pauseCallback != nil {
e.pauseManager.UnregisterCallback(e.pauseCallback)
}
return nil
}
func (e *Endpoint) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
e.device.Down()
case pause.EventDeviceWake:
e.device.Up()
}
}
type peerConfig struct {
destination M.Socksaddr
endpoint netip.AddrPort
publicKeyHex string
preSharedKeyHex string
allowedIPs []netip.Prefix
keepalive uint16
reserved [3]uint8
}
func (c peerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.publicKeyHex
if c.endpoint.IsValid() {
ipcLines += "\nendpoint=" + c.endpoint.String()
}
if c.preSharedKeyHex != "" {
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
}
for _, allowedIP := range c.allowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP.String()
}
if c.keepalive > 0 {
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
}
return ipcLines
}