package wireguard import ( "context" "net" "net/netip" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) func RegisterEndpoint(registry *endpoint.Registry) { endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint) } var ( _ adapter.Endpoint = (*Endpoint)(nil) _ adapter.InterfaceUpdateListener = (*Endpoint)(nil) ) type Endpoint struct { endpoint.Adapter ctx context.Context router adapter.Router logger logger.ContextLogger localAddresses []netip.Prefix endpoint *wireguard.Endpoint } func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) { ep := &Endpoint{ Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), ctx: ctx, router: router, logger: logger, localAddresses: options.Address, } if options.Detour == "" { options.IsWireGuardListener = true } outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } var udpTimeout time.Duration if options.UDPTimeout != 0 { udpTimeout = time.Duration(options.UDPTimeout) } else { udpTimeout = C.UDPTimeout } wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{ Context: ctx, Logger: logger, System: options.System, Handler: ep, UDPTimeout: udpTimeout, Dialer: outboundDialer, CreateDialer: func(interfaceName string) N.Dialer { return common.Must1(dialer.NewDefault(ctx, option.DialerOptions{ BindInterface: interfaceName, })) }, Name: options.Name, MTU: options.MTU, Address: options.Address, PrivateKey: options.PrivateKey, ListenPort: options.ListenPort, ResolvePeer: func(domain string) (netip.Addr, error) { endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy)) if lookupErr != nil { return netip.Addr{}, lookupErr } return endpointAddresses[0], nil }, Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions { return wireguard.PeerOptions{ Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port), PublicKey: it.PublicKey, PreSharedKey: it.PreSharedKey, AllowedIPs: it.AllowedIPs, PersistentKeepaliveInterval: it.PersistentKeepaliveInterval, Reserved: it.Reserved, } }), Workers: options.Workers, }) if err != nil { return nil, err } ep.endpoint = wgEndpoint return ep, nil } func (w *Endpoint) Start(stage adapter.StartStage) error { switch stage { case adapter.StartStateStart: return w.endpoint.Start(false) case adapter.StartStatePostStart: return w.endpoint.Start(true) } return nil } func (w *Endpoint) Close() error { return w.endpoint.Close() } func (w *Endpoint) InterfaceUpdated() { w.endpoint.BindUpdate() return } func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { return w.router.PreMatch(adapter.InboundContext{ Inbound: w.Tag(), InboundType: w.Type(), Network: network, Source: source, Destination: destination, }) } func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { var metadata adapter.InboundContext metadata.Inbound = w.Tag() metadata.InboundType = w.Type() metadata.Source = source for _, localPrefix := range w.localAddresses { if localPrefix.Contains(destination.Addr) { metadata.OriginDestination = destination if destination.Addr.Is4() { destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1}) } else { destination.Addr = netip.IPv6Loopback() } break } } metadata.Destination = destination w.logger.InfoContext(ctx, "inbound connection from ", source) w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) w.router.RouteConnectionEx(ctx, conn, metadata, onClose) } func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { var metadata adapter.InboundContext metadata.Inbound = w.Tag() metadata.InboundType = w.Type() metadata.Source = source metadata.Destination = destination for _, localPrefix := range w.localAddresses { if localPrefix.Contains(destination.Addr) { metadata.OriginDestination = destination if destination.Addr.Is4() { metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1}) } else { metadata.Destination.Addr = netip.IPv6Loopback() } conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) } } w.logger.InfoContext(ctx, "inbound packet connection from ", source) w.logger.InfoContext(ctx, "inbound packet connection to ", destination) w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) } func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch network { case N.NetworkTCP: w.logger.InfoContext(ctx, "outbound connection to ", destination) case N.NetworkUDP: w.logger.InfoContext(ctx, "outbound packet connection to ", destination) } if destination.IsFqdn() { destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses) } else if !destination.Addr.IsValid() { return nil, E.New("invalid destination: ", destination) } return w.endpoint.DialContext(ctx, network, destination) } func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { w.logger.InfoContext(ctx, "outbound packet connection to ", destination) if destination.IsFqdn() { destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses) if err != nil { return nil, err } return packetConn, err } return w.endpoint.ListenPacket(ctx, destination) }