//go:build with_wireguard package outbound import ( "context" "encoding/base64" "encoding/hex" "fmt" "net" "net/netip" "os" "strings" "sync" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) var _ adapter.Outbound = (*WireGuard)(nil) type WireGuard struct { myOutboundAdapter ctx context.Context serverAddr M.Socksaddr dialer N.Dialer endpoint conn.Endpoint device *device.Device tunDevice *wireTunDevice connAccess sync.Mutex conn *wireConn } func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) { outbound := &WireGuard{ myOutboundAdapter: myOutboundAdapter{ protocol: C.TypeWireGuard, network: options.Network.Build(), router: router, logger: logger, tag: tag, }, ctx: ctx, serverAddr: options.ServerOptions.Build(), dialer: dialer.NewOutbound(router, options.OutboundDialerOptions), } var endpointIp netip.Addr if !outbound.serverAddr.IsFqdn() { endpointIp = outbound.serverAddr.Addr } else { endpointIp = netip.AddrFrom4([4]byte{127, 0, 0, 1}) } outbound.endpoint = conn.StdNetEndpoint(netip.AddrPortFrom(endpointIp, outbound.serverAddr.Port)) localAddress := make([]tcpip.AddressWithPrefix, len(options.LocalAddress)) if len(localAddress) == 0 { return nil, E.New("missing local address") } for index, address := range options.LocalAddress { if strings.Contains(address, "/") { prefix, err := netip.ParsePrefix(address) if err != nil { return nil, E.Cause(err, "parse local address prefix ", address) } localAddress[index] = tcpip.AddressWithPrefix{ Address: tcpip.Address(prefix.Addr().AsSlice()), PrefixLen: prefix.Bits(), } } else { addr, err := netip.ParseAddr(address) if err != nil { return nil, E.Cause(err, "parse local address ", address) } localAddress[index] = tcpip.Address(addr.AsSlice()).WithPrefix() } } var privateKey, peerPublicKey, preSharedKey string { bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) if err != nil { return nil, E.Cause(err, "decode private key") } privateKey = hex.EncodeToString(bytes) } { bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey) if err != nil { return nil, E.Cause(err, "decode peer public key") } peerPublicKey = hex.EncodeToString(bytes) } if options.PreSharedKey != "" { bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey) if err != nil { return nil, E.Cause(err, "decode pre shared key") } preSharedKey = hex.EncodeToString(bytes) } ipcConf := "private_key=" + privateKey ipcConf += "\npublic_key=" + peerPublicKey ipcConf += "\nendpoint=" + outbound.endpoint.DstToString() if preSharedKey != "" { ipcConf += "\npreshared_key=" + preSharedKey } var has4, has6 bool for _, address := range localAddress { if address.Address.To4() != "" { has4 = true } else { has6 = true } } if has4 { ipcConf += "\nallowed_ip=0.0.0.0/0" } if has6 { ipcConf += "\nallowed_ip=::/0" } mtu := options.MTU if mtu == 0 { mtu = 1408 } wireDevice, err := newWireDevice(localAddress, mtu) if err != nil { return nil, err } wgDevice := device.NewDevice(wireDevice, (*wireClientBind)(outbound), &device.Logger{ Verbosef: func(format string, args ...interface{}) { logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) }, Errorf: func(format string, args ...interface{}) { logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) }, }) if debug.Enabled { logger.Trace("created wireguard ipc conf: \n", ipcConf) } err = wgDevice.IpcSet(ipcConf) if err != nil { return nil, E.Cause(err, "setup wireguard") } outbound.device = wgDevice outbound.tunDevice = wireDevice return outbound, nil } func (w *WireGuard) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch network { case N.NetworkTCP: w.logger.InfoContext(ctx, "outbound connection to ", destination) case N.NetworkUDP: w.logger.InfoContext(ctx, "outbound packet connection to ", destination) } addr := tcpip.FullAddress{ NIC: defaultNIC, Port: destination.Port, } if destination.IsFqdn() { addrs, err := w.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } addr.Addr = tcpip.Address(addrs[0].AsSlice()) } else { addr.Addr = tcpip.Address(destination.Addr.AsSlice()) } bind := tcpip.FullAddress{ NIC: defaultNIC, } var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { networkProtocol = header.IPv4ProtocolNumber bind.Addr = w.tunDevice.addr4 } else { networkProtocol = header.IPv6ProtocolNumber bind.Addr = w.tunDevice.addr6 } switch N.NetworkName(network) { case N.NetworkTCP: return gonet.DialTCPWithBind(ctx, w.tunDevice.stack, bind, addr, networkProtocol) case N.NetworkUDP: return gonet.DialUDP(w.tunDevice.stack, &bind, &addr, networkProtocol) default: return nil, E.Extend(N.ErrUnknownNetwork, network) } } func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { w.logger.InfoContext(ctx, "outbound packet connection to ", destination) bind := tcpip.FullAddress{ NIC: defaultNIC, } var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() || w.tunDevice.addr6 == "" { networkProtocol = header.IPv4ProtocolNumber bind.Addr = w.tunDevice.addr4 } else { networkProtocol = header.IPv6ProtocolNumber bind.Addr = w.tunDevice.addr6 } return gonet.DialUDP(w.tunDevice.stack, &bind, nil, networkProtocol) } func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { return NewConnection(ctx, w, conn, metadata) } func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { return NewPacketConnection(ctx, w, conn, metadata) } func (w *WireGuard) Start() error { w.tunDevice.events <- tun.EventUp return nil } func (w *WireGuard) Close() error { return common.Close( common.PtrOrNil(w.tunDevice), common.PtrOrNil(w.device), common.PtrOrNil(w.conn), ) } var _ conn.Bind = (*wireClientBind)(nil) type wireClientBind WireGuard func (c *wireClientBind) connect() (*wireConn, error) { c.connAccess.Lock() defer c.connAccess.Unlock() if c.conn != nil { select { case <-c.conn.done: default: return c.conn, nil } } udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr) if err != nil { return nil, &wireError{err} } c.conn = &wireConn{ Conn: udpConn, done: make(chan struct{}), } return c.conn, nil } func (c *wireClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return []conn.ReceiveFunc{c.receive}, 0, nil } func (c *wireClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { udpConn, err := c.connect() if err != nil { return } n, err = udpConn.Read(b) if err != nil { udpConn.Close() err = &wireError{err} } ep = c.endpoint return } func (c *wireClientBind) Close() error { c.connAccess.Lock() defer c.connAccess.Unlock() common.Close(common.PtrOrNil(c.conn)) return nil } func (c *wireClientBind) SetMark(mark uint32) error { return nil } func (c *wireClientBind) Send(b []byte, ep conn.Endpoint) error { udpConn, err := c.connect() if err != nil { return err } _, err = udpConn.Write(b) if err != nil { udpConn.Close() } return err } func (c *wireClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { return c.endpoint, nil } type wireError struct { cause error } func (w *wireError) Error() string { return w.cause.Error() } func (w *wireError) Timeout() bool { if cause, causeNet := w.cause.(net.Error); causeNet { return cause.Timeout() } return false } func (w *wireError) Temporary() bool { return true } func (w *wireError) Unwrap() error { return w.cause } type wireConn struct { net.Conn access sync.Mutex done chan struct{} } func (w *wireConn) Close() error { w.access.Lock() defer w.access.Unlock() select { case <-w.done: return net.ErrClosed default: } w.Conn.Close() close(w.done) return nil } var _ tun.Device = (*wireTunDevice)(nil) const defaultNIC tcpip.NICID = 1 type wireTunDevice struct { stack *stack.Stack mtu uint32 events chan tun.Event outbound chan *stack.PacketBuffer dispatcher stack.NetworkDispatcher done chan struct{} addr4 tcpip.Address addr6 tcpip.Address } func newWireDevice(localAddresses []tcpip.AddressWithPrefix, mtu uint32) (*wireTunDevice, error) { ipStack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, HandleLocal: true, }) tunDevice := &wireTunDevice{ stack: ipStack, mtu: mtu, events: make(chan tun.Event, 4), outbound: make(chan *stack.PacketBuffer, 256), done: make(chan struct{}), } err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) if err != nil { return nil, E.New(err.String()) } for _, addr := range localAddresses { var protoAddr tcpip.ProtocolAddress if len(addr.Address) == net.IPv4len { tunDevice.addr4 = addr.Address protoAddr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: addr, } } else { tunDevice.addr6 = addr.Address protoAddr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: addr, } } err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{}) if err != nil { return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) } } sOpt := tcpip.TCPSACKEnabled(true) ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) cOpt := tcpip.CongestionControlOption("cubic") ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt) ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC}) ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC}) return tunDevice, nil } func (w *wireTunDevice) File() *os.File { return nil } func (w *wireTunDevice) Read(p []byte, offset int) (n int, err error) { packetBuffer, ok := <-w.outbound if !ok { return 0, os.ErrClosed } defer packetBuffer.DecRef() p = p[offset:] for _, slice := range packetBuffer.AsSlices() { n += copy(p[n:], slice) } return } func (w *wireTunDevice) Write(p []byte, offset int) (n int, err error) { p = p[offset:] if len(p) == 0 { return } var networkProtocol tcpip.NetworkProtocolNumber switch header.IPVersion(p) { case header.IPv4Version: networkProtocol = header.IPv4ProtocolNumber case header.IPv6Version: networkProtocol = header.IPv6ProtocolNumber } packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: bufferv2.MakeWithData(p), }) defer packetBuffer.DecRef() w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) n = len(p) return } func (w *wireTunDevice) Flush() error { return nil } func (w *wireTunDevice) MTU() (int, error) { return int(w.mtu), nil } func (w *wireTunDevice) Name() (string, error) { return "sing-box", nil } func (w *wireTunDevice) Events() chan tun.Event { return w.events } func (w *wireTunDevice) Close() error { select { case <-w.done: return os.ErrClosed default: } close(w.done) w.stack.Close() for _, endpoint := range w.stack.CleanupEndpoints() { endpoint.Abort() } w.stack.Wait() close(w.outbound) return nil } var _ stack.LinkEndpoint = (*wireEndpoint)(nil) type wireEndpoint wireTunDevice func (ep *wireEndpoint) MTU() uint32 { return ep.mtu } func (ep *wireEndpoint) MaxHeaderLength() uint16 { return 0 } func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress { return "" } func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone } func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) { ep.dispatcher = dispatcher } func (ep *wireEndpoint) IsAttached() bool { return ep.dispatcher != nil } func (ep *wireEndpoint) Wait() { } func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) { } func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) { for _, packetBuffer := range list.AsSlice() { packetBuffer.IncRef() ep.outbound <- packetBuffer } return list.Len(), nil }