refactor: WireGuard endpoint

This commit is contained in:
世界 2024-11-21 18:10:41 +08:00
parent 445c0ff9cb
commit fb02440650
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
89 changed files with 2187 additions and 679 deletions

28
adapter/endpoint.go Normal file
View file

@ -0,0 +1,28 @@
package adapter
import (
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type Endpoint interface {
Lifecycle
Type() string
Tag() string
Outbound
}
type EndpointRegistry interface {
option.EndpointOptionsRegistry
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) (Endpoint, error)
}
type EndpointManager interface {
Lifecycle
Endpoints() []Endpoint
Get(tag string) (Endpoint, bool)
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) error
}

View file

@ -0,0 +1,43 @@
package endpoint
import "github.com/sagernet/sing-box/option"
type Adapter struct {
endpointType string
endpointTag string
network []string
dependencies []string
}
func NewAdapter(endpointType string, endpointTag string, network []string, dependencies []string) Adapter {
return Adapter{
endpointType: endpointType,
endpointTag: endpointTag,
network: network,
dependencies: dependencies,
}
}
func NewAdapterWithDialerOptions(endpointType string, endpointTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string
if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour}
}
return NewAdapter(endpointType, endpointTag, network, dependencies)
}
func (a *Adapter) Type() string {
return a.endpointType
}
func (a *Adapter) Tag() string {
return a.endpointTag
}
func (a *Adapter) Network() []string {
return a.network
}
func (a *Adapter) Dependencies() []string {
return a.dependencies
}

147
adapter/endpoint/manager.go Normal file
View file

@ -0,0 +1,147 @@
package endpoint
import (
"context"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
var _ adapter.EndpointManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.EndpointRegistry
access sync.Mutex
started bool
stage adapter.StartStage
endpoints []adapter.Endpoint
endpointByTag map[string]adapter.Endpoint
}
func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
endpointByTag: make(map[string]adapter.Endpoint),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
if stage == adapter.StartStateStart {
// started with outbound manager
return nil
}
for _, endpoint := range m.endpoints {
err := adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if !m.started {
return nil
}
m.started = false
endpoints := m.endpoints
m.endpoints = nil
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, endpoint := range endpoints {
monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
err = E.Append(err, endpoint.Close(), func(err error) error {
return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
})
monitor.Finish()
}
return nil
}
func (m *Manager) Endpoints() []adapter.Endpoint {
m.access.Lock()
defer m.access.Unlock()
return m.endpoints
}
func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
m.access.Lock()
defer m.access.Unlock()
endpoint, found := m.endpointByTag[tag]
return endpoint, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
endpoint, found := m.endpointByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.endpointByTag, tag)
index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == endpoint
})
if index == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return endpoint.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
}
if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
if m.started {
err = existsEndpoint.Close()
if err != nil {
return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
}
}
existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == existsEndpoint
})
if existsIndex == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
}
m.endpoints = append(m.endpoints, endpoint)
m.endpointByTag[tag] = endpoint
return nil
}

View file

@ -0,0 +1,72 @@
package endpoint
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options T) (adapter.Endpoint, error)
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any {
return new(Options)
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Endpoint, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.EndpointRegistry = (*Registry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Endpoint, error)
)
type Registry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructor map[string]constructorFunc
}
func NewRegistry() *Registry {
return &Registry{
optionsType: make(map[string]optionsConstructorFunc),
constructor: make(map[string]constructorFunc),
}
}
func (m *Registry) CreateOptions(outboundType string) (any, bool) {
m.access.Lock()
defer m.access.Unlock()
optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Endpoint, error) {
m.access.Lock()
defer m.access.Unlock()
constructor, loaded := m.constructor[outboundType]
if !loaded {
return nil, E.New("outbound type not found: " + outboundType)
}
return constructor(ctx, router, logger, tag, options)
}
func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
m.access.Lock()
defer m.access.Unlock()
m.optionsType[outboundType] = optionsConstructor
m.constructor[outboundType] = constructor
}

View file

@ -13,7 +13,7 @@ import (
) )
type Inbound interface { type Inbound interface {
Service Lifecycle
Type() string Type() string
Tag() string Tag() string
} }

View file

@ -18,6 +18,7 @@ var _ adapter.InboundManager = (*Manager)(nil)
type Manager struct { type Manager struct {
logger log.ContextLogger logger log.ContextLogger
registry adapter.InboundRegistry registry adapter.InboundRegistry
endpoint adapter.EndpointManager
access sync.Mutex access sync.Mutex
started bool started bool
stage adapter.StartStage stage adapter.StartStage
@ -25,10 +26,11 @@ type Manager struct {
inboundByTag map[string]adapter.Inbound inboundByTag map[string]adapter.Inbound
} }
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager { func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endpoint adapter.EndpointManager) *Manager {
return &Manager{ return &Manager{
logger: logger, logger: logger,
registry: registry, registry: registry,
endpoint: endpoint,
inboundByTag: make(map[string]adapter.Inbound), inboundByTag: make(map[string]adapter.Inbound),
} }
} }
@ -79,9 +81,12 @@ func (m *Manager) Inbounds() []adapter.Inbound {
func (m *Manager) Get(tag string) (adapter.Inbound, bool) { func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
m.access.Lock() m.access.Lock()
defer m.access.Unlock()
inbound, found := m.inboundByTag[tag] inbound, found := m.inboundByTag[tag]
return inbound, found m.access.Unlock()
if found {
return inbound, true
}
return m.endpoint.Get(tag)
} }
func (m *Manager) Remove(tag string) error { func (m *Manager) Remove(tag string) error {

View file

@ -1,6 +1,9 @@
package adapter package adapter
func LegacyStart(starter any, stage StartStage) error { func LegacyStart(starter any, stage StartStage) error {
if lifecycle, isLifecycle := starter.(Lifecycle); isLifecycle {
return lifecycle.Start(stage)
}
switch stage { switch stage {
case StartStateInitialize: case StartStateInitialize:
if preStarter, isPreStarter := starter.(interface { if preStarter, isPreStarter := starter.(interface {

View file

@ -5,35 +5,35 @@ import (
) )
type Adapter struct { type Adapter struct {
protocol string outboundType string
outboundTag string
network []string network []string
tag string
dependencies []string dependencies []string
} }
func NewAdapter(protocol string, network []string, tag string, dependencies []string) Adapter { func NewAdapter(outboundType string, outboundTag string, network []string, dependencies []string) Adapter {
return Adapter{ return Adapter{
protocol: protocol, outboundType: outboundType,
outboundTag: outboundTag,
network: network, network: network,
tag: tag,
dependencies: dependencies, dependencies: dependencies,
} }
} }
func NewAdapterWithDialerOptions(protocol string, network []string, tag string, dialOptions option.DialerOptions) Adapter { func NewAdapterWithDialerOptions(outboundType string, outboundTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string var dependencies []string
if dialOptions.Detour != "" { if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour} dependencies = []string{dialOptions.Detour}
} }
return NewAdapter(protocol, network, tag, dependencies) return NewAdapter(outboundType, outboundTag, network, dependencies)
} }
func (a *Adapter) Type() string { func (a *Adapter) Type() string {
return a.protocol return a.outboundType
} }
func (a *Adapter) Tag() string { func (a *Adapter) Tag() string {
return a.tag return a.outboundTag
} }
func (a *Adapter) Network() []string { func (a *Adapter) Network() []string {

View file

@ -21,6 +21,7 @@ var _ adapter.OutboundManager = (*Manager)(nil)
type Manager struct { type Manager struct {
logger log.ContextLogger logger log.ContextLogger
registry adapter.OutboundRegistry registry adapter.OutboundRegistry
endpoint adapter.EndpointManager
defaultTag string defaultTag string
access sync.Mutex access sync.Mutex
started bool started bool
@ -32,10 +33,11 @@ type Manager struct {
defaultOutboundFallback adapter.Outbound defaultOutboundFallback adapter.Outbound
} }
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager { func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
return &Manager{ return &Manager{
logger: logger, logger: logger,
registry: registry, registry: registry,
endpoint: endpoint,
defaultTag: defaultTag, defaultTag: defaultTag,
outboundByTag: make(map[string]adapter.Outbound), outboundByTag: make(map[string]adapter.Outbound),
dependByTag: make(map[string][]string), dependByTag: make(map[string][]string),
@ -56,7 +58,14 @@ func (m *Manager) Start(stage adapter.StartStage) error {
outbounds := m.outbounds outbounds := m.outbounds
m.access.Unlock() m.access.Unlock()
if stage == adapter.StartStateStart { if stage == adapter.StartStateStart {
return m.startOutbounds(outbounds) if m.defaultTag != "" && m.defaultOutbound == nil {
defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
if !loaded {
return E.New("default outbound not found: ", m.defaultTag)
}
m.defaultOutbound = defaultEndpoint
}
return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
} else { } else {
for _, outbound := range outbounds { for _, outbound := range outbounds {
err := adapter.LegacyStart(outbound, stage) err := adapter.LegacyStart(outbound, stage)
@ -87,7 +96,14 @@ func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
} }
started[outboundTag] = true started[outboundTag] = true
canContinue = true canContinue = true
if starter, isStarter := outboundToStart.(interface { if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start(adapter.StartStateStart)
monitor.Finish()
if err != nil {
return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
} else if starter, isStarter := outboundToStart.(interface {
Start() error Start() error
}); isStarter { }); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]") monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
@ -160,9 +176,12 @@ func (m *Manager) Outbounds() []adapter.Outbound {
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock() m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag] outbound, found := m.outboundByTag[tag]
return outbound, found m.access.Unlock()
if found {
return outbound, true
}
return m.endpoint.Get(tag)
} }
func (m *Manager) Default() adapter.Outbound { func (m *Manager) Default() adapter.Outbound {

51
box.go
View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
@ -36,6 +37,7 @@ type Box struct {
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
network *route.NetworkManager network *route.NetworkManager
endpoint *endpoint.Manager
inbound *inbound.Manager inbound *inbound.Manager
outbound *outbound.Manager outbound *outbound.Manager
connection *route.ConnectionManager connection *route.ConnectionManager
@ -54,6 +56,7 @@ func Context(
ctx context.Context, ctx context.Context,
inboundRegistry adapter.InboundRegistry, inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry, outboundRegistry adapter.OutboundRegistry,
endpointRegistry adapter.EndpointRegistry,
) context.Context { ) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil { service.FromContext[adapter.InboundRegistry](ctx) == nil {
@ -65,6 +68,11 @@ func Context(
ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry) ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry)
ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry) ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry)
} }
if service.FromContext[option.EndpointOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.EndpointRegistry](ctx) == nil {
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
}
return ctx return ctx
} }
@ -76,12 +84,16 @@ func New(options Options) (*Box, error) {
} }
ctx = service.ContextWithDefaultRegistry(ctx) ctx = service.ContextWithDefaultRegistry(ctx)
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context")
}
if inboundRegistry == nil { if inboundRegistry == nil {
return nil, E.New("missing inbound registry in context") return nil, E.New("missing inbound registry in context")
} }
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
if outboundRegistry == nil { if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context") return nil, E.New("missing outbound registry in context")
} }
@ -119,8 +131,10 @@ func New(options Options) (*Box, error) {
} }
routeOptions := common.PtrValueOrDefault(options.Route) routeOptions := common.PtrValueOrDefault(options.Route)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry) endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
@ -135,6 +149,24 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize router") return nil, E.Cause(err, "initialize router")
} }
for i, endpointOptions := range options.Endpoints {
var tag string
if endpointOptions.Tag != "" {
tag = endpointOptions.Tag
} else {
tag = F.ToString(i)
}
err = endpointManager.Create(ctx,
router,
logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")),
tag,
endpointOptions.Type,
endpointOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]")
}
}
for i, inboundOptions := range options.Inbounds { for i, inboundOptions := range options.Inbounds {
var tag string var tag string
if inboundOptions.Tag != "" { if inboundOptions.Tag != "" {
@ -241,6 +273,7 @@ func New(options Options) (*Box, error) {
} }
return &Box{ return &Box{
network: networkManager, network: networkManager,
endpoint: endpointManager,
inbound: inboundManager, inbound: inboundManager,
outbound: outboundManager, outbound: outboundManager,
connection: connectionManager, connection: connectionManager,
@ -303,7 +336,7 @@ func (s *Box) preStart() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound) err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
@ -327,7 +360,11 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound) err = adapter.Start(adapter.StartStateStart, s.endpoint)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
@ -335,7 +372,7 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound) err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }

View file

@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) {
configPaths = append(configPaths, "config.json") configPaths = append(configPaths, "config.json")
} }
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger())) globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry()) globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
} }

View file

@ -285,7 +285,7 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
} }
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
return d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay) return d.udpListener.ListenPacket(context.Background(), network, address)
} }
func trackConn(conn net.Conn, err error) (net.Conn, error) { func trackConn(conn net.Conn, err error) (net.Conn, error) {

View file

@ -0,0 +1,32 @@
---
icon: material/new-box
---
!!! question "Since sing-box 1.11.0"
# Endpoint
Endpoint is protocols that has both inbound and outbound behavior.
### Structure
```json
{
"endpoints": [
{
"type": "",
"tag": ""
}
]
}
```
### Fields
| Type | Format |
|-------------|---------------------------|
| `wireguard` | [WireGuard](./wireguard/) |
#### tag
The tag of the endpoint.

View file

@ -0,0 +1,32 @@
---
icon: material/new-box
---
!!! question "自 sing-box 1.11.0 起"
# 端点
端点是具有入站和出站行为的协议。
### 结构
```json
{
"endpoints": [
{
"type": "",
"tag": ""
}
]
}
```
### 字段
| 类型 | 格式 |
|-------------|---------------------------|
| `wireguard` | [WireGuard](./wiregaurd/) |
#### tag
端点的标签。

View file

@ -0,0 +1,138 @@
---
icon: material/new-box
---
!!! question "Since sing-box 1.11.0"
### Structure
```json
{
"type": "wireguard",
"tag": "wg-ep",
"system": false,
"name": "",
"mtu": 1408,
"gso": false,
"address": [],
"private_key": "",
"listen_port": 10000,
"peers": [
{
"address": "127.0.0.1",
"port": 10001,
"public_key": "",
"pre_shared_key": "",
"allowed_ips": [],
"persistent_keepalive_interval": 0,
"reserved": [0, 0, 0]
}
],
"udp_timeout": "",
"workers": 0,
... // Dial Fields
}
```
### Fields
#### system
Use system interface.
Requires privilege and cannot conflict with exists system interfaces.
#### name
Custom interface name for system interface.
#### mtu
WireGuard MTU.
`1408` will be used by default.
#### gso
!!! quote ""
Only supported on Linux.
Try to enable generic segmentation offload.
#### address
==Required==
List of IP (v4 or v6) address prefixes to be assigned to the interface.
#### private_key
==Required==
WireGuard requires base64-encoded public and private keys. These can be generated using the wg(8) utility:
```shell
wg genkey
echo "private key" || wg pubkey
```
or `sing-box generate wg-keypair`.
#### peers
==Required==
List of WireGuard peers.
#### peers.address
WireGuard peer address.
#### peers.port
WireGuard peer port.
#### peers.public_key
==Required==
WireGuard peer public key.
#### peers.pre_shared_key
WireGuard peer pre-shared key.
#### peers.allowed_ips
==Required==
WireGuard allowed IPs.
#### peers.persistent_keepalive_interval
WireGuard persistent keepalive interval, in seconds.
Disabled by default.
#### peers.reserved
WireGuard reserved field bytes.
#### udp_timeout
UDP NAT expiration time.
`5m` will be used by default.
#### workers
WireGuard worker count.
CPU count is used by default.
### Dial Fields
See [Dial Fields](/configuration/shared/dial/) for details.

View file

@ -0,0 +1,140 @@
---
icon: material/new-box
---
!!! question "自 sing-box 1.11.0 起"
### 结构
```json
{
"type": "wireguard",
"tag": "wg-ep",
"system": false,
"name": "",
"mtu": 1408,
"gso": false,
"address": [],
"private_key": "",
"listen_port": 10000,
"peers": [
{
"address": "127.0.0.1",
"port": 10001,
"public_key": "",
"pre_shared_key": "",
"allowed_ips": [],
"persistent_keepalive_interval": 0,
"reserved": [0, 0, 0]
}
],
"udp_timeout": "",
"workers": 0,
... // 拨号字段
}
```
### 字段
#### system_interface
使用系统设备。
需要特权且不能与已有系统接口冲突。
#### name
为系统接口自定义设备名称。
#### mtu
WireGuard MTU。
默认使用 1408。
#### gso
!!! quote ""
仅支持 Linux。
尝试启用通用分段卸载。
#### address
==必填==
接口的 IPv4/IPv6 地址或地址段的列表您。
要分配给接口的 IPv4 或 v6地址段列表。
#### private_key
==必填==
WireGuard 需要 base64 编码的公钥和私钥。 这些可以使用 wg(8) 实用程序生成:
```shell
wg genkey
echo "private key" || wg pubkey
```
`sing-box generate wg-keypair`.
#### peers
==必填==
WireGuard 对等方的列表。
#### peers.address
对等方的 IP 地址。
#### peers.port
对等方的 WireGuard 端口。
#### peers.public_key
==必填==
对等方的 WireGuard 公钥。
#### peers.pre_shared_key
对等方的预共享密钥。
#### peers.allowed_ips
==必填==
对等方的允许 IP 地址。
#### peers.persistent_keepalive_interval
对等方的持久性保持活动间隔,以秒为单位。
默认禁用。
#### peers.reserved
对等方的保留字段字节。
#### udp_timeout
UDP NAT 过期时间。
默认使用 `5m`
#### workers
WireGuard worker 数量。
默认使用 CPU 数量。
### 拨号字段
参阅 [拨号字段](/zh/configuration/shared/dial/)。

View file

@ -360,7 +360,9 @@ Performance may degrade slightly, so it is not recommended to enable on when it
#### udp_timeout #### udp_timeout
UDP NAT expiration time in seconds, default is 300 (5 minutes). UDP NAT expiration time.
`5m` will be used by default.
#### stack #### stack

View file

@ -356,7 +356,9 @@ tun 接口的 IPv6 前缀。
#### udp_timeout #### udp_timeout
UDP NAT 过期时间,以秒为单位,默认为 3005 分钟)。 UDP NAT 过期时间。
默认使用 `5m`
#### stack #### stack

View file

@ -9,6 +9,7 @@ sing-box uses JSON for configuration files.
"log": {}, "log": {},
"dns": {}, "dns": {},
"ntp": {}, "ntp": {},
"endpoints": [],
"inbounds": [], "inbounds": [],
"outbounds": [], "outbounds": [],
"route": {}, "route": {},
@ -23,6 +24,7 @@ sing-box uses JSON for configuration files.
| `log` | [Log](./log/) | | `log` | [Log](./log/) |
| `dns` | [DNS](./dns/) | | `dns` | [DNS](./dns/) |
| `ntp` | [NTP](./ntp/) | | `ntp` | [NTP](./ntp/) |
| `endpoints` | [Endpoint](./endpoint/) |
| `inbounds` | [Inbound](./inbound/) | | `inbounds` | [Inbound](./inbound/) |
| `outbounds` | [Outbound](./outbound/) | | `outbounds` | [Outbound](./outbound/) |
| `route` | [Route](./route/) | | `route` | [Route](./route/) |

View file

@ -8,6 +8,7 @@ sing-box 使用 JSON 作为配置文件格式。
{ {
"log": {}, "log": {},
"dns": {}, "dns": {},
"endpoints": [],
"inbounds": [], "inbounds": [],
"outbounds": [], "outbounds": [],
"route": {}, "route": {},
@ -21,6 +22,7 @@ sing-box 使用 JSON 作为配置文件格式。
|----------------|------------------------| |----------------|------------------------|
| `log` | [日志](./log/) | | `log` | [日志](./log/) |
| `dns` | [DNS](./dns/) | | `dns` | [DNS](./dns/) |
| `endpoints` | [端点](./endpoint/) |
| `inbounds` | [入站](./inbound/) | | `inbounds` | [入站](./inbound/) |
| `outbounds` | [出站](./outbound/) | | `outbounds` | [出站](./outbound/) |
| `route` | [路由](./route/) | | `route` | [路由](./route/) |

View file

@ -1,3 +1,11 @@
---
icon: material/delete-clock
---
!!! failure "Deprecated in sing-box 1.11.0"
WireGuard outbound is deprecated and will be removed in sing-box 1.13.0, check [Migration](/migration/#migrate-wireguard-outbound-to-endpoint).
!!! quote "Changes in sing-box 1.8.0" !!! quote "Changes in sing-box 1.8.0"
:material-plus: [gso](#gso) :material-plus: [gso](#gso)
@ -15,7 +23,7 @@
"gso": false, "gso": false,
"interface_name": "wg0", "interface_name": "wg0",
"local_address": [ "local_address": [
"10.0.0.2/32" "10.0.0.1/32"
], ],
"private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=", "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
"peers": [ "peers": [

View file

@ -1,3 +1,11 @@
---
icon: material/delete-clock
---
!!! failure "已在 sing-box 1.11.0 废弃"
WireGuard 出站已被启用,且将在 sing-box 1.13.0 中被移除,参阅 [迁移指南](/migration/#migrate-wireguard-outbound-to-endpoint)。
!!! quote "sing-box 1.8.0 中的更改" !!! quote "sing-box 1.8.0 中的更改"
:material-plus: [gso](#gso) :material-plus: [gso](#gso)
@ -15,7 +23,7 @@
"gso": false, "gso": false,
"interface_name": "wg0", "interface_name": "wg0",
"local_address": [ "local_address": [
"10.0.0.2/32" "10.0.0.1/32"
], ],
"private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=", "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
"peer_public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=", "peer_public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=",

View file

@ -68,9 +68,9 @@ Enable UDP fragmentation.
#### udp_timeout #### udp_timeout
UDP NAT expiration time in seconds. UDP NAT expiration time.
`5m` is used by default. `5m` will be used by default.
#### detour #### detour

View file

@ -69,7 +69,7 @@ icon: material/delete-clock
#### udp_timeout #### udp_timeout
UDP NAT 过期时间,以秒为单位 UDP NAT 过期时间。
默认使用 `5m` 默认使用 `5m`

View file

@ -28,6 +28,13 @@ Destination override fields (`override_address` / `override_port`) in direct out
and can be replaced by rule actions, and can be replaced by rule actions,
check [Migration](../migration/#migrate-destination-override-fields-to-route-options). check [Migration](../migration/#migrate-destination-override-fields-to-route-options).
#### WireGuard outbound
WireGuard outbound is deprecated and can be replaced by endpoint,
check [Migration](../migration/#migrate-wireguard-outbound-to-endpoint).
Old outbound will be removed in sing-box 1.13.0.
## 1.10.0 ## 1.10.0
#### TUN address fields are merged #### TUN address fields are merged

View file

@ -27,6 +27,13 @@ direct 出站中的目标地址覆盖字段(`override_address` / `override_por
旧字段将在 sing-box 1.13.0 中被移除。 旧字段将在 sing-box 1.13.0 中被移除。
#### WireGuard 出站
WireGuard 出站已废弃且可以通过端点替代,
参阅 [迁移指南](/migration/#migrate-wireguard-outbound-to-endpoint)。
旧出站将在 sing-box 1.13.0 中被移除。
## 1.10.0 ## 1.10.0
#### Match source 规则项已重命名 #### Match source 规则项已重命名

View file

@ -194,6 +194,78 @@ Destination override fields in direct outbound are deprecated and can be replace
} }
``` ```
### Migrate WireGuard outbound to endpoint
WireGuard outbound is deprecated and can be replaced by endpoint.
!!! info "References"
[Endpoint](/configuration/endpoint/) /
[WireGuard Endpoint](/configuration/endpoint/wireguard/) /
[WireGuard Outbound](/configuration/outbound/wireguard/)
=== ":material-card-remove: Deprecated"
```json
{
"outbounds": [
{
"type": "wireguard",
"tag": "wg-out",
"server": "127.0.0.1",
"server_port": 10001,
"system_interface": true,
"gso": true,
"interface_name": "wg0",
"local_address": [
"10.0.0.1/32"
],
"private_key": "<private_key>",
"peer_public_key": "<peer_public_key>",
"pre_shared_key": "<pre_shared_key>",
"reserved": [0, 0, 0],
"mtu": 1408
}
]
}
```
=== ":material-card-multiple: New"
```json
{
"endpoints": [
{
"type": "wireguard",
"tag": "wg-ep",
"system": true,
"name": "wg0",
"mtu": 1408,
"gso": true,
"address": [
"10.0.0.2/32"
],
"private_key": "<private_key>",
"listen_port": 10000,
"peers": [
{
"address": "127.0.0.1",
"port": 10001,
"public_key": "<peer_public_key>",
"pre_shared_key": "<pre_shared_key>",
"allowed_ips": [
"0.0.0.0/0"
],
"persistent_keepalive_interval": 30,
"reserved": [0, 0, 0]
}
]
}
]
}
```
## 1.10.0 ## 1.10.0
### TUN address fields are merged ### TUN address fields are merged

View file

@ -104,7 +104,6 @@ icon: material/arrange-bring-forward
### 迁移旧的入站字段到规则动作 ### 迁移旧的入站字段到规则动作
入站选项已被弃用,且可以被规则动作替代。 入站选项已被弃用,且可以被规则动作替代。
!!! info "参考" !!! info "参考"
@ -196,6 +195,78 @@ direct 出站中的目标地址覆盖字段已废弃,且可以被路由字段
} }
``` ```
### 迁移 WireGuard 出站到端点
WireGuard 出站已被弃用,且可以被端点替代。
!!! info "参考"
[端点](/zh/configuration/endpoint/) /
[WireGuard 端点](/zh/configuration/endpoint/wireguard/) /
[WireGuard 出站](/zh/configuration/outbound/wireguard/)
=== ":material-card-remove: 弃用的"
```json
{
"outbounds": [
{
"type": "wireguard",
"tag": "wg-out",
"server": "127.0.0.1",
"server_port": 10001,
"system_interface": true,
"gso": true,
"interface_name": "wg0",
"local_address": [
"10.0.0.1/32"
],
"private_key": "<private_key>",
"peer_public_key": "<peer_public_key>",
"pre_shared_key": "<pre_shared_key>",
"reserved": [0, 0, 0],
"mtu": 1408
}
]
}
```
=== ":material-card-multiple: 新的"
```json
{
"endpoints": [
{
"type": "wireguard",
"tag": "wg-ep",
"system": true,
"name": "wg0",
"mtu": 1408,
"gso": true,
"address": [
"10.0.0.2/32"
],
"private_key": "<private_key>",
"listen_port": 10000,
"peers": [
{
"address": "127.0.0.1",
"port": 10001,
"public_key": "<peer_public_key>",
"pre_shared_key": "<pre_shared_key>",
"allowed_ips": [
"0.0.0.0/0"
],
"persistent_keepalive_interval": 30,
"reserved": [0, 0, 0]
}
]
}
]
}
```
## 1.10.0 ## 1.10.0
### TUN 地址字段已合并 ### TUN 地址字段已合并

View file

@ -113,6 +113,15 @@ var OptionDestinationOverrideFields = Note{
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options", MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options",
} }
var OptionWireGuardOutbound = Note{
Name: "wireguard-outbound",
Description: "legacy wireguard outbound",
DeprecatedVersion: "1.11.0",
ScheduledVersion: "1.13.0",
EnvName: "WIREGUARD_OUTBOUND",
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-wireguard-outbound-to-endpoint",
}
var Options = []Note{ var Options = []Note{
OptionBadMatchSource, OptionBadMatchSource,
OptionGEOIP, OptionGEOIP,
@ -121,4 +130,5 @@ var Options = []Note{
OptionSpecialOutbounds, OptionSpecialOutbounds,
OptionInboundOptions, OptionInboundOptions,
OptionDestinationOverrideFields, OptionDestinationOverrideFields,
OptionWireGuardOutbound,
} }

View file

@ -30,7 +30,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err
} }
func CheckConfig(configContent string) error { func CheckConfig(configContent string) error {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
options, err := parseConfig(ctx, configContent) options, err := parseConfig(ctx, configContent)
if err != nil { if err != nil {
return err return err
@ -131,7 +131,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica
} }
func FormatConfig(configContent string) (*StringBox, error) { func FormatConfig(configContent string) (*StringBox, error) {
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()), configContent) options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -44,7 +44,7 @@ type BoxService struct {
} }
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) { func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager)) service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
options, err := parseConfig(ctx, configContent) options, err := parseConfig(ctx, configContent)

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
@ -82,6 +83,14 @@ func OutboundRegistry() *outbound.Registry {
return registry return registry
} }
func EndpointRegistry() *endpoint.Registry {
registry := endpoint.NewRegistry()
registerWireGuardEndpoint(registry)
return registry
}
func registerStubForRemovedInbounds(registry *inbound.Registry) { func registerStubForRemovedInbounds(registry *inbound.Registry) {
inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) { inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0") return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0")

View file

@ -3,6 +3,7 @@
package include package include
import ( import (
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/protocol/wireguard" "github.com/sagernet/sing-box/protocol/wireguard"
) )
@ -10,3 +11,7 @@ import (
func registerWireGuardOutbound(registry *outbound.Registry) { func registerWireGuardOutbound(registry *outbound.Registry) {
wireguard.RegisterOutbound(registry) wireguard.RegisterOutbound(registry)
} }
func registerWireGuardEndpoint(registry *endpoint.Registry) {
wireguard.RegisterEndpoint(registry)
}

View file

@ -6,6 +6,7 @@ import (
"context" "context"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@ -14,7 +15,13 @@ import (
) )
func registerWireGuardOutbound(registry *outbound.Registry) { func registerWireGuardOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) { outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
})
}
func registerWireGuardEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`) return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
}) })
} }

View file

@ -112,6 +112,9 @@ nav:
- V2Ray Transport: configuration/shared/v2ray-transport.md - V2Ray Transport: configuration/shared/v2ray-transport.md
- UDP over TCP: configuration/shared/udp-over-tcp.md - UDP over TCP: configuration/shared/udp-over-tcp.md
- TCP Brutal: configuration/shared/tcp-brutal.md - TCP Brutal: configuration/shared/tcp-brutal.md
- Endpoint:
- configuration/endpoint/index.md
- WireGuard: configuration/endpoint/wireguard.md
- Inbound: - Inbound:
- configuration/inbound/index.md - configuration/inbound/index.md
- Direct: configuration/inbound/direct.md - Direct: configuration/inbound/direct.md
@ -241,6 +244,7 @@ plugins:
Multiplex: 多路复用 Multiplex: 多路复用
V2Ray Transport: V2Ray 传输层 V2Ray Transport: V2Ray 传输层
Endpoint: 端点
Inbound: 入站 Inbound: 入站
Outbound: 出站 Outbound: 出站

47
option/endpoint.go Normal file
View file

@ -0,0 +1,47 @@
package option
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/service"
)
type EndpointOptionsRegistry interface {
CreateOptions(endpointType string) (any, bool)
}
type _Endpoint struct {
Type string `json:"type"`
Tag string `json:"tag,omitempty"`
Options any `json:"-"`
}
type Endpoint _Endpoint
func (h *Endpoint) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return badjson.MarshallObjectsContext(ctx, (*_Endpoint)(h), h.Options)
}
func (h *Endpoint) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.UnmarshalContext(ctx, content, (*_Endpoint)(h))
if err != nil {
return err
}
registry := service.FromContext[EndpointOptionsRegistry](ctx)
if registry == nil {
return E.New("missing Endpoint fields registry in context")
}
options, loaded := registry.CreateOptions(h.Type)
if !loaded {
return E.New("unknown inbound type: ", h.Type)
}
err = badjson.UnmarshallExcludedContext(ctx, content, (*_Endpoint)(h), options)
if err != nil {
return err
}
h.Options = options
return nil
}

View file

@ -28,7 +28,7 @@ func (h *Inbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
} }
func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error { func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.Unmarshal(content, (*_Inbound)(h)) err := json.UnmarshalContext(ctx, content, (*_Inbound)(h))
if err != nil { if err != nil {
return err return err
} }

View file

@ -13,6 +13,7 @@ type _Options struct {
Log *LogOptions `json:"log,omitempty"` Log *LogOptions `json:"log,omitempty"`
DNS *DNSOptions `json:"dns,omitempty"` DNS *DNSOptions `json:"dns,omitempty"`
NTP *NTPOptions `json:"ntp,omitempty"` NTP *NTPOptions `json:"ntp,omitempty"`
Endpoints []Endpoint `json:"endpoints,omitempty"`
Inbounds []Inbound `json:"inbounds,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"` Route *RouteOptions `json:"route,omitempty"`

View file

@ -30,7 +30,7 @@ func (h *Outbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
} }
func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error { func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.Unmarshal(content, (*_Outbound)(h)) err := json.UnmarshalContext(ctx, content, (*_Outbound)(h))
if err != nil { if err != nil {
return err return err
} }

View file

@ -6,14 +6,38 @@ import (
"github.com/sagernet/sing/common/json/badoption" "github.com/sagernet/sing/common/json/badoption"
) )
type WireGuardOutboundOptions struct { type WireGuardEndpointOptions struct {
System bool `json:"system,omitempty"`
Name string `json:"name,omitempty"`
MTU uint32 `json:"mtu,omitempty"`
GSO bool `json:"gso,omitempty"`
Address badoption.Listable[netip.Prefix] `json:"address"`
PrivateKey string `json:"private_key"`
ListenPort uint16 `json:"listen_port,omitempty"`
Peers []WireGuardPeer `json:"peers,omitempty"`
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
Workers int `json:"workers,omitempty"`
DialerOptions
}
type WireGuardPeer struct {
Address string `json:"address,omitempty"`
Port uint16 `json:"port,omitempty"`
PublicKey string `json:"public_key,omitempty"`
PreSharedKey string `json:"pre_shared_key,omitempty"`
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"`
}
type LegacyWireGuardOutboundOptions struct {
DialerOptions DialerOptions
SystemInterface bool `json:"system_interface,omitempty"` SystemInterface bool `json:"system_interface,omitempty"`
GSO bool `json:"gso,omitempty"` GSO bool `json:"gso,omitempty"`
InterfaceName string `json:"interface_name,omitempty"` InterfaceName string `json:"interface_name,omitempty"`
LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"` LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"`
PrivateKey string `json:"private_key"` PrivateKey string `json:"private_key"`
Peers []WireGuardPeer `json:"peers,omitempty"` Peers []LegacyWireGuardPeer `json:"peers,omitempty"`
ServerOptions ServerOptions
PeerPublicKey string `json:"peer_public_key"` PeerPublicKey string `json:"peer_public_key"`
PreSharedKey string `json:"pre_shared_key,omitempty"` PreSharedKey string `json:"pre_shared_key,omitempty"`
@ -23,10 +47,10 @@ type WireGuardOutboundOptions struct {
Network NetworkList `json:"network,omitempty"` Network NetworkList `json:"network,omitempty"`
} }
type WireGuardPeer struct { type LegacyWireGuardPeer struct {
ServerOptions ServerOptions
PublicKey string `json:"public_key,omitempty"` PublicKey string `json:"public_key,omitempty"`
PreSharedKey string `json:"pre_shared_key,omitempty"` PreSharedKey string `json:"pre_shared_key,omitempty"`
AllowedIPs badoption.Listable[string] `json:"allowed_ips,omitempty"` AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"` Reserved []uint8 `json:"reserved,omitempty"`
} }

View file

@ -26,7 +26,7 @@ type Outbound struct {
func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) { func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) {
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapter(C.TypeBlock, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil), Adapter: outbound.NewAdapter(C.TypeBlock, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
logger: logger, logger: logger,
}, nil }, nil
} }

View file

@ -68,7 +68,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (i *Inbound) Start() error { func (i *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return i.listener.Start() return i.listener.Start()
} }

View file

@ -52,7 +52,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
logger: logger, logger: logger,
domainStrategy: dns.DomainStrategy(options.DomainStrategy), domainStrategy: dns.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay), fallbackDelay: time.Duration(options.FallbackDelay),

View file

@ -28,7 +28,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapter(C.TypeDNS, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil), Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
router: router, router: router,
logger: logger, logger: logger,
}, nil }, nil

View file

@ -38,7 +38,7 @@ type Selector struct {
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) { func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
outbound := &Selector{ outbound := &Selector{
Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
ctx: ctx, ctx: ctx,
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,

View file

@ -49,7 +49,7 @@ type URLTest struct {
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) { func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
outbound := &URLTest{ outbound := &URLTest{
Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx, ctx: ctx,
router: router, router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outboundManager: service.FromContext[adapter.OutboundManager](ctx),

View file

@ -61,7 +61,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -39,7 +39,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, tag, []string{N.NetworkTCP}, options.DialerOptions),
logger: logger, logger: logger,
client: sHTTP.NewClient(sHTTP.Options{ client: sHTTP.NewClient(sHTTP.Options{
Dialer: detour, Dialer: detour,

View file

@ -160,7 +160,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -95,7 +95,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, networkList, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions),
logger: logger, logger: logger,
client: client, client: client,
}, nil }, nil

View file

@ -171,7 +171,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -81,7 +81,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, networkList, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions),
logger: logger, logger: logger,
client: client, client: client,
}, nil }, nil

View file

@ -54,7 +54,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -78,7 +78,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (n *Inbound) Start() error { func (n *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
var tlsConfig *tls.STDConfig var tlsConfig *tls.STDConfig
if n.tlsConfig != nil { if n.tlsConfig != nil {
err := n.tlsConfig.Start() err := n.tlsConfig.Start()

View file

@ -42,7 +42,10 @@ func NewRedirect(ctx context.Context, router adapter.Router, logger log.ContextL
return redirect, nil return redirect, nil
} }
func (h *Redirect) Start() error { func (h *Redirect) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -61,7 +61,10 @@ func NewTProxy(ctx context.Context, router adapter.Router, logger log.ContextLog
return tproxy, nil return tproxy, nil
} }
func (t *TProxy) Start() error { func (t *TProxy) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := t.listener.Start() err := t.listener.Start()
if err != nil { if err != nil {
return err return err

View file

@ -93,7 +93,10 @@ func newInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, err return inbound, err
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -101,7 +101,10 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont
return inbound, err return inbound, err
} }
func (h *MultiInbound) Start() error { func (h *MultiInbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -86,7 +86,10 @@ func newRelayInbound(ctx context.Context, router adapter.Router, logger log.Cont
return inbound, err return inbound, err
} }
func (h *RelayInbound) Start() error { func (h *RelayInbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -49,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
method: method, method: method,

View file

@ -90,7 +90,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -29,7 +29,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) {
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, tag, []string{N.NetworkTCP}, options.DialerOptions),
} }
if options.TLS == nil || !options.TLS.Enabled { if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired return nil, C.ErrTLSRequired

View file

@ -50,7 +50,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start() return h.listener.Start()
} }

View file

@ -50,7 +50,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
router: router, router: router,
logger: logger, logger: logger,
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password), client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),

View file

@ -54,7 +54,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, tag, []string{N.NetworkTCP}, options.DialerOptions),
ctx: ctx, ctx: ctx,
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,

View file

@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, []string{N.NetworkTCP}, tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, tag, []string{N.NetworkTCP}, options.DialerOptions),
ctx: ctx, ctx: ctx,
logger: logger, logger: logger,
proxy: NewProxyListener(ctx, logger, outboundDialer), proxy: NewProxyListener(ctx, logger, outboundDialer),

View file

@ -110,7 +110,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -43,7 +43,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(), serverAddr: options.ServerOptions.Build(),

View file

@ -142,7 +142,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
return &Outbound{ return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
client: client, client: client,
udpStream: options.UDPOverStream, udpStream: options.UDPOverStream,

View file

@ -300,7 +300,9 @@ func (t *Inbound) Tag() string {
return t.tag return t.tag
} }
func (t *Inbound) Start() error { func (t *Inbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
if C.IsAndroid && t.platformInterface == nil { if C.IsAndroid && t.platformInterface == nil {
t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager()) t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager())
} }
@ -348,10 +350,7 @@ func (t *Inbound) Start() error {
} }
t.tunStack = tunStack t.tunStack = tunStack
t.logger.Info("started at ", t.tunOptions.Name) t.logger.Info("started at ", t.tunOptions.Name)
return nil case adapter.StartStatePostStart:
}
func (t *Inbound) PostStart() error {
monitor := taskmonitor.New(t.logger, C.StartTimeout) monitor := taskmonitor.New(t.logger, C.StartTimeout)
monitor.Start("starting tun stack") monitor.Start("starting tun stack")
err := t.tunStack.Start() err := t.tunStack.Start()
@ -399,6 +398,7 @@ func (t *Inbound) PostStart() error {
t.routeAddressSet = nil t.routeAddressSet = nil
t.routeExcludeAddressSet = nil t.routeExcludeAddressSet = nil
} }
}
return nil return nil
} }

View file

@ -89,7 +89,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil { if h.tlsConfig != nil {
err := h.tlsConfig.Start() err := h.tlsConfig.Start()
if err != nil { if err != nil {

View file

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(), serverAddr: options.ServerOptions.Build(),

View file

@ -99,7 +99,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil return inbound, nil
} }
func (h *Inbound) Start() error { func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := h.service.Start() err := h.service.Start()
if err != nil { if err != nil {
return err return err

View file

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err return nil, err
} }
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, tag, options.Network.Build(), options.DialerOptions),
logger: logger, logger: logger,
dialer: outboundDialer, dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(), serverAddr: options.ServerOptions.Build(),

View file

@ -0,0 +1,211 @@
package wireguard
import (
"context"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
}
var (
_ adapter.Endpoint = (*Endpoint)(nil)
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
)
type Endpoint struct {
endpoint.Adapter
ctx context.Context
router adapter.Router
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
}
func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
ep := &Endpoint{
Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
logger: logger,
localAddresses: options.Address,
}
if options.Detour == "" {
options.IsWireGuardListener = true
} else if options.GSO {
return nil, E.New("gso is conflict with detour")
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil {
return nil, err
}
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx,
Logger: logger,
System: options.System,
Handler: ep,
UDPTimeout: time.Duration(options.UDPTimeout),
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
BindInterface: interfaceName,
}))
},
Name: options.Name,
MTU: options.MTU,
GSO: options.GSO,
Address: options.Address,
PrivateKey: options.PrivateKey,
ListenPort: options.ListenPort,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
return endpointAddresses[0], nil
},
Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions {
return wireguard.PeerOptions{
Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port),
PublicKey: it.PublicKey,
PreSharedKey: it.PreSharedKey,
AllowedIPs: it.AllowedIPs,
PersistentKeepaliveInterval: it.PersistentKeepaliveInterval,
Reserved: it.Reserved,
}
}),
Workers: options.Workers,
})
if err != nil {
return nil, err
}
ep.endpoint = wgEndpoint
return ep, nil
}
func (w *Endpoint) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
return w.endpoint.Start(false)
case adapter.StartStatePostStart:
return w.endpoint.Start(true)
}
return nil
}
func (w *Endpoint) Close() error {
return w.endpoint.Close()
}
func (w *Endpoint) InterfaceUpdated() {
w.endpoint.BindUpdate()
return
}
func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
return w.router.PreMatch(adapter.InboundContext{
Inbound: w.Tag(),
InboundType: w.Type(),
Network: network,
Source: source,
Destination: destination,
})
}
func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext
metadata.Inbound = w.Tag()
metadata.InboundType = w.Type()
metadata.Source = source
for _, localPrefix := range w.localAddresses {
if localPrefix.Contains(destination.Addr) {
metadata.OriginDestination = destination
if destination.Addr.Is4() {
destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
} else {
destination.Addr = netip.IPv6Loopback()
}
break
}
}
metadata.Destination = destination
w.logger.InfoContext(ctx, "inbound connection from ", source)
w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
w.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}
func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext
metadata.Inbound = w.Tag()
metadata.InboundType = w.Type()
metadata.Source = source
metadata.Destination = destination
for _, localPrefix := range w.localAddresses {
if localPrefix.Contains(destination.Addr) {
metadata.OriginDestination = destination
if destination.Addr.Is4() {
metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
} else {
metadata.Destination.Addr = netip.IPv6Loopback()
}
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
}
}
w.logger.InfoContext(ctx, "inbound packet connection from ", source)
w.logger.InfoContext(ctx, "inbound packet connection to ", destination)
w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return w.endpoint.DialContext(ctx, network, destination)
}
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, err
}
return w.endpoint.ListenPacket(ctx, destination)
}

View file

@ -2,238 +2,163 @@ package wireguard
import ( import (
"context" "context"
"encoding/base64"
"encoding/hex"
"fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard" "github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
) )
func RegisterOutbound(registry *outbound.Registry) { func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound) outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
} }
var _ adapter.InterfaceUpdateListener = (*Outbound)(nil) var (
_ adapter.Endpoint = (*Endpoint)(nil)
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
)
type Outbound struct { type Outbound struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
logger logger.ContextLogger logger logger.ContextLogger
workers int localAddresses []netip.Prefix
peers []wireguard.PeerConfig endpoint *wireguard.Endpoint
useStdNetBind bool
listener N.Dialer
ipcConf string
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
bind conn.Bind
device *device.Device
tunDevice wireguard.Device
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
deprecated.Report(ctx, deprecated.OptionWireGuardOutbound)
outbound := &Outbound{ outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions), Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx, ctx: ctx,
router: router, router: router,
logger: logger, logger: logger,
workers: options.Workers, localAddresses: options.LocalAddress,
pauseManager: service.FromContext[pause.Manager](ctx),
} }
peers, err := wireguard.ParsePeers(options) if options.Detour == "" {
if err != nil { options.IsWireGuardListener = true
return nil, err } else if options.GSO {
}
outbound.peers = peers
if len(options.LocalAddress) == 0 {
return nil, E.New("missing local address")
}
if options.GSO {
if options.GSO && options.Detour != "" {
return nil, E.New("gso is conflict with detour") return nil, E.New("gso is conflict with detour")
} }
options.IsWireGuardListener = true outboundDialer, err := dialer.New(ctx, options.DialerOptions)
outbound.useStdNetBind = true
}
listener, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outbound.listener = listener peers := common.Map(options.Peers, func(it option.LegacyWireGuardPeer) wireguard.PeerOptions {
var privateKey string return wireguard.PeerOptions{
{ Endpoint: it.ServerOptions.Build(),
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) PublicKey: it.PublicKey,
PreSharedKey: it.PreSharedKey,
AllowedIPs: it.AllowedIPs,
// PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval),
Reserved: it.Reserved,
}
})
if len(peers) == 0 {
peers = []wireguard.PeerOptions{{
Endpoint: options.ServerOptions.Build(),
PublicKey: options.PeerPublicKey,
PreSharedKey: options.PreSharedKey,
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.PrefixFrom(netip.IPv6Unspecified(), 0)},
Reserved: options.Reserved,
}}
}
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx,
Logger: logger,
System: options.SystemInterface,
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
BindInterface: interfaceName,
}))
},
Name: options.InterfaceName,
MTU: options.MTU,
GSO: options.GSO,
Address: options.LocalAddress,
PrivateKey: options.PrivateKey,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
return endpointAddresses[0], nil
},
Peers: peers,
Workers: options.Workers,
})
if err != nil { if err != nil {
return nil, E.Cause(err, "decode private key") return nil, err
} }
privateKey = hex.EncodeToString(bytes) outbound.endpoint = wgEndpoint
}
outbound.ipcConf = "private_key=" + privateKey
mtu := options.MTU
if mtu == 0 {
mtu = 1408
}
var wireTunDevice wireguard.Device
if !options.SystemInterface && tun.WithGVisor {
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
} else {
wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
}
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
outbound.tunDevice = wireTunDevice
return outbound, nil return outbound, nil
} }
func (w *Outbound) Start() error { func (o *Outbound) Start(stage adapter.StartStage) error {
if common.Any(w.peers, func(peer wireguard.PeerConfig) bool { switch stage {
return !peer.Endpoint.IsValid() case adapter.StartStateStart:
}) { return o.endpoint.Start(false)
// wait for all outbounds to be started and continue in PortStart case adapter.StartStatePostStart:
return nil return o.endpoint.Start(true)
}
return w.start()
}
func (w *Outbound) PostStart() error {
if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
return peer.Endpoint.IsValid()
}) {
return nil
}
return w.start()
}
func (w *Outbound) start() error {
err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
if err != nil {
return err
}
var bind conn.Bind
if w.useStdNetBind {
bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
peerLen := len(w.peers)
if peerLen == 1 {
isConnect = true
connectAddr = w.peers[0].Endpoint
reserved = w.peers[0].Reserved
}
bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
}
if w.useStdNetBind || len(w.peers) > 1 {
for _, peer := range w.peers {
if peer.Reserved != [3]uint8{} {
bind.SetReservedForEndpoint(peer.Endpoint, peer.Reserved)
}
}
}
err = w.tunDevice.Start()
if err != nil {
return err
}
wgDevice := device.NewDevice(w.ctx, w.tunDevice, bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) {
w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}, w.workers)
ipcConf := w.ipcConf
for _, peer := range w.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
w.device = wgDevice
w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
return nil
}
func (w *Outbound) Close() error {
if w.device != nil {
w.device.Close()
}
if w.pauseCallback != nil {
w.pauseManager.UnregisterCallback(w.pauseCallback)
} }
return nil return nil
} }
func (w *Outbound) InterfaceUpdated() { func (o *Outbound) Close() error {
w.device.BindUpdate() return o.endpoint.Close()
}
func (o *Outbound) InterfaceUpdated() {
o.endpoint.BindUpdate()
return return
} }
func (w *Outbound) onPauseUpdated(event int) { func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch event {
case pause.EventDevicePaused:
w.device.Down()
case pause.EventDeviceWake:
w.device.Up()
}
}
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network { switch network {
case N.NetworkTCP: case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination) o.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP: case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
} }
if destination.IsFqdn() { if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses) return N.DialSerial(ctx, o.endpoint, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
} }
return w.tunDevice.DialContext(ctx, network, destination) return o.endpoint.DialContext(ctx, network, destination)
} }
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() { if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses) packetConn, _, err := N.ListenSerial(ctx, o.endpoint, destination, destinationAddresses)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return packetConn, err return packetConn, err
} }
return w.tunDevice.ListenPacket(ctx, destination) return o.endpoint.ListenPacket(ctx, destination)
} }

View file

@ -41,15 +41,15 @@ type NetworkManager struct {
autoDetectInterface bool autoDetectInterface bool
defaultOptions adapter.NetworkOptions defaultOptions adapter.NetworkOptions
autoRedirectOutputMark uint32 autoRedirectOutputMark uint32
networkMonitor tun.NetworkUpdateMonitor networkMonitor tun.NetworkUpdateMonitor
interfaceMonitor tun.DefaultInterfaceMonitor interfaceMonitor tun.DefaultInterfaceMonitor
packageManager tun.PackageManager packageManager tun.PackageManager
powerListener winpowrprof.EventListener powerListener winpowrprof.EventListener
pauseManager pause.Manager pauseManager pause.Manager
platformInterface platform.Interface platformInterface platform.Interface
inboundManager adapter.InboundManager endpoint adapter.EndpointManager
outboundManager adapter.OutboundManager inbound adapter.InboundManager
outbound adapter.OutboundManager
wifiState adapter.WIFIState wifiState adapter.WIFIState
started bool started bool
} }
@ -69,7 +69,9 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
}, },
pauseManager: service.FromContext[pause.Manager](ctx), pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx), platformInterface: service.FromContext[platform.Interface](ctx),
outboundManager: service.FromContext[adapter.OutboundManager](ctx), endpoint: service.FromContext[adapter.EndpointManager](ctx),
inbound: service.FromContext[adapter.InboundManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
} }
if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault { if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault {
if routeOptions.DefaultInterface != "" { if routeOptions.DefaultInterface != "" {
@ -355,14 +357,21 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
func (r *NetworkManager) ResetNetwork() { func (r *NetworkManager) ResetNetwork() {
conntrack.Close() conntrack.Close()
for _, inbound := range r.inboundManager.Inbounds() { for _, endpoint := range r.endpoint.Endpoints() {
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()
}
}
for _, inbound := range r.inbound.Inbounds() {
listener, isListener := inbound.(adapter.InterfaceUpdateListener) listener, isListener := inbound.(adapter.InterfaceUpdateListener)
if isListener { if isListener {
listener.InterfaceUpdated() listener.InterfaceUpdated()
} }
} }
for _, outbound := range r.outboundManager.Outbounds() { for _, outbound := range r.outbound.Outbounds() {
listener, isListener := outbound.(adapter.InterfaceUpdateListener) listener, isListener := outbound.(adapter.InterfaceUpdateListener)
if isListener { if isListener {
listener.InterfaceUpdated() listener.InterfaceUpdated()

View file

@ -11,7 +11,7 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
R "github.com/sagernet/sing-box/route/rule" R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
tun "github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/cache" "github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"

View file

@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
select { select {
case <-c.done: case <-c.done:
default: default:
c.logger.Error(context.Background(), E.Cause(err, "read packet")) c.logger.Error(E.Cause(err, "read packet"))
err = nil err = nil
} }
return return
@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
b := packets[0] b := packets[0]
common.ClearArray(b[1:4]) common.ClearArray(b[1:4])
} }
eps[0] = Endpoint(M.AddrPortFromNet(addr)) eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
count = 1 count = 1
return return
} }
@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
time.Sleep(time.Second) time.Sleep(time.Second)
return err return err
} }
destination := netip.AddrPort(ep.(Endpoint)) destination := netip.AddrPort(ep.(remoteEndpoint))
for _, b := range bufs { for _, b := range bufs {
if len(b) > 3 { if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination] reserved, loaded := c.reservedForEndpoint[destination]
@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Endpoint(ap), nil return remoteEndpoint(ap), nil
} }
func (c *ClientBind) BatchSize() int { func (c *ClientBind) BatchSize() int {
@ -229,3 +229,31 @@ func (w *wireConn) Close() error {
close(w.done) close(w.done)
return nil return nil
} }
var _ conn.Endpoint = (*remoteEndpoint)(nil)
type remoteEndpoint netip.AddrPort
func (e remoteEndpoint) ClearSrc() {
}
func (e remoteEndpoint) SrcToString() string {
return ""
}
func (e remoteEndpoint) DstToString() string {
return (netip.AddrPort)(e).String()
}
func (e remoteEndpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
}
func (e remoteEndpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
}
func (e remoteEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}

View file

@ -1,13 +1,44 @@
package wireguard package wireguard
import ( import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/tun" "github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
) )
type Device interface { type Device interface {
tun.Device wgTun.Device
N.Dialer N.Dialer
Start() error Start() error
// NewEndpoint() (stack.LinkEndpoint, error) SetDevice(device *device.Device)
}
type DeviceOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
AllowedAddress []netip.Prefix
}
func NewDevice(options DeviceOptions) (Device, error) {
if !options.System {
return newStackDevice(options)
} else if options.Handler == nil {
return newSystemDevice(options)
} else {
return newSystemStackDevice(options)
}
} }

View file

@ -5,7 +5,6 @@ package wireguard
import ( import (
"context" "context"
"net" "net"
"net/netip"
"os" "os"
"github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/buffer"
@ -15,52 +14,41 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun" wgTun "github.com/sagernet/wireguard-go/tun"
) )
var _ Device = (*StackDevice)(nil) var _ Device = (*stackDevice)(nil)
const defaultNIC tcpip.NICID = 1 type stackDevice struct {
type StackDevice struct {
stack *stack.Stack stack *stack.Stack
mtu uint32 mtu uint32
events chan wgTun.Event events chan wgTun.Event
outbound chan *stack.PacketBuffer outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{} done chan struct{}
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
addr4 tcpip.Address addr4 tcpip.Address
addr6 tcpip.Address addr6 tcpip.Address
} }
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { func newStackDevice(options DeviceOptions) (*stackDevice, error) {
ipStack := stack.New(stack.Options{ tunDevice := &stackDevice{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, mtu: options.MTU,
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
events: make(chan wgTun.Event, 1), events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256), outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}), done: make(chan struct{}),
} }
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
if err != nil { if err != nil {
return nil, E.New(err.String()) return nil, err
} }
for _, prefix := range localAddresses { for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr()) addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
tunDevice.addr6 = addr tunDevice.addr6 = addr
protoAddr.Protocol = ipv6.ProtocolNumber protoAddr.Protocol = ipv6.ProtocolNumber
} }
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{}) gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if err != nil { if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
} }
} }
sOpt := tcpip.TCPSACKEnabled(true) tunDevice.stack = ipStack
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) if options.Handler != nil {
cOpt := tcpip.CongestionControlOption("cubic") ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC}) }
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
return tunDevice, nil return tunDevice, nil
} }
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) { func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return (*wireEndpoint)(w), nil
}
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
addr := tcpip.FullAddress{ addr := tcpip.FullAddress{
NIC: defaultNIC, NIC: tun.DefaultNIC,
Port: destination.Port, Port: destination.Port,
Addr: tun.AddressFromAddr(destination.Addr), Addr: tun.AddressFromAddr(destination.Addr),
} }
bind := tcpip.FullAddress{ bind := tcpip.FullAddress{
NIC: defaultNIC, NIC: tun.DefaultNIC,
} }
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() { if destination.IsIPv4() {
@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
} }
} }
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
bind := tcpip.FullAddress{ bind := tcpip.FullAddress{
NIC: defaultNIC, NIC: tun.DefaultNIC,
} }
var networkProtocol tcpip.NetworkProtocolNumber var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() { if destination.IsIPv4() {
@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil return udpConn, nil
} }
func (w *StackDevice) Inet4Address() netip.Addr { func (w *stackDevice) SetDevice(device *device.Device) {
return tun.AddrFromAddress(w.addr4)
} }
func (w *StackDevice) Inet6Address() netip.Addr { func (w *stackDevice) Start() error {
return tun.AddrFromAddress(w.addr6)
}
func (w *StackDevice) Start() error {
w.events <- wgTun.EventUp w.events <- wgTun.EventUp
return nil return nil
} }
func (w *StackDevice) File() *os.File { func (w *stackDevice) File() *os.File {
return nil return nil
} }
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select { select {
case packetBuffer, ok := <-w.outbound: case packetBuffer, ok := <-w.outbound:
if !ok { if !ok {
@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
sizes[0] = n sizes[0] = n
count = 1 count = 1
return return
case packet := <-w.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
count = 1
return
case <-w.done: case <-w.done:
return 0, os.ErrClosed return 0, os.ErrClosed
} }
} }
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) { func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs { for _, b := range bufs {
b = b[offset:] b = b[offset:]
if len(b) == 0 { if len(b) == 0 {
@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return return
} }
func (w *StackDevice) Flush() error { func (w *stackDevice) Flush() error {
return nil return nil
} }
func (w *StackDevice) MTU() (int, error) { func (w *stackDevice) MTU() (int, error) {
return int(w.mtu), nil return int(w.mtu), nil
} }
func (w *StackDevice) Name() (string, error) { func (w *stackDevice) Name() (string, error) {
return "sing-box", nil return "sing-box", nil
} }
func (w *StackDevice) Events() <-chan wgTun.Event { func (w *stackDevice) Events() <-chan wgTun.Event {
return w.events return w.events
} }
func (w *StackDevice) Close() error { func (w *stackDevice) Close() error {
close(w.done) close(w.done)
close(w.events) close(w.events)
w.stack.Close() w.stack.Close()
@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
return nil return nil
} }
func (w *StackDevice) BatchSize() int { func (w *stackDevice) BatchSize() int {
return 1 return 1
} }
var _ stack.LinkEndpoint = (*wireEndpoint)(nil) var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint StackDevice type wireEndpoint stackDevice
func (ep *wireEndpoint) MTU() uint32 { func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu return ep.mtu

View file

@ -2,12 +2,12 @@
package wireguard package wireguard
import ( import "github.com/sagernet/sing-tun"
"net/netip"
"github.com/sagernet/sing-tun" func newStackDevice(options DeviceOptions) (Device, error) {
) return nil, tun.ErrGVisorNotIncluded
}
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
func newSystemStackDevice(options DeviceOptions) (Device, error) {
return nil, tun.ErrGVisorNotIncluded return nil, tun.ErrGVisorNotIncluded
} }

View file

@ -6,96 +6,89 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"runtime"
"sync" "sync"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun" wgTun "github.com/sagernet/wireguard-go/tun"
) )
var _ Device = (*SystemDevice)(nil) var _ Device = (*systemDevice)(nil)
type SystemDevice struct { type systemDevice struct {
options DeviceOptions
dialer N.Dialer dialer N.Dialer
device tun.Tun device tun.Tun
batchDevice tun.LinuxTUN batchDevice tun.LinuxTUN
name string
mtu uint32
inet4Addresses []netip.Prefix
inet6Addresses []netip.Prefix
gso bool
events chan wgTun.Event events chan wgTun.Event
closeOnce sync.Once closeOnce sync.Once
} }
func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) { func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
var inet4Addresses []netip.Prefix if options.Name == "" {
var inet6Addresses []netip.Prefix options.Name = tun.CalculateInterfaceName("wg")
for _, prefixes := range localPrefixes {
if prefixes.Addr().Is4() {
inet4Addresses = append(inet4Addresses, prefixes)
} else {
inet6Addresses = append(inet6Addresses, prefixes)
} }
} return &systemDevice{
if interfaceName == "" { options: options,
interfaceName = tun.CalculateInterfaceName("wg") dialer: options.CreateDialer(options.Name),
}
return &SystemDevice{
dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{
BindInterface: interfaceName,
})),
name: interfaceName,
mtu: mtu,
inet4Addresses: inet4Addresses,
inet6Addresses: inet6Addresses,
gso: gso,
events: make(chan wgTun.Event, 1), events: make(chan wgTun.Event, 1),
}, nil }, nil
} }
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return w.dialer.DialContext(ctx, network, destination) return w.dialer.DialContext(ctx, network, destination)
} }
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return w.dialer.ListenPacket(ctx, destination) return w.dialer.ListenPacket(ctx, destination)
} }
func (w *SystemDevice) Inet4Address() netip.Addr { func (w *systemDevice) SetDevice(device *device.Device) {
if len(w.inet4Addresses) == 0 {
return netip.Addr{}
}
return w.inet4Addresses[0].Addr()
} }
func (w *SystemDevice) Inet6Address() netip.Addr { func (w *systemDevice) Start() error {
if len(w.inet6Addresses) == 0 { networkManager := service.FromContext[adapter.NetworkManager](w.options.Context)
return netip.Addr{} tunOptions := tun.Options{
Name: w.options.Name,
Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is6()
}),
MTU: w.options.MTU,
GSO: w.options.GSO,
InterfaceScope: true,
Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }),
InterfaceMonitor: networkManager.InterfaceMonitor(),
InterfaceFinder: networkManager.InterfaceFinder(),
Logger: w.options.Logger,
} }
return w.inet6Addresses[0].Addr() // works with Linux, macOS with IFSCOPE routes, not tested on Windows
if runtime.GOOS == "darwin" {
tunOptions.AutoRoute = true
} }
tunInterface, err := tun.New(tunOptions)
func (w *SystemDevice) Start() error {
tunInterface, err := tun.New(tun.Options{
Name: w.name,
Inet4Address: w.inet4Addresses,
Inet6Address: w.inet6Addresses,
MTU: w.mtu,
GSO: w.gso,
})
if err != nil { if err != nil {
return err return err
} }
err = tunInterface.Start()
if err != nil {
return err
}
w.options.Logger.Info("started at ", w.options.Name)
w.device = tunInterface w.device = tunInterface
if w.gso { if w.options.GSO {
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN) batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
if !isBatchTUN { if !isBatchTUN {
tunInterface.Close() tunInterface.Close()
@ -107,15 +100,15 @@ func (w *SystemDevice) Start() error {
return nil return nil
} }
func (w *SystemDevice) File() *os.File { func (w *systemDevice) File() *os.File {
return nil return nil
} }
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
if w.batchDevice != nil { if w.batchDevice != nil {
count, err = w.batchDevice.BatchRead(bufs, offset, sizes) count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes)
} else { } else {
sizes[0], err = w.device.Read(bufs[0][offset:]) sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
if err == nil { if err == nil {
count = 1 count = 1
} else if errors.Is(err, tun.ErrTooManySegments) { } else if errors.Is(err, tun.ErrTooManySegments) {
@ -125,12 +118,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int,
return return
} }
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) { func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil { if w.batchDevice != nil {
return 0, w.batchDevice.BatchWrite(bufs, offset) return w.batchDevice.BatchWrite(bufs, offset)
} else { } else {
for _, b := range bufs { for _, packet := range bufs {
_, err = w.device.Write(b[offset:]) if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
if err != nil { if err != nil {
return return
} }
@ -140,28 +137,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return return
} }
func (w *SystemDevice) Flush() error { func (w *systemDevice) Flush() error {
return nil return nil
} }
func (w *SystemDevice) MTU() (int, error) { func (w *systemDevice) MTU() (int, error) {
return int(w.mtu), nil return int(w.options.MTU), nil
} }
func (w *SystemDevice) Name() (string, error) { func (w *systemDevice) Name() (string, error) {
return w.name, nil return w.options.Name, nil
} }
func (w *SystemDevice) Events() <-chan wgTun.Event { func (w *systemDevice) Events() <-chan wgTun.Event {
return w.events return w.events
} }
func (w *SystemDevice) Close() error { func (w *systemDevice) Close() error {
close(w.events) close(w.events)
return w.device.Close() return w.device.Close()
} }
func (w *SystemDevice) BatchSize() int { func (w *systemDevice) BatchSize() int {
if w.batchDevice != nil { if w.batchDevice != nil {
return w.batchDevice.BatchSize() return w.batchDevice.BatchSize()
} }

View file

@ -0,0 +1,182 @@
//go:build with_gvisor
package wireguard
import (
"net/netip"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/wireguard-go/device"
)
var _ Device = (*systemStackDevice)(nil)
type systemStackDevice struct {
*systemDevice
stack *stack.Stack
endpoint *deviceEndpoint
writeBufs [][]byte
}
func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
system, err := newSystemDevice(options)
if err != nil {
return nil, err
}
endpoint := &deviceEndpoint{
mtu: options.MTU,
done: make(chan struct{}),
}
ipStack, err := tun.NewGVisorStack(endpoint)
if err != nil {
return nil, err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
return &systemStackDevice{
systemDevice: system,
stack: ipStack,
endpoint: endpoint,
}, nil
}
func (w *systemStackDevice) SetDevice(device *device.Device) {
w.endpoint.device = device
}
func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil {
w.writeBufs = w.writeBufs[:0]
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
w.writeBufs = append(w.writeBufs, packet)
}
}
if len(w.writeBufs) > 0 {
return w.batchDevice.BatchWrite(bufs, offset)
}
} else {
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
}
if err != nil {
return
}
}
}
// WireGuard will not read count
return
}
func (w *systemStackDevice) Close() error {
close(w.endpoint.done)
w.stack.Close()
for _, endpoint := range w.stack.CleanupEndpoints() {
endpoint.Abort()
}
w.stack.Wait()
return w.systemDevice.Close()
}
func (w *systemStackDevice) writeStack(packet []byte) bool {
var (
networkProtocol tcpip.NetworkProtocolNumber
destination netip.Addr
)
switch header.IPVersion(packet) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
}
for _, prefix := range w.options.Address {
if prefix.Contains(destination) {
return false
}
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
packetBuffer.DecRef()
return true
}
type deviceEndpoint struct {
mtu uint32
done chan struct{}
device *device.Device
dispatcher stack.NetworkDispatcher
}
func (ep *deviceEndpoint) MTU() uint32 {
return ep.mtu
}
func (ep *deviceEndpoint) SetMTU(mtu uint32) {
}
func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityRXChecksumOffload
}
func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
ep.dispatcher = dispatcher
}
func (ep *deviceEndpoint) IsAttached() bool {
return ep.dispatcher != nil
}
func (ep *deviceEndpoint) Wait() {
}
func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
return true
}
func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
destination := packetBuffer.Network().DestinationAddress()
ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
}
return list.Len(), nil
}
func (ep *deviceEndpoint) Close() {
}
func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
}

View file

@ -1,35 +1,260 @@
package wireguard package wireguard
import ( import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip" "net/netip"
"os"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn" "github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
"go4.org/netipx"
) )
var _ conn.Endpoint = (*Endpoint)(nil) type Endpoint struct {
options EndpointOptions
type Endpoint netip.AddrPort peers []peerConfig
ipcConf string
func (e Endpoint) ClearSrc() { allowedAddress []netip.Prefix
tunDevice Device
device *device.Device
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
} }
func (e Endpoint) SrcToString() string { func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
return "" if options.PrivateKey == "" {
return nil, E.New("missing private key")
}
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey := hex.EncodeToString(privateKeyBytes)
ipcConf := "private_key=" + privateKey
if options.ListenPort != 0 {
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
}
var peers []peerConfig
for peerIndex, rawPeer := range options.Peers {
peer := peerConfig{
allowedIPs: rawPeer.AllowedIPs,
keepalive: rawPeer.PersistentKeepaliveInterval,
}
if rawPeer.Endpoint.Addr.IsValid() {
peer.endpoint = rawPeer.Endpoint.AddrPort()
} else if rawPeer.Endpoint.IsFqdn() {
peer.destination = rawPeer.Endpoint
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
if rawPeer.PreSharedKey != "" {
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
}
copy(peer.reserved[:], rawPeer.Reserved[:])
}
peers = append(peers, peer)
}
var allowedPrefixBuilder netipx.IPSetBuilder
for _, peer := range options.Peers {
for _, prefix := range peer.AllowedIPs {
allowedPrefixBuilder.AddPrefix(prefix)
}
}
allowedIPSet, err := allowedPrefixBuilder.IPSet()
if err != nil {
return nil, err
}
allowedAddresses := allowedIPSet.Prefixes()
if options.MTU == 0 {
options.MTU = 1408
}
deviceOptions := DeviceOptions{
Context: options.Context,
Logger: options.Logger,
System: options.System,
Handler: options.Handler,
UDPTimeout: options.UDPTimeout,
CreateDialer: options.CreateDialer,
Name: options.Name,
MTU: options.MTU,
GSO: options.GSO,
Address: options.Address,
AllowedAddress: allowedAddresses,
}
tunDevice, err := NewDevice(deviceOptions)
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
return &Endpoint{
options: options,
peers: peers,
ipcConf: ipcConf,
allowedAddress: allowedAddresses,
tunDevice: tunDevice,
}, nil
} }
func (e Endpoint) DstToString() string { func (e *Endpoint) Start(resolve bool) error {
return (netip.AddrPort)(e).String() if common.Any(e.peers, func(peer peerConfig) bool {
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
}) {
if !resolve {
return nil
}
for peerIndex, peer := range e.peers {
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
continue
}
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
if err != nil {
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
}
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
}
} else if resolve {
return nil
}
var bind conn.Bind
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
if isWgListener {
bind = conn.NewStdNetBind(wgListener)
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
if len(e.peers) == 1 {
isConnect = true
connectAddr = e.peers[0].endpoint
reserved = e.peers[0].reserved
}
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
}
if isWgListener || len(e.peers) > 1 {
for _, peer := range e.peers {
if peer.reserved != [3]uint8{} {
bind.SetReservedForEndpoint(peer.endpoint, peer.reserved)
}
}
}
err := e.tunDevice.Start()
if err != nil {
return err
}
logger := &device.Logger{
Verbosef: func(format string, args ...interface{}) {
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
e.tunDevice.SetDevice(wgDevice)
ipcConf := e.ipcConf
for _, peer := range e.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
e.device = wgDevice
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
if e.pauseManager != nil {
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
}
return nil
} }
func (e Endpoint) DstToBytes() []byte { func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
b, _ := (netip.AddrPort)(e).MarshalBinary() if !destination.Addr.IsValid() {
return b return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.DialContext(ctx, network, destination)
} }
func (e Endpoint) DstIP() netip.Addr { func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return (netip.AddrPort)(e).Addr() if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.ListenPacket(ctx, destination)
} }
func (e Endpoint) SrcIP() netip.Addr { func (e *Endpoint) BindUpdate() error {
return netip.Addr{} return e.device.BindUpdate()
}
func (e *Endpoint) Close() error {
if e.device != nil {
e.device.Close()
}
if e.pauseCallback != nil {
e.pauseManager.UnregisterCallback(e.pauseCallback)
}
return nil
}
func (e *Endpoint) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
e.device.Down()
case pause.EventDeviceWake:
e.device.Up()
}
}
type peerConfig struct {
destination M.Socksaddr
endpoint netip.AddrPort
publicKeyHex string
preSharedKeyHex string
allowedIPs []netip.Prefix
keepalive uint16
reserved [3]uint8
}
func (c peerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.publicKeyHex
if c.endpoint.IsValid() {
ipcLines += "\nendpoint=" + c.endpoint.String()
}
if c.preSharedKeyHex != "" {
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
}
for _, allowedIP := range c.allowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP.String()
}
if c.keepalive > 0 {
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
}
return ipcLines
} }

View file

@ -0,0 +1,40 @@
package wireguard
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type EndpointOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
Dialer N.Dialer
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
PrivateKey string
ListenPort uint16
ResolvePeer func(domain string) (netip.Addr, error)
Peers []PeerOptions
Workers int
}
type PeerOptions struct {
Endpoint M.Socksaddr
PublicKey string
PreSharedKey string
AllowedIPs []netip.Prefix
PersistentKeepaliveInterval uint16
Reserved []uint8
}

View file

@ -1,148 +0,0 @@
package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"net/netip"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type PeerConfig struct {
destination M.Socksaddr
domainStrategy dns.DomainStrategy
Endpoint netip.AddrPort
PublicKey string
PreSharedKey string
AllowedIPs []string
Reserved [3]uint8
}
func (c PeerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.PublicKey
ipcLines += "\nendpoint=" + c.Endpoint.String()
if c.PreSharedKey != "" {
ipcLines += "\npreshared_key=" + c.PreSharedKey
}
for _, allowedIP := range c.AllowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP
}
return ipcLines
}
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
var peers []PeerConfig
if len(options.Peers) > 0 {
for peerIndex, rawPeer := range options.Peers {
peer := PeerConfig{
AllowedIPs: rawPeer.AllowedIPs,
}
destination := rawPeer.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if rawPeer.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed_ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
} else {
peer := PeerConfig{}
var (
addressHas4 bool
addressHas6 bool
)
for _, localAddress := range options.LocalAddress {
if localAddress.Addr().Is4() {
addressHas4 = true
} else {
addressHas6 = true
}
}
if addressHas4 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
}
if addressHas6 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
}
destination := options.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
if err != nil {
return nil, E.Cause(err, "decode peer public key")
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if options.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key")
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(options.Reserved) > 0 {
if len(options.Reserved) != 3 {
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
return peers, nil
}
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
for peerIndex, peer := range peers {
if peer.Endpoint.IsValid() {
continue
}
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
if err != nil {
if len(peers) == 1 {
return E.Cause(err, "resolve endpoint domain")
} else {
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
}
}
if len(destinationAddresses) == 0 {
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
}
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
}
return nil
}