package wireguard import ( "context" "encoding/base64" "encoding/hex" "fmt" "net" "net/netip" "os" "strings" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" "github.com/sagernet/wireguard-go/conn" "github.com/sagernet/wireguard-go/device" "go4.org/netipx" ) type Endpoint struct { options EndpointOptions peers []peerConfig ipcConf string allowedAddress []netip.Prefix tunDevice Device device *device.Device pauseManager pause.Manager pauseCallback *list.Element[pause.Callback] } func NewEndpoint(options EndpointOptions) (*Endpoint, error) { if options.PrivateKey == "" { return nil, E.New("missing private key") } privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) if err != nil { return nil, E.Cause(err, "decode private key") } privateKey := hex.EncodeToString(privateKeyBytes) ipcConf := "private_key=" + privateKey if options.ListenPort != 0 { ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort) } var peers []peerConfig for peerIndex, rawPeer := range options.Peers { peer := peerConfig{ allowedIPs: rawPeer.AllowedIPs, keepalive: rawPeer.PersistentKeepaliveInterval, } if rawPeer.Endpoint.Addr.IsValid() { peer.endpoint = rawPeer.Endpoint.AddrPort() } else if rawPeer.Endpoint.IsFqdn() { peer.destination = rawPeer.Endpoint } publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey) if err != nil { return nil, E.Cause(err, "decode public key for peer ", peerIndex) } peer.publicKeyHex = hex.EncodeToString(publicKeyBytes) if rawPeer.PreSharedKey != "" { preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey) if err != nil { return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex) } peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes) } if len(rawPeer.AllowedIPs) == 0 { return nil, E.New("missing allowed ips for peer ", peerIndex) } if len(rawPeer.Reserved) > 0 { if len(rawPeer.Reserved) != 3 { return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved)) } copy(peer.reserved[:], rawPeer.Reserved[:]) } peers = append(peers, peer) } var allowedPrefixBuilder netipx.IPSetBuilder for _, peer := range options.Peers { for _, prefix := range peer.AllowedIPs { allowedPrefixBuilder.AddPrefix(prefix) } } allowedIPSet, err := allowedPrefixBuilder.IPSet() if err != nil { return nil, err } allowedAddresses := allowedIPSet.Prefixes() if options.MTU == 0 { options.MTU = 1408 } deviceOptions := DeviceOptions{ Context: options.Context, Logger: options.Logger, System: options.System, Handler: options.Handler, UDPTimeout: options.UDPTimeout, CreateDialer: options.CreateDialer, Name: options.Name, MTU: options.MTU, GSO: options.GSO, Address: options.Address, AllowedAddress: allowedAddresses, } tunDevice, err := NewDevice(deviceOptions) if err != nil { return nil, E.Cause(err, "create WireGuard device") } return &Endpoint{ options: options, peers: peers, ipcConf: ipcConf, allowedAddress: allowedAddresses, tunDevice: tunDevice, }, nil } func (e *Endpoint) Start(resolve bool) error { if common.Any(e.peers, func(peer peerConfig) bool { return !peer.endpoint.IsValid() && peer.destination.IsFqdn() }) { if !resolve { return nil } for peerIndex, peer := range e.peers { if peer.endpoint.IsValid() || !peer.destination.IsFqdn() { continue } destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn) if err != nil { return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination) } e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port) } } else if resolve { return nil } var bind conn.Bind wgListener, isWgListener := e.options.Dialer.(conn.Listener) if isWgListener { bind = conn.NewStdNetBind(wgListener) } else { var ( isConnect bool connectAddr netip.AddrPort reserved [3]uint8 ) peerLen := len(e.peers) if peerLen == 1 { isConnect = true connectAddr = e.peers[0].endpoint reserved = e.peers[0].reserved } bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved) } err := e.tunDevice.Start() if err != nil { return err } logger := &device.Logger{ Verbosef: func(format string, args ...interface{}) { e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) }, Errorf: func(format string, args ...interface{}) { e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) }, } wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers) e.tunDevice.SetDevice(wgDevice) ipcConf := e.ipcConf for _, peer := range e.peers { ipcConf += peer.GenerateIpcLines() } err = wgDevice.IpcSet(ipcConf) if err != nil { return E.Cause(err, "setup wireguard: \n", ipcConf) } e.device = wgDevice e.pauseManager = service.FromContext[pause.Manager](e.options.Context) if e.pauseManager != nil { e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated) } return nil } func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { if !destination.Addr.IsValid() { return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") } return e.tunDevice.DialContext(ctx, network, destination) } func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { if !destination.Addr.IsValid() { return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") } return e.tunDevice.ListenPacket(ctx, destination) } func (e *Endpoint) BindUpdate() error { return e.device.BindUpdate() } func (e *Endpoint) Close() error { if e.device != nil { e.device.Close() } if e.pauseCallback != nil { e.pauseManager.UnregisterCallback(e.pauseCallback) } return nil } func (e *Endpoint) onPauseUpdated(event int) { switch event { case pause.EventDevicePaused: e.device.Down() case pause.EventDeviceWake: e.device.Up() } } type peerConfig struct { destination M.Socksaddr endpoint netip.AddrPort publicKeyHex string preSharedKeyHex string allowedIPs []netip.Prefix keepalive uint16 reserved [3]uint8 } func (c peerConfig) GenerateIpcLines() string { ipcLines := "\npublic_key=" + c.publicKeyHex if c.endpoint.IsValid() { ipcLines += "\nendpoint=" + c.endpoint.String() } if c.preSharedKeyHex != "" { ipcLines += "\npreshared_key=" + c.preSharedKeyHex } for _, allowedIP := range c.allowedIPs { ipcLines += "\nallowed_ip=" + allowedIP.String() } if c.keepalive > 0 { ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive) } return ipcLines }