//go:build with_wireguard package outbound import ( "context" "encoding/base64" "encoding/hex" "fmt" "net" "os" "strings" "syscall" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/wireguard-go/device" ) var ( _ adapter.IPOutbound = (*WireGuard)(nil) _ adapter.InterfaceUpdateListener = (*WireGuard)(nil) ) type WireGuard struct { myOutboundAdapter bind *wireguard.ClientBind device *device.Device natDevice wireguard.NatDevice tunDevice wireguard.Device } func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) { outbound := &WireGuard{ myOutboundAdapter: myOutboundAdapter{ protocol: C.TypeWireGuard, network: options.Network.Build(), router: router, logger: logger, tag: tag, }, } var reserved [3]uint8 if len(options.Reserved) > 0 { if len(options.Reserved) != 3 { return nil, E.New("invalid reserved value, required 3 bytes, got ", len(options.Reserved)) } copy(reserved[:], options.Reserved) } var isConnect bool var connectAddr M.Socksaddr if len(options.Peers) < 2 { isConnect = true if len(options.Peers) == 1 { connectAddr = options.Peers[0].ServerOptions.Build() } else { connectAddr = options.ServerOptions.Build() } } outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved) localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build) if len(localPrefixes) == 0 { return nil, E.New("missing local address") } var privateKey string { bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) if err != nil { return nil, E.Cause(err, "decode private key") } privateKey = hex.EncodeToString(bytes) } ipcConf := "private_key=" + privateKey if len(options.Peers) > 0 { for i, peer := range options.Peers { var peerPublicKey, preSharedKey string { bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey) if err != nil { return nil, E.Cause(err, "decode public key for peer ", i) } peerPublicKey = hex.EncodeToString(bytes) } if peer.PreSharedKey != "" { bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey) if err != nil { return nil, E.Cause(err, "decode pre shared key for peer ", i) } preSharedKey = hex.EncodeToString(bytes) } destination := peer.ServerOptions.Build() ipcConf += "\npublic_key=" + peerPublicKey ipcConf += "\nendpoint=" + destination.String() if preSharedKey != "" { ipcConf += "\npreshared_key=" + preSharedKey } if len(peer.AllowedIPs) == 0 { return nil, E.New("missing allowed_ips for peer ", i) } for _, allowedIP := range peer.AllowedIPs { ipcConf += "\nallowed_ip=" + allowedIP } if len(peer.Reserved) > 0 { if len(peer.Reserved) != 3 { return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved)) } copy(reserved[:], options.Reserved) outbound.bind.SetReservedForEndpoint(destination, reserved) } } } else { var peerPublicKey, preSharedKey string { bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey) if err != nil { return nil, E.Cause(err, "decode peer public key") } peerPublicKey = hex.EncodeToString(bytes) } if options.PreSharedKey != "" { bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey) if err != nil { return nil, E.Cause(err, "decode pre shared key") } preSharedKey = hex.EncodeToString(bytes) } ipcConf += "\npublic_key=" + peerPublicKey ipcConf += "\nendpoint=" + options.ServerOptions.Build().String() if preSharedKey != "" { ipcConf += "\npreshared_key=" + preSharedKey } var has4, has6 bool for _, address := range localPrefixes { if address.Addr().Is4() { has4 = true } else { has6 = true } } if has4 { ipcConf += "\nallowed_ip=0.0.0.0/0" } if has6 { ipcConf += "\nallowed_ip=::/0" } } mtu := options.MTU if mtu == 0 { mtu = 1408 } var tunDevice wireguard.Device var err error if !options.SystemInterface && tun.WithGVisor { tunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu, options.IPRewrite) } else { tunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu) } if err != nil { return nil, E.Cause(err, "create WireGuard device") } natDevice, isNatDevice := tunDevice.(wireguard.NatDevice) if !isNatDevice && router.NatRequired(tag) { natDevice = wireguard.NewNATDevice(tunDevice, options.IPRewrite) } deviceInput := tunDevice if natDevice != nil { deviceInput = natDevice } wgDevice := device.NewDevice(deviceInput, outbound.bind, &device.Logger{ Verbosef: func(format string, args ...interface{}) { logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) }, Errorf: func(format string, args ...interface{}) { logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) }, }, options.Workers) if debug.Enabled { logger.Trace("created wireguard ipc conf: \n", ipcConf) } err = wgDevice.IpcSet(ipcConf) if err != nil { return nil, E.Cause(err, "setup wireguard") } outbound.device = wgDevice outbound.natDevice = natDevice outbound.tunDevice = tunDevice return outbound, nil } func (w *WireGuard) InterfaceUpdated() error { w.bind.Reset() return nil } func (w *WireGuard) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch network { case N.NetworkTCP: w.logger.InfoContext(ctx, "outbound connection to ", destination) case N.NetworkUDP: w.logger.InfoContext(ctx, "outbound packet connection to ", destination) } if destination.IsFqdn() { addrs, err := w.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } return N.DialSerial(ctx, w.tunDevice, network, destination, addrs) } return w.tunDevice.DialContext(ctx, network, destination) } func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { w.logger.InfoContext(ctx, "outbound packet connection to ", destination) return w.tunDevice.ListenPacket(ctx, destination) } func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { return NewConnection(ctx, w, conn, metadata) } func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { return NewPacketConnection(ctx, w, conn, metadata) } func (w *WireGuard) NewIPConnection(ctx context.Context, conn tun.RouteContext, metadata adapter.InboundContext) (tun.DirectDestination, error) { if w.natDevice == nil { return nil, os.ErrInvalid } session := tun.RouteSession{ IPVersion: metadata.IPVersion, Network: tun.NetworkFromName(metadata.Network), Source: metadata.Source.AddrPort(), Destination: metadata.Destination.AddrPort(), } switch session.Network { case syscall.IPPROTO_TCP: w.logger.InfoContext(ctx, "linked connection to ", metadata.Destination) case syscall.IPPROTO_UDP: w.logger.InfoContext(ctx, "linked packet connection to ", metadata.Destination) default: w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection to ", metadata.Destination.AddrString()) } return w.natDevice.CreateDestination(session, conn), nil } func (w *WireGuard) Start() error { return w.tunDevice.Start() } func (w *WireGuard) Close() error { if w.device != nil { w.device.Close() } w.tunDevice.Close() return nil }