refactor: WireGuard endpoint

This commit is contained in:
世界 2024-11-21 18:10:41 +08:00
parent b910428410
commit 013a69c002
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
76 changed files with 1650 additions and 675 deletions

28
adapter/endpoint.go Normal file
View file

@ -0,0 +1,28 @@
package adapter
import (
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type Endpoint interface {
Lifecycle
Type() string
Tag() string
Outbound
}
type EndpointRegistry interface {
option.EndpointOptionsRegistry
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) (Endpoint, error)
}
type EndpointManager interface {
Lifecycle
Endpoints() []Endpoint
Get(tag string) (Endpoint, bool)
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) error
}

View file

@ -0,0 +1,43 @@
package endpoint
import "github.com/sagernet/sing-box/option"
type Adapter struct {
endpointType string
endpointTag string
network []string
dependencies []string
}
func NewAdapter(endpointType string, endpointTag string, network []string, dependencies []string) Adapter {
return Adapter{
endpointType: endpointType,
endpointTag: endpointTag,
network: network,
dependencies: dependencies,
}
}
func NewAdapterWithDialerOptions(endpointType string, endpointTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string
if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour}
}
return NewAdapter(endpointType, endpointTag, network, dependencies)
}
func (a *Adapter) Type() string {
return a.endpointType
}
func (a *Adapter) Tag() string {
return a.endpointTag
}
func (a *Adapter) Network() []string {
return a.network
}
func (a *Adapter) Dependencies() []string {
return a.dependencies
}

147
adapter/endpoint/manager.go Normal file
View file

@ -0,0 +1,147 @@
package endpoint
import (
"context"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
var _ adapter.EndpointManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.EndpointRegistry
access sync.Mutex
started bool
stage adapter.StartStage
endpoints []adapter.Endpoint
endpointByTag map[string]adapter.Endpoint
}
func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
endpointByTag: make(map[string]adapter.Endpoint),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
if stage == adapter.StartStateStart {
// started with outbound manager
return nil
}
for _, endpoint := range m.endpoints {
err := adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if !m.started {
return nil
}
m.started = false
endpoints := m.endpoints
m.endpoints = nil
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, endpoint := range endpoints {
monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
err = E.Append(err, endpoint.Close(), func(err error) error {
return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
})
monitor.Finish()
}
return nil
}
func (m *Manager) Endpoints() []adapter.Endpoint {
m.access.Lock()
defer m.access.Unlock()
return m.endpoints
}
func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
m.access.Lock()
defer m.access.Unlock()
endpoint, found := m.endpointByTag[tag]
return endpoint, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
endpoint, found := m.endpointByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.endpointByTag, tag)
index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == endpoint
})
if index == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return endpoint.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
}
if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
if m.started {
err = existsEndpoint.Close()
if err != nil {
return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
}
}
existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == existsEndpoint
})
if existsIndex == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
}
m.endpoints = append(m.endpoints, endpoint)
m.endpointByTag[tag] = endpoint
return nil
}

View file

@ -0,0 +1,72 @@
package endpoint
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options T) (adapter.Endpoint, error)
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any {
return new(Options)
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Endpoint, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.EndpointRegistry = (*Registry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Endpoint, error)
)
type Registry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructor map[string]constructorFunc
}
func NewRegistry() *Registry {
return &Registry{
optionsType: make(map[string]optionsConstructorFunc),
constructor: make(map[string]constructorFunc),
}
}
func (m *Registry) CreateOptions(outboundType string) (any, bool) {
m.access.Lock()
defer m.access.Unlock()
optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Endpoint, error) {
m.access.Lock()
defer m.access.Unlock()
constructor, loaded := m.constructor[outboundType]
if !loaded {
return nil, E.New("outbound type not found: " + outboundType)
}
return constructor(ctx, router, logger, tag, options)
}
func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
m.access.Lock()
defer m.access.Unlock()
m.optionsType[outboundType] = optionsConstructor
m.constructor[outboundType] = constructor
}

View file

@ -13,7 +13,7 @@ import (
) )
type Inbound interface { type Inbound interface {
Service Lifecycle
Type() string Type() string
Tag() string Tag() string
} }

View file

@ -18,6 +18,7 @@ var _ adapter.InboundManager = (*Manager)(nil)
type Manager struct { type Manager struct {
logger log.ContextLogger logger log.ContextLogger
registry adapter.InboundRegistry registry adapter.InboundRegistry
endpoint adapter.EndpointManager
access sync.Mutex access sync.Mutex
started bool started bool
stage adapter.StartStage stage adapter.StartStage
@ -25,10 +26,11 @@ type Manager struct {
inboundByTag map[string]adapter.Inbound inboundByTag map[string]adapter.Inbound
} }
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager { func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endpoint adapter.EndpointManager) *Manager {
return &Manager{ return &Manager{
logger: logger, logger: logger,
registry: registry, registry: registry,
endpoint: endpoint,
inboundByTag: make(map[string]adapter.Inbound), inboundByTag: make(map[string]adapter.Inbound),
} }
} }
@ -79,9 +81,12 @@ func (m *Manager) Inbounds() []adapter.Inbound {
func (m *Manager) Get(tag string) (adapter.Inbound, bool) { func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
m.access.Lock() m.access.Lock()
defer m.access.Unlock()
inbound, found := m.inboundByTag[tag] inbound, found := m.inboundByTag[tag]
return inbound, found m.access.Unlock()
if found {
return inbound, true
}
return m.endpoint.Get(tag)
} }
func (m *Manager) Remove(tag string) error { func (m *Manager) Remove(tag string) error {

View file

@ -1,6 +1,9 @@
package adapter package adapter
func LegacyStart(starter any, stage StartStage) error { func LegacyStart(starter any, stage StartStage) error {
if lifecycle, isLifecycle := starter.(Lifecycle); isLifecycle {
return lifecycle.Start(stage)
}
switch stage { switch stage {
case StartStateInitialize: case StartStateInitialize:
if preStarter, isPreStarter := starter.(interface { if preStarter, isPreStarter := starter.(interface {

View file

@ -5,35 +5,35 @@ import (
) )
type Adapter struct { type Adapter struct {
protocol string outboundType string
outboundTag string
network []string network []string
tag string
dependencies []string dependencies []string
} }
func NewAdapter(protocol string, network []string, tag string, dependencies []string) Adapter { func NewAdapter(outboundType string, outboundTag string, network []string, dependencies []string) Adapter {
return Adapter{ return Adapter{
protocol: protocol, outboundType: outboundType,
outboundTag: outboundTag,
network: network, network: network,
tag: tag,
dependencies: dependencies, dependencies: dependencies,
} }
} }
func NewAdapterWithDialerOptions(protocol string, network []string, tag string, dialOptions option.DialerOptions) Adapter { func NewAdapterWithDialerOptions(outboundType string, outboundTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string var dependencies []string
if dialOptions.Detour != "" { if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour} dependencies = []string{dialOptions.Detour}
} }
return NewAdapter(protocol, network, tag, dependencies) return NewAdapter(outboundType, outboundTag, network, dependencies)
} }
func (a *Adapter) Type() string { func (a *Adapter) Type() string {
return a.protocol return a.outboundType
} }
func (a *Adapter) Tag() string { func (a *Adapter) Tag() string {
return a.tag return a.outboundTag
} }
func (a *Adapter) Network() []string { func (a *Adapter) Network() []string {

View file

@ -21,6 +21,7 @@ var _ adapter.OutboundManager = (*Manager)(nil)
type Manager struct { type Manager struct {
logger log.ContextLogger logger log.ContextLogger
registry adapter.OutboundRegistry registry adapter.OutboundRegistry
endpoint adapter.EndpointManager
defaultTag string defaultTag string
access sync.Mutex access sync.Mutex
started bool started bool
@ -32,10 +33,11 @@ type Manager struct {
defaultOutboundFallback adapter.Outbound defaultOutboundFallback adapter.Outbound
} }
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager { func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
return &Manager{ return &Manager{
logger: logger, logger: logger,
registry: registry, registry: registry,
endpoint: endpoint,
defaultTag: defaultTag, defaultTag: defaultTag,
outboundByTag: make(map[string]adapter.Outbound), outboundByTag: make(map[string]adapter.Outbound),
dependByTag: make(map[string][]string), dependByTag: make(map[string][]string),
@ -56,7 +58,14 @@ func (m *Manager) Start(stage adapter.StartStage) error {
outbounds := m.outbounds outbounds := m.outbounds
m.access.Unlock() m.access.Unlock()
if stage == adapter.StartStateStart { if stage == adapter.StartStateStart {
return m.startOutbounds(outbounds) if m.defaultOutbound == nil {
if len(outbounds) > 0 {
m.defaultOutbound = outbounds[0]
} else if len(m.endpoint.Endpoints()) > 0 {
m.defaultOutbound = m.endpoint.Endpoints()[0]
}
}
return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
} else { } else {
for _, outbound := range outbounds { for _, outbound := range outbounds {
err := adapter.LegacyStart(outbound, stage) err := adapter.LegacyStart(outbound, stage)
@ -87,7 +96,14 @@ func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
} }
started[outboundTag] = true started[outboundTag] = true
canContinue = true canContinue = true
if starter, isStarter := outboundToStart.(interface { if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start(adapter.StartStateStart)
monitor.Finish()
if err != nil {
return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
} else if starter, isStarter := outboundToStart.(interface {
Start() error Start() error
}); isStarter { }); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]") monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
@ -160,9 +176,12 @@ func (m *Manager) Outbounds() []adapter.Outbound {
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock() m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag] outbound, found := m.outboundByTag[tag]
return outbound, found m.access.Unlock()
if found {
return outbound, true
}
return m.endpoint.Get(tag)
} }
func (m *Manager) Default() adapter.Outbound { func (m *Manager) Default() adapter.Outbound {
@ -195,6 +214,9 @@ func (m *Manager) Remove(tag string) error {
if len(m.outbounds) > 0 { if len(m.outbounds) > 0 {
m.defaultOutbound = m.outbounds[0] m.defaultOutbound = m.outbounds[0]
m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag()) m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
} else if len(m.endpoint.Endpoints()) > 0 {
m.defaultOutbound = m.endpoint.Endpoints()[0]
m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
} else { } else {
m.defaultOutbound = nil m.defaultOutbound = nil
} }
@ -259,7 +281,7 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.
for _, dependency := range dependencies { for _, dependency := range dependencies {
m.dependByTag[dependency] = append(m.dependByTag[dependency], tag) m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
} }
if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) { if tag == m.defaultTag || (m.started && m.defaultTag == "" && m.defaultOutbound == nil) {
m.defaultOutbound = outbound m.defaultOutbound = outbound
if m.started { if m.started {
m.logger.Info("updated default outbound to ", outbound.Tag()) m.logger.Info("updated default outbound to ", outbound.Tag())

51
box.go
View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
@ -36,6 +37,7 @@ type Box struct {
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
network *route.NetworkManager network *route.NetworkManager
endpoint *endpoint.Manager
inbound *inbound.Manager inbound *inbound.Manager
outbound *outbound.Manager outbound *outbound.Manager
connection *route.ConnectionManager connection *route.ConnectionManager
@ -54,6 +56,7 @@ func Context(
ctx context.Context, ctx context.Context,
inboundRegistry adapter.InboundRegistry, inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry, outboundRegistry adapter.OutboundRegistry,
endpointRegistry adapter.EndpointRegistry,
) context.Context { ) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil { service.FromContext[adapter.InboundRegistry](ctx) == nil {
@ -65,6 +68,11 @@ func Context(
ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry) ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry)
ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry) ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry)
} }
if service.FromContext[option.EndpointOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.EndpointRegistry](ctx) == nil {
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
}
return ctx return ctx
} }
@ -76,12 +84,16 @@ func New(options Options) (*Box, error) {
} }
ctx = service.ContextWithDefaultRegistry(ctx) ctx = service.ContextWithDefaultRegistry(ctx)
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context")
}
if inboundRegistry == nil { if inboundRegistry == nil {
return nil, E.New("missing inbound registry in context") return nil, E.New("missing inbound registry in context")
} }
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
if outboundRegistry == nil { if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context") return nil, E.New("missing outbound registry in context")
} }
@ -119,8 +131,10 @@ func New(options Options) (*Box, error) {
} }
routeOptions := common.PtrValueOrDefault(options.Route) routeOptions := common.PtrValueOrDefault(options.Route)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry) endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
@ -135,6 +149,24 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize router") return nil, E.Cause(err, "initialize router")
} }
for i, endpointOptions := range options.Endpoints {
var tag string
if endpointOptions.Tag != "" {
tag = endpointOptions.Tag
} else {
tag = F.ToString(i)
}
err = endpointManager.Create(ctx,
router,
logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")),
tag,
endpointOptions.Type,
endpointOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]")
}
}
for i, inboundOptions := range options.Inbounds { for i, inboundOptions := range options.Inbounds {
var tag string var tag string
if inboundOptions.Tag != "" { if inboundOptions.Tag != "" {
@ -241,6 +273,7 @@ func New(options Options) (*Box, error) {
} }
return &Box{ return &Box{
network: networkManager, network: networkManager,
endpoint: endpointManager,
inbound: inboundManager, inbound: inboundManager,
outbound: outboundManager, outbound: outboundManager,
connection: connectionManager, connection: connectionManager,
@ -303,7 +336,7 @@ func (s *Box) preStart() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound) err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
@ -327,7 +360,11 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound) err = adapter.Start(adapter.StartStateStart, s.endpoint)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
@ -335,7 +372,7 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound) err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }

View file

@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) {
configPaths = append(configPaths, "config.json") configPaths = append(configPaths, "config.json")
} }
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger())) globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry()) globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
} }

View file

@ -279,7 +279,7 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
} }
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
return trackPacketConn(d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)) return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address))
} }
func trackConn(conn net.Conn, err error) (net.Conn, error) { func trackConn(conn net.Conn, err error) (net.Conn, error) {

View file

@ -109,6 +109,15 @@ var OptionDestinationOverrideFields = Note{
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options", MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options",
} }
var OptionWireGuardOutbound = Note{
Name: "wireguard-outbound",
Description: "legacy wireguard outbound",
DeprecatedVersion: "1.11.0",
ScheduledVersion: "1.13.0",
EnvName: "WIREGUARD_OUTBOUND",
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-wireguard-outbound-to-endpoint",
}
var Options = []Note{ var Options = []Note{
OptionBadMatchSource, OptionBadMatchSource,
OptionGEOIP, OptionGEOIP,
@ -117,4 +126,5 @@ var Options = []Note{
OptionSpecialOutbounds, OptionSpecialOutbounds,
OptionInboundOptions, OptionInboundOptions,
OptionDestinationOverrideFields, OptionDestinationOverrideFields,
OptionWireGuardOutbound,
} }

View file

@ -30,7 +30,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err
} }
func CheckConfig(configContent string) error { func CheckConfig(configContent string) error {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
options, err := parseConfig(ctx, configContent) options, err := parseConfig(ctx, configContent)
if err != nil { if err != nil {
return err return err
@ -131,7 +131,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica
} }
func FormatConfig(configContent string) (string, error) { func FormatConfig(configContent string) (string, error) {
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()), configContent) options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent)
if err != nil { if err != nil {
return "", err return "", err
} }

View file

@ -44,7 +44,7 @@ type BoxService struct {
} }
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) { func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager)) service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
options, err := parseConfig(ctx, configContent) options, err := parseConfig(ctx, configContent)

4
go.mod
View file

@ -32,11 +32,11 @@ require (
github.com/sagernet/sing-shadowsocks v0.2.7 github.com/sagernet/sing-shadowsocks v0.2.7
github.com/sagernet/sing-shadowsocks2 v0.2.0 github.com/sagernet/sing-shadowsocks2 v0.2.0
github.com/sagernet/sing-shadowtls v0.2.0-alpha.2 github.com/sagernet/sing-shadowtls v0.2.0-alpha.2
github.com/sagernet/sing-tun v0.6.0-alpha.9 github.com/sagernet/sing-tun v0.6.0-alpha.10
github.com/sagernet/sing-vmess v0.1.12 github.com/sagernet/sing-vmess v0.1.12
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
github.com/sagernet/utls v1.6.7 github.com/sagernet/utls v1.6.7
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8 github.com/sagernet/wireguard-go v0.0.1-beta.2
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0

8
go.sum
View file

@ -124,16 +124,16 @@ github.com/sagernet/sing-shadowsocks2 v0.2.0 h1:wpZNs6wKnR7mh1wV9OHwOyUr21VkS3wK
github.com/sagernet/sing-shadowsocks2 v0.2.0/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowsocks2 v0.2.0/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ=
github.com/sagernet/sing-shadowtls v0.2.0-alpha.2 h1:RPrpgAdkP5td0vLfS5ldvYosFjSsZtRPxiyLV6jyKg0= github.com/sagernet/sing-shadowtls v0.2.0-alpha.2 h1:RPrpgAdkP5td0vLfS5ldvYosFjSsZtRPxiyLV6jyKg0=
github.com/sagernet/sing-shadowtls v0.2.0-alpha.2/go.mod h1:0j5XlzKxaWRIEjc1uiSKmVoWb0k+L9QgZVb876+thZA= github.com/sagernet/sing-shadowtls v0.2.0-alpha.2/go.mod h1:0j5XlzKxaWRIEjc1uiSKmVoWb0k+L9QgZVb876+thZA=
github.com/sagernet/sing-tun v0.6.0-alpha.9 h1:Qf667035KnlydZ+ftj3U4HH+oddi3RdyKzBiCcnSgaI= github.com/sagernet/sing-tun v0.6.0-alpha.10 h1:kJOMUR6VKHkTrtJ+kPJVsCqrJYmW0nTRJLYv+Or7lNA=
github.com/sagernet/sing-tun v0.6.0-alpha.9/go.mod h1:TgvxE2YD7O9c/unHju0nWAGBGsVppWIuju13vlmdllM= github.com/sagernet/sing-tun v0.6.0-alpha.10/go.mod h1:UmZpZ06gItrbOFLhyeZsilHKQDa5h4NSQy8LalkTkXQ=
github.com/sagernet/sing-vmess v0.1.12 h1:2gFD8JJb+eTFMoa8FIVMnknEi+vCSfaiTXTfEYAYAPg= github.com/sagernet/sing-vmess v0.1.12 h1:2gFD8JJb+eTFMoa8FIVMnknEi+vCSfaiTXTfEYAYAPg=
github.com/sagernet/sing-vmess v0.1.12/go.mod h1:luTSsfyBGAc9VhtCqwjR+dt1QgqBhuYBCONB/POhF8I= github.com/sagernet/sing-vmess v0.1.12/go.mod h1:luTSsfyBGAc9VhtCqwjR+dt1QgqBhuYBCONB/POhF8I=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
github.com/sagernet/utls v1.6.7 h1:Ep3+aJ8FUGGta+II2IEVNUc3EDhaRCZINWkj/LloIA8= github.com/sagernet/utls v1.6.7 h1:Ep3+aJ8FUGGta+II2IEVNUc3EDhaRCZINWkj/LloIA8=
github.com/sagernet/utls v1.6.7/go.mod h1:Uua1TKO/FFuAhLr9rkaVnnrTmmiItzDjv1BUb2+ERwM= github.com/sagernet/utls v1.6.7/go.mod h1:Uua1TKO/FFuAhLr9rkaVnnrTmmiItzDjv1BUb2+ERwM=
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8 h1:R0OMYAScomNAVpTfbHFpxqJpvwuhxSRi+g6z7gZhABs= github.com/sagernet/wireguard-go v0.0.1-beta.2 h1:afmDgfCL2Esc+2EYtdcJFepTWHX9+kZnosC0A84VJ9s=
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8/go.mod h1:K4J7/npM+VAMUeUmTa2JaA02JmyheP0GpRBOUvn3ecc= github.com/sagernet/wireguard-go v0.0.1-beta.2/go.mod h1:8xfewtQJZ1g3HeMQbLpJxTjyTiE3FL+Joq5LQoKLFEw=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@ -82,6 +83,14 @@ func OutboundRegistry() *outbound.Registry {
return registry return registry
} }
func EndpointRegistry() *endpoint.Registry {
registry := endpoint.NewRegistry()
registerWireGuardEndpoint(registry)
return registry
}
func registerStubForRemovedInbounds(registry *inbound.Registry) { func registerStubForRemovedInbounds(registry *inbound.Registry) {
inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) { inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0") return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0")

View file

@ -3,6 +3,7 @@
package include package include
import ( import (
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/protocol/wireguard" "github.com/sagernet/sing-box/protocol/wireguard"
) )
@ -10,3 +11,7 @@ import (
func registerWireGuardOutbound(registry *outbound.Registry) { func registerWireGuardOutbound(registry *outbound.Registry) {
wireguard.RegisterOutbound(registry) wireguard.RegisterOutbound(registry)
} }
func registerWireGuardEndpoint(registry *endpoint.Registry) {
wireguard.RegisterEndpoint(registry)
}

View file

@ -6,6 +6,7 @@ import (
"context" "context"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@ -14,7 +15,13 @@ import (
) )
func registerWireGuardOutbound(registry *outbound.Registry) { func registerWireGuardOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) { outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
})
}
func registerWireGuardEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`) return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
}) })
} }

47
option/endpoint.go Normal file
View file

@ -0,0 +1,47 @@
package option
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/service"
)
type EndpointOptionsRegistry interface {
CreateOptions(endpointType string) (any, bool)
}
type _Endpoint struct {
Type string `json:"type"`
Tag string `json:"tag,omitempty"`
Options any `json:"-"`
}
type Endpoint _Endpoint
func (h *Endpoint) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return badjson.MarshallObjectsContext(ctx, (*_Endpoint)(h), h.Options)
}
func (h *Endpoint) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.UnmarshalContext(ctx, content, (*_Endpoint)(h))
if err != nil {
return err
}
registry := service.FromContext[EndpointOptionsRegistry](ctx)
if registry == nil {
return E.New("missing Endpoint fields registry in context")
}
options, loaded := registry.CreateOptions(h.Type)
if !loaded {
return E.New("unknown inbound type: ", h.Type)
}
err = badjson.UnmarshallExcludedContext(ctx, content, (*_Endpoint)(h), options)
if err != nil {
return err
}
h.Options = options
return nil
}

View file

@ -28,7 +28,7 @@ func (h *Inbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
} }
func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error { func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.Unmarshal(content, (*_Inbound)(h)) err := json.UnmarshalContext(ctx, content, (*_Inbound)(h))
if err != nil { if err != nil {
return err return err
} }

View file

@ -13,6 +13,7 @@ type _Options struct {
Log *LogOptions `json:"log,omitempty"` Log *LogOptions `json:"log,omitempty"`
DNS *DNSOptions `json:"dns,omitempty"` DNS *DNSOptions `json:"dns,omitempty"`
NTP *NTPOptions `json:"ntp,omitempty"` NTP *NTPOptions `json:"ntp,omitempty"`
Endpoints []Endpoint `json:"endpoints,omitempty"`
Inbounds []Inbound `json:"inbounds,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"` Route *RouteOptions `json:"route,omitempty"`

View file

@ -30,7 +30,7 @@ func (h *Outbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
} }
func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error { func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.Unmarshal(content, (*_Outbound)(h)) err := json.UnmarshalContext(ctx, content, (*_Outbound)(h))
if err != nil { if err != nil {
return err return err
} }

View file

@ -6,14 +6,38 @@ import (
"github.com/sagernet/sing/common/json/badoption" "github.com/sagernet/sing/common/json/badoption"
) )
type WireGuardOutboundOptions struct { type WireGuardEndpointOptions struct {
System bool `json:"system,omitempty"`
Name string `json:"name,omitempty"`
MTU uint32 `json:"mtu,omitempty"`
GSO bool `json:"gso,omitempty"`
Address badoption.Listable[netip.Prefix] `json:"address"`
PrivateKey string `json:"private_key"`
ListenPort uint16 `json:"listen_port,omitempty"`
Peers []WireGuardPeer `json:"peers,omitempty"`
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
Workers int `json:"workers,omitempty"`
DialerOptions
}
type WireGuardPeer struct {
Address string `json:"address,omitempty"`
Port uint16 `json:"port,omitempty"`
PublicKey string `json:"public_key,omitempty"`
PreSharedKey string `json:"pre_shared_key,omitempty"`
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
PersistentKeepaliveInterval badoption.Duration `json:"persistent_keepalive_interval,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"`
}
type LegacyWireGuardOutboundOptions struct {
DialerOptions DialerOptions
SystemInterface bool `json:"system_interface,omitempty"` SystemInterface bool `json:"system_interface,omitempty"`
GSO bool `json:"gso,omitempty"` GSO bool `json:"gso,omitempty"`
InterfaceName string `json:"interface_name,omitempty"` InterfaceName string `json:"interface_name,omitempty"`
LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"` LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"`
PrivateKey string `json:"private_key"` PrivateKey string `json:"private_key"`
Peers []WireGuardPeer `json:"peers,omitempty"` Peers []LegacyWireGuardPeer `json:"peers,omitempty"`
ServerOptions ServerOptions
PeerPublicKey string `json:"peer_public_key"` PeerPublicKey string `json:"peer_public_key"`
PreSharedKey string `json:"pre_shared_key,omitempty"` PreSharedKey string `json:"pre_shared_key,omitempty"`
@ -23,10 +47,10 @@ type WireGuardOutboundOptions struct {
Network NetworkList `json:"network,omitempty"` Network NetworkList `json:"network,omitempty"`
} }
type WireGuardPeer struct { type LegacyWireGuardPeer struct {
ServerOptions ServerOptions
PublicKey string `json:"public_key,omitempty"` PublicKey string `json:"public_key,omitempty"`
PreSharedKey string `json:"pre_shared_key,omitempty"` PreSharedKey string `json:"pre_shared_key,omitempty"`
AllowedIPs badoption.Listable[string] `json:"allowed_ips,omitempty"` AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"` Reserved []uint8 `json:"reserved,omitempty"`
} }

View file

@ -26,7 +26,7 @@ type Outbound struct {
func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) { func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) {
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapter(C.TypeBlock, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil), Adapter: outbound.NewAdapter(C.TypeBlock, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
logger: logger, logger: logger,
}, nil }, nil
} }

View file

@ -68,7 +68,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (i *Inbound) Start() error { func (i *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return i.listener.Start() return i.listener.Start()
} }

View file

@ -52,7 +52,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
logger: logger, logger: logger,
domainStrategy: dns.DomainStrategy(options.DomainStrategy), domainStrategy: dns.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay), fallbackDelay: time.Duration(options.FallbackDelay),

View file

@ -28,7 +28,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapter(C.TypeDNS, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil), Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
router: router, router: router,
logger: logger, logger: logger,
}, nil }, nil

View file

@ -38,7 +38,7 @@ type Selector struct {
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) { func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
outbound := &Selector{ outbound := &Selector{
Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
ctx: ctx, ctx: ctx,
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,

View file

@ -49,7 +49,7 @@ type URLTest struct {
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) { func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
outbound := &URLTest{ outbound := &URLTest{
Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx, ctx: ctx,
router: router, router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outboundManager: service.FromContext[adapter.OutboundManager](ctx),

View file

@ -61,7 +61,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -39,7 +39,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, tag, []string{N.NetworkTCP}, options.DialerOptions),
logger: logger, logger: logger,
client: sHTTP.NewClient(sHTTP.Options{ client: sHTTP.NewClient(sHTTP.Options{
Dialer: detour, Dialer: detour,

View file

@ -160,7 +160,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -95,7 +95,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, networkList, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions),
logger: logger, logger: logger,
client: client, client: client,
}, nil }, nil

View file

@ -171,7 +171,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -81,7 +81,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, networkList, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions),
logger: logger, logger: logger,
client: client, client: client,
}, nil }, nil

View file

@ -54,7 +54,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -78,7 +78,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (n *Inbound) Start() error { func (n *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
var tlsConfig *tls.STDConfig var tlsConfig *tls.STDConfig
if n.tlsConfig != nil { if n.tlsConfig != nil {
err := n.tlsConfig.Start() err := n.tlsConfig.Start()

View file

@ -42,7 +42,10 @@ func NewRedirect(ctx context.Context, router adapter.Router, logger log.ContextL
return redirect, nil return redirect, nil
} }
func (h *Redirect) Start() error { func (h *Redirect) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -61,7 +61,10 @@ func NewTProxy(ctx context.Context, router adapter.Router, logger log.ContextLog
return tproxy, nil return tproxy, nil
} }
func (t *TProxy) Start() error { func (t *TProxy) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := t.listener.Start() err := t.listener.Start()
if err != nil { if err != nil {
return err return err

View file

@ -93,7 +93,10 @@ func newInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, err return inbound, err
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -101,7 +101,10 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont
return inbound, err return inbound, err
} }
func (h *MultiInbound) Start() error { func (h *MultiInbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -86,7 +86,10 @@ func newRelayInbound(ctx context.Context, router adapter.Router, logger log.Cont
return inbound, err return inbound, err
} }
func (h *RelayInbound) Start() error { func (h *RelayInbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -49,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
method: method, method: method,

View file

@ -90,7 +90,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -29,7 +29,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) {
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, tag, []string{N.NetworkTCP}, options.DialerOptions),
} }
if options.TLS == nil || !options.TLS.Enabled { if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired return nil, C.ErrTLSRequired

View file

@ -50,7 +50,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -50,7 +50,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
router: router, router: router,
logger: logger, logger: logger,
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password), client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),

View file

@ -54,7 +54,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, tag, []string{N.NetworkTCP}, options.DialerOptions),
ctx: ctx, ctx: ctx,
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,

View file

@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, tag, []string{N.NetworkTCP}, options.DialerOptions),
ctx: ctx, ctx: ctx,
logger: logger, logger: logger,
proxy: NewProxyListener(ctx, logger, outboundDialer), proxy: NewProxyListener(ctx, logger, outboundDialer),

View file

@ -110,7 +110,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -43,7 +43,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(), serverAddr: options.ServerOptions.Build(),

View file

@ -142,7 +142,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
client: client, client: client,
udpStream: options.UDPOverStream, udpStream: options.UDPOverStream,

View file

@ -300,7 +300,9 @@ func (t *Inbound) Tag() string {
return t.tag return t.tag
} }
func (t *Inbound) Start() error { func (t *Inbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
if C.IsAndroid && t.platformInterface == nil { if C.IsAndroid && t.platformInterface == nil {
t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager()) t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager())
} }
@ -348,10 +350,7 @@ func (t *Inbound) Start() error {
} }
t.tunStack = tunStack t.tunStack = tunStack
t.logger.Info("started at ", t.tunOptions.Name) t.logger.Info("started at ", t.tunOptions.Name)
return nil case adapter.StartStatePostStart:
}
func (t *Inbound) PostStart() error {
monitor := taskmonitor.New(t.logger, C.StartTimeout) monitor := taskmonitor.New(t.logger, C.StartTimeout)
monitor.Start("starting tun stack") monitor.Start("starting tun stack")
err := t.tunStack.Start() err := t.tunStack.Start()
@ -399,6 +398,7 @@ func (t *Inbound) PostStart() error {
t.routeAddressSet = nil t.routeAddressSet = nil
t.routeExcludeAddressSet = nil t.routeExcludeAddressSet = nil
} }
}
return nil return nil
} }

View file

@ -89,7 +89,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(), serverAddr: options.ServerOptions.Build(),

View file

@ -99,7 +99,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := h.service.Start() err := h.service.Start()
if err != nil { if err != nil {
return err return err

View file

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(), serverAddr: options.ServerOptions.Build(),

View file

@ -0,0 +1,211 @@
package wireguard
import (
"context"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
}
var (
_ adapter.Endpoint = (*Endpoint)(nil)
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
)
type Endpoint struct {
endpoint.Adapter
ctx context.Context
router adapter.Router
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
}
func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
ep := &Endpoint{
Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
logger: logger,
localAddresses: options.Address,
}
if options.Detour == "" {
options.IsWireGuardListener = true
} else if options.GSO {
return nil, E.New("gso is conflict with detour")
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil {
return nil, err
}
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx,
Logger: logger,
System: options.System,
Handler: ep,
UDPTimeout: time.Duration(options.UDPTimeout),
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
BindInterface: interfaceName,
}))
},
Name: options.Name,
MTU: options.MTU,
GSO: options.GSO,
Address: options.Address,
PrivateKey: options.PrivateKey,
ListenPort: options.ListenPort,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
return endpointAddresses[0], nil
},
Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions {
return wireguard.PeerOptions{
Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port),
PublicKey: it.PublicKey,
PreSharedKey: it.PreSharedKey,
AllowedIPs: it.AllowedIPs,
PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval),
Reserved: it.Reserved,
}
}),
Workers: options.Workers,
})
if err != nil {
return nil, err
}
ep.endpoint = wgEndpoint
return ep, nil
}
func (w *Endpoint) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
return w.endpoint.Start(false)
case adapter.StartStatePostStart:
return w.endpoint.Start(true)
}
return nil
}
func (w *Endpoint) Close() error {
return w.endpoint.Close()
}
func (w *Endpoint) InterfaceUpdated() {
w.endpoint.BindUpdate()
return
}
func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
return w.router.PreMatch(adapter.InboundContext{
Inbound: w.Tag(),
InboundType: w.Type(),
Network: network,
Source: source,
Destination: destination,
})
}
func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext
metadata.Inbound = w.Tag()
metadata.InboundType = w.Type()
metadata.Source = source
for _, localPrefix := range w.localAddresses {
if localPrefix.Contains(destination.Addr) {
metadata.OriginDestination = destination
if destination.Addr.Is4() {
destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
} else {
destination.Addr = netip.IPv6Loopback()
}
break
}
}
metadata.Destination = destination
w.logger.InfoContext(ctx, "inbound connection from ", source)
w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
w.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}
func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext
metadata.Inbound = w.Tag()
metadata.InboundType = w.Type()
metadata.Source = source
metadata.Destination = destination
for _, localPrefix := range w.localAddresses {
if localPrefix.Contains(destination.Addr) {
metadata.OriginDestination = destination
if destination.Addr.Is4() {
metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
} else {
metadata.Destination.Addr = netip.IPv6Loopback()
}
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
}
}
w.logger.InfoContext(ctx, "inbound packet connection from ", source)
w.logger.InfoContext(ctx, "inbound packet connection to ", destination)
w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return w.endpoint.DialContext(ctx, network, destination)
}
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, err
}
return w.endpoint.ListenPacket(ctx, destination)
}

View file

@ -2,231 +2,153 @@ package wireguard
import ( import (
"context" "context"
"encoding/base64"
"encoding/hex"
"fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard" "github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
) )
func RegisterOutbound(registry *outbound.Registry) { func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound) outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
} }
var _ adapter.InterfaceUpdateListener = (*Outbound)(nil) var (
_ adapter.Endpoint = (*Endpoint)(nil)
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
)
type Outbound struct { type Outbound struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
logger logger.ContextLogger logger logger.ContextLogger
workers int localAddresses []netip.Prefix
peers []wireguard.PeerConfig endpoint *wireguard.Endpoint
useStdNetBind bool
listener N.Dialer
ipcConf string
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
bind conn.Bind
device *device.Device
tunDevice wireguard.Device
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
deprecated.Report(ctx, deprecated.OptionWireGuardOutbound)
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx, ctx: ctx,
router: router, router: router,
logger: logger, logger: logger,
workers: options.Workers, localAddresses: options.LocalAddress,
pauseManager: service.FromContext[pause.Manager](ctx),
} }
peers, err := wireguard.ParsePeers(options) if options.Detour == "" {
if err != nil { options.IsWireGuardListener = true
return nil, err } else if options.GSO {
}
outbound.peers = peers
if len(options.LocalAddress) == 0 {
return nil, E.New("missing local address")
}
if options.GSO {
if options.GSO && options.Detour != "" {
return nil, E.New("gso is conflict with detour") return nil, E.New("gso is conflict with detour")
} }
options.IsWireGuardListener = true outboundDialer, err := dialer.New(ctx, options.DialerOptions)
outbound.useStdNetBind = true
}
listener, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outbound.listener = listener wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
var privateKey string Context: ctx,
{ Logger: logger,
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) System: options.SystemInterface,
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
BindInterface: interfaceName,
}))
},
Name: options.InterfaceName,
MTU: options.MTU,
GSO: options.GSO,
Address: options.LocalAddress,
PrivateKey: options.PrivateKey,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
return endpointAddresses[0], nil
},
Peers: common.Map(options.Peers, func(it option.LegacyWireGuardPeer) wireguard.PeerOptions {
return wireguard.PeerOptions{
Endpoint: it.ServerOptions.Build(),
PublicKey: it.PublicKey,
PreSharedKey: it.PreSharedKey,
AllowedIPs: it.AllowedIPs,
// PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval),
Reserved: it.Reserved,
}
}),
Workers: options.Workers,
})
if err != nil { if err != nil {
return nil, E.Cause(err, "decode private key") return nil, err
} }
privateKey = hex.EncodeToString(bytes) outbound.endpoint = wgEndpoint
}
outbound.ipcConf = "private_key=" + privateKey
mtu := options.MTU
if mtu == 0 {
mtu = 1408
}
var wireTunDevice wireguard.Device
if !options.SystemInterface && tun.WithGVisor {
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
} else {
wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
}
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
outbound.tunDevice = wireTunDevice
return outbound, nil return outbound, nil
} }
func (w *Outbound) Start() error { func (o *Outbound) Start(stage adapter.StartStage) error {
if common.Any(w.peers, func(peer wireguard.PeerConfig) bool { switch stage {
return !peer.Endpoint.IsValid() case adapter.StartStateStart:
}) { return o.endpoint.Start(false)
// wait for all outbounds to be started and continue in PortStart case adapter.StartStatePostStart:
return nil return o.endpoint.Start(true)
}
return w.start()
}
func (w *Outbound) PostStart() error {
if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
return peer.Endpoint.IsValid()
}) {
return nil
}
return w.start()
}
func (w *Outbound) start() error {
err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
if err != nil {
return err
}
var bind conn.Bind
if w.useStdNetBind {
bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
peerLen := len(w.peers)
if peerLen == 1 {
isConnect = true
connectAddr = w.peers[0].Endpoint
reserved = w.peers[0].Reserved
}
bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
}
err = w.tunDevice.Start()
if err != nil {
return err
}
wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) {
w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}, w.workers)
ipcConf := w.ipcConf
for _, peer := range w.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
w.device = wgDevice
w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
return nil
}
func (w *Outbound) Close() error {
if w.device != nil {
w.device.Close()
}
if w.pauseCallback != nil {
w.pauseManager.UnregisterCallback(w.pauseCallback)
} }
return nil return nil
} }
func (w *Outbound) InterfaceUpdated() { func (o *Outbound) Close() error {
w.device.BindUpdate() return o.endpoint.Close()
}
func (o *Outbound) InterfaceUpdated() {
o.endpoint.BindUpdate()
return return
} }
func (w *Outbound) onPauseUpdated(event int) { func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch event {
case pause.EventDevicePaused:
w.device.Down()
case pause.EventDeviceWake:
w.device.Up()
}
}
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network { switch network {
case N.NetworkTCP: case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination) o.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP: case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
} }
if destination.IsFqdn() { if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses) return N.DialSerial(ctx, o.endpoint, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
} }
return w.tunDevice.DialContext(ctx, network, destination) return o.endpoint.DialContext(ctx, network, destination)
} }
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() { if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses) packetConn, _, err := N.ListenSerial(ctx, o.endpoint, destination, destinationAddresses)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return packetConn, err return packetConn, err
} }
return w.tunDevice.ListenPacket(ctx, destination) return o.endpoint.ListenPacket(ctx, destination)
} }

View file

@ -41,15 +41,15 @@ type NetworkManager struct {
autoDetectInterface bool autoDetectInterface bool
defaultOptions adapter.NetworkOptions defaultOptions adapter.NetworkOptions
autoRedirectOutputMark uint32 autoRedirectOutputMark uint32
networkMonitor tun.NetworkUpdateMonitor networkMonitor tun.NetworkUpdateMonitor
interfaceMonitor tun.DefaultInterfaceMonitor interfaceMonitor tun.DefaultInterfaceMonitor
packageManager tun.PackageManager packageManager tun.PackageManager
powerListener winpowrprof.EventListener powerListener winpowrprof.EventListener
pauseManager pause.Manager pauseManager pause.Manager
platformInterface platform.Interface platformInterface platform.Interface
inboundManager adapter.InboundManager endpoint adapter.EndpointManager
outboundManager adapter.OutboundManager inbound adapter.InboundManager
outbound adapter.OutboundManager
wifiState adapter.WIFIState wifiState adapter.WIFIState
started bool started bool
} }
@ -69,7 +69,9 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
}, },
pauseManager: service.FromContext[pause.Manager](ctx), pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx), platformInterface: service.FromContext[platform.Interface](ctx),
outboundManager: service.FromContext[adapter.OutboundManager](ctx), endpoint: service.FromContext[adapter.EndpointManager](ctx),
inbound: service.FromContext[adapter.InboundManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
} }
if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault { if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault {
if routeOptions.DefaultInterface != "" { if routeOptions.DefaultInterface != "" {
@ -358,14 +360,21 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
func (r *NetworkManager) ResetNetwork() { func (r *NetworkManager) ResetNetwork() {
conntrack.Close() conntrack.Close()
for _, inbound := range r.inboundManager.Inbounds() { for _, endpoint := range r.endpoint.Endpoints() {
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()
}
}
for _, inbound := range r.inbound.Inbounds() {
listener, isListener := inbound.(adapter.InterfaceUpdateListener) listener, isListener := inbound.(adapter.InterfaceUpdateListener)
if isListener { if isListener {
listener.InterfaceUpdated() listener.InterfaceUpdated()
} }
} }
for _, outbound := range r.outboundManager.Outbounds() { for _, outbound := range r.outbound.Outbounds() {
listener, isListener := outbound.(adapter.InterfaceUpdateListener) listener, isListener := outbound.(adapter.InterfaceUpdateListener)
if isListener { if isListener {
listener.InterfaceUpdated() listener.InterfaceUpdated()

View file

@ -11,7 +11,7 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
R "github.com/sagernet/sing-box/route/rule" R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
tun "github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/cache" "github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"

View file

@ -32,7 +32,7 @@ func TestMain(m *testing.M) {
var globalCtx context.Context var globalCtx context.Context
func init() { func init() {
globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
} }
func startInstance(t *testing.T, options option.Options) *box.Box { func startInstance(t *testing.T, options option.Options) *box.Box {

View file

@ -37,12 +37,12 @@ func _TestWireGuard(t *testing.T) {
Outbounds: []option.Outbound{ Outbounds: []option.Outbound{
{ {
Type: C.TypeWireGuard, Type: C.TypeWireGuard,
Options: &option.WireGuardOutboundOptions{ Options: &option.WireGuardEndpointOptions{
ServerOptions: option.ServerOptions{ ServerOptions: option.ServerOptions{
Server: "127.0.0.1", Server: "127.0.0.1",
ServerPort: serverPort, ServerPort: serverPort,
}, },
LocalAddress: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")}, Address: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=", PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=", PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
}, },

View file

@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
select { select {
case <-c.done: case <-c.done:
default: default:
c.logger.Error(context.Background(), E.Cause(err, "read packet")) c.logger.Error(E.Cause(err, "read packet"))
err = nil err = nil
} }
return return
@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
b := packets[0] b := packets[0]
common.ClearArray(b[1:4]) common.ClearArray(b[1:4])
} }
eps[0] = Endpoint(M.AddrPortFromNet(addr)) eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
count = 1 count = 1
return return
} }
@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
time.Sleep(time.Second) time.Sleep(time.Second)
return err return err
} }
destination := netip.AddrPort(ep.(Endpoint)) destination := netip.AddrPort(ep.(remoteEndpoint))
for _, b := range bufs { for _, b := range bufs {
if len(b) > 3 { if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination] reserved, loaded := c.reservedForEndpoint[destination]
@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Endpoint(ap), nil return remoteEndpoint(ap), nil
} }
func (c *ClientBind) BatchSize() int { func (c *ClientBind) BatchSize() int {
@ -229,3 +229,31 @@ func (w *wireConn) Close() error {
close(w.done) close(w.done)
return nil return nil
} }
var _ conn.Endpoint = (*remoteEndpoint)(nil)
type remoteEndpoint netip.AddrPort
func (e remoteEndpoint) ClearSrc() {
}
func (e remoteEndpoint) SrcToString() string {
return ""
}
func (e remoteEndpoint) DstToString() string {
return (netip.AddrPort)(e).String()
}
func (e remoteEndpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
}
func (e remoteEndpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
}
func (e remoteEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}

View file

@ -1,13 +1,44 @@
package wireguard package wireguard
import ( import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/tun" "github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
) )
type Device interface { type Device interface {
tun.Device wgTun.Device
N.Dialer N.Dialer
Start() error Start() error
// NewEndpoint() (stack.LinkEndpoint, error) SetDevice(device *device.Device)
}
type DeviceOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
AllowedAddress []netip.Prefix
}
func NewDevice(options DeviceOptions) (Device, error) {
if !options.System {
return newStackDevice(options)
} else if options.Handler == nil {
return newSystemDevice(options)
} else {
return newSystemStackDevice(options)
}
} }

View file

@ -5,7 +5,6 @@ package wireguard
import ( import (
"context" "context"
"net" "net"
"net/netip"
"os" "os"
"github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/buffer"
@ -15,52 +14,41 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun" wgTun "github.com/sagernet/wireguard-go/tun"
) )
var _ Device = (*StackDevice)(nil) var _ Device = (*stackDevice)(nil)
const defaultNIC tcpip.NICID = 1 type stackDevice struct {
type StackDevice struct {
stack *stack.Stack stack *stack.Stack
mtu uint32 mtu uint32
events chan wgTun.Event events chan wgTun.Event
outbound chan *stack.PacketBuffer outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{} done chan struct{}
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
addr4 tcpip.Address addr4 tcpip.Address
addr6 tcpip.Address addr6 tcpip.Address
} }
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { func newStackDevice(options DeviceOptions) (*stackDevice, error) {
ipStack := stack.New(stack.Options{ tunDevice := &stackDevice{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, mtu: options.MTU,
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
events: make(chan wgTun.Event, 1), events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256), outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}), done: make(chan struct{}),
} }
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
if err != nil { if err != nil {
return nil, E.New(err.String()) return nil, err
} }
for _, prefix := range localAddresses { for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr()) addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
tunDevice.addr6 = addr tunDevice.addr6 = addr
protoAddr.Protocol = ipv6.ProtocolNumber protoAddr.Protocol = ipv6.ProtocolNumber
} }
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{}) gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if err != nil { if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
} }
} }
sOpt := tcpip.TCPSACKEnabled(true) tunDevice.stack = ipStack
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) if options.Handler != nil {
cOpt := tcpip.CongestionControlOption("cubic") ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC}) }
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
return tunDevice, nil return tunDevice, nil
} }
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) { func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return (*wireEndpoint)(w), nil
}
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
addr := tcpip.FullAddress{ addr := tcpip.FullAddress{
NIC: defaultNIC, NIC: tun.DefaultNIC,
Port: destination.Port, Port: destination.Port,
Addr: tun.AddressFromAddr(destination.Addr), Addr: tun.AddressFromAddr(destination.Addr),
} }
bind := tcpip.FullAddress{ bind := tcpip.FullAddress{
NIC: defaultNIC, NIC: tun.DefaultNIC,
} }
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() { if destination.IsIPv4() {
@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
} }
} }
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
bind := tcpip.FullAddress{ bind := tcpip.FullAddress{
NIC: defaultNIC, NIC: tun.DefaultNIC,
} }
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() { if destination.IsIPv4() {
@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil return udpConn, nil
} }
func (w *StackDevice) Inet4Address() netip.Addr { func (w *stackDevice) SetDevice(device *device.Device) {
return tun.AddrFromAddress(w.addr4)
} }
func (w *StackDevice) Inet6Address() netip.Addr { func (w *stackDevice) Start() error {
return tun.AddrFromAddress(w.addr6)
}
func (w *StackDevice) Start() error {
w.events <- wgTun.EventUp w.events <- wgTun.EventUp
return nil return nil
} }
func (w *StackDevice) File() *os.File { func (w *stackDevice) File() *os.File {
return nil return nil
} }
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select { select {
case packetBuffer, ok := <-w.outbound: case packetBuffer, ok := <-w.outbound:
if !ok { if !ok {
@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
sizes[0] = n sizes[0] = n
count = 1 count = 1
return return
case packet := <-w.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
count = 1
return
case <-w.done: case <-w.done:
return 0, os.ErrClosed return 0, os.ErrClosed
} }
} }
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) { func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs { for _, b := range bufs {
b = b[offset:] b = b[offset:]
if len(b) == 0 { if len(b) == 0 {
@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return return
} }
func (w *StackDevice) Flush() error { func (w *stackDevice) Flush() error {
return nil return nil
} }
func (w *StackDevice) MTU() (int, error) { func (w *stackDevice) MTU() (int, error) {
return int(w.mtu), nil return int(w.mtu), nil
} }
func (w *StackDevice) Name() (string, error) { func (w *stackDevice) Name() (string, error) {
return "sing-box", nil return "sing-box", nil
} }
func (w *StackDevice) Events() <-chan wgTun.Event { func (w *stackDevice) Events() <-chan wgTun.Event {
return w.events return w.events
} }
func (w *StackDevice) Close() error { func (w *stackDevice) Close() error {
close(w.done) close(w.done)
close(w.events) close(w.events)
w.stack.Close() w.stack.Close()
@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
return nil return nil
} }
func (w *StackDevice) BatchSize() int { func (w *stackDevice) BatchSize() int {
return 1 return 1
} }
var _ stack.LinkEndpoint = (*wireEndpoint)(nil) var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint StackDevice type wireEndpoint stackDevice
func (ep *wireEndpoint) MTU() uint32 { func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu return ep.mtu

View file

@ -2,12 +2,12 @@
package wireguard package wireguard
import ( import "github.com/sagernet/sing-tun"
"net/netip"
"github.com/sagernet/sing-tun" func newStackDevice(options DeviceOptions) (Device, error) {
) return nil, tun.ErrGVisorNotIncluded
}
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
func newSystemStackDevice(options DeviceOptions) (Device, error) {
return nil, tun.ErrGVisorNotIncluded return nil, tun.ErrGVisorNotIncluded
} }

View file

@ -6,96 +6,88 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"runtime"
"sync" "sync"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun" wgTun "github.com/sagernet/wireguard-go/tun"
) )
var _ Device = (*SystemDevice)(nil) var _ Device = (*systemDevice)(nil)
type SystemDevice struct { type systemDevice struct {
options DeviceOptions
dialer N.Dialer dialer N.Dialer
device tun.Tun device tun.Tun
batchDevice tun.LinuxTUN batchDevice tun.LinuxTUN
name string
mtu uint32
inet4Addresses []netip.Prefix
inet6Addresses []netip.Prefix
gso bool
events chan wgTun.Event events chan wgTun.Event
closeOnce sync.Once closeOnce sync.Once
} }
func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) { func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
var inet4Addresses []netip.Prefix if options.Name == "" {
var inet6Addresses []netip.Prefix options.Name = tun.CalculateInterfaceName("wg")
for _, prefixes := range localPrefixes {
if prefixes.Addr().Is4() {
inet4Addresses = append(inet4Addresses, prefixes)
} else {
inet6Addresses = append(inet6Addresses, prefixes)
} }
} return &systemDevice{
if interfaceName == "" { options: options,
interfaceName = tun.CalculateInterfaceName("wg") dialer: options.CreateDialer(options.Name),
}
return &SystemDevice{
dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{
BindInterface: interfaceName,
})),
name: interfaceName,
mtu: mtu,
inet4Addresses: inet4Addresses,
inet6Addresses: inet6Addresses,
gso: gso,
events: make(chan wgTun.Event, 1), events: make(chan wgTun.Event, 1),
}, nil }, nil
} }
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return w.dialer.DialContext(ctx, network, destination) return w.dialer.DialContext(ctx, network, destination)
} }
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return w.dialer.ListenPacket(ctx, destination) return w.dialer.ListenPacket(ctx, destination)
} }
func (w *SystemDevice) Inet4Address() netip.Addr { func (w *systemDevice) SetDevice(device *device.Device) {
if len(w.inet4Addresses) == 0 {
return netip.Addr{}
}
return w.inet4Addresses[0].Addr()
} }
func (w *SystemDevice) Inet6Address() netip.Addr { func (w *systemDevice) Start() error {
if len(w.inet6Addresses) == 0 { networkManager := service.FromContext[adapter.NetworkManager](w.options.Context)
return netip.Addr{} tunOptions := tun.Options{
Name: w.options.Name,
Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is6()
}),
MTU: w.options.MTU,
GSO: w.options.GSO,
InterfaceScope: true,
Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }),
InterfaceMonitor: networkManager.InterfaceMonitor(),
InterfaceFinder: networkManager.InterfaceFinder(),
} }
return w.inet6Addresses[0].Addr() // works with Linux, macOS with IFSCOPE routes, not tested on Windows
if runtime.GOOS == "darwin" {
tunOptions.AutoRoute = true
} }
tunInterface, err := tun.New(tunOptions)
func (w *SystemDevice) Start() error {
tunInterface, err := tun.New(tun.Options{
Name: w.name,
Inet4Address: w.inet4Addresses,
Inet6Address: w.inet6Addresses,
MTU: w.mtu,
GSO: w.gso,
})
if err != nil { if err != nil {
return err return err
} }
err = tunInterface.Start()
if err != nil {
return err
}
w.options.Logger.Info("started at ", w.options.Name)
w.device = tunInterface w.device = tunInterface
if w.gso { if w.options.GSO {
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN) batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
if !isBatchTUN { if !isBatchTUN {
tunInterface.Close() tunInterface.Close()
@ -107,15 +99,15 @@ func (w *SystemDevice) Start() error {
return nil return nil
} }
func (w *SystemDevice) File() *os.File { func (w *systemDevice) File() *os.File {
return nil return nil
} }
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
if w.batchDevice != nil { if w.batchDevice != nil {
count, err = w.batchDevice.BatchRead(bufs, offset, sizes) count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes)
} else { } else {
sizes[0], err = w.device.Read(bufs[0][offset:]) sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
if err == nil { if err == nil {
count = 1 count = 1
} else if errors.Is(err, tun.ErrTooManySegments) { } else if errors.Is(err, tun.ErrTooManySegments) {
@ -125,12 +117,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int,
return return
} }
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) { func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil { if w.batchDevice != nil {
return 0, w.batchDevice.BatchWrite(bufs, offset) return 0, w.batchDevice.BatchWrite(bufs, offset)
} else { } else {
for _, b := range bufs { for _, packet := range bufs {
_, err = w.device.Write(b[offset:]) if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
if err != nil { if err != nil {
return return
} }
@ -140,28 +136,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return return
} }
func (w *SystemDevice) Flush() error { func (w *systemDevice) Flush() error {
return nil return nil
} }
func (w *SystemDevice) MTU() (int, error) { func (w *systemDevice) MTU() (int, error) {
return int(w.mtu), nil return int(w.options.MTU), nil
} }
func (w *SystemDevice) Name() (string, error) { func (w *systemDevice) Name() (string, error) {
return w.name, nil return w.options.Name, nil
} }
func (w *SystemDevice) Events() <-chan wgTun.Event { func (w *systemDevice) Events() <-chan wgTun.Event {
return w.events return w.events
} }
func (w *SystemDevice) Close() error { func (w *systemDevice) Close() error {
close(w.events) close(w.events)
return w.device.Close() return w.device.Close()
} }
func (w *SystemDevice) BatchSize() int { func (w *systemDevice) BatchSize() int {
if w.batchDevice != nil { if w.batchDevice != nil {
return w.batchDevice.BatchSize() return w.batchDevice.BatchSize()
} }

View file

@ -0,0 +1,182 @@
//go:build with_gvisor
package wireguard
import (
"net/netip"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/wireguard-go/device"
)
var _ Device = (*systemStackDevice)(nil)
type systemStackDevice struct {
*systemDevice
stack *stack.Stack
endpoint *deviceEndpoint
writeBufs [][]byte
}
func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
system, err := newSystemDevice(options)
if err != nil {
return nil, err
}
endpoint := &deviceEndpoint{
mtu: options.MTU,
done: make(chan struct{}),
}
ipStack, err := tun.NewGVisorStack(endpoint)
if err != nil {
return nil, err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
return &systemStackDevice{
systemDevice: system,
stack: ipStack,
endpoint: endpoint,
}, nil
}
func (w *systemStackDevice) SetDevice(device *device.Device) {
w.endpoint.device = device
}
func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil {
w.writeBufs = w.writeBufs[:0]
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
w.writeBufs = append(w.writeBufs, packet)
}
}
if len(w.writeBufs) > 0 {
return 0, w.batchDevice.BatchWrite(bufs, offset)
}
} else {
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
}
if err != nil {
return
}
}
}
// WireGuard will not read count
return
}
func (w *systemStackDevice) Close() error {
close(w.endpoint.done)
w.stack.Close()
for _, endpoint := range w.stack.CleanupEndpoints() {
endpoint.Abort()
}
w.stack.Wait()
return w.systemDevice.Close()
}
func (w *systemStackDevice) writeStack(packet []byte) bool {
var (
networkProtocol tcpip.NetworkProtocolNumber
destination netip.Addr
)
switch header.IPVersion(packet) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
}
for _, prefix := range w.options.Address {
if prefix.Contains(destination) {
return false
}
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
packetBuffer.DecRef()
return true
}
type deviceEndpoint struct {
mtu uint32
done chan struct{}
device *device.Device
dispatcher stack.NetworkDispatcher
}
func (ep *deviceEndpoint) MTU() uint32 {
return ep.mtu
}
func (ep *deviceEndpoint) SetMTU(mtu uint32) {
}
func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityRXChecksumOffload
}
func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
ep.dispatcher = dispatcher
}
func (ep *deviceEndpoint) IsAttached() bool {
return ep.dispatcher != nil
}
func (ep *deviceEndpoint) Wait() {
}
func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
return true
}
func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
destination := packetBuffer.Network().DestinationAddress()
ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
}
return list.Len(), nil
}
func (ep *deviceEndpoint) Close() {
}
func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
}

View file

@ -1,35 +1,255 @@
package wireguard package wireguard
import ( import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip" "net/netip"
"os"
"strings"
"time"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn" "github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
"go4.org/netipx"
) )
var _ conn.Endpoint = (*Endpoint)(nil) type Endpoint struct {
options EndpointOptions
type Endpoint netip.AddrPort peers []peerConfig
ipcConf string
func (e Endpoint) ClearSrc() { allowedAddress []netip.Prefix
tunDevice Device
device *device.Device
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
} }
func (e Endpoint) SrcToString() string { func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
return "" if options.PrivateKey == "" {
return nil, E.New("missing private key")
}
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey := hex.EncodeToString(privateKeyBytes)
ipcConf := "private_key=" + privateKey
if options.ListenPort != 0 {
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
}
var peers []peerConfig
for peerIndex, rawPeer := range options.Peers {
peer := peerConfig{
allowedIPs: rawPeer.AllowedIPs,
keepalive: rawPeer.PersistentKeepaliveInterval,
}
if !rawPeer.Endpoint.IsValid() {
return nil, E.New("invalid endpoint for peer ", peerIndex, ": ", rawPeer.Endpoint)
} else if rawPeer.Endpoint.Addr.IsValid() {
peer.endpoint = rawPeer.Endpoint.AddrPort()
} else {
peer.destination = rawPeer.Endpoint
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
if rawPeer.PreSharedKey != "" {
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
}
copy(peer.reserved[:], rawPeer.Reserved[:])
}
peers = append(peers, peer)
}
var allowedPrefixBuilder netipx.IPSetBuilder
for _, peer := range options.Peers {
for _, prefix := range peer.AllowedIPs {
allowedPrefixBuilder.AddPrefix(prefix)
}
}
allowedIPSet, err := allowedPrefixBuilder.IPSet()
if err != nil {
return nil, err
}
allowedAddresses := allowedIPSet.Prefixes()
if options.MTU == 0 {
options.MTU = 1408
}
deviceOptions := DeviceOptions{
Context: options.Context,
Logger: options.Logger,
System: options.System,
Handler: options.Handler,
UDPTimeout: options.UDPTimeout,
CreateDialer: options.CreateDialer,
Name: options.Name,
MTU: options.MTU,
GSO: options.GSO,
Address: options.Address,
AllowedAddress: allowedAddresses,
}
tunDevice, err := NewDevice(deviceOptions)
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
return &Endpoint{
options: options,
peers: peers,
ipcConf: ipcConf,
allowedAddress: allowedAddresses,
tunDevice: tunDevice,
}, nil
} }
func (e Endpoint) DstToString() string { func (e *Endpoint) Start(resolve bool) error {
return (netip.AddrPort)(e).String() if common.Any(e.peers, func(peer peerConfig) bool {
return !peer.endpoint.IsValid()
}) {
if !resolve {
return nil
}
for peerIndex, peer := range e.peers {
if peer.endpoint.IsValid() {
continue
}
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
if err != nil {
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
}
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
}
} else if resolve {
return nil
}
var bind conn.Bind
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
if isWgListener {
bind = conn.NewStdNetBind(wgListener)
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
peerLen := len(e.peers)
if peerLen == 1 {
isConnect = true
connectAddr = e.peers[0].endpoint
reserved = e.peers[0].reserved
}
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
}
err := e.tunDevice.Start()
if err != nil {
return err
}
logger := &device.Logger{
Verbosef: func(format string, args ...interface{}) {
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
e.tunDevice.SetDevice(wgDevice)
ipcConf := e.ipcConf
for _, peer := range e.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
e.device = wgDevice
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
if e.pauseManager != nil {
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
}
return nil
} }
func (e Endpoint) DstToBytes() []byte { func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
b, _ := (netip.AddrPort)(e).MarshalBinary() if !destination.Addr.IsValid() {
return b return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.DialContext(ctx, network, destination)
} }
func (e Endpoint) DstIP() netip.Addr { func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return (netip.AddrPort)(e).Addr() if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.ListenPacket(ctx, destination)
} }
func (e Endpoint) SrcIP() netip.Addr { func (e *Endpoint) BindUpdate() error {
return netip.Addr{} return e.device.BindUpdate()
}
func (e *Endpoint) Close() error {
if e.device != nil {
e.device.Close()
}
if e.pauseCallback != nil {
e.pauseManager.UnregisterCallback(e.pauseCallback)
}
return nil
}
func (e *Endpoint) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
e.device.Down()
case pause.EventDeviceWake:
e.device.Up()
}
}
type peerConfig struct {
destination M.Socksaddr
endpoint netip.AddrPort
publicKeyHex string
preSharedKeyHex string
allowedIPs []netip.Prefix
keepalive time.Duration
reserved [3]uint8
}
func (c peerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.publicKeyHex
ipcLines += "\nendpoint=" + c.endpoint.String()
if c.preSharedKeyHex != "" {
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
}
for _, allowedIP := range c.allowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP.String()
}
if c.keepalive > 0 {
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(int(c.keepalive.Seconds()))
}
return ipcLines
} }

View file

@ -0,0 +1,40 @@
package wireguard
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type EndpointOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
Dialer N.Dialer
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
PrivateKey string
ListenPort uint16
ResolvePeer func(domain string) (netip.Addr, error)
Peers []PeerOptions
Workers int
}
type PeerOptions struct {
Endpoint M.Socksaddr
PublicKey string
PreSharedKey string
AllowedIPs []netip.Prefix
PersistentKeepaliveInterval time.Duration
Reserved []uint8
}

View file

@ -1,148 +0,0 @@
package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"net/netip"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type PeerConfig struct {
destination M.Socksaddr
domainStrategy dns.DomainStrategy
Endpoint netip.AddrPort
PublicKey string
PreSharedKey string
AllowedIPs []string
Reserved [3]uint8
}
func (c PeerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.PublicKey
ipcLines += "\nendpoint=" + c.Endpoint.String()
if c.PreSharedKey != "" {
ipcLines += "\npreshared_key=" + c.PreSharedKey
}
for _, allowedIP := range c.AllowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP
}
return ipcLines
}
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
var peers []PeerConfig
if len(options.Peers) > 0 {
for peerIndex, rawPeer := range options.Peers {
peer := PeerConfig{
AllowedIPs: rawPeer.AllowedIPs,
}
destination := rawPeer.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if rawPeer.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed_ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
} else {
peer := PeerConfig{}
var (
addressHas4 bool
addressHas6 bool
)
for _, localAddress := range options.LocalAddress {
if localAddress.Addr().Is4() {
addressHas4 = true
} else {
addressHas6 = true
}
}
if addressHas4 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
}
if addressHas6 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
}
destination := options.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
if err != nil {
return nil, E.Cause(err, "decode peer public key")
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if options.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key")
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(options.Reserved) > 0 {
if len(options.Reserved) != 3 {
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
return peers, nil
}
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
for peerIndex, peer := range peers {
if peer.Endpoint.IsValid() {
continue
}
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
if err != nil {
if len(peers) == 1 {
return E.Cause(err, "resolve endpoint domain")
} else {
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
}
}
if len(destinationAddresses) == 0 {
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
}
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
}
return nil
}