mirror of
https://github.com/SagerNet/sing-box.git
synced 2024-11-27 02:51:36 +00:00
refactor: WireGuard endpoint
This commit is contained in:
parent
7dbc105f89
commit
e5bfd9e6b1
28
adapter/endpoint.go
Normal file
28
adapter/endpoint.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type Endpoint interface {
|
||||
Lifecycle
|
||||
Type() string
|
||||
Tag() string
|
||||
Outbound
|
||||
}
|
||||
|
||||
type EndpointRegistry interface {
|
||||
option.EndpointOptionsRegistry
|
||||
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) (Endpoint, error)
|
||||
}
|
||||
|
||||
type EndpointManager interface {
|
||||
Lifecycle
|
||||
Endpoints() []Endpoint
|
||||
Get(tag string) (Endpoint, bool)
|
||||
Remove(tag string) error
|
||||
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) error
|
||||
}
|
43
adapter/endpoint/adapter.go
Normal file
43
adapter/endpoint/adapter.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package endpoint
|
||||
|
||||
import "github.com/sagernet/sing-box/option"
|
||||
|
||||
type Adapter struct {
|
||||
endpointType string
|
||||
endpointTag string
|
||||
network []string
|
||||
dependencies []string
|
||||
}
|
||||
|
||||
func NewAdapter(endpointType string, endpointTag string, network []string, dependencies []string) Adapter {
|
||||
return Adapter{
|
||||
endpointType: endpointType,
|
||||
endpointTag: endpointTag,
|
||||
network: network,
|
||||
dependencies: dependencies,
|
||||
}
|
||||
}
|
||||
|
||||
func NewAdapterWithDialerOptions(endpointType string, endpointTag string, network []string, dialOptions option.DialerOptions) Adapter {
|
||||
var dependencies []string
|
||||
if dialOptions.Detour != "" {
|
||||
dependencies = []string{dialOptions.Detour}
|
||||
}
|
||||
return NewAdapter(endpointType, endpointTag, network, dependencies)
|
||||
}
|
||||
|
||||
func (a *Adapter) Type() string {
|
||||
return a.endpointType
|
||||
}
|
||||
|
||||
func (a *Adapter) Tag() string {
|
||||
return a.endpointTag
|
||||
}
|
||||
|
||||
func (a *Adapter) Network() []string {
|
||||
return a.network
|
||||
}
|
||||
|
||||
func (a *Adapter) Dependencies() []string {
|
||||
return a.dependencies
|
||||
}
|
147
adapter/endpoint/manager.go
Normal file
147
adapter/endpoint/manager.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package endpoint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/taskmonitor"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
var _ adapter.EndpointManager = (*Manager)(nil)
|
||||
|
||||
type Manager struct {
|
||||
logger log.ContextLogger
|
||||
registry adapter.EndpointRegistry
|
||||
access sync.Mutex
|
||||
started bool
|
||||
stage adapter.StartStage
|
||||
endpoints []adapter.Endpoint
|
||||
endpointByTag map[string]adapter.Endpoint
|
||||
}
|
||||
|
||||
func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
registry: registry,
|
||||
endpointByTag: make(map[string]adapter.Endpoint),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(stage adapter.StartStage) error {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
if m.started && m.stage >= stage {
|
||||
panic("already started")
|
||||
}
|
||||
m.started = true
|
||||
m.stage = stage
|
||||
if stage == adapter.StartStateStart {
|
||||
// started with outbound manager
|
||||
return nil
|
||||
}
|
||||
for _, endpoint := range m.endpoints {
|
||||
err := adapter.LegacyStart(endpoint, stage)
|
||||
if err != nil {
|
||||
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Close() error {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
if !m.started {
|
||||
return nil
|
||||
}
|
||||
m.started = false
|
||||
endpoints := m.endpoints
|
||||
m.endpoints = nil
|
||||
monitor := taskmonitor.New(m.logger, C.StopTimeout)
|
||||
var err error
|
||||
for _, endpoint := range endpoints {
|
||||
monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||
err = E.Append(err, endpoint.Close(), func(err error) error {
|
||||
return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||
})
|
||||
monitor.Finish()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Endpoints() []adapter.Endpoint {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
return m.endpoints
|
||||
}
|
||||
|
||||
func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
endpoint, found := m.endpointByTag[tag]
|
||||
return endpoint, found
|
||||
}
|
||||
|
||||
func (m *Manager) Remove(tag string) error {
|
||||
m.access.Lock()
|
||||
endpoint, found := m.endpointByTag[tag]
|
||||
if !found {
|
||||
m.access.Unlock()
|
||||
return os.ErrInvalid
|
||||
}
|
||||
delete(m.endpointByTag, tag)
|
||||
index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
|
||||
return it == endpoint
|
||||
})
|
||||
if index == -1 {
|
||||
panic("invalid endpoint index")
|
||||
}
|
||||
m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
|
||||
started := m.started
|
||||
m.access.Unlock()
|
||||
if started {
|
||||
return endpoint.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
|
||||
endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
if m.started {
|
||||
for _, stage := range adapter.ListStartStages {
|
||||
err = adapter.LegacyStart(endpoint, stage)
|
||||
if err != nil {
|
||||
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
|
||||
if m.started {
|
||||
err = existsEndpoint.Close()
|
||||
if err != nil {
|
||||
return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
|
||||
}
|
||||
}
|
||||
existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
|
||||
return it == existsEndpoint
|
||||
})
|
||||
if existsIndex == -1 {
|
||||
panic("invalid endpoint index")
|
||||
}
|
||||
m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
|
||||
}
|
||||
m.endpoints = append(m.endpoints, endpoint)
|
||||
m.endpointByTag[tag] = endpoint
|
||||
return nil
|
||||
}
|
72
adapter/endpoint/registry.go
Normal file
72
adapter/endpoint/registry.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package endpoint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options T) (adapter.Endpoint, error)
|
||||
|
||||
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
|
||||
registry.register(outboundType, func() any {
|
||||
return new(Options)
|
||||
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Endpoint, error) {
|
||||
var options *Options
|
||||
if rawOptions != nil {
|
||||
options = rawOptions.(*Options)
|
||||
}
|
||||
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
|
||||
})
|
||||
}
|
||||
|
||||
var _ adapter.EndpointRegistry = (*Registry)(nil)
|
||||
|
||||
type (
|
||||
optionsConstructorFunc func() any
|
||||
constructorFunc func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Endpoint, error)
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
access sync.Mutex
|
||||
optionsType map[string]optionsConstructorFunc
|
||||
constructor map[string]constructorFunc
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
optionsType: make(map[string]optionsConstructorFunc),
|
||||
constructor: make(map[string]constructorFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Registry) CreateOptions(outboundType string) (any, bool) {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
optionsConstructor, loaded := m.optionsType[outboundType]
|
||||
if !loaded {
|
||||
return nil, false
|
||||
}
|
||||
return optionsConstructor(), true
|
||||
}
|
||||
|
||||
func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Endpoint, error) {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
constructor, loaded := m.constructor[outboundType]
|
||||
if !loaded {
|
||||
return nil, E.New("outbound type not found: " + outboundType)
|
||||
}
|
||||
return constructor(ctx, router, logger, tag, options)
|
||||
}
|
||||
|
||||
func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
m.optionsType[outboundType] = optionsConstructor
|
||||
m.constructor[outboundType] = constructor
|
||||
}
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
type Inbound interface {
|
||||
Service
|
||||
Lifecycle
|
||||
Type() string
|
||||
Tag() string
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ var _ adapter.InboundManager = (*Manager)(nil)
|
|||
type Manager struct {
|
||||
logger log.ContextLogger
|
||||
registry adapter.InboundRegistry
|
||||
endpoint adapter.EndpointManager
|
||||
access sync.Mutex
|
||||
started bool
|
||||
stage adapter.StartStage
|
||||
|
@ -25,10 +26,11 @@ type Manager struct {
|
|||
inboundByTag map[string]adapter.Inbound
|
||||
}
|
||||
|
||||
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager {
|
||||
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endpoint adapter.EndpointManager) *Manager {
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
registry: registry,
|
||||
endpoint: endpoint,
|
||||
inboundByTag: make(map[string]adapter.Inbound),
|
||||
}
|
||||
}
|
||||
|
@ -79,9 +81,12 @@ func (m *Manager) Inbounds() []adapter.Inbound {
|
|||
|
||||
func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
inbound, found := m.inboundByTag[tag]
|
||||
return inbound, found
|
||||
m.access.Unlock()
|
||||
if found {
|
||||
return inbound, true
|
||||
}
|
||||
return m.endpoint.Get(tag)
|
||||
}
|
||||
|
||||
func (m *Manager) Remove(tag string) error {
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package adapter
|
||||
|
||||
func LegacyStart(starter any, stage StartStage) error {
|
||||
if lifecycle, isLifecycle := starter.(Lifecycle); isLifecycle {
|
||||
return lifecycle.Start(stage)
|
||||
}
|
||||
switch stage {
|
||||
case StartStateInitialize:
|
||||
if preStarter, isPreStarter := starter.(interface {
|
||||
|
|
|
@ -5,35 +5,35 @@ import (
|
|||
)
|
||||
|
||||
type Adapter struct {
|
||||
protocol string
|
||||
outboundType string
|
||||
outboundTag string
|
||||
network []string
|
||||
tag string
|
||||
dependencies []string
|
||||
}
|
||||
|
||||
func NewAdapter(protocol string, network []string, tag string, dependencies []string) Adapter {
|
||||
func NewAdapter(outboundType string, outboundTag string, network []string, dependencies []string) Adapter {
|
||||
return Adapter{
|
||||
protocol: protocol,
|
||||
outboundType: outboundType,
|
||||
outboundTag: outboundTag,
|
||||
network: network,
|
||||
tag: tag,
|
||||
dependencies: dependencies,
|
||||
}
|
||||
}
|
||||
|
||||
func NewAdapterWithDialerOptions(protocol string, network []string, tag string, dialOptions option.DialerOptions) Adapter {
|
||||
func NewAdapterWithDialerOptions(outboundType string, outboundTag string, network []string, dialOptions option.DialerOptions) Adapter {
|
||||
var dependencies []string
|
||||
if dialOptions.Detour != "" {
|
||||
dependencies = []string{dialOptions.Detour}
|
||||
}
|
||||
return NewAdapter(protocol, network, tag, dependencies)
|
||||
return NewAdapter(outboundType, outboundTag, network, dependencies)
|
||||
}
|
||||
|
||||
func (a *Adapter) Type() string {
|
||||
return a.protocol
|
||||
return a.outboundType
|
||||
}
|
||||
|
||||
func (a *Adapter) Tag() string {
|
||||
return a.tag
|
||||
return a.outboundTag
|
||||
}
|
||||
|
||||
func (a *Adapter) Network() []string {
|
||||
|
|
|
@ -21,6 +21,7 @@ var _ adapter.OutboundManager = (*Manager)(nil)
|
|||
type Manager struct {
|
||||
logger log.ContextLogger
|
||||
registry adapter.OutboundRegistry
|
||||
endpoint adapter.EndpointManager
|
||||
defaultTag string
|
||||
access sync.Mutex
|
||||
started bool
|
||||
|
@ -32,10 +33,11 @@ type Manager struct {
|
|||
defaultOutboundFallback adapter.Outbound
|
||||
}
|
||||
|
||||
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager {
|
||||
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
registry: registry,
|
||||
endpoint: endpoint,
|
||||
defaultTag: defaultTag,
|
||||
outboundByTag: make(map[string]adapter.Outbound),
|
||||
dependByTag: make(map[string][]string),
|
||||
|
@ -56,7 +58,14 @@ func (m *Manager) Start(stage adapter.StartStage) error {
|
|||
outbounds := m.outbounds
|
||||
m.access.Unlock()
|
||||
if stage == adapter.StartStateStart {
|
||||
return m.startOutbounds(outbounds)
|
||||
if m.defaultTag != "" && m.defaultOutbound == nil {
|
||||
defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
|
||||
if !loaded {
|
||||
return E.New("default outbound not found: ", m.defaultTag)
|
||||
}
|
||||
m.defaultOutbound = defaultEndpoint
|
||||
}
|
||||
return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
|
||||
} else {
|
||||
for _, outbound := range outbounds {
|
||||
err := adapter.LegacyStart(outbound, stage)
|
||||
|
@ -87,7 +96,14 @@ func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
|
|||
}
|
||||
started[outboundTag] = true
|
||||
canContinue = true
|
||||
if starter, isStarter := outboundToStart.(interface {
|
||||
if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
|
||||
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
|
||||
err := starter.Start(adapter.StartStateStart)
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
|
||||
}
|
||||
} else if starter, isStarter := outboundToStart.(interface {
|
||||
Start() error
|
||||
}); isStarter {
|
||||
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
|
||||
|
@ -160,9 +176,12 @@ func (m *Manager) Outbounds() []adapter.Outbound {
|
|||
|
||||
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
outbound, found := m.outboundByTag[tag]
|
||||
return outbound, found
|
||||
m.access.Unlock()
|
||||
if found {
|
||||
return outbound, true
|
||||
}
|
||||
return m.endpoint.Get(tag)
|
||||
}
|
||||
|
||||
func (m *Manager) Default() adapter.Outbound {
|
||||
|
|
51
box.go
51
box.go
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
|
@ -36,6 +37,7 @@ type Box struct {
|
|||
logFactory log.Factory
|
||||
logger log.ContextLogger
|
||||
network *route.NetworkManager
|
||||
endpoint *endpoint.Manager
|
||||
inbound *inbound.Manager
|
||||
outbound *outbound.Manager
|
||||
connection *route.ConnectionManager
|
||||
|
@ -54,6 +56,7 @@ func Context(
|
|||
ctx context.Context,
|
||||
inboundRegistry adapter.InboundRegistry,
|
||||
outboundRegistry adapter.OutboundRegistry,
|
||||
endpointRegistry adapter.EndpointRegistry,
|
||||
) context.Context {
|
||||
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
|
||||
service.FromContext[adapter.InboundRegistry](ctx) == nil {
|
||||
|
@ -65,6 +68,11 @@ func Context(
|
|||
ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry)
|
||||
ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry)
|
||||
}
|
||||
if service.FromContext[option.EndpointOptionsRegistry](ctx) == nil ||
|
||||
service.FromContext[adapter.EndpointRegistry](ctx) == nil {
|
||||
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
|
||||
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
|
@ -76,12 +84,16 @@ func New(options Options) (*Box, error) {
|
|||
}
|
||||
ctx = service.ContextWithDefaultRegistry(ctx)
|
||||
|
||||
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
|
||||
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
|
||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||
|
||||
if endpointRegistry == nil {
|
||||
return nil, E.New("missing endpoint registry in context")
|
||||
}
|
||||
if inboundRegistry == nil {
|
||||
return nil, E.New("missing inbound registry in context")
|
||||
}
|
||||
|
||||
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
|
||||
if outboundRegistry == nil {
|
||||
return nil, E.New("missing outbound registry in context")
|
||||
}
|
||||
|
@ -119,8 +131,10 @@ func New(options Options) (*Box, error) {
|
|||
}
|
||||
|
||||
routeOptions := common.PtrValueOrDefault(options.Route)
|
||||
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry)
|
||||
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final)
|
||||
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
|
||||
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
|
||||
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
|
||||
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
|
||||
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
|
||||
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
|
||||
|
||||
|
@ -135,6 +149,24 @@ func New(options Options) (*Box, error) {
|
|||
if err != nil {
|
||||
return nil, E.Cause(err, "initialize router")
|
||||
}
|
||||
for i, endpointOptions := range options.Endpoints {
|
||||
var tag string
|
||||
if endpointOptions.Tag != "" {
|
||||
tag = endpointOptions.Tag
|
||||
} else {
|
||||
tag = F.ToString(i)
|
||||
}
|
||||
err = endpointManager.Create(ctx,
|
||||
router,
|
||||
logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")),
|
||||
tag,
|
||||
endpointOptions.Type,
|
||||
endpointOptions.Options,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "initialize inbound[", i, "]")
|
||||
}
|
||||
}
|
||||
for i, inboundOptions := range options.Inbounds {
|
||||
var tag string
|
||||
if inboundOptions.Tag != "" {
|
||||
|
@ -241,6 +273,7 @@ func New(options Options) (*Box, error) {
|
|||
}
|
||||
return &Box{
|
||||
network: networkManager,
|
||||
endpoint: endpointManager,
|
||||
inbound: inboundManager,
|
||||
outbound: outboundManager,
|
||||
connection: connectionManager,
|
||||
|
@ -303,7 +336,7 @@ func (s *Box) preStart() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound)
|
||||
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound, s.endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -327,7 +360,11 @@ func (s *Box) start() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound)
|
||||
err = adapter.Start(adapter.StartStateStart, s.endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound, s.endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -335,7 +372,7 @@ func (s *Box) start() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound)
|
||||
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound, s.endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) {
|
|||
configPaths = append(configPaths, "config.json")
|
||||
}
|
||||
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
|
||||
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry())
|
||||
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||
}
|
||||
|
|
|
@ -279,7 +279,7 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
|
|||
}
|
||||
|
||||
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
|
||||
return trackPacketConn(d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay))
|
||||
return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address))
|
||||
}
|
||||
|
||||
func trackConn(conn net.Conn, err error) (net.Conn, error) {
|
||||
|
|
|
@ -109,6 +109,15 @@ var OptionDestinationOverrideFields = Note{
|
|||
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options",
|
||||
}
|
||||
|
||||
var OptionWireGuardOutbound = Note{
|
||||
Name: "wireguard-outbound",
|
||||
Description: "legacy wireguard outbound",
|
||||
DeprecatedVersion: "1.11.0",
|
||||
ScheduledVersion: "1.13.0",
|
||||
EnvName: "WIREGUARD_OUTBOUND",
|
||||
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-wireguard-outbound-to-endpoint",
|
||||
}
|
||||
|
||||
var Options = []Note{
|
||||
OptionBadMatchSource,
|
||||
OptionGEOIP,
|
||||
|
@ -117,4 +126,5 @@ var Options = []Note{
|
|||
OptionSpecialOutbounds,
|
||||
OptionInboundOptions,
|
||||
OptionDestinationOverrideFields,
|
||||
OptionWireGuardOutbound,
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err
|
|||
}
|
||||
|
||||
func CheckConfig(configContent string) error {
|
||||
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
|
||||
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||
options, err := parseConfig(ctx, configContent)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -131,7 +131,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica
|
|||
}
|
||||
|
||||
func FormatConfig(configContent string) (string, error) {
|
||||
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()), configContent)
|
||||
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ type BoxService struct {
|
|||
}
|
||||
|
||||
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
|
||||
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
|
||||
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
|
||||
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
|
||||
options, err := parseConfig(ctx, configContent)
|
||||
|
|
3
go.mod
3
go.mod
|
@ -36,7 +36,7 @@ require (
|
|||
github.com/sagernet/sing-vmess v0.1.12
|
||||
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
|
||||
github.com/sagernet/utls v1.6.7
|
||||
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8
|
||||
github.com/sagernet/wireguard-go v0.0.1-beta.4
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/stretchr/testify v1.9.0
|
||||
|
@ -92,6 +92,7 @@ require (
|
|||
golang.org/x/text v0.20.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
golang.org/x/tools v0.24.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
|
6
go.sum
6
go.sum
|
@ -132,8 +132,8 @@ github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxe
|
|||
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
|
||||
github.com/sagernet/utls v1.6.7 h1:Ep3+aJ8FUGGta+II2IEVNUc3EDhaRCZINWkj/LloIA8=
|
||||
github.com/sagernet/utls v1.6.7/go.mod h1:Uua1TKO/FFuAhLr9rkaVnnrTmmiItzDjv1BUb2+ERwM=
|
||||
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8 h1:R0OMYAScomNAVpTfbHFpxqJpvwuhxSRi+g6z7gZhABs=
|
||||
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8/go.mod h1:K4J7/npM+VAMUeUmTa2JaA02JmyheP0GpRBOUvn3ecc=
|
||||
github.com/sagernet/wireguard-go v0.0.1-beta.4 h1:8uyM5fxfEXdu4RH05uOK+v25i3lTNdCYMPSAUJ14FnI=
|
||||
github.com/sagernet/wireguard-go v0.0.1-beta.4/go.mod h1:jGXij2Gn2wbrWuYNUmmNhf1dwcZtvyAvQoe8Xd8MbUo=
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
|
||||
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
|
@ -195,6 +195,8 @@ golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
|||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY=
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
|
@ -82,6 +83,14 @@ func OutboundRegistry() *outbound.Registry {
|
|||
return registry
|
||||
}
|
||||
|
||||
func EndpointRegistry() *endpoint.Registry {
|
||||
registry := endpoint.NewRegistry()
|
||||
|
||||
registerWireGuardEndpoint(registry)
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
func registerStubForRemovedInbounds(registry *inbound.Registry) {
|
||||
inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
|
||||
return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0")
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package include
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/protocol/wireguard"
|
||||
)
|
||||
|
@ -10,3 +11,7 @@ import (
|
|||
func registerWireGuardOutbound(registry *outbound.Registry) {
|
||||
wireguard.RegisterOutbound(registry)
|
||||
}
|
||||
|
||||
func registerWireGuardEndpoint(registry *endpoint.Registry) {
|
||||
wireguard.RegisterEndpoint(registry)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
|
@ -14,7 +15,13 @@ import (
|
|||
)
|
||||
|
||||
func registerWireGuardOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
|
||||
outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
|
||||
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
|
||||
})
|
||||
}
|
||||
|
||||
func registerWireGuardEndpoint(registry *endpoint.Registry) {
|
||||
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
|
||||
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
|
||||
})
|
||||
}
|
||||
|
|
47
option/endpoint.go
Normal file
47
option/endpoint.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package option
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json"
|
||||
"github.com/sagernet/sing/common/json/badjson"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
type EndpointOptionsRegistry interface {
|
||||
CreateOptions(endpointType string) (any, bool)
|
||||
}
|
||||
|
||||
type _Endpoint struct {
|
||||
Type string `json:"type"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Options any `json:"-"`
|
||||
}
|
||||
|
||||
type Endpoint _Endpoint
|
||||
|
||||
func (h *Endpoint) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||
return badjson.MarshallObjectsContext(ctx, (*_Endpoint)(h), h.Options)
|
||||
}
|
||||
|
||||
func (h *Endpoint) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
err := json.UnmarshalContext(ctx, content, (*_Endpoint)(h))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registry := service.FromContext[EndpointOptionsRegistry](ctx)
|
||||
if registry == nil {
|
||||
return E.New("missing Endpoint fields registry in context")
|
||||
}
|
||||
options, loaded := registry.CreateOptions(h.Type)
|
||||
if !loaded {
|
||||
return E.New("unknown inbound type: ", h.Type)
|
||||
}
|
||||
err = badjson.UnmarshallExcludedContext(ctx, content, (*_Endpoint)(h), options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Options = options
|
||||
return nil
|
||||
}
|
|
@ -28,7 +28,7 @@ func (h *Inbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|||
}
|
||||
|
||||
func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
err := json.Unmarshal(content, (*_Inbound)(h))
|
||||
err := json.UnmarshalContext(ctx, content, (*_Inbound)(h))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ type _Options struct {
|
|||
Log *LogOptions `json:"log,omitempty"`
|
||||
DNS *DNSOptions `json:"dns,omitempty"`
|
||||
NTP *NTPOptions `json:"ntp,omitempty"`
|
||||
Endpoints []Endpoint `json:"endpoints,omitempty"`
|
||||
Inbounds []Inbound `json:"inbounds,omitempty"`
|
||||
Outbounds []Outbound `json:"outbounds,omitempty"`
|
||||
Route *RouteOptions `json:"route,omitempty"`
|
||||
|
|
|
@ -30,7 +30,7 @@ func (h *Outbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|||
}
|
||||
|
||||
func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
||||
err := json.Unmarshal(content, (*_Outbound)(h))
|
||||
err := json.UnmarshalContext(ctx, content, (*_Outbound)(h))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,14 +6,38 @@ import (
|
|||
"github.com/sagernet/sing/common/json/badoption"
|
||||
)
|
||||
|
||||
type WireGuardOutboundOptions struct {
|
||||
type WireGuardEndpointOptions struct {
|
||||
System bool `json:"system,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
MTU uint32 `json:"mtu,omitempty"`
|
||||
GSO bool `json:"gso,omitempty"`
|
||||
Address badoption.Listable[netip.Prefix] `json:"address"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ListenPort uint16 `json:"listen_port,omitempty"`
|
||||
Peers []WireGuardPeer `json:"peers,omitempty"`
|
||||
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
|
||||
Workers int `json:"workers,omitempty"`
|
||||
DialerOptions
|
||||
}
|
||||
|
||||
type WireGuardPeer struct {
|
||||
Address string `json:"address,omitempty"`
|
||||
Port uint16 `json:"port,omitempty"`
|
||||
PublicKey string `json:"public_key,omitempty"`
|
||||
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
||||
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
|
||||
PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"`
|
||||
Reserved []uint8 `json:"reserved,omitempty"`
|
||||
}
|
||||
|
||||
type LegacyWireGuardOutboundOptions struct {
|
||||
DialerOptions
|
||||
SystemInterface bool `json:"system_interface,omitempty"`
|
||||
GSO bool `json:"gso,omitempty"`
|
||||
InterfaceName string `json:"interface_name,omitempty"`
|
||||
LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
Peers []WireGuardPeer `json:"peers,omitempty"`
|
||||
Peers []LegacyWireGuardPeer `json:"peers,omitempty"`
|
||||
ServerOptions
|
||||
PeerPublicKey string `json:"peer_public_key"`
|
||||
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
||||
|
@ -23,10 +47,10 @@ type WireGuardOutboundOptions struct {
|
|||
Network NetworkList `json:"network,omitempty"`
|
||||
}
|
||||
|
||||
type WireGuardPeer struct {
|
||||
type LegacyWireGuardPeer struct {
|
||||
ServerOptions
|
||||
PublicKey string `json:"public_key,omitempty"`
|
||||
PreSharedKey string `json:"pre_shared_key,omitempty"`
|
||||
AllowedIPs badoption.Listable[string] `json:"allowed_ips,omitempty"`
|
||||
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
|
||||
Reserved []uint8 `json:"reserved,omitempty"`
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ type Outbound struct {
|
|||
|
||||
func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) {
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapter(C.TypeBlock, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil),
|
||||
Adapter: outbound.NewAdapter(C.TypeBlock, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -68,7 +68,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (i *Inbound) Start() error {
|
||||
func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return i.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||
logger: logger,
|
||||
domainStrategy: dns.DomainStrategy(options.DomainStrategy),
|
||||
fallbackDelay: time.Duration(options.FallbackDelay),
|
||||
|
|
|
@ -28,7 +28,7 @@ type Outbound struct {
|
|||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapter(C.TypeDNS, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil),
|
||||
Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
|
||||
router: router,
|
||||
logger: logger,
|
||||
}, nil
|
||||
|
|
|
@ -38,7 +38,7 @@ type Selector struct {
|
|||
|
||||
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
|
||||
outbound := &Selector{
|
||||
Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds),
|
||||
Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
|
||||
ctx: ctx,
|
||||
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
||||
logger: logger,
|
||||
|
|
|
@ -49,7 +49,7 @@ type URLTest struct {
|
|||
|
||||
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
|
||||
outbound := &URLTest{
|
||||
Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds),
|
||||
Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
||||
|
|
|
@ -61,7 +61,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if h.tlsConfig != nil {
|
||||
err := h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
|
|
|
@ -39,7 +39,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||
logger: logger,
|
||||
client: sHTTP.NewClient(sHTTP.Options{
|
||||
Dialer: detour,
|
||||
|
|
|
@ -160,7 +160,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
|||
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if h.tlsConfig != nil {
|
||||
err := h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
|
|
|
@ -95,7 +95,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, networkList, tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions),
|
||||
logger: logger,
|
||||
client: client,
|
||||
}, nil
|
||||
|
|
|
@ -171,7 +171,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
|||
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if h.tlsConfig != nil {
|
||||
err := h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
|
|
|
@ -81,7 +81,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, networkList, tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions),
|
||||
logger: logger,
|
||||
client: client,
|
||||
}, nil
|
||||
|
|
|
@ -54,7 +54,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -78,7 +78,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (n *Inbound) Start() error {
|
||||
func (n *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
var tlsConfig *tls.STDConfig
|
||||
if n.tlsConfig != nil {
|
||||
err := n.tlsConfig.Start()
|
||||
|
|
|
@ -42,7 +42,10 @@ func NewRedirect(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return redirect, nil
|
||||
}
|
||||
|
||||
func (h *Redirect) Start() error {
|
||||
func (h *Redirect) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,10 @@ func NewTProxy(ctx context.Context, router adapter.Router, logger log.ContextLog
|
|||
return tproxy, nil
|
||||
}
|
||||
|
||||
func (t *TProxy) Start() error {
|
||||
func (t *TProxy) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
err := t.listener.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -93,7 +93,10 @@ func newInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, err
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -101,7 +101,10 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont
|
|||
return inbound, err
|
||||
}
|
||||
|
||||
func (h *MultiInbound) Start() error {
|
||||
func (h *MultiInbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -86,7 +86,10 @@ func newRelayInbound(ctx context.Context, router adapter.Router, logger log.Cont
|
|||
return inbound, err
|
||||
}
|
||||
|
||||
func (h *RelayInbound) Start() error {
|
||||
func (h *RelayInbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, options.Network.Build(), tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, tag, options.Network.Build(), options.DialerOptions),
|
||||
logger: logger,
|
||||
dialer: outboundDialer,
|
||||
method: method,
|
||||
|
|
|
@ -90,7 +90,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ type Outbound struct {
|
|||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) {
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||
}
|
||||
if options.TLS == nil || !options.TLS.Enabled {
|
||||
return nil, C.ErrTLSRequired
|
||||
|
|
|
@ -50,7 +50,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, options.Network.Build(), tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
|
||||
router: router,
|
||||
logger: logger,
|
||||
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),
|
||||
|
|
|
@ -54,7 +54,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
dialer: outboundDialer,
|
||||
|
|
|
@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, []string{N.NetworkTCP}, tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
proxy: NewProxyListener(ctx, logger, outboundDialer),
|
||||
|
|
|
@ -110,7 +110,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if h.tlsConfig != nil {
|
||||
err := h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
|
|
|
@ -43,7 +43,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, options.Network.Build(), tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, tag, options.Network.Build(), options.DialerOptions),
|
||||
logger: logger,
|
||||
dialer: outboundDialer,
|
||||
serverAddr: options.ServerOptions.Build(),
|
||||
|
|
|
@ -142,7 +142,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
|||
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if h.tlsConfig != nil {
|
||||
err := h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
|
|
|
@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, options.Network.Build(), tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, tag, options.Network.Build(), options.DialerOptions),
|
||||
logger: logger,
|
||||
client: client,
|
||||
udpStream: options.UDPOverStream,
|
||||
|
|
|
@ -300,7 +300,9 @@ func (t *Inbound) Tag() string {
|
|||
return t.tag
|
||||
}
|
||||
|
||||
func (t *Inbound) Start() error {
|
||||
func (t *Inbound) Start(stage adapter.StartStage) error {
|
||||
switch stage {
|
||||
case adapter.StartStateStart:
|
||||
if C.IsAndroid && t.platformInterface == nil {
|
||||
t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager())
|
||||
}
|
||||
|
@ -348,10 +350,7 @@ func (t *Inbound) Start() error {
|
|||
}
|
||||
t.tunStack = tunStack
|
||||
t.logger.Info("started at ", t.tunOptions.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Inbound) PostStart() error {
|
||||
case adapter.StartStatePostStart:
|
||||
monitor := taskmonitor.New(t.logger, C.StartTimeout)
|
||||
monitor.Start("starting tun stack")
|
||||
err := t.tunStack.Start()
|
||||
|
@ -399,6 +398,7 @@ func (t *Inbound) PostStart() error {
|
|||
t.routeAddressSet = nil
|
||||
t.routeExcludeAddressSet = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -89,7 +89,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if h.tlsConfig != nil {
|
||||
err := h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
|
|
|
@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, options.Network.Build(), tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, tag, options.Network.Build(), options.DialerOptions),
|
||||
logger: logger,
|
||||
dialer: outboundDialer,
|
||||
serverAddr: options.ServerOptions.Build(),
|
||||
|
|
|
@ -99,7 +99,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
|||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start() error {
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
err := h.service.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||
return nil, err
|
||||
}
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, options.Network.Build(), tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, tag, options.Network.Build(), options.DialerOptions),
|
||||
logger: logger,
|
||||
dialer: outboundDialer,
|
||||
serverAddr: options.ServerOptions.Build(),
|
||||
|
|
211
protocol/wireguard/endpoint.go
Normal file
211
protocol/wireguard/endpoint.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/endpoint"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/wireguard"
|
||||
"github.com/sagernet/sing-dns"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterEndpoint(registry *endpoint.Registry) {
|
||||
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
|
||||
}
|
||||
|
||||
var (
|
||||
_ adapter.Endpoint = (*Endpoint)(nil)
|
||||
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
|
||||
)
|
||||
|
||||
type Endpoint struct {
|
||||
endpoint.Adapter
|
||||
ctx context.Context
|
||||
router adapter.Router
|
||||
logger logger.ContextLogger
|
||||
localAddresses []netip.Prefix
|
||||
endpoint *wireguard.Endpoint
|
||||
}
|
||||
|
||||
func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
|
||||
ep := &Endpoint{
|
||||
Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
logger: logger,
|
||||
localAddresses: options.Address,
|
||||
}
|
||||
if options.Detour == "" {
|
||||
options.IsWireGuardListener = true
|
||||
} else if options.GSO {
|
||||
return nil, E.New("gso is conflict with detour")
|
||||
}
|
||||
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
System: options.System,
|
||||
Handler: ep,
|
||||
UDPTimeout: time.Duration(options.UDPTimeout),
|
||||
Dialer: outboundDialer,
|
||||
CreateDialer: func(interfaceName string) N.Dialer {
|
||||
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
|
||||
BindInterface: interfaceName,
|
||||
}))
|
||||
},
|
||||
Name: options.Name,
|
||||
MTU: options.MTU,
|
||||
GSO: options.GSO,
|
||||
Address: options.Address,
|
||||
PrivateKey: options.PrivateKey,
|
||||
ListenPort: options.ListenPort,
|
||||
ResolvePeer: func(domain string) (netip.Addr, error) {
|
||||
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
|
||||
if lookupErr != nil {
|
||||
return netip.Addr{}, lookupErr
|
||||
}
|
||||
return endpointAddresses[0], nil
|
||||
},
|
||||
Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions {
|
||||
return wireguard.PeerOptions{
|
||||
Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port),
|
||||
PublicKey: it.PublicKey,
|
||||
PreSharedKey: it.PreSharedKey,
|
||||
AllowedIPs: it.AllowedIPs,
|
||||
PersistentKeepaliveInterval: it.PersistentKeepaliveInterval,
|
||||
Reserved: it.Reserved,
|
||||
}
|
||||
}),
|
||||
Workers: options.Workers,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ep.endpoint = wgEndpoint
|
||||
return ep, nil
|
||||
}
|
||||
|
||||
func (w *Endpoint) Start(stage adapter.StartStage) error {
|
||||
switch stage {
|
||||
case adapter.StartStateStart:
|
||||
return w.endpoint.Start(false)
|
||||
case adapter.StartStatePostStart:
|
||||
return w.endpoint.Start(true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Endpoint) Close() error {
|
||||
return w.endpoint.Close()
|
||||
}
|
||||
|
||||
func (w *Endpoint) InterfaceUpdated() {
|
||||
w.endpoint.BindUpdate()
|
||||
return
|
||||
}
|
||||
|
||||
func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
|
||||
return w.router.PreMatch(adapter.InboundContext{
|
||||
Inbound: w.Tag(),
|
||||
InboundType: w.Type(),
|
||||
Network: network,
|
||||
Source: source,
|
||||
Destination: destination,
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||
var metadata adapter.InboundContext
|
||||
metadata.Inbound = w.Tag()
|
||||
metadata.InboundType = w.Type()
|
||||
metadata.Source = source
|
||||
for _, localPrefix := range w.localAddresses {
|
||||
if localPrefix.Contains(destination.Addr) {
|
||||
metadata.OriginDestination = destination
|
||||
if destination.Addr.Is4() {
|
||||
destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
|
||||
} else {
|
||||
destination.Addr = netip.IPv6Loopback()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
metadata.Destination = destination
|
||||
w.logger.InfoContext(ctx, "inbound connection from ", source)
|
||||
w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||
w.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||
var metadata adapter.InboundContext
|
||||
metadata.Inbound = w.Tag()
|
||||
metadata.InboundType = w.Type()
|
||||
metadata.Source = source
|
||||
metadata.Destination = destination
|
||||
for _, localPrefix := range w.localAddresses {
|
||||
if localPrefix.Contains(destination.Addr) {
|
||||
metadata.OriginDestination = destination
|
||||
if destination.Addr.Is4() {
|
||||
metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
|
||||
} else {
|
||||
metadata.Destination.Addr = netip.IPv6Loopback()
|
||||
}
|
||||
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
|
||||
}
|
||||
}
|
||||
w.logger.InfoContext(ctx, "inbound packet connection from ", source)
|
||||
w.logger.InfoContext(ctx, "inbound packet connection to ", destination)
|
||||
w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch network {
|
||||
case N.NetworkTCP:
|
||||
w.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
case N.NetworkUDP:
|
||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
}
|
||||
if destination.IsFqdn() {
|
||||
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses)
|
||||
} else if !destination.Addr.IsValid() {
|
||||
return nil, E.New("invalid destination: ", destination)
|
||||
}
|
||||
return w.endpoint.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
if destination.IsFqdn() {
|
||||
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return packetConn, err
|
||||
}
|
||||
return w.endpoint.ListenPacket(ctx, destination)
|
||||
}
|
|
@ -2,231 +2,153 @@ package wireguard
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/experimental/deprecated"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/wireguard"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-dns"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/sing/service/pause"
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
|
||||
outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
|
||||
}
|
||||
|
||||
var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
|
||||
var (
|
||||
_ adapter.Endpoint = (*Endpoint)(nil)
|
||||
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
|
||||
)
|
||||
|
||||
type Outbound struct {
|
||||
outbound.Adapter
|
||||
ctx context.Context
|
||||
router adapter.Router
|
||||
logger logger.ContextLogger
|
||||
workers int
|
||||
peers []wireguard.PeerConfig
|
||||
useStdNetBind bool
|
||||
listener N.Dialer
|
||||
ipcConf string
|
||||
|
||||
pauseManager pause.Manager
|
||||
pauseCallback *list.Element[pause.Callback]
|
||||
bind conn.Bind
|
||||
device *device.Device
|
||||
tunDevice wireguard.Device
|
||||
localAddresses []netip.Prefix
|
||||
endpoint *wireguard.Endpoint
|
||||
}
|
||||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
|
||||
deprecated.Report(ctx, deprecated.OptionWireGuardOutbound)
|
||||
outbound := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions),
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
logger: logger,
|
||||
workers: options.Workers,
|
||||
pauseManager: service.FromContext[pause.Manager](ctx),
|
||||
localAddresses: options.LocalAddress,
|
||||
}
|
||||
peers, err := wireguard.ParsePeers(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outbound.peers = peers
|
||||
if len(options.LocalAddress) == 0 {
|
||||
return nil, E.New("missing local address")
|
||||
}
|
||||
if options.GSO {
|
||||
if options.GSO && options.Detour != "" {
|
||||
if options.Detour == "" {
|
||||
options.IsWireGuardListener = true
|
||||
} else if options.GSO {
|
||||
return nil, E.New("gso is conflict with detour")
|
||||
}
|
||||
options.IsWireGuardListener = true
|
||||
outbound.useStdNetBind = true
|
||||
}
|
||||
listener, err := dialer.New(ctx, options.DialerOptions)
|
||||
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outbound.listener = listener
|
||||
var privateKey string
|
||||
{
|
||||
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
||||
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
System: options.SystemInterface,
|
||||
Dialer: outboundDialer,
|
||||
CreateDialer: func(interfaceName string) N.Dialer {
|
||||
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
|
||||
BindInterface: interfaceName,
|
||||
}))
|
||||
},
|
||||
Name: options.InterfaceName,
|
||||
MTU: options.MTU,
|
||||
GSO: options.GSO,
|
||||
Address: options.LocalAddress,
|
||||
PrivateKey: options.PrivateKey,
|
||||
ResolvePeer: func(domain string) (netip.Addr, error) {
|
||||
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
|
||||
if lookupErr != nil {
|
||||
return netip.Addr{}, lookupErr
|
||||
}
|
||||
return endpointAddresses[0], nil
|
||||
},
|
||||
Peers: common.Map(options.Peers, func(it option.LegacyWireGuardPeer) wireguard.PeerOptions {
|
||||
return wireguard.PeerOptions{
|
||||
Endpoint: it.ServerOptions.Build(),
|
||||
PublicKey: it.PublicKey,
|
||||
PreSharedKey: it.PreSharedKey,
|
||||
AllowedIPs: it.AllowedIPs,
|
||||
// PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval),
|
||||
Reserved: it.Reserved,
|
||||
}
|
||||
}),
|
||||
Workers: options.Workers,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode private key")
|
||||
return nil, err
|
||||
}
|
||||
privateKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
outbound.ipcConf = "private_key=" + privateKey
|
||||
mtu := options.MTU
|
||||
if mtu == 0 {
|
||||
mtu = 1408
|
||||
}
|
||||
var wireTunDevice wireguard.Device
|
||||
if !options.SystemInterface && tun.WithGVisor {
|
||||
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
|
||||
} else {
|
||||
wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create WireGuard device")
|
||||
}
|
||||
outbound.tunDevice = wireTunDevice
|
||||
outbound.endpoint = wgEndpoint
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func (w *Outbound) Start() error {
|
||||
if common.Any(w.peers, func(peer wireguard.PeerConfig) bool {
|
||||
return !peer.Endpoint.IsValid()
|
||||
}) {
|
||||
// wait for all outbounds to be started and continue in PortStart
|
||||
return nil
|
||||
}
|
||||
return w.start()
|
||||
}
|
||||
|
||||
func (w *Outbound) PostStart() error {
|
||||
if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
|
||||
return peer.Endpoint.IsValid()
|
||||
}) {
|
||||
return nil
|
||||
}
|
||||
return w.start()
|
||||
}
|
||||
|
||||
func (w *Outbound) start() error {
|
||||
err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var bind conn.Bind
|
||||
if w.useStdNetBind {
|
||||
bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
|
||||
} else {
|
||||
var (
|
||||
isConnect bool
|
||||
connectAddr netip.AddrPort
|
||||
reserved [3]uint8
|
||||
)
|
||||
peerLen := len(w.peers)
|
||||
if peerLen == 1 {
|
||||
isConnect = true
|
||||
connectAddr = w.peers[0].Endpoint
|
||||
reserved = w.peers[0].Reserved
|
||||
}
|
||||
bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
|
||||
}
|
||||
err = w.tunDevice.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{
|
||||
Verbosef: func(format string, args ...interface{}) {
|
||||
w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
Errorf: func(format string, args ...interface{}) {
|
||||
w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
}, w.workers)
|
||||
ipcConf := w.ipcConf
|
||||
for _, peer := range w.peers {
|
||||
ipcConf += peer.GenerateIpcLines()
|
||||
}
|
||||
err = wgDevice.IpcSet(ipcConf)
|
||||
if err != nil {
|
||||
return E.Cause(err, "setup wireguard: \n", ipcConf)
|
||||
}
|
||||
w.device = wgDevice
|
||||
w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Outbound) Close() error {
|
||||
if w.device != nil {
|
||||
w.device.Close()
|
||||
}
|
||||
if w.pauseCallback != nil {
|
||||
w.pauseManager.UnregisterCallback(w.pauseCallback)
|
||||
func (o *Outbound) Start(stage adapter.StartStage) error {
|
||||
switch stage {
|
||||
case adapter.StartStateStart:
|
||||
return o.endpoint.Start(false)
|
||||
case adapter.StartStatePostStart:
|
||||
return o.endpoint.Start(true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Outbound) InterfaceUpdated() {
|
||||
w.device.BindUpdate()
|
||||
func (o *Outbound) Close() error {
|
||||
return o.endpoint.Close()
|
||||
}
|
||||
|
||||
func (o *Outbound) InterfaceUpdated() {
|
||||
o.endpoint.BindUpdate()
|
||||
return
|
||||
}
|
||||
|
||||
func (w *Outbound) onPauseUpdated(event int) {
|
||||
switch event {
|
||||
case pause.EventDevicePaused:
|
||||
w.device.Down()
|
||||
case pause.EventDeviceWake:
|
||||
w.device.Up()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch network {
|
||||
case N.NetworkTCP:
|
||||
w.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
o.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
case N.NetworkUDP:
|
||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
}
|
||||
if destination.IsFqdn() {
|
||||
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
||||
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
|
||||
return N.DialSerial(ctx, o.endpoint, network, destination, destinationAddresses)
|
||||
} else if !destination.Addr.IsValid() {
|
||||
return nil, E.New("invalid destination: ", destination)
|
||||
}
|
||||
return w.tunDevice.DialContext(ctx, network, destination)
|
||||
return o.endpoint.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
if destination.IsFqdn() {
|
||||
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
|
||||
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
|
||||
packetConn, _, err := N.ListenSerial(ctx, o.endpoint, destination, destinationAddresses)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return packetConn, err
|
||||
}
|
||||
return w.tunDevice.ListenPacket(ctx, destination)
|
||||
return o.endpoint.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
|
|
@ -41,15 +41,15 @@ type NetworkManager struct {
|
|||
autoDetectInterface bool
|
||||
defaultOptions adapter.NetworkOptions
|
||||
autoRedirectOutputMark uint32
|
||||
|
||||
networkMonitor tun.NetworkUpdateMonitor
|
||||
interfaceMonitor tun.DefaultInterfaceMonitor
|
||||
packageManager tun.PackageManager
|
||||
powerListener winpowrprof.EventListener
|
||||
pauseManager pause.Manager
|
||||
platformInterface platform.Interface
|
||||
inboundManager adapter.InboundManager
|
||||
outboundManager adapter.OutboundManager
|
||||
endpoint adapter.EndpointManager
|
||||
inbound adapter.InboundManager
|
||||
outbound adapter.OutboundManager
|
||||
wifiState adapter.WIFIState
|
||||
started bool
|
||||
}
|
||||
|
@ -69,7 +69,9 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
|
|||
},
|
||||
pauseManager: service.FromContext[pause.Manager](ctx),
|
||||
platformInterface: service.FromContext[platform.Interface](ctx),
|
||||
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
|
||||
endpoint: service.FromContext[adapter.EndpointManager](ctx),
|
||||
inbound: service.FromContext[adapter.InboundManager](ctx),
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
}
|
||||
if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault {
|
||||
if routeOptions.DefaultInterface != "" {
|
||||
|
@ -358,14 +360,21 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
|
|||
func (r *NetworkManager) ResetNetwork() {
|
||||
conntrack.Close()
|
||||
|
||||
for _, inbound := range r.inboundManager.Inbounds() {
|
||||
for _, endpoint := range r.endpoint.Endpoints() {
|
||||
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
|
||||
if isListener {
|
||||
listener.InterfaceUpdated()
|
||||
}
|
||||
}
|
||||
|
||||
for _, inbound := range r.inbound.Inbounds() {
|
||||
listener, isListener := inbound.(adapter.InterfaceUpdateListener)
|
||||
if isListener {
|
||||
listener.InterfaceUpdated()
|
||||
}
|
||||
}
|
||||
|
||||
for _, outbound := range r.outboundManager.Outbounds() {
|
||||
for _, outbound := range r.outbound.Outbounds() {
|
||||
listener, isListener := outbound.(adapter.InterfaceUpdateListener)
|
||||
if isListener {
|
||||
listener.InterfaceUpdated()
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
C "github.com/sagernet/sing-box/constant"
|
||||
R "github.com/sagernet/sing-box/route/rule"
|
||||
"github.com/sagernet/sing-dns"
|
||||
tun "github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/cache"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
|
|
|
@ -32,7 +32,7 @@ func TestMain(m *testing.M) {
|
|||
var globalCtx context.Context
|
||||
|
||||
func init() {
|
||||
globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
|
||||
globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
|
||||
}
|
||||
|
||||
func startInstance(t *testing.T, options option.Options) *box.Box {
|
||||
|
|
|
@ -37,12 +37,12 @@ func _TestWireGuard(t *testing.T) {
|
|||
Outbounds: []option.Outbound{
|
||||
{
|
||||
Type: C.TypeWireGuard,
|
||||
Options: &option.WireGuardOutboundOptions{
|
||||
Options: &option.WireGuardEndpointOptions{
|
||||
ServerOptions: option.ServerOptions{
|
||||
Server: "127.0.0.1",
|
||||
ServerPort: serverPort,
|
||||
},
|
||||
LocalAddress: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
|
||||
Address: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
|
||||
PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
|
||||
PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
|
||||
},
|
||||
|
|
|
@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
|||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
c.logger.Error(context.Background(), E.Cause(err, "read packet"))
|
||||
c.logger.Error(E.Cause(err, "read packet"))
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
|
@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
|
|||
b := packets[0]
|
||||
common.ClearArray(b[1:4])
|
||||
}
|
||||
eps[0] = Endpoint(M.AddrPortFromNet(addr))
|
||||
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
|
||||
count = 1
|
||||
return
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
|||
time.Sleep(time.Second)
|
||||
return err
|
||||
}
|
||||
destination := netip.AddrPort(ep.(Endpoint))
|
||||
destination := netip.AddrPort(ep.(remoteEndpoint))
|
||||
for _, b := range bufs {
|
||||
if len(b) > 3 {
|
||||
reserved, loaded := c.reservedForEndpoint[destination]
|
||||
|
@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Endpoint(ap), nil
|
||||
return remoteEndpoint(ap), nil
|
||||
}
|
||||
|
||||
func (c *ClientBind) BatchSize() int {
|
||||
|
@ -229,3 +229,31 @@ func (w *wireConn) Close() error {
|
|||
close(w.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ conn.Endpoint = (*remoteEndpoint)(nil)
|
||||
|
||||
type remoteEndpoint netip.AddrPort
|
||||
|
||||
func (e remoteEndpoint) ClearSrc() {
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) DstToString() string {
|
||||
return (netip.AddrPort)(e).String()
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) DstToBytes() []byte {
|
||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) DstIP() netip.Addr {
|
||||
return (netip.AddrPort)(e).Addr()
|
||||
}
|
||||
|
||||
func (e remoteEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,44 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/wireguard-go/tun"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
tun.Device
|
||||
wgTun.Device
|
||||
N.Dialer
|
||||
Start() error
|
||||
// NewEndpoint() (stack.LinkEndpoint, error)
|
||||
SetDevice(device *device.Device)
|
||||
}
|
||||
|
||||
type DeviceOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.ContextLogger
|
||||
System bool
|
||||
Handler tun.Handler
|
||||
UDPTimeout time.Duration
|
||||
CreateDialer func(interfaceName string) N.Dialer
|
||||
Name string
|
||||
MTU uint32
|
||||
GSO bool
|
||||
Address []netip.Prefix
|
||||
AllowedAddress []netip.Prefix
|
||||
}
|
||||
|
||||
func NewDevice(options DeviceOptions) (Device, error) {
|
||||
if !options.System {
|
||||
return newStackDevice(options)
|
||||
} else if options.Handler == nil {
|
||||
return newSystemDevice(options)
|
||||
} else {
|
||||
return newSystemStackDevice(options)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ package wireguard
|
|||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
|
@ -15,52 +14,41 @@ import (
|
|||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
)
|
||||
|
||||
var _ Device = (*StackDevice)(nil)
|
||||
var _ Device = (*stackDevice)(nil)
|
||||
|
||||
const defaultNIC tcpip.NICID = 1
|
||||
|
||||
type StackDevice struct {
|
||||
type stackDevice struct {
|
||||
stack *stack.Stack
|
||||
mtu uint32
|
||||
events chan wgTun.Event
|
||||
outbound chan *stack.PacketBuffer
|
||||
packetOutbound chan *buf.Buffer
|
||||
done chan struct{}
|
||||
dispatcher stack.NetworkDispatcher
|
||||
addr4 tcpip.Address
|
||||
addr6 tcpip.Address
|
||||
}
|
||||
|
||||
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
|
||||
ipStack := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
|
||||
HandleLocal: true,
|
||||
})
|
||||
tunDevice := &StackDevice{
|
||||
stack: ipStack,
|
||||
mtu: mtu,
|
||||
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
||||
tunDevice := &stackDevice{
|
||||
mtu: options.MTU,
|
||||
events: make(chan wgTun.Event, 1),
|
||||
outbound: make(chan *stack.PacketBuffer, 256),
|
||||
packetOutbound: make(chan *buf.Buffer, 256),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
|
||||
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
|
||||
if err != nil {
|
||||
return nil, E.New(err.String())
|
||||
return nil, err
|
||||
}
|
||||
for _, prefix := range localAddresses {
|
||||
for _, prefix := range options.Address {
|
||||
addr := tun.AddressFromAddr(prefix.Addr())
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
|
@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
|
|||
tunDevice.addr6 = addr
|
||||
protoAddr.Protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
|
||||
if err != nil {
|
||||
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
|
||||
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
|
||||
if gErr != nil {
|
||||
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
|
||||
}
|
||||
}
|
||||
sOpt := tcpip.TCPSACKEnabled(true)
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
|
||||
cOpt := tcpip.CongestionControlOption("cubic")
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
|
||||
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
|
||||
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
|
||||
tunDevice.stack = ipStack
|
||||
if options.Handler != nil {
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||
}
|
||||
return tunDevice, nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
|
||||
return (*wireEndpoint)(w), nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
addr := tcpip.FullAddress{
|
||||
NIC: defaultNIC,
|
||||
NIC: tun.DefaultNIC,
|
||||
Port: destination.Port,
|
||||
Addr: tun.AddressFromAddr(destination.Addr),
|
||||
}
|
||||
bind := tcpip.FullAddress{
|
||||
NIC: defaultNIC,
|
||||
NIC: tun.DefaultNIC,
|
||||
}
|
||||
var networkProtocol tcpip.NetworkProtocolNumber
|
||||
if destination.IsIPv4() {
|
||||
|
@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
|
|||
}
|
||||
}
|
||||
|
||||
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
bind := tcpip.FullAddress{
|
||||
NIC: defaultNIC,
|
||||
NIC: tun.DefaultNIC,
|
||||
}
|
||||
var networkProtocol tcpip.NetworkProtocolNumber
|
||||
if destination.IsIPv4() {
|
||||
|
@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
|
|||
return udpConn, nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Inet4Address() netip.Addr {
|
||||
return tun.AddrFromAddress(w.addr4)
|
||||
func (w *stackDevice) SetDevice(device *device.Device) {
|
||||
}
|
||||
|
||||
func (w *StackDevice) Inet6Address() netip.Addr {
|
||||
return tun.AddrFromAddress(w.addr6)
|
||||
}
|
||||
|
||||
func (w *StackDevice) Start() error {
|
||||
func (w *stackDevice) Start() error {
|
||||
w.events <- wgTun.EventUp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) File() *os.File {
|
||||
func (w *stackDevice) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
select {
|
||||
case packetBuffer, ok := <-w.outbound:
|
||||
if !ok {
|
||||
|
@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
|
|||
sizes[0] = n
|
||||
count = 1
|
||||
return
|
||||
case packet := <-w.packetOutbound:
|
||||
defer packet.Release()
|
||||
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
|
||||
count = 1
|
||||
return
|
||||
case <-w.done:
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
for _, b := range bufs {
|
||||
b = b[offset:]
|
||||
if len(b) == 0 {
|
||||
|
@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *StackDevice) Flush() error {
|
||||
func (w *stackDevice) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) MTU() (int, error) {
|
||||
func (w *stackDevice) MTU() (int, error) {
|
||||
return int(w.mtu), nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Name() (string, error) {
|
||||
func (w *stackDevice) Name() (string, error) {
|
||||
return "sing-box", nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) Events() <-chan wgTun.Event {
|
||||
func (w *stackDevice) Events() <-chan wgTun.Event {
|
||||
return w.events
|
||||
}
|
||||
|
||||
func (w *StackDevice) Close() error {
|
||||
func (w *stackDevice) Close() error {
|
||||
close(w.done)
|
||||
close(w.events)
|
||||
w.stack.Close()
|
||||
|
@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (w *StackDevice) BatchSize() int {
|
||||
func (w *stackDevice) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
|
||||
|
||||
type wireEndpoint StackDevice
|
||||
type wireEndpoint stackDevice
|
||||
|
||||
func (ep *wireEndpoint) MTU() uint32 {
|
||||
return ep.mtu
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
import "github.com/sagernet/sing-tun"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
)
|
||||
|
||||
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
|
||||
func newStackDevice(options DeviceOptions) (Device, error) {
|
||||
return nil, tun.ErrGVisorNotIncluded
|
||||
}
|
||||
|
||||
func newSystemStackDevice(options DeviceOptions) (Device, error) {
|
||||
return nil, tun.ErrGVisorNotIncluded
|
||||
}
|
||||
|
|
|
@ -6,96 +6,88 @@ import (
|
|||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
wgTun "github.com/sagernet/wireguard-go/tun"
|
||||
)
|
||||
|
||||
var _ Device = (*SystemDevice)(nil)
|
||||
var _ Device = (*systemDevice)(nil)
|
||||
|
||||
type SystemDevice struct {
|
||||
type systemDevice struct {
|
||||
options DeviceOptions
|
||||
dialer N.Dialer
|
||||
device tun.Tun
|
||||
batchDevice tun.LinuxTUN
|
||||
name string
|
||||
mtu uint32
|
||||
inet4Addresses []netip.Prefix
|
||||
inet6Addresses []netip.Prefix
|
||||
gso bool
|
||||
events chan wgTun.Event
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
|
||||
var inet4Addresses []netip.Prefix
|
||||
var inet6Addresses []netip.Prefix
|
||||
for _, prefixes := range localPrefixes {
|
||||
if prefixes.Addr().Is4() {
|
||||
inet4Addresses = append(inet4Addresses, prefixes)
|
||||
} else {
|
||||
inet6Addresses = append(inet6Addresses, prefixes)
|
||||
func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
|
||||
if options.Name == "" {
|
||||
options.Name = tun.CalculateInterfaceName("wg")
|
||||
}
|
||||
}
|
||||
if interfaceName == "" {
|
||||
interfaceName = tun.CalculateInterfaceName("wg")
|
||||
}
|
||||
|
||||
return &SystemDevice{
|
||||
dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{
|
||||
BindInterface: interfaceName,
|
||||
})),
|
||||
name: interfaceName,
|
||||
mtu: mtu,
|
||||
inet4Addresses: inet4Addresses,
|
||||
inet6Addresses: inet6Addresses,
|
||||
gso: gso,
|
||||
return &systemDevice{
|
||||
options: options,
|
||||
dialer: options.CreateDialer(options.Name),
|
||||
events: make(chan wgTun.Event, 1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
return w.dialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return w.dialer.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Inet4Address() netip.Addr {
|
||||
if len(w.inet4Addresses) == 0 {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return w.inet4Addresses[0].Addr()
|
||||
func (w *systemDevice) SetDevice(device *device.Device) {
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Inet6Address() netip.Addr {
|
||||
if len(w.inet6Addresses) == 0 {
|
||||
return netip.Addr{}
|
||||
func (w *systemDevice) Start() error {
|
||||
networkManager := service.FromContext[adapter.NetworkManager](w.options.Context)
|
||||
tunOptions := tun.Options{
|
||||
Name: w.options.Name,
|
||||
Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is4()
|
||||
}),
|
||||
Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is6()
|
||||
}),
|
||||
MTU: w.options.MTU,
|
||||
GSO: w.options.GSO,
|
||||
InterfaceScope: true,
|
||||
Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is4()
|
||||
}),
|
||||
Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }),
|
||||
InterfaceMonitor: networkManager.InterfaceMonitor(),
|
||||
InterfaceFinder: networkManager.InterfaceFinder(),
|
||||
}
|
||||
return w.inet6Addresses[0].Addr()
|
||||
// works with Linux, macOS with IFSCOPE routes, not tested on Windows
|
||||
if runtime.GOOS == "darwin" {
|
||||
tunOptions.AutoRoute = true
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Start() error {
|
||||
tunInterface, err := tun.New(tun.Options{
|
||||
Name: w.name,
|
||||
Inet4Address: w.inet4Addresses,
|
||||
Inet6Address: w.inet6Addresses,
|
||||
MTU: w.mtu,
|
||||
GSO: w.gso,
|
||||
})
|
||||
tunInterface, err := tun.New(tunOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tunInterface.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.options.Logger.Info("started at ", w.options.Name)
|
||||
w.device = tunInterface
|
||||
if w.gso {
|
||||
if w.options.GSO {
|
||||
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
|
||||
if !isBatchTUN {
|
||||
tunInterface.Close()
|
||||
|
@ -107,15 +99,15 @@ func (w *SystemDevice) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) File() *os.File {
|
||||
func (w *systemDevice) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
||||
if w.batchDevice != nil {
|
||||
count, err = w.batchDevice.BatchRead(bufs, offset, sizes)
|
||||
count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes)
|
||||
} else {
|
||||
sizes[0], err = w.device.Read(bufs[0][offset:])
|
||||
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
|
||||
if err == nil {
|
||||
count = 1
|
||||
} else if errors.Is(err, tun.ErrTooManySegments) {
|
||||
|
@ -125,12 +117,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int,
|
|||
return
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
if w.batchDevice != nil {
|
||||
return 0, w.batchDevice.BatchWrite(bufs, offset)
|
||||
return w.batchDevice.BatchWrite(bufs, offset)
|
||||
} else {
|
||||
for _, b := range bufs {
|
||||
_, err = w.device.Write(b[offset:])
|
||||
for _, packet := range bufs {
|
||||
if tun.PacketOffset > 0 {
|
||||
common.ClearArray(packet[offset-tun.PacketOffset : offset])
|
||||
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
|
||||
}
|
||||
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -140,28 +136,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Flush() error {
|
||||
func (w *systemDevice) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) MTU() (int, error) {
|
||||
return int(w.mtu), nil
|
||||
func (w *systemDevice) MTU() (int, error) {
|
||||
return int(w.options.MTU), nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Name() (string, error) {
|
||||
return w.name, nil
|
||||
func (w *systemDevice) Name() (string, error) {
|
||||
return w.options.Name, nil
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Events() <-chan wgTun.Event {
|
||||
func (w *systemDevice) Events() <-chan wgTun.Event {
|
||||
return w.events
|
||||
}
|
||||
|
||||
func (w *SystemDevice) Close() error {
|
||||
func (w *systemDevice) Close() error {
|
||||
close(w.events)
|
||||
return w.device.Close()
|
||||
}
|
||||
|
||||
func (w *SystemDevice) BatchSize() int {
|
||||
func (w *systemDevice) BatchSize() int {
|
||||
if w.batchDevice != nil {
|
||||
return w.batchDevice.BatchSize()
|
||||
}
|
||||
|
|
182
transport/wireguard/device_system_stack.go
Normal file
182
transport/wireguard/device_system_stack.go
Normal file
|
@ -0,0 +1,182 @@
|
|||
//go:build with_gvisor
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
)
|
||||
|
||||
var _ Device = (*systemStackDevice)(nil)
|
||||
|
||||
type systemStackDevice struct {
|
||||
*systemDevice
|
||||
stack *stack.Stack
|
||||
endpoint *deviceEndpoint
|
||||
writeBufs [][]byte
|
||||
}
|
||||
|
||||
func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
|
||||
system, err := newSystemDevice(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endpoint := &deviceEndpoint{
|
||||
mtu: options.MTU,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
ipStack, err := tun.NewGVisorStack(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
||||
return &systemStackDevice{
|
||||
systemDevice: system,
|
||||
stack: ipStack,
|
||||
endpoint: endpoint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) SetDevice(device *device.Device) {
|
||||
w.endpoint.device = device
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
|
||||
if w.batchDevice != nil {
|
||||
w.writeBufs = w.writeBufs[:0]
|
||||
for _, packet := range bufs {
|
||||
if !w.writeStack(packet[offset:]) {
|
||||
w.writeBufs = append(w.writeBufs, packet)
|
||||
}
|
||||
}
|
||||
if len(w.writeBufs) > 0 {
|
||||
return w.batchDevice.BatchWrite(bufs, offset)
|
||||
}
|
||||
} else {
|
||||
for _, packet := range bufs {
|
||||
if !w.writeStack(packet[offset:]) {
|
||||
if tun.PacketOffset > 0 {
|
||||
common.ClearArray(packet[offset-tun.PacketOffset : offset])
|
||||
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
|
||||
}
|
||||
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// WireGuard will not read count
|
||||
return
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) Close() error {
|
||||
close(w.endpoint.done)
|
||||
w.stack.Close()
|
||||
for _, endpoint := range w.stack.CleanupEndpoints() {
|
||||
endpoint.Abort()
|
||||
}
|
||||
w.stack.Wait()
|
||||
return w.systemDevice.Close()
|
||||
}
|
||||
|
||||
func (w *systemStackDevice) writeStack(packet []byte) bool {
|
||||
var (
|
||||
networkProtocol tcpip.NetworkProtocolNumber
|
||||
destination netip.Addr
|
||||
)
|
||||
switch header.IPVersion(packet) {
|
||||
case header.IPv4Version:
|
||||
networkProtocol = header.IPv4ProtocolNumber
|
||||
destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
|
||||
case header.IPv6Version:
|
||||
networkProtocol = header.IPv6ProtocolNumber
|
||||
destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
|
||||
}
|
||||
for _, prefix := range w.options.Address {
|
||||
if prefix.Contains(destination) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
|
||||
packetBuffer.DecRef()
|
||||
return true
|
||||
}
|
||||
|
||||
type deviceEndpoint struct {
|
||||
mtu uint32
|
||||
done chan struct{}
|
||||
device *device.Device
|
||||
dispatcher stack.NetworkDispatcher
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) MTU() uint32 {
|
||||
return ep.mtu
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) SetMTU(mtu uint32) {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityRXChecksumOffload
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
ep.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) IsAttached() bool {
|
||||
return ep.dispatcher != nil
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Wait() {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
|
||||
for _, packetBuffer := range list.AsSlice() {
|
||||
destination := packetBuffer.Network().DestinationAddress()
|
||||
ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
|
||||
}
|
||||
return list.Len(), nil
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) Close() {
|
||||
}
|
||||
|
||||
func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
|
||||
}
|
|
@ -1,35 +1,254 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/sing/service/pause"
|
||||
"github.com/sagernet/wireguard-go/conn"
|
||||
"github.com/sagernet/wireguard-go/device"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
var _ conn.Endpoint = (*Endpoint)(nil)
|
||||
|
||||
type Endpoint netip.AddrPort
|
||||
|
||||
func (e Endpoint) ClearSrc() {
|
||||
type Endpoint struct {
|
||||
options EndpointOptions
|
||||
peers []peerConfig
|
||||
ipcConf string
|
||||
allowedAddress []netip.Prefix
|
||||
tunDevice Device
|
||||
device *device.Device
|
||||
pauseManager pause.Manager
|
||||
pauseCallback *list.Element[pause.Callback]
|
||||
}
|
||||
|
||||
func (e Endpoint) SrcToString() string {
|
||||
return ""
|
||||
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
|
||||
if options.PrivateKey == "" {
|
||||
return nil, E.New("missing private key")
|
||||
}
|
||||
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode private key")
|
||||
}
|
||||
privateKey := hex.EncodeToString(privateKeyBytes)
|
||||
ipcConf := "private_key=" + privateKey
|
||||
if options.ListenPort != 0 {
|
||||
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
|
||||
}
|
||||
var peers []peerConfig
|
||||
for peerIndex, rawPeer := range options.Peers {
|
||||
peer := peerConfig{
|
||||
allowedIPs: rawPeer.AllowedIPs,
|
||||
keepalive: rawPeer.PersistentKeepaliveInterval,
|
||||
}
|
||||
if rawPeer.Endpoint.Addr.IsValid() {
|
||||
peer.endpoint = rawPeer.Endpoint.AddrPort()
|
||||
} else if rawPeer.Endpoint.IsFqdn() {
|
||||
peer.destination = rawPeer.Endpoint
|
||||
}
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
||||
}
|
||||
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
|
||||
if rawPeer.PreSharedKey != "" {
|
||||
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
||||
}
|
||||
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
|
||||
}
|
||||
if len(rawPeer.AllowedIPs) == 0 {
|
||||
return nil, E.New("missing allowed ips for peer ", peerIndex)
|
||||
}
|
||||
if len(rawPeer.Reserved) > 0 {
|
||||
if len(rawPeer.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
|
||||
}
|
||||
copy(peer.reserved[:], rawPeer.Reserved[:])
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
var allowedPrefixBuilder netipx.IPSetBuilder
|
||||
for _, peer := range options.Peers {
|
||||
for _, prefix := range peer.AllowedIPs {
|
||||
allowedPrefixBuilder.AddPrefix(prefix)
|
||||
}
|
||||
}
|
||||
allowedIPSet, err := allowedPrefixBuilder.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedAddresses := allowedIPSet.Prefixes()
|
||||
if options.MTU == 0 {
|
||||
options.MTU = 1408
|
||||
}
|
||||
deviceOptions := DeviceOptions{
|
||||
Context: options.Context,
|
||||
Logger: options.Logger,
|
||||
System: options.System,
|
||||
Handler: options.Handler,
|
||||
UDPTimeout: options.UDPTimeout,
|
||||
CreateDialer: options.CreateDialer,
|
||||
Name: options.Name,
|
||||
MTU: options.MTU,
|
||||
GSO: options.GSO,
|
||||
Address: options.Address,
|
||||
AllowedAddress: allowedAddresses,
|
||||
}
|
||||
tunDevice, err := NewDevice(deviceOptions)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create WireGuard device")
|
||||
}
|
||||
return &Endpoint{
|
||||
options: options,
|
||||
peers: peers,
|
||||
ipcConf: ipcConf,
|
||||
allowedAddress: allowedAddresses,
|
||||
tunDevice: tunDevice,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e Endpoint) DstToString() string {
|
||||
return (netip.AddrPort)(e).String()
|
||||
func (e *Endpoint) Start(resolve bool) error {
|
||||
if common.Any(e.peers, func(peer peerConfig) bool {
|
||||
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
|
||||
}) {
|
||||
if !resolve {
|
||||
return nil
|
||||
}
|
||||
for peerIndex, peer := range e.peers {
|
||||
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
|
||||
continue
|
||||
}
|
||||
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
|
||||
}
|
||||
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
|
||||
}
|
||||
} else if resolve {
|
||||
return nil
|
||||
}
|
||||
var bind conn.Bind
|
||||
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
|
||||
if isWgListener {
|
||||
bind = conn.NewStdNetBind(wgListener)
|
||||
} else {
|
||||
var (
|
||||
isConnect bool
|
||||
connectAddr netip.AddrPort
|
||||
reserved [3]uint8
|
||||
)
|
||||
peerLen := len(e.peers)
|
||||
if peerLen == 1 {
|
||||
isConnect = true
|
||||
connectAddr = e.peers[0].endpoint
|
||||
reserved = e.peers[0].reserved
|
||||
}
|
||||
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
|
||||
}
|
||||
err := e.tunDevice.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger := &device.Logger{
|
||||
Verbosef: func(format string, args ...interface{}) {
|
||||
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
Errorf: func(format string, args ...interface{}) {
|
||||
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
||||
},
|
||||
}
|
||||
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
|
||||
e.tunDevice.SetDevice(wgDevice)
|
||||
ipcConf := e.ipcConf
|
||||
for _, peer := range e.peers {
|
||||
ipcConf += peer.GenerateIpcLines()
|
||||
}
|
||||
err = wgDevice.IpcSet(ipcConf)
|
||||
if err != nil {
|
||||
return E.Cause(err, "setup wireguard: \n", ipcConf)
|
||||
}
|
||||
e.device = wgDevice
|
||||
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
|
||||
if e.pauseManager != nil {
|
||||
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e Endpoint) DstToBytes() []byte {
|
||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||
return b
|
||||
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
return e.tunDevice.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (e Endpoint) DstIP() netip.Addr {
|
||||
return (netip.AddrPort)(e).Addr()
|
||||
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
|
||||
}
|
||||
return e.tunDevice.ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
func (e Endpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
func (e *Endpoint) BindUpdate() error {
|
||||
return e.device.BindUpdate()
|
||||
}
|
||||
|
||||
func (e *Endpoint) Close() error {
|
||||
if e.device != nil {
|
||||
e.device.Close()
|
||||
}
|
||||
if e.pauseCallback != nil {
|
||||
e.pauseManager.UnregisterCallback(e.pauseCallback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Endpoint) onPauseUpdated(event int) {
|
||||
switch event {
|
||||
case pause.EventDevicePaused:
|
||||
e.device.Down()
|
||||
case pause.EventDeviceWake:
|
||||
e.device.Up()
|
||||
}
|
||||
}
|
||||
|
||||
type peerConfig struct {
|
||||
destination M.Socksaddr
|
||||
endpoint netip.AddrPort
|
||||
publicKeyHex string
|
||||
preSharedKeyHex string
|
||||
allowedIPs []netip.Prefix
|
||||
keepalive uint16
|
||||
reserved [3]uint8
|
||||
}
|
||||
|
||||
func (c peerConfig) GenerateIpcLines() string {
|
||||
ipcLines := "\npublic_key=" + c.publicKeyHex
|
||||
if c.endpoint.IsValid() {
|
||||
ipcLines += "\nendpoint=" + c.endpoint.String()
|
||||
}
|
||||
if c.preSharedKeyHex != "" {
|
||||
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
|
||||
}
|
||||
for _, allowedIP := range c.allowedIPs {
|
||||
ipcLines += "\nallowed_ip=" + allowedIP.String()
|
||||
}
|
||||
if c.keepalive > 0 {
|
||||
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
|
||||
}
|
||||
return ipcLines
|
||||
}
|
||||
|
|
40
transport/wireguard/endpoint_options.go
Normal file
40
transport/wireguard/endpoint_options.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type EndpointOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.ContextLogger
|
||||
System bool
|
||||
Handler tun.Handler
|
||||
UDPTimeout time.Duration
|
||||
Dialer N.Dialer
|
||||
CreateDialer func(interfaceName string) N.Dialer
|
||||
Name string
|
||||
MTU uint32
|
||||
GSO bool
|
||||
Address []netip.Prefix
|
||||
PrivateKey string
|
||||
ListenPort uint16
|
||||
ResolvePeer func(domain string) (netip.Addr, error)
|
||||
Peers []PeerOptions
|
||||
Workers int
|
||||
}
|
||||
|
||||
type PeerOptions struct {
|
||||
Endpoint M.Socksaddr
|
||||
PublicKey string
|
||||
PreSharedKey string
|
||||
AllowedIPs []netip.Prefix
|
||||
PersistentKeepaliveInterval uint16
|
||||
Reserved []uint8
|
||||
}
|
|
@ -1,148 +0,0 @@
|
|||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-dns"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type PeerConfig struct {
|
||||
destination M.Socksaddr
|
||||
domainStrategy dns.DomainStrategy
|
||||
Endpoint netip.AddrPort
|
||||
PublicKey string
|
||||
PreSharedKey string
|
||||
AllowedIPs []string
|
||||
Reserved [3]uint8
|
||||
}
|
||||
|
||||
func (c PeerConfig) GenerateIpcLines() string {
|
||||
ipcLines := "\npublic_key=" + c.PublicKey
|
||||
ipcLines += "\nendpoint=" + c.Endpoint.String()
|
||||
if c.PreSharedKey != "" {
|
||||
ipcLines += "\npreshared_key=" + c.PreSharedKey
|
||||
}
|
||||
for _, allowedIP := range c.AllowedIPs {
|
||||
ipcLines += "\nallowed_ip=" + allowedIP
|
||||
}
|
||||
return ipcLines
|
||||
}
|
||||
|
||||
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
|
||||
var peers []PeerConfig
|
||||
if len(options.Peers) > 0 {
|
||||
for peerIndex, rawPeer := range options.Peers {
|
||||
peer := PeerConfig{
|
||||
AllowedIPs: rawPeer.AllowedIPs,
|
||||
}
|
||||
destination := rawPeer.ServerOptions.Build()
|
||||
if destination.IsFqdn() {
|
||||
peer.destination = destination
|
||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
||||
} else {
|
||||
peer.Endpoint = destination.AddrPort()
|
||||
}
|
||||
{
|
||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
|
||||
}
|
||||
peer.PublicKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if rawPeer.PreSharedKey != "" {
|
||||
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
|
||||
}
|
||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if len(rawPeer.AllowedIPs) == 0 {
|
||||
return nil, E.New("missing allowed_ips for peer ", peerIndex)
|
||||
}
|
||||
if len(rawPeer.Reserved) > 0 {
|
||||
if len(rawPeer.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
|
||||
}
|
||||
copy(peer.Reserved[:], options.Reserved)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
} else {
|
||||
peer := PeerConfig{}
|
||||
var (
|
||||
addressHas4 bool
|
||||
addressHas6 bool
|
||||
)
|
||||
for _, localAddress := range options.LocalAddress {
|
||||
if localAddress.Addr().Is4() {
|
||||
addressHas4 = true
|
||||
} else {
|
||||
addressHas6 = true
|
||||
}
|
||||
}
|
||||
if addressHas4 {
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
|
||||
}
|
||||
if addressHas6 {
|
||||
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
|
||||
}
|
||||
destination := options.ServerOptions.Build()
|
||||
if destination.IsFqdn() {
|
||||
peer.destination = destination
|
||||
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
|
||||
} else {
|
||||
peer.Endpoint = destination.AddrPort()
|
||||
}
|
||||
{
|
||||
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode peer public key")
|
||||
}
|
||||
peer.PublicKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if options.PreSharedKey != "" {
|
||||
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode pre shared key")
|
||||
}
|
||||
peer.PreSharedKey = hex.EncodeToString(bytes)
|
||||
}
|
||||
if len(options.Reserved) > 0 {
|
||||
if len(options.Reserved) != 3 {
|
||||
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
|
||||
}
|
||||
copy(peer.Reserved[:], options.Reserved)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
|
||||
for peerIndex, peer := range peers {
|
||||
if peer.Endpoint.IsValid() {
|
||||
continue
|
||||
}
|
||||
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
|
||||
if err != nil {
|
||||
if len(peers) == 1 {
|
||||
return E.Cause(err, "resolve endpoint domain")
|
||||
} else {
|
||||
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
|
||||
}
|
||||
}
|
||||
if len(destinationAddresses) == 0 {
|
||||
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
|
||||
}
|
||||
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Reference in a new issue