mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-23 17:11:29 +00:00
refactor: WireGuard endpoint
This commit is contained in:
parent
7dbc105f89
commit
e5bfd9e6b1
28
adapter/endpoint.go
Normal file
28
adapter/endpoint.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Endpoint interface {
|
||||||
|
Lifecycle
|
||||||
|
Type() string
|
||||||
|
Tag() string
|
||||||
|
Outbound
|
||||||
|
}
|
||||||
|
|
||||||
|
type EndpointRegistry interface {
|
||||||
|
option.EndpointOptionsRegistry
|
||||||
|
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) (Endpoint, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EndpointManager interface {
|
||||||
|
Lifecycle
|
||||||
|
Endpoints() []Endpoint
|
||||||
|
Get(tag string) (Endpoint, bool)
|
||||||
|
Remove(tag string) error
|
||||||
|
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) error
|
||||||
|
}
|
43
adapter/endpoint/adapter.go
Normal file
43
adapter/endpoint/adapter.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package endpoint
|
||||||
|
|
||||||
|
import "github.com/sagernet/sing-box/option"
|
||||||
|
|
||||||
|
type Adapter struct {
|
||||||
|
endpointType string
|
||||||
|
endpointTag string
|
||||||
|
network []string
|
||||||
|
dependencies []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAdapter(endpointType string, endpointTag string, network []string, dependencies []string) Adapter {
|
||||||
|
return Adapter{
|
||||||
|
endpointType: endpointType,
|
||||||
|
endpointTag: endpointTag,
|
||||||
|
network: network,
|
||||||
|
dependencies: dependencies,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAdapterWithDialerOptions(endpointType string, endpointTag string, network []string, dialOptions option.DialerOptions) Adapter {
|
||||||
|
var dependencies []string
|
||||||
|
if dialOptions.Detour != "" {
|
||||||
|
dependencies = []string{dialOptions.Detour}
|
||||||
|
}
|
||||||
|
return NewAdapter(endpointType, endpointTag, network, dependencies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adapter) Type() string {
|
||||||
|
return a.endpointType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adapter) Tag() string {
|
||||||
|
return a.endpointTag
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adapter) Network() []string {
|
||||||
|
return a.network
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adapter) Dependencies() []string {
|
||||||
|
return a.dependencies
|
||||||
|
}
|
147
adapter/endpoint/manager.go
Normal file
147
adapter/endpoint/manager.go
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
package endpoint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ adapter.EndpointManager = (*Manager)(nil)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
logger log.ContextLogger
|
||||||
|
registry adapter.EndpointRegistry
|
||||||
|
access sync.Mutex
|
||||||
|
started bool
|
||||||
|
stage adapter.StartStage
|
||||||
|
endpoints []adapter.Endpoint
|
||||||
|
endpointByTag map[string]adapter.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
logger: logger,
|
||||||
|
registry: registry,
|
||||||
|
endpointByTag: make(map[string]adapter.Endpoint),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(stage adapter.StartStage) error {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if m.started && m.stage >= stage {
|
||||||
|
panic("already started")
|
||||||
|
}
|
||||||
|
m.started = true
|
||||||
|
m.stage = stage
|
||||||
|
if stage == adapter.StartStateStart {
|
||||||
|
// started with outbound manager
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, endpoint := range m.endpoints {
|
||||||
|
err := adapter.LegacyStart(endpoint, stage)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Close() error {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if !m.started {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.started = false
|
||||||
|
endpoints := m.endpoints
|
||||||
|
m.endpoints = nil
|
||||||
|
monitor := taskmonitor.New(m.logger, C.StopTimeout)
|
||||||
|
var err error
|
||||||
|
for _, endpoint := range endpoints {
|
||||||
|
monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||||
|
err = E.Append(err, endpoint.Close(), func(err error) error {
|
||||||
|
return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||||
|
})
|
||||||
|
monitor.Finish()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Endpoints() []adapter.Endpoint {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
return m.endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
endpoint, found := m.endpointByTag[tag]
|
||||||
|
return endpoint, found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Remove(tag string) error {
|
||||||
|
m.access.Lock()
|
||||||
|
endpoint, found := m.endpointByTag[tag]
|
||||||
|
if !found {
|
||||||
|
m.access.Unlock()
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
delete(m.endpointByTag, tag)
|
||||||
|
index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
|
||||||
|
return it == endpoint
|
||||||
|
})
|
||||||
|
if index == -1 {
|
||||||
|
panic("invalid endpoint index")
|
||||||
|
}
|
||||||
|
m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
|
||||||
|
started := m.started
|
||||||
|
m.access.Unlock()
|
||||||
|
if started {
|
||||||
|
return endpoint.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
|
||||||
|
endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
if m.started {
|
||||||
|
for _, stage := range adapter.ListStartStages {
|
||||||
|
err = adapter.LegacyStart(endpoint, stage)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
|
||||||
|
if m.started {
|
||||||
|
err = existsEndpoint.Close()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
|
||||||
|
return it == existsEndpoint
|
||||||
|
})
|
||||||
|
if existsIndex == -1 {
|
||||||
|
panic("invalid endpoint index")
|
||||||
|
}
|
||||||
|
m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
|
||||||
|
}
|
||||||
|
m.endpoints = append(m.endpoints, endpoint)
|
||||||
|
m.endpointByTag[tag] = endpoint
|
||||||
|
return nil
|
||||||
|
}
|
72
adapter/endpoint/registry.go
Normal file
72
adapter/endpoint/registry.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package endpoint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options T) (adapter.Endpoint, error)
|
||||||
|
|
||||||
|
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
|
||||||
|
registry.register(outboundType, func() any {
|
||||||
|
return new(Options)
|
||||||
|
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Endpoint, error) {
|
||||||
|
var options *Options
|
||||||
|
if rawOptions != nil {
|
||||||
|
options = rawOptions.(*Options)
|
||||||
|
}
|
||||||
|
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ adapter.EndpointRegistry = (*Registry)(nil)
|
||||||
|
|
||||||
|
type (
|
||||||
|
optionsConstructorFunc func() any
|
||||||
|
constructorFunc func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Endpoint, error)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Registry struct {
|
||||||
|
access sync.Mutex
|
||||||
|
optionsType map[string]optionsConstructorFunc
|
||||||
|
constructor map[string]constructorFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
return &Registry{
|
||||||
|
optionsType: make(map[string]optionsConstructorFunc),
|
||||||
|
constructor: make(map[string]constructorFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Registry) CreateOptions(outboundType string) (any, bool) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
optionsConstructor, loaded := m.optionsType[outboundType]
|
||||||
|
if !loaded {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return optionsConstructor(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Endpoint, error) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
constructor, loaded := m.constructor[outboundType]
|
||||||
|
if !loaded {
|
||||||
|
return nil, E.New("outbound type not found: " + outboundType)
|
||||||
|
}
|
||||||
|
return constructor(ctx, router, logger, tag, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
|
||||||
|
m.access.Lock()
|
||||||
|
defer m.access.Unlock()
|
||||||
|
m.optionsType[outboundType] = optionsConstructor
|
||||||
|
m.constructor[outboundType] = constructor
|
||||||
|
}
|
|
@ -13,7 +13,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Inbound interface {
|
type Inbound interface {
|
||||||
Service
|
Lifecycle
|
||||||
Type() string
|
Type() string
|
||||||
Tag() string
|
Tag() string
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ var _ adapter.InboundManager = (*Manager)(nil)
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
registry adapter.InboundRegistry
|
registry adapter.InboundRegistry
|
||||||
|
endpoint adapter.EndpointManager
|
||||||
access sync.Mutex
|
access sync.Mutex
|
||||||
started bool
|
started bool
|
||||||
stage adapter.StartStage
|
stage adapter.StartStage
|
||||||
|
@ -25,10 +26,11 @@ type Manager struct {
|
||||||
inboundByTag map[string]adapter.Inbound
|
inboundByTag map[string]adapter.Inbound
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager {
|
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endpoint adapter.EndpointManager) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
registry: registry,
|
registry: registry,
|
||||||
|
endpoint: endpoint,
|
||||||
inboundByTag: make(map[string]adapter.Inbound),
|
inboundByTag: make(map[string]adapter.Inbound),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -79,9 +81,12 @@ func (m *Manager) Inbounds() []adapter.Inbound {
|
||||||
|
|
||||||
func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
|
func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
|
||||||
m.access.Lock()
|
m.access.Lock()
|
||||||
defer m.access.Unlock()
|
|
||||||
inbound, found := m.inboundByTag[tag]
|
inbound, found := m.inboundByTag[tag]
|
||||||
return inbound, found
|
m.access.Unlock()
|
||||||
|
if found {
|
||||||
|
return inbound, true
|
||||||
|
}
|
||||||
|
return m.endpoint.Get(tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Remove(tag string) error {
|
func (m *Manager) Remove(tag string) error {
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package adapter
|
package adapter
|
||||||
|
|
||||||
func LegacyStart(starter any, stage StartStage) error {
|
func LegacyStart(starter any, stage StartStage) error {
|
||||||
|
if lifecycle, isLifecycle := starter.(Lifecycle); isLifecycle {
|
||||||
|
return lifecycle.Start(stage)
|
||||||
|
}
|
||||||
switch stage {
|
switch stage {
|
||||||
case StartStateInitialize:
|
case StartStateInitialize:
|
||||||
if preStarter, isPreStarter := starter.(interface {
|
if preStarter, isPreStarter := starter.(interface {
|
||||||
|
|
|
@ -5,35 +5,35 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adapter struct {
|
type Adapter struct {
|
||||||
protocol string
|
outboundType string
|
||||||
|
outboundTag string
|
||||||
network []string
|
network []string
|
||||||
tag string
|
|
||||||
dependencies []string
|
dependencies []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAdapter(protocol string, network []string, tag string, dependencies []string) Adapter {
|
func NewAdapter(outboundType string, outboundTag string, network []string, dependencies []string) Adapter {
|
||||||
return Adapter{
|
return Adapter{
|
||||||
protocol: protocol,
|
outboundType: outboundType,
|
||||||
|
outboundTag: outboundTag,
|
||||||
network: network,
|
network: network,
|
||||||
tag: tag,
|
|
||||||
dependencies: dependencies,
|
dependencies: dependencies,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAdapterWithDialerOptions(protocol string, network []string, tag string, dialOptions option.DialerOptions) Adapter {
|
func NewAdapterWithDialerOptions(outboundType string, outboundTag string, network []string, dialOptions option.DialerOptions) Adapter {
|
||||||
var dependencies []string
|
var dependencies []string
|
||||||
if dialOptions.Detour != "" {
|
if dialOptions.Detour != "" {
|
||||||
dependencies = []string{dialOptions.Detour}
|
dependencies = []string{dialOptions.Detour}
|
||||||
}
|
}
|
||||||
return NewAdapter(protocol, network, tag, dependencies)
|
return NewAdapter(outboundType, outboundTag, network, dependencies)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adapter) Type() string {
|
func (a *Adapter) Type() string {
|
||||||
return a.protocol
|
return a.outboundType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adapter) Tag() string {
|
func (a *Adapter) Tag() string {
|
||||||
return a.tag
|
return a.outboundTag
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adapter) Network() []string {
|
func (a *Adapter) Network() []string {
|
||||||
|
|
|
@ -21,6 +21,7 @@ var _ adapter.OutboundManager = (*Manager)(nil)
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
registry adapter.OutboundRegistry
|
registry adapter.OutboundRegistry
|
||||||
|
endpoint adapter.EndpointManager
|
||||||
defaultTag string
|
defaultTag string
|
||||||
access sync.Mutex
|
access sync.Mutex
|
||||||
started bool
|
started bool
|
||||||
|
@ -32,10 +33,11 @@ type Manager struct {
|
||||||
defaultOutboundFallback adapter.Outbound
|
defaultOutboundFallback adapter.Outbound
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager {
|
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
registry: registry,
|
registry: registry,
|
||||||
|
endpoint: endpoint,
|
||||||
defaultTag: defaultTag,
|
defaultTag: defaultTag,
|
||||||
outboundByTag: make(map[string]adapter.Outbound),
|
outboundByTag: make(map[string]adapter.Outbound),
|
||||||
dependByTag: make(map[string][]string),
|
dependByTag: make(map[string][]string),
|
||||||
|
@ -56,7 +58,14 @@ func (m *Manager) Start(stage adapter.StartStage) error {
|
||||||
outbounds := m.outbounds
|
outbounds := m.outbounds
|
||||||
m.access.Unlock()
|
m.access.Unlock()
|
||||||
if stage == adapter.StartStateStart {
|
if stage == adapter.StartStateStart {
|
||||||
return m.startOutbounds(outbounds)
|
if m.defaultTag != "" && m.defaultOutbound == nil {
|
||||||
|
defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
|
||||||
|
if !loaded {
|
||||||
|
return E.New("default outbound not found: ", m.defaultTag)
|
||||||
|
}
|
||||||
|
m.defaultOutbound = defaultEndpoint
|
||||||
|
}
|
||||||
|
return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
|
||||||
} else {
|
} else {
|
||||||
for _, outbound := range outbounds {
|
for _, outbound := range outbounds {
|
||||||
err := adapter.LegacyStart(outbound, stage)
|
err := adapter.LegacyStart(outbound, stage)
|
||||||
|
@ -87,7 +96,14 @@ func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
|
||||||
}
|
}
|
||||||
started[outboundTag] = true
|
started[outboundTag] = true
|
||||||
canContinue = true
|
canContinue = true
|
||||||
if starter, isStarter := outboundToStart.(interface {
|
if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
|
||||||
|
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
|
||||||
|
err := starter.Start(adapter.StartStateStart)
|
||||||
|
monitor.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
|
||||||
|
}
|
||||||
|
} else if starter, isStarter := outboundToStart.(interface {
|
||||||
Start() error
|
Start() error
|
||||||
}); isStarter {
|
}); isStarter {
|
||||||
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
|
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
|
||||||
|
@ -160,9 +176,12 @@ func (m *Manager) Outbounds() []adapter.Outbound {
|
||||||
|
|
||||||
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
|
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
|
||||||
m.access.Lock()
|
m.access.Lock()
|
||||||
defer m.access.Unlock()
|
|
||||||
outbound, found := m.outboundByTag[tag]
|
outbound, found := m.outboundByTag[tag]
|
||||||
return outbound, found
|
m.access.Unlock()
|
||||||
|
if found {
|
||||||
|
return outbound, true
|
||||||
|
}
|
||||||
|
return m.endpoint.Get(tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Default() adapter.Outbound {
|
func (m *Manager) Default() adapter.Outbound {
|
||||||
|
|
51
box.go
51
box.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||||
"github.com/sagernet/sing-box/adapter/inbound"
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
@ -36,6 +37,7 @@ type Box struct {
|
||||||
logFactory log.Factory
|
logFactory log.Factory
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
network *route.NetworkManager
|
network *route.NetworkManager
|
||||||
|
endpoint *endpoint.Manager
|
||||||
inbound *inbound.Manager
|
inbound *inbound.Manager
|
||||||
outbound *outbound.Manager
|
outbound *outbound.Manager
|
||||||
connection *route.ConnectionManager
|
connection *route.ConnectionManager
|
||||||
|
@ -54,6 +56,7 @@ func Context(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
inboundRegistry adapter.InboundRegistry,
|
inboundRegistry adapter.InboundRegistry,
|
||||||
outboundRegistry adapter.OutboundRegistry,
|
outboundRegistry adapter.OutboundRegistry,
|
||||||
|
endpointRegistry adapter.EndpointRegistry,
|
||||||
) context.Context {
|
) context.Context {
|
||||||
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
|
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
|
||||||
service.FromContext[adapter.InboundRegistry](ctx) == nil {
|
service.FromContext[adapter.InboundRegistry](ctx) == nil {
|
||||||
|
@ -65,6 +68,11 @@ func Context(
|
||||||
ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry)
|
ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry)
|
||||||
ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry)
|
ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry)
|
||||||
}
|
}
|
||||||
|
if service.FromContext[option.EndpointOptionsRegistry](ctx) == nil ||
|
||||||
|
service.FromContext[adapter.EndpointRegistry](ctx) == nil {
|
||||||
|
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
|
||||||
|
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
|
||||||
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,12 +84,16 @@ func New(options Options) (*Box, error) {
|
||||||
}
|
}
|
||||||
ctx = service.ContextWithDefaultRegistry(ctx)
|
ctx = service.ContextWithDefaultRegistry(ctx)
|
||||||
|
|
||||||
|
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
|
||||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||||
|
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||||
|
|
||||||
|
if endpointRegistry == nil {
|
||||||
|
return nil, E.New("missing endpoint registry in context")
|
||||||
|
}
|
||||||
if inboundRegistry == nil {
|
if inboundRegistry == nil {
|
||||||
return nil, E.New("missing inbound registry in context")
|
return nil, E.New("missing inbound registry in context")
|
||||||
}
|
}
|
||||||
|
|
||||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
|
||||||
if outboundRegistry == nil {
|
if outboundRegistry == nil {
|
||||||
return nil, E.New("missing outbound registry in context")
|
return nil, E.New("missing outbound registry in context")
|
||||||
}
|
}
|
||||||
|
@ -119,8 +131,10 @@ func New(options Options) (*Box, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
routeOptions := common.PtrValueOrDefault(options.Route)
|
routeOptions := common.PtrValueOrDefault(options.Route)
|
||||||
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry)
|
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
|
||||||
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final)
|
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
|
||||||
|
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
|
||||||
|
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
|
||||||
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
|
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
|
||||||
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
|
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
|
||||||
|
|
||||||
|
@ -135,6 +149,24 @@ func New(options Options) (*Box, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "initialize router")
|
return nil, E.Cause(err, "initialize router")
|
||||||
}
|
}
|
||||||
|
for i, endpointOptions := range options.Endpoints {
|
||||||
|
var tag string
|
||||||
|
if endpointOptions.Tag != "" {
|
||||||
|
tag = endpointOptions.Tag
|
||||||
|
} else {
|
||||||
|
tag = F.ToString(i)
|
||||||
|
}
|
||||||
|
err = endpointManager.Create(ctx,
|
||||||
|
router,
|
||||||
|
logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")),
|
||||||
|
tag,
|
||||||
|
endpointOptions.Type,
|
||||||
|
endpointOptions.Options,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "initialize inbound[", i, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
for i, inboundOptions := range options.Inbounds {
|
for i, inboundOptions := range options.Inbounds {
|
||||||
var tag string
|
var tag string
|
||||||
if inboundOptions.Tag != "" {
|
if inboundOptions.Tag != "" {
|
||||||
|
@ -241,6 +273,7 @@ func New(options Options) (*Box, error) {
|
||||||
}
|
}
|
||||||
return &Box{
|
return &Box{
|
||||||
network: networkManager,
|
network: networkManager,
|
||||||
|
endpoint: endpointManager,
|
||||||
inbound: inboundManager,
|
inbound: inboundManager,
|
||||||
outbound: outboundManager,
|
outbound: outboundManager,
|
||||||
connection: connectionManager,
|
connection: connectionManager,
|
||||||
|
@ -303,7 +336,7 @@ func (s *Box) preStart() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound)
|
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound, s.endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -327,7 +360,11 @@ func (s *Box) start() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound)
|
err = adapter.Start(adapter.StartStateStart, s.endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound, s.endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -335,7 +372,7 @@ func (s *Box) start() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound)
|
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound, s.endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) {
|
||||||
configPaths = append(configPaths, "config.json")
|
configPaths = append(configPaths, "config.json")
|
||||||
}
|
}
|
||||||
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
|
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
|
||||||
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry())
|
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||||
}
|
}
|
||||||
|
|
|
@ -279,7 +279,7 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
|
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
|
||||||
return trackPacketConn(d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay))
|
return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address))
|
||||||
}
|
}
|
||||||
|
|
||||||
func trackConn(conn net.Conn, err error) (net.Conn, error) {
|
func trackConn(conn net.Conn, err error) (net.Conn, error) {
|
||||||
|
|
|
@ -109,6 +109,15 @@ var OptionDestinationOverrideFields = Note{
|
||||||
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options",
|
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var OptionWireGuardOutbound = Note{
|
||||||
|
Name: "wireguard-outbound",
|
||||||
|
Description: "legacy wireguard outbound",
|
||||||
|
DeprecatedVersion: "1.11.0",
|
||||||
|
ScheduledVersion: "1.13.0",
|
||||||
|
EnvName: "WIREGUARD_OUTBOUND",
|
||||||
|
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-wireguard-outbound-to-endpoint",
|
||||||
|
}
|
||||||
|
|
||||||
var Options = []Note{
|
var Options = []Note{
|
||||||
OptionBadMatchSource,
|
OptionBadMatchSource,
|
||||||
OptionGEOIP,
|
OptionGEOIP,
|
||||||
|
@ -117,4 +126,5 @@ var Options = []Note{
|
||||||
OptionSpecialOutbounds,
|
OptionSpecialOutbounds,
|
||||||
OptionInboundOptions,
|
OptionInboundOptions,
|
||||||
OptionDestinationOverrideFields,
|
OptionDestinationOverrideFields,
|
||||||
|
OptionWireGuardOutbound,
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckConfig(configContent string) error {
|
func CheckConfig(configContent string) error {
|
||||||
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
|
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||||
options, err := parseConfig(ctx, configContent)
|
options, err := parseConfig(ctx, configContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -131,7 +131,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatConfig(configContent string) (string, error) {
|
func FormatConfig(configContent string) (string, error) {
|
||||||
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()), configContent)
|
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ type BoxService struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
|
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
|
||||||
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
|
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||||
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
|
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
|
||||||
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
|
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
|
||||||
options, err := parseConfig(ctx, configContent)
|
options, err := parseConfig(ctx, configContent)
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -36,7 +36,7 @@ require (
|
||||||
github.com/sagernet/sing-vmess v0.1.12
|
github.com/sagernet/sing-vmess v0.1.12
|
||||||
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
|
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
|
||||||
github.com/sagernet/utls v1.6.7
|
github.com/sagernet/utls v1.6.7
|
||||||
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8
|
github.com/sagernet/wireguard-go v0.0.1-beta.4
|
||||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
|
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
|
||||||
github.com/spf13/cobra v1.8.1
|
github.com/spf13/cobra v1.8.1
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
|
@ -92,6 +92,7 @@ require (
|
||||||
golang.org/x/text v0.20.0 // indirect
|
golang.org/x/text v0.20.0 // indirect
|
||||||
golang.org/x/time v0.7.0 // indirect
|
golang.org/x/time v0.7.0 // indirect
|
||||||
golang.org/x/tools v0.24.0 // indirect
|
golang.org/x/tools v0.24.0 // indirect
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
|
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -132,8 +132,8 @@ github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxe
|
||||||
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
|
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
|
||||||
github.com/sagernet/utls v1.6.7 h1:Ep3+aJ8FUGGta+II2IEVNUc3EDhaRCZINWkj/LloIA8=
|
github.com/sagernet/utls v1.6.7 h1:Ep3+aJ8FUGGta+II2IEVNUc3EDhaRCZINWkj/LloIA8=
|
||||||
github.com/sagernet/utls v1.6.7/go.mod h1:Uua1TKO/FFuAhLr9rkaVnnrTmmiItzDjv1BUb2+ERwM=
|
github.com/sagernet/utls v1.6.7/go.mod h1:Uua1TKO/FFuAhLr9rkaVnnrTmmiItzDjv1BUb2+ERwM=
|
||||||
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8 h1:R0OMYAScomNAVpTfbHFpxqJpvwuhxSRi+g6z7gZhABs=
|
github.com/sagernet/wireguard-go v0.0.1-beta.4 h1:8uyM5fxfEXdu4RH05uOK+v25i3lTNdCYMPSAUJ14FnI=
|
||||||
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8/go.mod h1:K4J7/npM+VAMUeUmTa2JaA02JmyheP0GpRBOUvn3ecc=
|
github.com/sagernet/wireguard-go v0.0.1-beta.4/go.mod h1:jGXij2Gn2wbrWuYNUmmNhf1dwcZtvyAvQoe8Xd8MbUo=
|
||||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
|
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
|
||||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
|
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
|
||||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||||
|
@ -195,6 +195,8 @@ golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY=
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||||
"github.com/sagernet/sing-box/adapter/inbound"
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
@ -82,6 +83,14 @@ func OutboundRegistry() *outbound.Registry {
|
||||||
return registry
|
return registry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EndpointRegistry() *endpoint.Registry {
|
||||||
|
registry := endpoint.NewRegistry()
|
||||||
|
|
||||||
|
registerWireGuardEndpoint(registry)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
}
|
||||||
|
|
||||||
func registerStubForRemovedInbounds(registry *inbound.Registry) {
|
func registerStubForRemovedInbounds(registry *inbound.Registry) {
|
||||||
inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
|
inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
|
||||||
return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0")
|
return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0")
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
package include
|
package include
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
"github.com/sagernet/sing-box/protocol/wireguard"
|
"github.com/sagernet/sing-box/protocol/wireguard"
|
||||||
)
|
)
|
||||||
|
@ -10,3 +11,7 @@ import (
|
||||||
func registerWireGuardOutbound(registry *outbound.Registry) {
|
func registerWireGuardOutbound(registry *outbound.Registry) {
|
||||||
wireguard.RegisterOutbound(registry)
|
wireguard.RegisterOutbound(registry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerWireGuardEndpoint(registry *endpoint.Registry) {
|
||||||
|
wireguard.RegisterEndpoint(registry)
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
|
@ -14,7 +15,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func registerWireGuardOutbound(registry *outbound.Registry) {
|
func registerWireGuardOutbound(registry *outbound.Registry) {
|
||||||
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
|
outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
|
||||||
|
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerWireGuardEndpoint(registry *endpoint.Registry) {
|
||||||
|
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
|
||||||
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
|
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
47
option/endpoint.go
Normal file
47
option/endpoint.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package option
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/json"
|
||||||
|
"github.com/sagernet/sing/common/json/badjson"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EndpointOptionsRegistry interface {
|
||||||
|
CreateOptions(endpointType string) (any, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type _Endpoint struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Tag string `json:"tag,omitempty"`
|
||||||
|
Options any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Endpoint _Endpoint
|
||||||
|
|
||||||
|
func (h *Endpoint) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||||
|
return badjson.MarshallObjectsContext(ctx, (*_Endpoint)(h), h.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Endpoint) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||||
|
err := json.UnmarshalContext(ctx, content, (*_Endpoint)(h))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registry := service.FromContext[EndpointOptionsRegistry](ctx)
|
||||||
|
if registry == nil {
|
||||||
|
return E.New("missing Endpoint fields registry in context")
|
||||||
|
}
|
||||||
|
options, loaded := registry.CreateOptions(h.Type)
|
||||||
|
if !loaded {
|
||||||
|
return E.New("unknown inbound type: ", h.Type)
|
||||||
|
}
|
||||||
|
err = badjson.UnmarshallExcludedContext(ctx, content, (*_Endpoint)(h), options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.Options = options
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -28,7 +28,7 @@ func (h *Inbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||||
err := json.Unmarshal(content, (*_Inbound)(h))
|
err := json.UnmarshalContext(ctx, content, (*_Inbound)(h))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ type _Options struct {
|
||||||
Log *LogOptions `json:"log,omitempty"`
|
Log *LogOptions `json:"log,omitempty"`
|
||||||
DNS *DNSOptions `json:"dns,omitempty"`
|
DNS *DNSOptions `json:"dns,omitempty"`
|
||||||
NTP *NTPOptions `json:"ntp,omitempty"`
|
NTP *NTPOptions `json:"ntp,omitempty"`
|
||||||
|
Endpoints []Endpoint `json:"endpoints,omitempty"`
|
||||||
Inbounds []Inbound `json:"inbounds,omitempty"`
|
Inbounds []Inbound `json:"inbounds,omitempty"`
|
||||||
Outbounds []Outbound `json:"outbounds,omitempty"`
|
Outbounds []Outbound `json:"outbounds,omitempty"`
|
||||||
Route *RouteOptions `json:"route,omitempty"`
|
Route *RouteOptions `json:"route,omitempty"`
|
||||||
|
|
|
@ -30,7 +30,7 @@ func (h *Outbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||||
err := json.Unmarshal(content, (*_Outbound)(h))
|
err := json.UnmarshalContext(ctx, content, (*_Outbound)(h))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,14 +6,38 @@ import (
|
||||||
"github.com/sagernet/sing/common/json/badoption"
|
"github.com/sagernet/sing/common/json/badoption"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WireGuardOutboundOptions struct {
|
type WireGuardEndpointOptions struct {
|
||||||
|
System bool `json:"system,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
MTU uint32 `json:"mtu,omitempty"`
|
||||||
|
GSO bool `json:"gso,omitempty"`
|
||||||
|
Address badoption.Listable[netip.Prefix] `json:"address"`
|
||||||
|
PrivateKey string `json:"private_key"`
|
||||||
|
ListenPort uint16 `json:"listen_port,omitempty"`
|
||||||
|
Peers []WireGuardPeer `json:"peers,omitempty"`
|
||||||
|
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
|
||||||
|
Workers int `json:"workers,omitempty"`
|
||||||
|
DialerOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type WireGuardPeer struct {
|
||||||
|
Address string `json:"address,omitempty"`
|
||||||
|
Port uint16 `json:"port,omitempty"`
|
||||||
|
PublicKey string `json:"public_key,omitempty"`
|
||||||
|
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
||||||
|
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
|
||||||
|
PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"`
|
||||||
|
Reserved []uint8 `json:"reserved,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LegacyWireGuardOutboundOptions struct {
|
||||||
DialerOptions
|
DialerOptions
|
||||||
SystemInterface bool `json:"system_interface,omitempty"`
|
SystemInterface bool `json:"system_interface,omitempty"`
|
||||||
GSO bool `json:"gso,omitempty"`
|
GSO bool `json:"gso,omitempty"`
|
||||||
InterfaceName string `json:"interface_name,omitempty"`
|
InterfaceName string `json:"interface_name,omitempty"`
|
||||||
LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"`
|
LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"`
|
||||||
PrivateKey string `json:"private_key"`
|
PrivateKey string `json:"private_key"`
|
||||||
Peers []WireGuardPeer `json:"peers,omitempty"`
|
Peers []LegacyWireGuardPeer `json:"peers,omitempty"`
|
||||||
ServerOptions
|
ServerOptions
|
||||||
PeerPublicKey string `json:"peer_public_key"`
|
PeerPublicKey string `json:"peer_public_key"`
|
||||||
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
||||||
|
@ -23,10 +47,10 @@ type WireGuardOutboundOptions struct {
|
||||||
Network NetworkList `json:"network,omitempty"`
|
Network NetworkList `json:"network,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type WireGuardPeer struct {
|
type LegacyWireGuardPeer struct {
|
||||||
ServerOptions
|
ServerOptions
|
||||||
PublicKey string `json:"public_key,omitempty"`
|
PublicKey string `json:"public_key,omitempty"`
|
||||||
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
||||||
AllowedIPs badoption.Listable[string] `json:"allowed_ips,omitempty"`
|
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
|
||||||
Reserved []uint8 `json:"reserved,omitempty"`
|
Reserved []uint8 `json:"reserved,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ type Outbound struct {
|
||||||
|
|
||||||
func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) {
|
func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) {
|
||||||
return &Outbound{
|
return &Outbound{
|
||||||
Adapter: outbound.NewAdapter(C.TypeBlock, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil),
|
Adapter: outbound.NewAdapter(C.TypeBlock, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Inbound) Start() error {
|
func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return i.listener.Start()
|
return i.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
domainStrategy: dns.DomainStrategy(options.DomainStrategy),
|
domainStrategy: dns.DomainStrategy(options.DomainStrategy),
|
||||||
fallbackDelay: time.Duration(options.FallbackDelay),
|
fallbackDelay: time.Duration(options.FallbackDelay),
|
||||||
|
|
|
@ -28,7 +28,7 @@ type Outbound struct {
|
||||||
|
|
||||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
|
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
|
||||||
return &Outbound{
|
return &Outbound{
|
||||||
Adapter: outbound.NewAdapter(C.TypeDNS, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil),
|
Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
|
||||||
router: router,
|
router: router,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
|
@ -38,7 +38,7 @@ type Selector struct {
|
||||||
|
|
||||||
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
|
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
|
||||||
outbound := &Selector{
|
outbound := &Selector{
|
||||||
Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds),
|
Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
|
|
@ -49,7 +49,7 @@ type URLTest struct {
|
||||||
|
|
||||||
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
|
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
|
||||||
outbound := &URLTest{
|
outbound := &URLTest{
|
||||||
Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds),
|
Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
router: router,
|
router: router,
|
||||||
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
||||||
|
|
|
@ -61,7 +61,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if h.tlsConfig != nil {
|
if h.tlsConfig != nil {
|
||||||
err := h.tlsConfig.Start()
|
err := h.tlsConfig.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -39,7 +39,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Outbound{
|
return &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: sHTTP.NewClient(sHTTP.Options{
|
client: sHTTP.NewClient(sHTTP.Options{
|
||||||
Dialer: detour,
|
Dialer: detour,
|
||||||
|
|
|
@ -160,7 +160,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
||||||
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if h.tlsConfig != nil {
|
if h.tlsConfig != nil {
|
||||||
err := h.tlsConfig.Start()
|
err := h.tlsConfig.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -95,7 +95,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Outbound{
|
return &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, networkList, tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: client,
|
client: client,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
|
@ -171,7 +171,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
||||||
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if h.tlsConfig != nil {
|
if h.tlsConfig != nil {
|
||||||
err := h.tlsConfig.Start()
|
err := h.tlsConfig.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -81,7 +81,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Outbound{
|
return &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, networkList, tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: client,
|
client: client,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
|
@ -54,7 +54,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return h.listener.Start()
|
return h.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Inbound) Start() error {
|
func (n *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
var tlsConfig *tls.STDConfig
|
var tlsConfig *tls.STDConfig
|
||||||
if n.tlsConfig != nil {
|
if n.tlsConfig != nil {
|
||||||
err := n.tlsConfig.Start()
|
err := n.tlsConfig.Start()
|
||||||
|
|
|
@ -42,7 +42,10 @@ func NewRedirect(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return redirect, nil
|
return redirect, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Redirect) Start() error {
|
func (h *Redirect) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return h.listener.Start()
|
return h.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,10 @@ func NewTProxy(ctx context.Context, router adapter.Router, logger log.ContextLog
|
||||||
return tproxy, nil
|
return tproxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TProxy) Start() error {
|
func (t *TProxy) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
err := t.listener.Start()
|
err := t.listener.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -93,7 +93,10 @@ func newInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, err
|
return inbound, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return h.listener.Start()
|
return h.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,10 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont
|
||||||
return inbound, err
|
return inbound, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MultiInbound) Start() error {
|
func (h *MultiInbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return h.listener.Start()
|
return h.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,10 @@ func newRelayInbound(ctx context.Context, router adapter.Router, logger log.Cont
|
||||||
return inbound, err
|
return inbound, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RelayInbound) Start() error {
|
func (h *RelayInbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return h.listener.Start()
|
return h.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, options.Network.Build(), tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, tag, options.Network.Build(), options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
dialer: outboundDialer,
|
dialer: outboundDialer,
|
||||||
method: method,
|
method: method,
|
||||||
|
|
|
@ -90,7 +90,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return h.listener.Start()
|
return h.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ type Outbound struct {
|
||||||
|
|
||||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) {
|
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) {
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||||
}
|
}
|
||||||
if options.TLS == nil || !options.TLS.Enabled {
|
if options.TLS == nil || !options.TLS.Enabled {
|
||||||
return nil, C.ErrTLSRequired
|
return nil, C.ErrTLSRequired
|
||||||
|
|
|
@ -50,7 +50,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return h.listener.Start()
|
return h.listener.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, options.Network.Build(), tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
|
||||||
router: router,
|
router: router,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),
|
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),
|
||||||
|
|
|
@ -54,7 +54,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
dialer: outboundDialer,
|
dialer: outboundDialer,
|
||||||
|
|
|
@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Outbound{
|
return &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
proxy: NewProxyListener(ctx, logger, outboundDialer),
|
proxy: NewProxyListener(ctx, logger, outboundDialer),
|
||||||
|
|
|
@ -110,7 +110,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if h.tlsConfig != nil {
|
if h.tlsConfig != nil {
|
||||||
err := h.tlsConfig.Start()
|
err := h.tlsConfig.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -43,7 +43,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, options.Network.Build(), tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, tag, options.Network.Build(), options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
dialer: outboundDialer,
|
dialer: outboundDialer,
|
||||||
serverAddr: options.ServerOptions.Build(),
|
serverAddr: options.ServerOptions.Build(),
|
||||||
|
|
|
@ -142,7 +142,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
||||||
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if h.tlsConfig != nil {
|
if h.tlsConfig != nil {
|
||||||
err := h.tlsConfig.Start()
|
err := h.tlsConfig.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Outbound{
|
return &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, options.Network.Build(), tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, tag, options.Network.Build(), options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: client,
|
client: client,
|
||||||
udpStream: options.UDPOverStream,
|
udpStream: options.UDPOverStream,
|
||||||
|
|
|
@ -300,104 +300,104 @@ func (t *Inbound) Tag() string {
|
||||||
return t.tag
|
return t.tag
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Inbound) Start() error {
|
func (t *Inbound) Start(stage adapter.StartStage) error {
|
||||||
if C.IsAndroid && t.platformInterface == nil {
|
switch stage {
|
||||||
t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager())
|
case adapter.StartStateStart:
|
||||||
}
|
if C.IsAndroid && t.platformInterface == nil {
|
||||||
if t.tunOptions.Name == "" {
|
t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager())
|
||||||
t.tunOptions.Name = tun.CalculateInterfaceName("")
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
tunInterface tun.Tun
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
monitor := taskmonitor.New(t.logger, C.StartTimeout)
|
|
||||||
monitor.Start("open tun interface")
|
|
||||||
if t.platformInterface != nil {
|
|
||||||
tunInterface, err = t.platformInterface.OpenTun(&t.tunOptions, t.platformOptions)
|
|
||||||
} else {
|
|
||||||
tunInterface, err = tun.New(t.tunOptions)
|
|
||||||
}
|
|
||||||
monitor.Finish()
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "configure tun interface")
|
|
||||||
}
|
|
||||||
t.logger.Trace("creating stack")
|
|
||||||
t.tunIf = tunInterface
|
|
||||||
var (
|
|
||||||
forwarderBindInterface bool
|
|
||||||
includeAllNetworks bool
|
|
||||||
)
|
|
||||||
if t.platformInterface != nil {
|
|
||||||
forwarderBindInterface = true
|
|
||||||
includeAllNetworks = t.platformInterface.IncludeAllNetworks()
|
|
||||||
}
|
|
||||||
tunStack, err := tun.NewStack(t.stack, tun.StackOptions{
|
|
||||||
Context: t.ctx,
|
|
||||||
Tun: tunInterface,
|
|
||||||
TunOptions: t.tunOptions,
|
|
||||||
UDPTimeout: t.udpTimeout,
|
|
||||||
Handler: t,
|
|
||||||
Logger: t.logger,
|
|
||||||
ForwarderBindInterface: forwarderBindInterface,
|
|
||||||
InterfaceFinder: t.networkManager.InterfaceFinder(),
|
|
||||||
IncludeAllNetworks: includeAllNetworks,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.tunStack = tunStack
|
|
||||||
t.logger.Info("started at ", t.tunOptions.Name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Inbound) PostStart() error {
|
|
||||||
monitor := taskmonitor.New(t.logger, C.StartTimeout)
|
|
||||||
monitor.Start("starting tun stack")
|
|
||||||
err := t.tunStack.Start()
|
|
||||||
monitor.Finish()
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "starting tun stack")
|
|
||||||
}
|
|
||||||
monitor.Start("starting tun interface")
|
|
||||||
err = t.tunIf.Start()
|
|
||||||
monitor.Finish()
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "starting TUN interface")
|
|
||||||
}
|
|
||||||
if t.autoRedirect != nil {
|
|
||||||
t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
|
|
||||||
for _, routeRuleSet := range t.routeRuleSet {
|
|
||||||
ipSets := routeRuleSet.ExtractIPSet()
|
|
||||||
if len(ipSets) == 0 {
|
|
||||||
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name())
|
|
||||||
}
|
|
||||||
t.routeAddressSet = append(t.routeAddressSet, ipSets...)
|
|
||||||
}
|
}
|
||||||
t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
|
if t.tunOptions.Name == "" {
|
||||||
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
t.tunOptions.Name = tun.CalculateInterfaceName("")
|
||||||
ipSets := routeExcludeRuleSet.ExtractIPSet()
|
}
|
||||||
if len(ipSets) == 0 {
|
var (
|
||||||
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
tunInterface tun.Tun
|
||||||
}
|
err error
|
||||||
t.routeExcludeAddressSet = append(t.routeExcludeAddressSet, ipSets...)
|
)
|
||||||
|
monitor := taskmonitor.New(t.logger, C.StartTimeout)
|
||||||
|
monitor.Start("open tun interface")
|
||||||
|
if t.platformInterface != nil {
|
||||||
|
tunInterface, err = t.platformInterface.OpenTun(&t.tunOptions, t.platformOptions)
|
||||||
|
} else {
|
||||||
|
tunInterface, err = tun.New(t.tunOptions)
|
||||||
}
|
}
|
||||||
monitor.Start("initialize auto-redirect")
|
|
||||||
err := t.autoRedirect.Start()
|
|
||||||
monitor.Finish()
|
monitor.Finish()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "auto-redirect")
|
return E.Cause(err, "configure tun interface")
|
||||||
}
|
}
|
||||||
for _, routeRuleSet := range t.routeRuleSet {
|
t.logger.Trace("creating stack")
|
||||||
t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
t.tunIf = tunInterface
|
||||||
routeRuleSet.DecRef()
|
var (
|
||||||
|
forwarderBindInterface bool
|
||||||
|
includeAllNetworks bool
|
||||||
|
)
|
||||||
|
if t.platformInterface != nil {
|
||||||
|
forwarderBindInterface = true
|
||||||
|
includeAllNetworks = t.platformInterface.IncludeAllNetworks()
|
||||||
}
|
}
|
||||||
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
tunStack, err := tun.NewStack(t.stack, tun.StackOptions{
|
||||||
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
Context: t.ctx,
|
||||||
routeExcludeRuleSet.DecRef()
|
Tun: tunInterface,
|
||||||
|
TunOptions: t.tunOptions,
|
||||||
|
UDPTimeout: t.udpTimeout,
|
||||||
|
Handler: t,
|
||||||
|
Logger: t.logger,
|
||||||
|
ForwarderBindInterface: forwarderBindInterface,
|
||||||
|
InterfaceFinder: t.networkManager.InterfaceFinder(),
|
||||||
|
IncludeAllNetworks: includeAllNetworks,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.tunStack = tunStack
|
||||||
|
t.logger.Info("started at ", t.tunOptions.Name)
|
||||||
|
case adapter.StartStatePostStart:
|
||||||
|
monitor := taskmonitor.New(t.logger, C.StartTimeout)
|
||||||
|
monitor.Start("starting tun stack")
|
||||||
|
err := t.tunStack.Start()
|
||||||
|
monitor.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "starting tun stack")
|
||||||
|
}
|
||||||
|
monitor.Start("starting tun interface")
|
||||||
|
err = t.tunIf.Start()
|
||||||
|
monitor.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "starting TUN interface")
|
||||||
|
}
|
||||||
|
if t.autoRedirect != nil {
|
||||||
|
t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||||
|
for _, routeRuleSet := range t.routeRuleSet {
|
||||||
|
ipSets := routeRuleSet.ExtractIPSet()
|
||||||
|
if len(ipSets) == 0 {
|
||||||
|
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name())
|
||||||
|
}
|
||||||
|
t.routeAddressSet = append(t.routeAddressSet, ipSets...)
|
||||||
|
}
|
||||||
|
t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet)
|
||||||
|
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
||||||
|
ipSets := routeExcludeRuleSet.ExtractIPSet()
|
||||||
|
if len(ipSets) == 0 {
|
||||||
|
t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name())
|
||||||
|
}
|
||||||
|
t.routeExcludeAddressSet = append(t.routeExcludeAddressSet, ipSets...)
|
||||||
|
}
|
||||||
|
monitor.Start("initialize auto-redirect")
|
||||||
|
err := t.autoRedirect.Start()
|
||||||
|
monitor.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "auto-redirect")
|
||||||
|
}
|
||||||
|
for _, routeRuleSet := range t.routeRuleSet {
|
||||||
|
t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||||
|
routeRuleSet.DecRef()
|
||||||
|
}
|
||||||
|
for _, routeExcludeRuleSet := range t.routeExcludeRuleSet {
|
||||||
|
t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet))
|
||||||
|
routeExcludeRuleSet.DecRef()
|
||||||
|
}
|
||||||
|
t.routeAddressSet = nil
|
||||||
|
t.routeExcludeAddressSet = nil
|
||||||
}
|
}
|
||||||
t.routeAddressSet = nil
|
|
||||||
t.routeExcludeAddressSet = nil
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if h.tlsConfig != nil {
|
if h.tlsConfig != nil {
|
||||||
err := h.tlsConfig.Start()
|
err := h.tlsConfig.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, options.Network.Build(), tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, tag, options.Network.Build(), options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
dialer: outboundDialer,
|
dialer: outboundDialer,
|
||||||
serverAddr: options.ServerOptions.Build(),
|
serverAddr: options.ServerOptions.Build(),
|
||||||
|
|
|
@ -99,7 +99,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||||
return inbound, nil
|
return inbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) Start() error {
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
err := h.service.Start()
|
err := h.service.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, options.Network.Build(), tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, tag, options.Network.Build(), options.DialerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
dialer: outboundDialer,
|
dialer: outboundDialer,
|
||||||
serverAddr: options.ServerOptions.Build(),
|
serverAddr: options.ServerOptions.Build(),
|
||||||
|
|
211
protocol/wireguard/endpoint.go
Normal file
211
protocol/wireguard/endpoint.go
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/transport/wireguard"
|
||||||
|
"github.com/sagernet/sing-dns"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisterEndpoint(registry *endpoint.Registry) {
|
||||||
|
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ adapter.Endpoint = (*Endpoint)(nil)
|
||||||
|
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Endpoint struct {
|
||||||
|
endpoint.Adapter
|
||||||
|
ctx context.Context
|
||||||
|
router adapter.Router
|
||||||
|
logger logger.ContextLogger
|
||||||
|
localAddresses []netip.Prefix
|
||||||
|
endpoint *wireguard.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
|
||||||
|
ep := &Endpoint{
|
||||||
|
Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||||
|
ctx: ctx,
|
||||||
|
router: router,
|
||||||
|
logger: logger,
|
||||||
|
localAddresses: options.Address,
|
||||||
|
}
|
||||||
|
if options.Detour == "" {
|
||||||
|
options.IsWireGuardListener = true
|
||||||
|
} else if options.GSO {
|
||||||
|
return nil, E.New("gso is conflict with detour")
|
||||||
|
}
|
||||||
|
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
|
||||||
|
Context: ctx,
|
||||||
|
Logger: logger,
|
||||||
|
System: options.System,
|
||||||
|
Handler: ep,
|
||||||
|
UDPTimeout: time.Duration(options.UDPTimeout),
|
||||||
|
Dialer: outboundDialer,
|
||||||
|
CreateDialer: func(interfaceName string) N.Dialer {
|
||||||
|
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
|
||||||
|
BindInterface: interfaceName,
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
Name: options.Name,
|
||||||
|
MTU: options.MTU,
|
||||||
|
GSO: options.GSO,
|
||||||
|
Address: options.Address,
|
||||||
|
PrivateKey: options.PrivateKey,
|
||||||
|
ListenPort: options.ListenPort,
|
||||||
|
ResolvePeer: func(domain string) (netip.Addr, error) {
|
||||||
|
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
|
||||||
|
if lookupErr != nil {
|
||||||
|
return netip.Addr{}, lookupErr
|
||||||
|
}
|
||||||
|
return endpointAddresses[0], nil
|
||||||
|
},
|
||||||
|
Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions {
|
||||||
|
return wireguard.PeerOptions{
|
||||||
|
Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port),
|
||||||
|
PublicKey: it.PublicKey,
|
||||||
|
PreSharedKey: it.PreSharedKey,
|
||||||
|
AllowedIPs: it.AllowedIPs,
|
||||||
|
PersistentKeepaliveInterval: it.PersistentKeepaliveInterval,
|
||||||
|
Reserved: it.Reserved,
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
Workers: options.Workers,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ep.endpoint = wgEndpoint
|
||||||
|
return ep, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) Start(stage adapter.StartStage) error {
|
||||||
|
switch stage {
|
||||||
|
case adapter.StartStateStart:
|
||||||
|
return w.endpoint.Start(false)
|
||||||
|
case adapter.StartStatePostStart:
|
||||||
|
return w.endpoint.Start(true)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) Close() error {
|
||||||
|
return w.endpoint.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) InterfaceUpdated() {
|
||||||
|
w.endpoint.BindUpdate()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
|
||||||
|
return w.router.PreMatch(adapter.InboundContext{
|
||||||
|
Inbound: w.Tag(),
|
||||||
|
InboundType: w.Type(),
|
||||||
|
Network: network,
|
||||||
|
Source: source,
|
||||||
|
Destination: destination,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Inbound = w.Tag()
|
||||||
|
metadata.InboundType = w.Type()
|
||||||
|
metadata.Source = source
|
||||||
|
for _, localPrefix := range w.localAddresses {
|
||||||
|
if localPrefix.Contains(destination.Addr) {
|
||||||
|
metadata.OriginDestination = destination
|
||||||
|
if destination.Addr.Is4() {
|
||||||
|
destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
|
||||||
|
} else {
|
||||||
|
destination.Addr = netip.IPv6Loopback()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata.Destination = destination
|
||||||
|
w.logger.InfoContext(ctx, "inbound connection from ", source)
|
||||||
|
w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
|
w.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Inbound = w.Tag()
|
||||||
|
metadata.InboundType = w.Type()
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
for _, localPrefix := range w.localAddresses {
|
||||||
|
if localPrefix.Contains(destination.Addr) {
|
||||||
|
metadata.OriginDestination = destination
|
||||||
|
if destination.Addr.Is4() {
|
||||||
|
metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
|
||||||
|
} else {
|
||||||
|
metadata.Destination.Addr = netip.IPv6Loopback()
|
||||||
|
}
|
||||||
|
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.logger.InfoContext(ctx, "inbound packet connection from ", source)
|
||||||
|
w.logger.InfoContext(ctx, "inbound packet connection to ", destination)
|
||||||
|
w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
switch network {
|
||||||
|
case N.NetworkTCP:
|
||||||
|
w.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||||
|
case N.NetworkUDP:
|
||||||
|
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||||
|
}
|
||||||
|
if destination.IsFqdn() {
|
||||||
|
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses)
|
||||||
|
} else if !destination.Addr.IsValid() {
|
||||||
|
return nil, E.New("invalid destination: ", destination)
|
||||||
|
}
|
||||||
|
return w.endpoint.DialContext(ctx, network, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
|
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||||
|
if destination.IsFqdn() {
|
||||||
|
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return packetConn, err
|
||||||
|
}
|
||||||
|
return w.endpoint.ListenPacket(ctx, destination)
|
||||||
|
}
|
|
@ -2,231 +2,153 @@ package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/adapter/outbound"
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/wireguard"
|
"github.com/sagernet/sing-box/transport/wireguard"
|
||||||
"github.com/sagernet/sing-tun"
|
"github.com/sagernet/sing-dns"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/sing/common/x/list"
|
|
||||||
"github.com/sagernet/sing/service"
|
"github.com/sagernet/sing/service"
|
||||||
"github.com/sagernet/sing/service/pause"
|
|
||||||
"github.com/sagernet/wireguard-go/conn"
|
|
||||||
"github.com/sagernet/wireguard-go/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterOutbound(registry *outbound.Registry) {
|
func RegisterOutbound(registry *outbound.Registry) {
|
||||||
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
|
outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
|
var (
|
||||||
|
_ adapter.Endpoint = (*Endpoint)(nil)
|
||||||
|
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
type Outbound struct {
|
type Outbound struct {
|
||||||
outbound.Adapter
|
outbound.Adapter
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
router adapter.Router
|
router adapter.Router
|
||||||
logger logger.ContextLogger
|
logger logger.ContextLogger
|
||||||
workers int
|
localAddresses []netip.Prefix
|
||||||
peers []wireguard.PeerConfig
|
endpoint *wireguard.Endpoint
|
||||||
useStdNetBind bool
|
|
||||||
listener N.Dialer
|
|
||||||
ipcConf string
|
|
||||||
|
|
||||||
pauseManager pause.Manager
|
|
||||||
pauseCallback *list.Element[pause.Callback]
|
|
||||||
bind conn.Bind
|
|
||||||
device *device.Device
|
|
||||||
tunDevice wireguard.Device
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
|
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
|
||||||
|
deprecated.Report(ctx, deprecated.OptionWireGuardOutbound)
|
||||||
outbound := &Outbound{
|
outbound := &Outbound{
|
||||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions),
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
router: router,
|
router: router,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
workers: options.Workers,
|
localAddresses: options.LocalAddress,
|
||||||
pauseManager: service.FromContext[pause.Manager](ctx),
|
|
||||||
}
|
}
|
||||||
peers, err := wireguard.ParsePeers(options)
|
if options.Detour == "" {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
outbound.peers = peers
|
|
||||||
if len(options.LocalAddress) == 0 {
|
|
||||||
return nil, E.New("missing local address")
|
|
||||||
}
|
|
||||||
if options.GSO {
|
|
||||||
if options.GSO && options.Detour != "" {
|
|
||||||
return nil, E.New("gso is conflict with detour")
|
|
||||||
}
|
|
||||||
options.IsWireGuardListener = true
|
options.IsWireGuardListener = true
|
||||||
outbound.useStdNetBind = true
|
} else if options.GSO {
|
||||||
|
return nil, E.New("gso is conflict with detour")
|
||||||
}
|
}
|
||||||
listener, err := dialer.New(ctx, options.DialerOptions)
|
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound.listener = listener
|
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
|
||||||
var privateKey string
|
Context: ctx,
|
||||||
{
|
Logger: logger,
|
||||||
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
System: options.SystemInterface,
|
||||||
if err != nil {
|
Dialer: outboundDialer,
|
||||||
return nil, E.Cause(err, "decode private key")
|
CreateDialer: func(interfaceName string) N.Dialer {
|
||||||
}
|
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
|
||||||
privateKey = hex.EncodeToString(bytes)
|
BindInterface: interfaceName,
|
||||||
}
|
}))
|
||||||
outbound.ipcConf = "private_key=" + privateKey
|
},
|
||||||
mtu := options.MTU
|
Name: options.InterfaceName,
|
||||||
if mtu == 0 {
|
MTU: options.MTU,
|
||||||
mtu = 1408
|
GSO: options.GSO,
|
||||||
}
|
Address: options.LocalAddress,
|
||||||
var wireTunDevice wireguard.Device
|
PrivateKey: options.PrivateKey,
|
||||||
if !options.SystemInterface && tun.WithGVisor {
|
ResolvePeer: func(domain string) (netip.Addr, error) {
|
||||||
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
|
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
|
||||||
} else {
|
if lookupErr != nil {
|
||||||
wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
|
return netip.Addr{}, lookupErr
|
||||||
}
|
}
|
||||||
|
return endpointAddresses[0], nil
|
||||||
|
},
|
||||||
|
Peers: common.Map(options.Peers, func(it option.LegacyWireGuardPeer) wireguard.PeerOptions {
|
||||||
|
return wireguard.PeerOptions{
|
||||||
|
Endpoint: it.ServerOptions.Build(),
|
||||||
|
PublicKey: it.PublicKey,
|
||||||
|
PreSharedKey: it.PreSharedKey,
|
||||||
|
AllowedIPs: it.AllowedIPs,
|
||||||
|
// PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval),
|
||||||
|
Reserved: it.Reserved,
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
Workers: options.Workers,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "create WireGuard device")
|
return nil, err
|
||||||
}
|
}
|
||||||
outbound.tunDevice = wireTunDevice
|
outbound.endpoint = wgEndpoint
|
||||||
return outbound, nil
|
return outbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Outbound) Start() error {
|
func (o *Outbound) Start(stage adapter.StartStage) error {
|
||||||
if common.Any(w.peers, func(peer wireguard.PeerConfig) bool {
|
switch stage {
|
||||||
return !peer.Endpoint.IsValid()
|
case adapter.StartStateStart:
|
||||||
}) {
|
return o.endpoint.Start(false)
|
||||||
// wait for all outbounds to be started and continue in PortStart
|
case adapter.StartStatePostStart:
|
||||||
return nil
|
return o.endpoint.Start(true)
|
||||||
}
|
|
||||||
return w.start()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Outbound) PostStart() error {
|
|
||||||
if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
|
|
||||||
return peer.Endpoint.IsValid()
|
|
||||||
}) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return w.start()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Outbound) start() error {
|
|
||||||
err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var bind conn.Bind
|
|
||||||
if w.useStdNetBind {
|
|
||||||
bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
|
|
||||||
} else {
|
|
||||||
var (
|
|
||||||
isConnect bool
|
|
||||||
connectAddr netip.AddrPort
|
|
||||||
reserved [3]uint8
|
|
||||||
)
|
|
||||||
peerLen := len(w.peers)
|
|
||||||
if peerLen == 1 {
|
|
||||||
isConnect = true
|
|
||||||
connectAddr = w.peers[0].Endpoint
|
|
||||||
reserved = w.peers[0].Reserved
|
|
||||||
}
|
|
||||||
bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
|
|
||||||
}
|
|
||||||
err = w.tunDevice.Start()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{
|
|
||||||
Verbosef: func(format string, args ...interface{}) {
|
|
||||||
w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
|
||||||
},
|
|
||||||
Errorf: func(format string, args ...interface{}) {
|
|
||||||
w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
|
||||||
},
|
|
||||||
}, w.workers)
|
|
||||||
ipcConf := w.ipcConf
|
|
||||||
for _, peer := range w.peers {
|
|
||||||
ipcConf += peer.GenerateIpcLines()
|
|
||||||
}
|
|
||||||
err = wgDevice.IpcSet(ipcConf)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "setup wireguard: \n", ipcConf)
|
|
||||||
}
|
|
||||||
w.device = wgDevice
|
|
||||||
w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Outbound) Close() error {
|
|
||||||
if w.device != nil {
|
|
||||||
w.device.Close()
|
|
||||||
}
|
|
||||||
if w.pauseCallback != nil {
|
|
||||||
w.pauseManager.UnregisterCallback(w.pauseCallback)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Outbound) InterfaceUpdated() {
|
func (o *Outbound) Close() error {
|
||||||
w.device.BindUpdate()
|
return o.endpoint.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Outbound) InterfaceUpdated() {
|
||||||
|
o.endpoint.BindUpdate()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Outbound) onPauseUpdated(event int) {
|
func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
switch event {
|
|
||||||
case pause.EventDevicePaused:
|
|
||||||
w.device.Down()
|
|
||||||
case pause.EventDeviceWake:
|
|
||||||
w.device.Up()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
|
||||||
switch network {
|
switch network {
|
||||||
case N.NetworkTCP:
|
case N.NetworkTCP:
|
||||||
w.logger.InfoContext(ctx, "outbound connection to ", destination)
|
o.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||||
case N.NetworkUDP:
|
case N.NetworkUDP:
|
||||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||||
}
|
}
|
||||||
if destination.IsFqdn() {
|
if destination.IsFqdn() {
|
||||||
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
|
return N.DialSerial(ctx, o.endpoint, network, destination, destinationAddresses)
|
||||||
|
} else if !destination.Addr.IsValid() {
|
||||||
|
return nil, E.New("invalid destination: ", destination)
|
||||||
}
|
}
|
||||||
return w.tunDevice.DialContext(ctx, network, destination)
|
return o.endpoint.DialContext(ctx, network, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||||
if destination.IsFqdn() {
|
if destination.IsFqdn() {
|
||||||
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
|
packetConn, _, err := N.ListenSerial(ctx, o.endpoint, destination, destinationAddresses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return packetConn, err
|
return packetConn, err
|
||||||
}
|
}
|
||||||
return w.tunDevice.ListenPacket(ctx, destination)
|
return o.endpoint.ListenPacket(ctx, destination)
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,17 +41,17 @@ type NetworkManager struct {
|
||||||
autoDetectInterface bool
|
autoDetectInterface bool
|
||||||
defaultOptions adapter.NetworkOptions
|
defaultOptions adapter.NetworkOptions
|
||||||
autoRedirectOutputMark uint32
|
autoRedirectOutputMark uint32
|
||||||
|
networkMonitor tun.NetworkUpdateMonitor
|
||||||
networkMonitor tun.NetworkUpdateMonitor
|
interfaceMonitor tun.DefaultInterfaceMonitor
|
||||||
interfaceMonitor tun.DefaultInterfaceMonitor
|
packageManager tun.PackageManager
|
||||||
packageManager tun.PackageManager
|
powerListener winpowrprof.EventListener
|
||||||
powerListener winpowrprof.EventListener
|
pauseManager pause.Manager
|
||||||
pauseManager pause.Manager
|
platformInterface platform.Interface
|
||||||
platformInterface platform.Interface
|
endpoint adapter.EndpointManager
|
||||||
inboundManager adapter.InboundManager
|
inbound adapter.InboundManager
|
||||||
outboundManager adapter.OutboundManager
|
outbound adapter.OutboundManager
|
||||||
wifiState adapter.WIFIState
|
wifiState adapter.WIFIState
|
||||||
started bool
|
started bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) {
|
func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) {
|
||||||
|
@ -69,7 +69,9 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
|
||||||
},
|
},
|
||||||
pauseManager: service.FromContext[pause.Manager](ctx),
|
pauseManager: service.FromContext[pause.Manager](ctx),
|
||||||
platformInterface: service.FromContext[platform.Interface](ctx),
|
platformInterface: service.FromContext[platform.Interface](ctx),
|
||||||
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
endpoint: service.FromContext[adapter.EndpointManager](ctx),
|
||||||
|
inbound: service.FromContext[adapter.InboundManager](ctx),
|
||||||
|
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||||
}
|
}
|
||||||
if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault {
|
if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault {
|
||||||
if routeOptions.DefaultInterface != "" {
|
if routeOptions.DefaultInterface != "" {
|
||||||
|
@ -358,14 +360,21 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
|
||||||
func (r *NetworkManager) ResetNetwork() {
|
func (r *NetworkManager) ResetNetwork() {
|
||||||
conntrack.Close()
|
conntrack.Close()
|
||||||
|
|
||||||
for _, inbound := range r.inboundManager.Inbounds() {
|
for _, endpoint := range r.endpoint.Endpoints() {
|
||||||
|
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
|
||||||
|
if isListener {
|
||||||
|
listener.InterfaceUpdated()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, inbound := range r.inbound.Inbounds() {
|
||||||
listener, isListener := inbound.(adapter.InterfaceUpdateListener)
|
listener, isListener := inbound.(adapter.InterfaceUpdateListener)
|
||||||
if isListener {
|
if isListener {
|
||||||
listener.InterfaceUpdated()
|
listener.InterfaceUpdated()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, outbound := range r.outboundManager.Outbounds() {
|
for _, outbound := range r.outbound.Outbounds() {
|
||||||
listener, isListener := outbound.(adapter.InterfaceUpdateListener)
|
listener, isListener := outbound.(adapter.InterfaceUpdateListener)
|
||||||
if isListener {
|
if isListener {
|
||||||
listener.InterfaceUpdated()
|
listener.InterfaceUpdated()
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
R "github.com/sagernet/sing-box/route/rule"
|
R "github.com/sagernet/sing-box/route/rule"
|
||||||
"github.com/sagernet/sing-dns"
|
"github.com/sagernet/sing-dns"
|
||||||
tun "github.com/sagernet/sing-tun"
|
"github.com/sagernet/sing-tun"
|
||||||
"github.com/sagernet/sing/common/cache"
|
"github.com/sagernet/sing/common/cache"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
F "github.com/sagernet/sing/common/format"
|
F "github.com/sagernet/sing/common/format"
|
||||||
|
|
|
@ -32,7 +32,7 @@ func TestMain(m *testing.M) {
|
||||||
var globalCtx context.Context
|
var globalCtx context.Context
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
|
globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||||
}
|
}
|
||||||
|
|
||||||
func startInstance(t *testing.T, options option.Options) *box.Box {
|
func startInstance(t *testing.T, options option.Options) *box.Box {
|
||||||
|
|
|
@ -37,12 +37,12 @@ func _TestWireGuard(t *testing.T) {
|
||||||
Outbounds: []option.Outbound{
|
Outbounds: []option.Outbound{
|
||||||
{
|
{
|
||||||
Type: C.TypeWireGuard,
|
Type: C.TypeWireGuard,
|
||||||
Options: &option.WireGuardOutboundOptions{
|
Options: &option.WireGuardEndpointOptions{
|
||||||
ServerOptions: option.ServerOptions{
|
ServerOptions: option.ServerOptions{
|
||||||
Server: "127.0.0.1",
|
Server: "127.0.0.1",
|
||||||
ServerPort: serverPort,
|
ServerPort: serverPort,
|
||||||
},
|
},
|
||||||
LocalAddress: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
|
Address: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
|
||||||
PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
|
PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
|
||||||
PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
|
PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
|
||||||
},
|
},
|
||||||
|
|
|
@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
||||||
select {
|
select {
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
default:
|
default:
|
||||||
c.logger.Error(context.Background(), E.Cause(err, "read packet"))
|
c.logger.Error(E.Cause(err, "read packet"))
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
||||||
b := packets[0]
|
b := packets[0]
|
||||||
common.ClearArray(b[1:4])
|
common.ClearArray(b[1:4])
|
||||||
}
|
}
|
||||||
eps[0] = Endpoint(M.AddrPortFromNet(addr))
|
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
|
||||||
count = 1
|
count = 1
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
destination := netip.AddrPort(ep.(Endpoint))
|
destination := netip.AddrPort(ep.(remoteEndpoint))
|
||||||
for _, b := range bufs {
|
for _, b := range bufs {
|
||||||
if len(b) > 3 {
|
if len(b) > 3 {
|
||||||
reserved, loaded := c.reservedForEndpoint[destination]
|
reserved, loaded := c.reservedForEndpoint[destination]
|
||||||
|
@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return Endpoint(ap), nil
|
return remoteEndpoint(ap), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientBind) BatchSize() int {
|
func (c *ClientBind) BatchSize() int {
|
||||||
|
@ -229,3 +229,31 @@ func (w *wireConn) Close() error {
|
||||||
close(w.done)
|
close(w.done)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ conn.Endpoint = (*remoteEndpoint)(nil)
|
||||||
|
|
||||||
|
type remoteEndpoint netip.AddrPort
|
||||||
|
|
||||||
|
func (e remoteEndpoint) ClearSrc() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e remoteEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e remoteEndpoint) DstToString() string {
|
||||||
|
return (netip.AddrPort)(e).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e remoteEndpoint) DstToBytes() []byte {
|
||||||
|
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e remoteEndpoint) DstIP() netip.Addr {
|
||||||
|
return (netip.AddrPort)(e).Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e remoteEndpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
|
@ -1,13 +1,44 @@
|
||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
"github.com/sagernet/wireguard-go/tun"
|
"github.com/sagernet/wireguard-go/device"
|
||||||
|
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
tun.Device
|
wgTun.Device
|
||||||
N.Dialer
|
N.Dialer
|
||||||
Start() error
|
Start() error
|
||||||
// NewEndpoint() (stack.LinkEndpoint, error)
|
SetDevice(device *device.Device)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceOptions struct {
|
||||||
|
Context context.Context
|
||||||
|
Logger logger.ContextLogger
|
||||||
|
System bool
|
||||||
|
Handler tun.Handler
|
||||||
|
UDPTimeout time.Duration
|
||||||
|
CreateDialer func(interfaceName string) N.Dialer
|
||||||
|
Name string
|
||||||
|
MTU uint32
|
||||||
|
GSO bool
|
||||||
|
Address []netip.Prefix
|
||||||
|
AllowedAddress []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDevice(options DeviceOptions) (Device, error) {
|
||||||
|
if !options.System {
|
||||||
|
return newStackDevice(options)
|
||||||
|
} else if options.Handler == nil {
|
||||||
|
return newSystemDevice(options)
|
||||||
|
} else {
|
||||||
|
return newSystemStackDevice(options)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ package wireguard
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sagernet/gvisor/pkg/buffer"
|
"github.com/sagernet/gvisor/pkg/buffer"
|
||||||
|
@ -15,52 +14,41 @@ import (
|
||||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
|
||||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||||
"github.com/sagernet/sing-tun"
|
"github.com/sagernet/sing-tun"
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/wireguard-go/device"
|
||||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Device = (*StackDevice)(nil)
|
var _ Device = (*stackDevice)(nil)
|
||||||
|
|
||||||
const defaultNIC tcpip.NICID = 1
|
type stackDevice struct {
|
||||||
|
stack *stack.Stack
|
||||||
type StackDevice struct {
|
mtu uint32
|
||||||
stack *stack.Stack
|
events chan wgTun.Event
|
||||||
mtu uint32
|
outbound chan *stack.PacketBuffer
|
||||||
events chan wgTun.Event
|
done chan struct{}
|
||||||
outbound chan *stack.PacketBuffer
|
dispatcher stack.NetworkDispatcher
|
||||||
packetOutbound chan *buf.Buffer
|
addr4 tcpip.Address
|
||||||
done chan struct{}
|
addr6 tcpip.Address
|
||||||
dispatcher stack.NetworkDispatcher
|
|
||||||
addr4 tcpip.Address
|
|
||||||
addr6 tcpip.Address
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
|
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
||||||
ipStack := stack.New(stack.Options{
|
tunDevice := &stackDevice{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
mtu: options.MTU,
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
|
events: make(chan wgTun.Event, 1),
|
||||||
HandleLocal: true,
|
outbound: make(chan *stack.PacketBuffer, 256),
|
||||||
})
|
done: make(chan struct{}),
|
||||||
tunDevice := &StackDevice{
|
|
||||||
stack: ipStack,
|
|
||||||
mtu: mtu,
|
|
||||||
events: make(chan wgTun.Event, 1),
|
|
||||||
outbound: make(chan *stack.PacketBuffer, 256),
|
|
||||||
packetOutbound: make(chan *buf.Buffer, 256),
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
|
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.New(err.String())
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, prefix := range localAddresses {
|
for _, prefix := range options.Address {
|
||||||
addr := tun.AddressFromAddr(prefix.Addr())
|
addr := tun.AddressFromAddr(prefix.Addr())
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
|
@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
|
||||||
tunDevice.addr6 = addr
|
tunDevice.addr6 = addr
|
||||||
protoAddr.Protocol = ipv6.ProtocolNumber
|
protoAddr.Protocol = ipv6.ProtocolNumber
|
||||||
}
|
}
|
||||||
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
|
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
|
||||||
if err != nil {
|
if gErr != nil {
|
||||||
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
|
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sOpt := tcpip.TCPSACKEnabled(true)
|
tunDevice.stack = ipStack
|
||||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
|
if options.Handler != nil {
|
||||||
cOpt := tcpip.CongestionControlOption("cubic")
|
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
|
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||||
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
|
}
|
||||||
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
|
|
||||||
return tunDevice, nil
|
return tunDevice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
|
func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
return (*wireEndpoint)(w), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
|
||||||
addr := tcpip.FullAddress{
|
addr := tcpip.FullAddress{
|
||||||
NIC: defaultNIC,
|
NIC: tun.DefaultNIC,
|
||||||
Port: destination.Port,
|
Port: destination.Port,
|
||||||
Addr: tun.AddressFromAddr(destination.Addr),
|
Addr: tun.AddressFromAddr(destination.Addr),
|
||||||
}
|
}
|
||||||
bind := tcpip.FullAddress{
|
bind := tcpip.FullAddress{
|
||||||
NIC: defaultNIC,
|
NIC: tun.DefaultNIC,
|
||||||
}
|
}
|
||||||
var networkProtocol tcpip.NetworkProtocolNumber
|
var networkProtocol tcpip.NetworkProtocolNumber
|
||||||
if destination.IsIPv4() {
|
if destination.IsIPv4() {
|
||||||
|
@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
bind := tcpip.FullAddress{
|
bind := tcpip.FullAddress{
|
||||||
NIC: defaultNIC,
|
NIC: tun.DefaultNIC,
|
||||||
}
|
}
|
||||||
var networkProtocol tcpip.NetworkProtocolNumber
|
var networkProtocol tcpip.NetworkProtocolNumber
|
||||||
if destination.IsIPv4() {
|
if destination.IsIPv4() {
|
||||||
|
@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
|
||||||
return udpConn, nil
|
return udpConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Inet4Address() netip.Addr {
|
func (w *stackDevice) SetDevice(device *device.Device) {
|
||||||
return tun.AddrFromAddress(w.addr4)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Inet6Address() netip.Addr {
|
func (w *stackDevice) Start() error {
|
||||||
return tun.AddrFromAddress(w.addr6)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *StackDevice) Start() error {
|
|
||||||
w.events <- wgTun.EventUp
|
w.events <- wgTun.EventUp
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) File() *os.File {
|
func (w *stackDevice) File() *os.File {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||||
select {
|
select {
|
||||||
case packetBuffer, ok := <-w.outbound:
|
case packetBuffer, ok := <-w.outbound:
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
|
||||||
sizes[0] = n
|
sizes[0] = n
|
||||||
count = 1
|
count = 1
|
||||||
return
|
return
|
||||||
case packet := <-w.packetOutbound:
|
|
||||||
defer packet.Release()
|
|
||||||
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
|
|
||||||
count = 1
|
|
||||||
return
|
|
||||||
case <-w.done:
|
case <-w.done:
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||||
for _, b := range bufs {
|
for _, b := range bufs {
|
||||||
b = b[offset:]
|
b = b[offset:]
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
|
@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Flush() error {
|
func (w *stackDevice) Flush() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) MTU() (int, error) {
|
func (w *stackDevice) MTU() (int, error) {
|
||||||
return int(w.mtu), nil
|
return int(w.mtu), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Name() (string, error) {
|
func (w *stackDevice) Name() (string, error) {
|
||||||
return "sing-box", nil
|
return "sing-box", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Events() <-chan wgTun.Event {
|
func (w *stackDevice) Events() <-chan wgTun.Event {
|
||||||
return w.events
|
return w.events
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) Close() error {
|
func (w *stackDevice) Close() error {
|
||||||
close(w.done)
|
close(w.done)
|
||||||
close(w.events)
|
close(w.events)
|
||||||
w.stack.Close()
|
w.stack.Close()
|
||||||
|
@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StackDevice) BatchSize() int {
|
func (w *stackDevice) BatchSize() int {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
|
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
|
||||||
|
|
||||||
type wireEndpoint StackDevice
|
type wireEndpoint stackDevice
|
||||||
|
|
||||||
func (ep *wireEndpoint) MTU() uint32 {
|
func (ep *wireEndpoint) MTU() uint32 {
|
||||||
return ep.mtu
|
return ep.mtu
|
||||||
|
|
|
@ -2,12 +2,12 @@
|
||||||
|
|
||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import "github.com/sagernet/sing-tun"
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing-tun"
|
func newStackDevice(options DeviceOptions) (Device, error) {
|
||||||
)
|
return nil, tun.ErrGVisorNotIncluded
|
||||||
|
}
|
||||||
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
|
|
||||||
|
func newSystemStackDevice(options DeviceOptions) (Device, error) {
|
||||||
return nil, tun.ErrGVisorNotIncluded
|
return nil, tun.ErrGVisorNotIncluded
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,96 +6,88 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
|
||||||
"github.com/sagernet/sing-box/option"
|
|
||||||
"github.com/sagernet/sing-tun"
|
"github.com/sagernet/sing-tun"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
"github.com/sagernet/wireguard-go/device"
|
||||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Device = (*SystemDevice)(nil)
|
var _ Device = (*systemDevice)(nil)
|
||||||
|
|
||||||
type SystemDevice struct {
|
type systemDevice struct {
|
||||||
dialer N.Dialer
|
options DeviceOptions
|
||||||
device tun.Tun
|
dialer N.Dialer
|
||||||
batchDevice tun.LinuxTUN
|
device tun.Tun
|
||||||
name string
|
batchDevice tun.LinuxTUN
|
||||||
mtu uint32
|
events chan wgTun.Event
|
||||||
inet4Addresses []netip.Prefix
|
closeOnce sync.Once
|
||||||
inet6Addresses []netip.Prefix
|
|
||||||
gso bool
|
|
||||||
events chan wgTun.Event
|
|
||||||
closeOnce sync.Once
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
|
func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
|
||||||
var inet4Addresses []netip.Prefix
|
if options.Name == "" {
|
||||||
var inet6Addresses []netip.Prefix
|
options.Name = tun.CalculateInterfaceName("wg")
|
||||||
for _, prefixes := range localPrefixes {
|
|
||||||
if prefixes.Addr().Is4() {
|
|
||||||
inet4Addresses = append(inet4Addresses, prefixes)
|
|
||||||
} else {
|
|
||||||
inet6Addresses = append(inet6Addresses, prefixes)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if interfaceName == "" {
|
return &systemDevice{
|
||||||
interfaceName = tun.CalculateInterfaceName("wg")
|
options: options,
|
||||||
}
|
dialer: options.CreateDialer(options.Name),
|
||||||
|
events: make(chan wgTun.Event, 1),
|
||||||
return &SystemDevice{
|
|
||||||
dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{
|
|
||||||
BindInterface: interfaceName,
|
|
||||||
})),
|
|
||||||
name: interfaceName,
|
|
||||||
mtu: mtu,
|
|
||||||
inet4Addresses: inet4Addresses,
|
|
||||||
inet6Addresses: inet6Addresses,
|
|
||||||
gso: gso,
|
|
||||||
events: make(chan wgTun.Event, 1),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
return w.dialer.DialContext(ctx, network, destination)
|
return w.dialer.DialContext(ctx, network, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
return w.dialer.ListenPacket(ctx, destination)
|
return w.dialer.ListenPacket(ctx, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Inet4Address() netip.Addr {
|
func (w *systemDevice) SetDevice(device *device.Device) {
|
||||||
if len(w.inet4Addresses) == 0 {
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
return w.inet4Addresses[0].Addr()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Inet6Address() netip.Addr {
|
func (w *systemDevice) Start() error {
|
||||||
if len(w.inet6Addresses) == 0 {
|
networkManager := service.FromContext[adapter.NetworkManager](w.options.Context)
|
||||||
return netip.Addr{}
|
tunOptions := tun.Options{
|
||||||
|
Name: w.options.Name,
|
||||||
|
Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
|
||||||
|
return it.Addr().Is4()
|
||||||
|
}),
|
||||||
|
Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
|
||||||
|
return it.Addr().Is6()
|
||||||
|
}),
|
||||||
|
MTU: w.options.MTU,
|
||||||
|
GSO: w.options.GSO,
|
||||||
|
InterfaceScope: true,
|
||||||
|
Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool {
|
||||||
|
return it.Addr().Is4()
|
||||||
|
}),
|
||||||
|
Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }),
|
||||||
|
InterfaceMonitor: networkManager.InterfaceMonitor(),
|
||||||
|
InterfaceFinder: networkManager.InterfaceFinder(),
|
||||||
}
|
}
|
||||||
return w.inet6Addresses[0].Addr()
|
// works with Linux, macOS with IFSCOPE routes, not tested on Windows
|
||||||
}
|
if runtime.GOOS == "darwin" {
|
||||||
|
tunOptions.AutoRoute = true
|
||||||
func (w *SystemDevice) Start() error {
|
}
|
||||||
tunInterface, err := tun.New(tun.Options{
|
tunInterface, err := tun.New(tunOptions)
|
||||||
Name: w.name,
|
|
||||||
Inet4Address: w.inet4Addresses,
|
|
||||||
Inet6Address: w.inet6Addresses,
|
|
||||||
MTU: w.mtu,
|
|
||||||
GSO: w.gso,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
err = tunInterface.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.options.Logger.Info("started at ", w.options.Name)
|
||||||
w.device = tunInterface
|
w.device = tunInterface
|
||||||
if w.gso {
|
if w.options.GSO {
|
||||||
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
|
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
|
||||||
if !isBatchTUN {
|
if !isBatchTUN {
|
||||||
tunInterface.Close()
|
tunInterface.Close()
|
||||||
|
@ -107,15 +99,15 @@ func (w *SystemDevice) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) File() *os.File {
|
func (w *systemDevice) File() *os.File {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||||
if w.batchDevice != nil {
|
if w.batchDevice != nil {
|
||||||
count, err = w.batchDevice.BatchRead(bufs, offset, sizes)
|
count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes)
|
||||||
} else {
|
} else {
|
||||||
sizes[0], err = w.device.Read(bufs[0][offset:])
|
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
count = 1
|
count = 1
|
||||||
} else if errors.Is(err, tun.ErrTooManySegments) {
|
} else if errors.Is(err, tun.ErrTooManySegments) {
|
||||||
|
@ -125,12 +117,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||||
if w.batchDevice != nil {
|
if w.batchDevice != nil {
|
||||||
return 0, w.batchDevice.BatchWrite(bufs, offset)
|
return w.batchDevice.BatchWrite(bufs, offset)
|
||||||
} else {
|
} else {
|
||||||
for _, b := range bufs {
|
for _, packet := range bufs {
|
||||||
_, err = w.device.Write(b[offset:])
|
if tun.PacketOffset > 0 {
|
||||||
|
common.ClearArray(packet[offset-tun.PacketOffset : offset])
|
||||||
|
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
|
||||||
|
}
|
||||||
|
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -140,28 +136,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Flush() error {
|
func (w *systemDevice) Flush() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) MTU() (int, error) {
|
func (w *systemDevice) MTU() (int, error) {
|
||||||
return int(w.mtu), nil
|
return int(w.options.MTU), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Name() (string, error) {
|
func (w *systemDevice) Name() (string, error) {
|
||||||
return w.name, nil
|
return w.options.Name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Events() <-chan wgTun.Event {
|
func (w *systemDevice) Events() <-chan wgTun.Event {
|
||||||
return w.events
|
return w.events
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) Close() error {
|
func (w *systemDevice) Close() error {
|
||||||
close(w.events)
|
close(w.events)
|
||||||
return w.device.Close()
|
return w.device.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SystemDevice) BatchSize() int {
|
func (w *systemDevice) BatchSize() int {
|
||||||
if w.batchDevice != nil {
|
if w.batchDevice != nil {
|
||||||
return w.batchDevice.BatchSize()
|
return w.batchDevice.BatchSize()
|
||||||
}
|
}
|
||||||
|
|
182
transport/wireguard/device_system_stack.go
Normal file
182
transport/wireguard/device_system_stack.go
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
//go:build with_gvisor
|
||||||
|
|
||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/sagernet/gvisor/pkg/buffer"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"github.com/sagernet/sing-tun"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/wireguard-go/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ Device = (*systemStackDevice)(nil)
|
||||||
|
|
||||||
|
type systemStackDevice struct {
|
||||||
|
*systemDevice
|
||||||
|
stack *stack.Stack
|
||||||
|
endpoint *deviceEndpoint
|
||||||
|
writeBufs [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
|
||||||
|
system, err := newSystemDevice(options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
endpoint := &deviceEndpoint{
|
||||||
|
mtu: options.MTU,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
ipStack, err := tun.NewGVisorStack(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||||
|
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||||
|
return &systemStackDevice{
|
||||||
|
systemDevice: system,
|
||||||
|
stack: ipStack,
|
||||||
|
endpoint: endpoint,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *systemStackDevice) SetDevice(device *device.Device) {
|
||||||
|
w.endpoint.device = device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||||
|
if w.batchDevice != nil {
|
||||||
|
w.writeBufs = w.writeBufs[:0]
|
||||||
|
for _, packet := range bufs {
|
||||||
|
if !w.writeStack(packet[offset:]) {
|
||||||
|
w.writeBufs = append(w.writeBufs, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(w.writeBufs) > 0 {
|
||||||
|
return w.batchDevice.BatchWrite(bufs, offset)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, packet := range bufs {
|
||||||
|
if !w.writeStack(packet[offset:]) {
|
||||||
|
if tun.PacketOffset > 0 {
|
||||||
|
common.ClearArray(packet[offset-tun.PacketOffset : offset])
|
||||||
|
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
|
||||||
|
}
|
||||||
|
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// WireGuard will not read count
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *systemStackDevice) Close() error {
|
||||||
|
close(w.endpoint.done)
|
||||||
|
w.stack.Close()
|
||||||
|
for _, endpoint := range w.stack.CleanupEndpoints() {
|
||||||
|
endpoint.Abort()
|
||||||
|
}
|
||||||
|
w.stack.Wait()
|
||||||
|
return w.systemDevice.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *systemStackDevice) writeStack(packet []byte) bool {
|
||||||
|
var (
|
||||||
|
networkProtocol tcpip.NetworkProtocolNumber
|
||||||
|
destination netip.Addr
|
||||||
|
)
|
||||||
|
switch header.IPVersion(packet) {
|
||||||
|
case header.IPv4Version:
|
||||||
|
networkProtocol = header.IPv4ProtocolNumber
|
||||||
|
destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
|
||||||
|
case header.IPv6Version:
|
||||||
|
networkProtocol = header.IPv6ProtocolNumber
|
||||||
|
destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
|
||||||
|
}
|
||||||
|
for _, prefix := range w.options.Address {
|
||||||
|
if prefix.Contains(destination) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(packet),
|
||||||
|
})
|
||||||
|
w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
|
||||||
|
packetBuffer.DecRef()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type deviceEndpoint struct {
|
||||||
|
mtu uint32
|
||||||
|
done chan struct{}
|
||||||
|
device *device.Device
|
||||||
|
dispatcher stack.NetworkDispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) MTU() uint32 {
|
||||||
|
return ep.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) SetMTU(mtu uint32) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||||
|
return stack.CapabilityRXChecksumOffload
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
|
ep.dispatcher = dispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) IsAttached() bool {
|
||||||
|
return ep.dispatcher != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) Wait() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
|
||||||
|
return header.ARPHardwareNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
|
for _, packetBuffer := range list.AsSlice() {
|
||||||
|
destination := packetBuffer.Network().DestinationAddress()
|
||||||
|
ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
|
||||||
|
}
|
||||||
|
return list.Len(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) Close() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
|
||||||
|
}
|
|
@ -1,35 +1,254 @@
|
||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
F "github.com/sagernet/sing/common/format"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
"github.com/sagernet/sing/common/x/list"
|
||||||
|
"github.com/sagernet/sing/service"
|
||||||
|
"github.com/sagernet/sing/service/pause"
|
||||||
"github.com/sagernet/wireguard-go/conn"
|
"github.com/sagernet/wireguard-go/conn"
|
||||||
|
"github.com/sagernet/wireguard-go/device"
|
||||||
|
|
||||||
|
"go4.org/netipx"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ conn.Endpoint = (*Endpoint)(nil)
|
type Endpoint struct {
|
||||||
|
options EndpointOptions
|
||||||
type Endpoint netip.AddrPort
|
peers []peerConfig
|
||||||
|
ipcConf string
|
||||||
func (e Endpoint) ClearSrc() {
|
allowedAddress []netip.Prefix
|
||||||
|
tunDevice Device
|
||||||
|
device *device.Device
|
||||||
|
pauseManager pause.Manager
|
||||||
|
pauseCallback *list.Element[pause.Callback]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) SrcToString() string {
|
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
|
||||||
return ""
|
if options.PrivateKey == "" {
|
||||||
|
return nil, E.New("missing private key")
|
||||||
|
}
|
||||||
|
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "decode private key")
|
||||||
|
}
|
||||||
|
privateKey := hex.EncodeToString(privateKeyBytes)
|
||||||
|
ipcConf := "private_key=" + privateKey
|
||||||
|
if options.ListenPort != 0 {
|
||||||
|
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
|
||||||
|
}
|
||||||
|
var peers []peerConfig
|
||||||
|
for peerIndex, rawPeer := range options.Peers {
|
||||||
|
peer := peerConfig{
|
||||||
|
allowedIPs: rawPeer.AllowedIPs,
|
||||||
|
keepalive: rawPeer.PersistentKeepaliveInterval,
|
||||||
|
}
|
||||||
|
if rawPeer.Endpoint.Addr.IsValid() {
|
||||||
|
peer.endpoint = rawPeer.Endpoint.AddrPort()
|
||||||
|
} else if rawPeer.Endpoint.IsFqdn() {
|
||||||
|
peer.destination = rawPeer.Endpoint
|
||||||
|
}
|
||||||
|
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
||||||
|
}
|
||||||
|
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
|
||||||
|
if rawPeer.PreSharedKey != "" {
|
||||||
|
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
||||||
|
}
|
||||||
|
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
|
||||||
|
}
|
||||||
|
if len(rawPeer.AllowedIPs) == 0 {
|
||||||
|
return nil, E.New("missing allowed ips for peer ", peerIndex)
|
||||||
|
}
|
||||||
|
if len(rawPeer.Reserved) > 0 {
|
||||||
|
if len(rawPeer.Reserved) != 3 {
|
||||||
|
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
|
||||||
|
}
|
||||||
|
copy(peer.reserved[:], rawPeer.Reserved[:])
|
||||||
|
}
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
var allowedPrefixBuilder netipx.IPSetBuilder
|
||||||
|
for _, peer := range options.Peers {
|
||||||
|
for _, prefix := range peer.AllowedIPs {
|
||||||
|
allowedPrefixBuilder.AddPrefix(prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allowedIPSet, err := allowedPrefixBuilder.IPSet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
allowedAddresses := allowedIPSet.Prefixes()
|
||||||
|
if options.MTU == 0 {
|
||||||
|
options.MTU = 1408
|
||||||
|
}
|
||||||
|
deviceOptions := DeviceOptions{
|
||||||
|
Context: options.Context,
|
||||||
|
Logger: options.Logger,
|
||||||
|
System: options.System,
|
||||||
|
Handler: options.Handler,
|
||||||
|
UDPTimeout: options.UDPTimeout,
|
||||||
|
CreateDialer: options.CreateDialer,
|
||||||
|
Name: options.Name,
|
||||||
|
MTU: options.MTU,
|
||||||
|
GSO: options.GSO,
|
||||||
|
Address: options.Address,
|
||||||
|
AllowedAddress: allowedAddresses,
|
||||||
|
}
|
||||||
|
tunDevice, err := NewDevice(deviceOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "create WireGuard device")
|
||||||
|
}
|
||||||
|
return &Endpoint{
|
||||||
|
options: options,
|
||||||
|
peers: peers,
|
||||||
|
ipcConf: ipcConf,
|
||||||
|
allowedAddress: allowedAddresses,
|
||||||
|
tunDevice: tunDevice,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) DstToString() string {
|
func (e *Endpoint) Start(resolve bool) error {
|
||||||
return (netip.AddrPort)(e).String()
|
if common.Any(e.peers, func(peer peerConfig) bool {
|
||||||
|
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
|
||||||
|
}) {
|
||||||
|
if !resolve {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for peerIndex, peer := range e.peers {
|
||||||
|
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
|
||||||
|
}
|
||||||
|
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
|
||||||
|
}
|
||||||
|
} else if resolve {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var bind conn.Bind
|
||||||
|
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
|
||||||
|
if isWgListener {
|
||||||
|
bind = conn.NewStdNetBind(wgListener)
|
||||||
|
} else {
|
||||||
|
var (
|
||||||
|
isConnect bool
|
||||||
|
connectAddr netip.AddrPort
|
||||||
|
reserved [3]uint8
|
||||||
|
)
|
||||||
|
peerLen := len(e.peers)
|
||||||
|
if peerLen == 1 {
|
||||||
|
isConnect = true
|
||||||
|
connectAddr = e.peers[0].endpoint
|
||||||
|
reserved = e.peers[0].reserved
|
||||||
|
}
|
||||||
|
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
|
||||||
|
}
|
||||||
|
err := e.tunDevice.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger := &device.Logger{
|
||||||
|
Verbosef: func(format string, args ...interface{}) {
|
||||||
|
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
||||||
|
},
|
||||||
|
Errorf: func(format string, args ...interface{}) {
|
||||||
|
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
|
||||||
|
e.tunDevice.SetDevice(wgDevice)
|
||||||
|
ipcConf := e.ipcConf
|
||||||
|
for _, peer := range e.peers {
|
||||||
|
ipcConf += peer.GenerateIpcLines()
|
||||||
|
}
|
||||||
|
err = wgDevice.IpcSet(ipcConf)
|
||||||
|
if err != nil {
|
||||||
|
return E.Cause(err, "setup wireguard: \n", ipcConf)
|
||||||
|
}
|
||||||
|
e.device = wgDevice
|
||||||
|
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
|
||||||
|
if e.pauseManager != nil {
|
||||||
|
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) DstToBytes() []byte {
|
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
if !destination.Addr.IsValid() {
|
||||||
return b
|
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||||
|
}
|
||||||
|
return e.tunDevice.DialContext(ctx, network, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) DstIP() netip.Addr {
|
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
return (netip.AddrPort)(e).Addr()
|
if !destination.Addr.IsValid() {
|
||||||
|
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||||
|
}
|
||||||
|
return e.tunDevice.ListenPacket(ctx, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) SrcIP() netip.Addr {
|
func (e *Endpoint) BindUpdate() error {
|
||||||
return netip.Addr{}
|
return e.device.BindUpdate()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Endpoint) Close() error {
|
||||||
|
if e.device != nil {
|
||||||
|
e.device.Close()
|
||||||
|
}
|
||||||
|
if e.pauseCallback != nil {
|
||||||
|
e.pauseManager.UnregisterCallback(e.pauseCallback)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Endpoint) onPauseUpdated(event int) {
|
||||||
|
switch event {
|
||||||
|
case pause.EventDevicePaused:
|
||||||
|
e.device.Down()
|
||||||
|
case pause.EventDeviceWake:
|
||||||
|
e.device.Up()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerConfig struct {
|
||||||
|
destination M.Socksaddr
|
||||||
|
endpoint netip.AddrPort
|
||||||
|
publicKeyHex string
|
||||||
|
preSharedKeyHex string
|
||||||
|
allowedIPs []netip.Prefix
|
||||||
|
keepalive uint16
|
||||||
|
reserved [3]uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c peerConfig) GenerateIpcLines() string {
|
||||||
|
ipcLines := "\npublic_key=" + c.publicKeyHex
|
||||||
|
if c.endpoint.IsValid() {
|
||||||
|
ipcLines += "\nendpoint=" + c.endpoint.String()
|
||||||
|
}
|
||||||
|
if c.preSharedKeyHex != "" {
|
||||||
|
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
|
||||||
|
}
|
||||||
|
for _, allowedIP := range c.allowedIPs {
|
||||||
|
ipcLines += "\nallowed_ip=" + allowedIP.String()
|
||||||
|
}
|
||||||
|
if c.keepalive > 0 {
|
||||||
|
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
|
||||||
|
}
|
||||||
|
return ipcLines
|
||||||
}
|
}
|
||||||
|
|
40
transport/wireguard/endpoint_options.go
Normal file
40
transport/wireguard/endpoint_options.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-tun"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EndpointOptions struct {
|
||||||
|
Context context.Context
|
||||||
|
Logger logger.ContextLogger
|
||||||
|
System bool
|
||||||
|
Handler tun.Handler
|
||||||
|
UDPTimeout time.Duration
|
||||||
|
Dialer N.Dialer
|
||||||
|
CreateDialer func(interfaceName string) N.Dialer
|
||||||
|
Name string
|
||||||
|
MTU uint32
|
||||||
|
GSO bool
|
||||||
|
Address []netip.Prefix
|
||||||
|
PrivateKey string
|
||||||
|
ListenPort uint16
|
||||||
|
ResolvePeer func(domain string) (netip.Addr, error)
|
||||||
|
Peers []PeerOptions
|
||||||
|
Workers int
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerOptions struct {
|
||||||
|
Endpoint M.Socksaddr
|
||||||
|
PublicKey string
|
||||||
|
PreSharedKey string
|
||||||
|
AllowedIPs []netip.Prefix
|
||||||
|
PersistentKeepaliveInterval uint16
|
||||||
|
Reserved []uint8
|
||||||
|
}
|
|
@ -1,148 +0,0 @@
|
||||||
package wireguard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
|
||||||
"github.com/sagernet/sing-box/option"
|
|
||||||
"github.com/sagernet/sing-dns"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PeerConfig struct {
|
|
||||||
destination M.Socksaddr
|
|
||||||
domainStrategy dns.DomainStrategy
|
|
||||||
Endpoint netip.AddrPort
|
|
||||||
PublicKey string
|
|
||||||
PreSharedKey string
|
|
||||||
AllowedIPs []string
|
|
||||||
Reserved [3]uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c PeerConfig) GenerateIpcLines() string {
|
|
||||||
ipcLines := "\npublic_key=" + c.PublicKey
|
|
||||||
ipcLines += "\nendpoint=" + c.Endpoint.String()
|
|
||||||
if c.PreSharedKey != "" {
|
|
||||||
ipcLines += "\npreshared_key=" + c.PreSharedKey
|
|
||||||
}
|
|
||||||
for _, allowedIP := range c.AllowedIPs {
|
|
||||||
ipcLines += "\nallowed_ip=" + allowedIP
|
|
||||||
}
|
|
||||||
return ipcLines
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
|
|
||||||
var peers []PeerConfig
|
|
||||||
if len(options.Peers) > 0 {
|
|
||||||
for peerIndex, rawPeer := range options.Peers {
|
|
||||||
peer := PeerConfig{
|
|
||||||
AllowedIPs: rawPeer.AllowedIPs,
|
|
||||||
}
|
|
||||||
destination := rawPeer.ServerOptions.Build()
|
|
||||||
if destination.IsFqdn() {
|
|
||||||
peer.destination = destination
|
|
||||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
|
||||||
} else {
|
|
||||||
peer.Endpoint = destination.AddrPort()
|
|
||||||
}
|
|
||||||
{
|
|
||||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
|
||||||
}
|
|
||||||
peer.PublicKey = hex.EncodeToString(bytes)
|
|
||||||
}
|
|
||||||
if rawPeer.PreSharedKey != "" {
|
|
||||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
|
||||||
}
|
|
||||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
|
||||||
}
|
|
||||||
if len(rawPeer.AllowedIPs) == 0 {
|
|
||||||
return nil, E.New("missing allowed_ips for peer ", peerIndex)
|
|
||||||
}
|
|
||||||
if len(rawPeer.Reserved) > 0 {
|
|
||||||
if len(rawPeer.Reserved) != 3 {
|
|
||||||
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
|
|
||||||
}
|
|
||||||
copy(peer.Reserved[:], options.Reserved)
|
|
||||||
}
|
|
||||||
peers = append(peers, peer)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
peer := PeerConfig{}
|
|
||||||
var (
|
|
||||||
addressHas4 bool
|
|
||||||
addressHas6 bool
|
|
||||||
)
|
|
||||||
for _, localAddress := range options.LocalAddress {
|
|
||||||
if localAddress.Addr().Is4() {
|
|
||||||
addressHas4 = true
|
|
||||||
} else {
|
|
||||||
addressHas6 = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if addressHas4 {
|
|
||||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
|
|
||||||
}
|
|
||||||
if addressHas6 {
|
|
||||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
|
|
||||||
}
|
|
||||||
destination := options.ServerOptions.Build()
|
|
||||||
if destination.IsFqdn() {
|
|
||||||
peer.destination = destination
|
|
||||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
|
||||||
} else {
|
|
||||||
peer.Endpoint = destination.AddrPort()
|
|
||||||
}
|
|
||||||
{
|
|
||||||
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "decode peer public key")
|
|
||||||
}
|
|
||||||
peer.PublicKey = hex.EncodeToString(bytes)
|
|
||||||
}
|
|
||||||
if options.PreSharedKey != "" {
|
|
||||||
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, E.Cause(err, "decode pre shared key")
|
|
||||||
}
|
|
||||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
|
||||||
}
|
|
||||||
if len(options.Reserved) > 0 {
|
|
||||||
if len(options.Reserved) != 3 {
|
|
||||||
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
|
|
||||||
}
|
|
||||||
copy(peer.Reserved[:], options.Reserved)
|
|
||||||
}
|
|
||||||
peers = append(peers, peer)
|
|
||||||
}
|
|
||||||
return peers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
|
|
||||||
for peerIndex, peer := range peers {
|
|
||||||
if peer.Endpoint.IsValid() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
|
|
||||||
if err != nil {
|
|
||||||
if len(peers) == 1 {
|
|
||||||
return E.Cause(err, "resolve endpoint domain")
|
|
||||||
} else {
|
|
||||||
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(destinationAddresses) == 0 {
|
|
||||||
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
|
|
||||||
}
|
|
||||||
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
|
|
||||||
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
Reference in a new issue