refactor: WireGuard endpoint

This commit is contained in:
世界 2024-11-21 18:10:41 +08:00
parent 7dbc105f89
commit e5bfd9e6b1
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
76 changed files with 1646 additions and 672 deletions

28
adapter/endpoint.go Normal file
View file

@ -0,0 +1,28 @@
package adapter
import (
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
type Endpoint interface {
Lifecycle
Type() string
Tag() string
Outbound
}
type EndpointRegistry interface {
option.EndpointOptionsRegistry
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) (Endpoint, error)
}
type EndpointManager interface {
Lifecycle
Endpoints() []Endpoint
Get(tag string) (Endpoint, bool)
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) error
}

View file

@ -0,0 +1,43 @@
package endpoint
import "github.com/sagernet/sing-box/option"
type Adapter struct {
endpointType string
endpointTag string
network []string
dependencies []string
}
func NewAdapter(endpointType string, endpointTag string, network []string, dependencies []string) Adapter {
return Adapter{
endpointType: endpointType,
endpointTag: endpointTag,
network: network,
dependencies: dependencies,
}
}
func NewAdapterWithDialerOptions(endpointType string, endpointTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string
if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour}
}
return NewAdapter(endpointType, endpointTag, network, dependencies)
}
func (a *Adapter) Type() string {
return a.endpointType
}
func (a *Adapter) Tag() string {
return a.endpointTag
}
func (a *Adapter) Network() []string {
return a.network
}
func (a *Adapter) Dependencies() []string {
return a.dependencies
}

147
adapter/endpoint/manager.go Normal file
View file

@ -0,0 +1,147 @@
package endpoint
import (
"context"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
var _ adapter.EndpointManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.EndpointRegistry
access sync.Mutex
started bool
stage adapter.StartStage
endpoints []adapter.Endpoint
endpointByTag map[string]adapter.Endpoint
}
func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
endpointByTag: make(map[string]adapter.Endpoint),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
if stage == adapter.StartStateStart {
// started with outbound manager
return nil
}
for _, endpoint := range m.endpoints {
err := adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if !m.started {
return nil
}
m.started = false
endpoints := m.endpoints
m.endpoints = nil
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, endpoint := range endpoints {
monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
err = E.Append(err, endpoint.Close(), func(err error) error {
return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
})
monitor.Finish()
}
return nil
}
func (m *Manager) Endpoints() []adapter.Endpoint {
m.access.Lock()
defer m.access.Unlock()
return m.endpoints
}
func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
m.access.Lock()
defer m.access.Unlock()
endpoint, found := m.endpointByTag[tag]
return endpoint, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
endpoint, found := m.endpointByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.endpointByTag, tag)
index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == endpoint
})
if index == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return endpoint.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
}
if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
if m.started {
err = existsEndpoint.Close()
if err != nil {
return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
}
}
existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == existsEndpoint
})
if existsIndex == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
}
m.endpoints = append(m.endpoints, endpoint)
m.endpointByTag[tag] = endpoint
return nil
}

View file

@ -0,0 +1,72 @@
package endpoint
import (
"context"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options T) (adapter.Endpoint, error)
func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any {
return new(Options)
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Endpoint, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
})
}
var _ adapter.EndpointRegistry = (*Registry)(nil)
type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Endpoint, error)
)
type Registry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructor map[string]constructorFunc
}
func NewRegistry() *Registry {
return &Registry{
optionsType: make(map[string]optionsConstructorFunc),
constructor: make(map[string]constructorFunc),
}
}
func (m *Registry) CreateOptions(outboundType string) (any, bool) {
m.access.Lock()
defer m.access.Unlock()
optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}
func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Endpoint, error) {
m.access.Lock()
defer m.access.Unlock()
constructor, loaded := m.constructor[outboundType]
if !loaded {
return nil, E.New("outbound type not found: " + outboundType)
}
return constructor(ctx, router, logger, tag, options)
}
func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
m.access.Lock()
defer m.access.Unlock()
m.optionsType[outboundType] = optionsConstructor
m.constructor[outboundType] = constructor
}

View file

@ -13,7 +13,7 @@ import (
)
type Inbound interface {
Service
Lifecycle
Type() string
Tag() string
}

View file

@ -18,6 +18,7 @@ var _ adapter.InboundManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.InboundRegistry
endpoint adapter.EndpointManager
access sync.Mutex
started bool
stage adapter.StartStage
@ -25,10 +26,11 @@ type Manager struct {
inboundByTag map[string]adapter.Inbound
}
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager {
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endpoint adapter.EndpointManager) *Manager {
return &Manager{
logger: logger,
registry: registry,
endpoint: endpoint,
inboundByTag: make(map[string]adapter.Inbound),
}
}
@ -79,9 +81,12 @@ func (m *Manager) Inbounds() []adapter.Inbound {
func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
m.access.Lock()
defer m.access.Unlock()
inbound, found := m.inboundByTag[tag]
return inbound, found
m.access.Unlock()
if found {
return inbound, true
}
return m.endpoint.Get(tag)
}
func (m *Manager) Remove(tag string) error {

View file

@ -1,6 +1,9 @@
package adapter
func LegacyStart(starter any, stage StartStage) error {
if lifecycle, isLifecycle := starter.(Lifecycle); isLifecycle {
return lifecycle.Start(stage)
}
switch stage {
case StartStateInitialize:
if preStarter, isPreStarter := starter.(interface {

View file

@ -5,35 +5,35 @@ import (
)
type Adapter struct {
protocol string
outboundType string
outboundTag string
network []string
tag string
dependencies []string
}
func NewAdapter(protocol string, network []string, tag string, dependencies []string) Adapter {
func NewAdapter(outboundType string, outboundTag string, network []string, dependencies []string) Adapter {
return Adapter{
protocol: protocol,
outboundType: outboundType,
outboundTag: outboundTag,
network: network,
tag: tag,
dependencies: dependencies,
}
}
func NewAdapterWithDialerOptions(protocol string, network []string, tag string, dialOptions option.DialerOptions) Adapter {
func NewAdapterWithDialerOptions(outboundType string, outboundTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string
if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour}
}
return NewAdapter(protocol, network, tag, dependencies)
return NewAdapter(outboundType, outboundTag, network, dependencies)
}
func (a *Adapter) Type() string {
return a.protocol
return a.outboundType
}
func (a *Adapter) Tag() string {
return a.tag
return a.outboundTag
}
func (a *Adapter) Network() []string {

View file

@ -21,6 +21,7 @@ var _ adapter.OutboundManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.OutboundRegistry
endpoint adapter.EndpointManager
defaultTag string
access sync.Mutex
started bool
@ -32,10 +33,11 @@ type Manager struct {
defaultOutboundFallback adapter.Outbound
}
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager {
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
return &Manager{
logger: logger,
registry: registry,
endpoint: endpoint,
defaultTag: defaultTag,
outboundByTag: make(map[string]adapter.Outbound),
dependByTag: make(map[string][]string),
@ -56,7 +58,14 @@ func (m *Manager) Start(stage adapter.StartStage) error {
outbounds := m.outbounds
m.access.Unlock()
if stage == adapter.StartStateStart {
return m.startOutbounds(outbounds)
if m.defaultTag != "" && m.defaultOutbound == nil {
defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
if !loaded {
return E.New("default outbound not found: ", m.defaultTag)
}
m.defaultOutbound = defaultEndpoint
}
return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
} else {
for _, outbound := range outbounds {
err := adapter.LegacyStart(outbound, stage)
@ -87,7 +96,14 @@ func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
}
started[outboundTag] = true
canContinue = true
if starter, isStarter := outboundToStart.(interface {
if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start(adapter.StartStateStart)
monitor.Finish()
if err != nil {
return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
} else if starter, isStarter := outboundToStart.(interface {
Start() error
}); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
@ -160,9 +176,12 @@ func (m *Manager) Outbounds() []adapter.Outbound {
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag]
return outbound, found
m.access.Unlock()
if found {
return outbound, true
}
return m.endpoint.Get(tag)
}
func (m *Manager) Default() adapter.Outbound {

51
box.go
View file

@ -9,6 +9,7 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
@ -36,6 +37,7 @@ type Box struct {
logFactory log.Factory
logger log.ContextLogger
network *route.NetworkManager
endpoint *endpoint.Manager
inbound *inbound.Manager
outbound *outbound.Manager
connection *route.ConnectionManager
@ -54,6 +56,7 @@ func Context(
ctx context.Context,
inboundRegistry adapter.InboundRegistry,
outboundRegistry adapter.OutboundRegistry,
endpointRegistry adapter.EndpointRegistry,
) context.Context {
if service.FromContext[option.InboundOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.InboundRegistry](ctx) == nil {
@ -65,6 +68,11 @@ func Context(
ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry)
ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry)
}
if service.FromContext[option.EndpointOptionsRegistry](ctx) == nil ||
service.FromContext[adapter.EndpointRegistry](ctx) == nil {
ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry)
ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry)
}
return ctx
}
@ -76,12 +84,16 @@ func New(options Options) (*Box, error) {
}
ctx = service.ContextWithDefaultRegistry(ctx)
endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
if endpointRegistry == nil {
return nil, E.New("missing endpoint registry in context")
}
if inboundRegistry == nil {
return nil, E.New("missing inbound registry in context")
}
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context")
}
@ -119,8 +131,10 @@ func New(options Options) (*Box, error) {
}
routeOptions := common.PtrValueOrDefault(options.Route)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final)
endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final)
service.MustRegister[adapter.EndpointManager](ctx, endpointManager)
service.MustRegister[adapter.InboundManager](ctx, inboundManager)
service.MustRegister[adapter.OutboundManager](ctx, outboundManager)
@ -135,6 +149,24 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, E.Cause(err, "initialize router")
}
for i, endpointOptions := range options.Endpoints {
var tag string
if endpointOptions.Tag != "" {
tag = endpointOptions.Tag
} else {
tag = F.ToString(i)
}
err = endpointManager.Create(ctx,
router,
logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")),
tag,
endpointOptions.Type,
endpointOptions.Options,
)
if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]")
}
}
for i, inboundOptions := range options.Inbounds {
var tag string
if inboundOptions.Tag != "" {
@ -241,6 +273,7 @@ func New(options Options) (*Box, error) {
}
return &Box{
network: networkManager,
endpoint: endpointManager,
inbound: inboundManager,
outbound: outboundManager,
connection: connectionManager,
@ -303,7 +336,7 @@ func (s *Box) preStart() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound)
err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}
@ -327,7 +360,11 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound)
err = adapter.Start(adapter.StartStateStart, s.endpoint)
if err != nil {
return err
}
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound, s.endpoint)
if err != nil {
return err
}
@ -335,7 +372,7 @@ func (s *Box) start() error {
if err != nil {
return err
}
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound)
err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound, s.endpoint)
if err != nil {
return err
}

View file

@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) {
configPaths = append(configPaths, "config.json")
}
globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger()))
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry())
globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
}

View file

@ -279,7 +279,7 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina
}
func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
return trackPacketConn(d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay))
return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address))
}
func trackConn(conn net.Conn, err error) (net.Conn, error) {

View file

@ -109,6 +109,15 @@ var OptionDestinationOverrideFields = Note{
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options",
}
var OptionWireGuardOutbound = Note{
Name: "wireguard-outbound",
Description: "legacy wireguard outbound",
DeprecatedVersion: "1.11.0",
ScheduledVersion: "1.13.0",
EnvName: "WIREGUARD_OUTBOUND",
MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-wireguard-outbound-to-endpoint",
}
var Options = []Note{
OptionBadMatchSource,
OptionGEOIP,
@ -117,4 +126,5 @@ var Options = []Note{
OptionSpecialOutbounds,
OptionInboundOptions,
OptionDestinationOverrideFields,
OptionWireGuardOutbound,
}

View file

@ -30,7 +30,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err
}
func CheckConfig(configContent string) error {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
options, err := parseConfig(ctx, configContent)
if err != nil {
return err
@ -131,7 +131,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica
}
func FormatConfig(configContent string) (string, error) {
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()), configContent)
options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent)
if err != nil {
return "", err
}

View file

@ -44,7 +44,7 @@ type BoxService struct {
}
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
options, err := parseConfig(ctx, configContent)

3
go.mod
View file

@ -36,7 +36,7 @@ require (
github.com/sagernet/sing-vmess v0.1.12
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
github.com/sagernet/utls v1.6.7
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8
github.com/sagernet/wireguard-go v0.0.1-beta.4
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0
@ -92,6 +92,7 @@ require (
golang.org/x/text v0.20.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.24.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

6
go.sum
View file

@ -132,8 +132,8 @@ github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxe
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
github.com/sagernet/utls v1.6.7 h1:Ep3+aJ8FUGGta+II2IEVNUc3EDhaRCZINWkj/LloIA8=
github.com/sagernet/utls v1.6.7/go.mod h1:Uua1TKO/FFuAhLr9rkaVnnrTmmiItzDjv1BUb2+ERwM=
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8 h1:R0OMYAScomNAVpTfbHFpxqJpvwuhxSRi+g6z7gZhABs=
github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8/go.mod h1:K4J7/npM+VAMUeUmTa2JaA02JmyheP0GpRBOUvn3ecc=
github.com/sagernet/wireguard-go v0.0.1-beta.4 h1:8uyM5fxfEXdu4RH05uOK+v25i3lTNdCYMPSAUJ14FnI=
github.com/sagernet/wireguard-go v0.0.1-beta.4/go.mod h1:jGXij2Gn2wbrWuYNUmmNhf1dwcZtvyAvQoe8Xd8MbUo=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc=
github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
@ -195,6 +195,8 @@ golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY=

View file

@ -4,6 +4,7 @@ import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
@ -82,6 +83,14 @@ func OutboundRegistry() *outbound.Registry {
return registry
}
func EndpointRegistry() *endpoint.Registry {
registry := endpoint.NewRegistry()
registerWireGuardEndpoint(registry)
return registry
}
func registerStubForRemovedInbounds(registry *inbound.Registry) {
inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0")

View file

@ -3,6 +3,7 @@
package include
import (
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/protocol/wireguard"
)
@ -10,3 +11,7 @@ import (
func registerWireGuardOutbound(registry *outbound.Registry) {
wireguard.RegisterOutbound(registry)
}
func registerWireGuardEndpoint(registry *endpoint.Registry) {
wireguard.RegisterEndpoint(registry)
}

View file

@ -6,6 +6,7 @@ import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
@ -14,7 +15,13 @@ import (
)
func registerWireGuardOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
})
}
func registerWireGuardEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`)
})
}

47
option/endpoint.go Normal file
View file

@ -0,0 +1,47 @@
package option
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/service"
)
type EndpointOptionsRegistry interface {
CreateOptions(endpointType string) (any, bool)
}
type _Endpoint struct {
Type string `json:"type"`
Tag string `json:"tag,omitempty"`
Options any `json:"-"`
}
type Endpoint _Endpoint
func (h *Endpoint) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return badjson.MarshallObjectsContext(ctx, (*_Endpoint)(h), h.Options)
}
func (h *Endpoint) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.UnmarshalContext(ctx, content, (*_Endpoint)(h))
if err != nil {
return err
}
registry := service.FromContext[EndpointOptionsRegistry](ctx)
if registry == nil {
return E.New("missing Endpoint fields registry in context")
}
options, loaded := registry.CreateOptions(h.Type)
if !loaded {
return E.New("unknown inbound type: ", h.Type)
}
err = badjson.UnmarshallExcludedContext(ctx, content, (*_Endpoint)(h), options)
if err != nil {
return err
}
h.Options = options
return nil
}

View file

@ -28,7 +28,7 @@ func (h *Inbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
}
func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.Unmarshal(content, (*_Inbound)(h))
err := json.UnmarshalContext(ctx, content, (*_Inbound)(h))
if err != nil {
return err
}

View file

@ -13,6 +13,7 @@ type _Options struct {
Log *LogOptions `json:"log,omitempty"`
DNS *DNSOptions `json:"dns,omitempty"`
NTP *NTPOptions `json:"ntp,omitempty"`
Endpoints []Endpoint `json:"endpoints,omitempty"`
Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"`

View file

@ -30,7 +30,7 @@ func (h *Outbound) MarshalJSONContext(ctx context.Context) ([]byte, error) {
}
func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error {
err := json.Unmarshal(content, (*_Outbound)(h))
err := json.UnmarshalContext(ctx, content, (*_Outbound)(h))
if err != nil {
return err
}

View file

@ -6,14 +6,38 @@ import (
"github.com/sagernet/sing/common/json/badoption"
)
type WireGuardOutboundOptions struct {
type WireGuardEndpointOptions struct {
System bool `json:"system,omitempty"`
Name string `json:"name,omitempty"`
MTU uint32 `json:"mtu,omitempty"`
GSO bool `json:"gso,omitempty"`
Address badoption.Listable[netip.Prefix] `json:"address"`
PrivateKey string `json:"private_key"`
ListenPort uint16 `json:"listen_port,omitempty"`
Peers []WireGuardPeer `json:"peers,omitempty"`
UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"`
Workers int `json:"workers,omitempty"`
DialerOptions
}
type WireGuardPeer struct {
Address string `json:"address,omitempty"`
Port uint16 `json:"port,omitempty"`
PublicKey string `json:"public_key,omitempty"`
PreSharedKey string `json:"pre_shared_key,omitempty"`
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"`
}
type LegacyWireGuardOutboundOptions struct {
DialerOptions
SystemInterface bool `json:"system_interface,omitempty"`
GSO bool `json:"gso,omitempty"`
InterfaceName string `json:"interface_name,omitempty"`
LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"`
PrivateKey string `json:"private_key"`
Peers []WireGuardPeer `json:"peers,omitempty"`
Peers []LegacyWireGuardPeer `json:"peers,omitempty"`
ServerOptions
PeerPublicKey string `json:"peer_public_key"`
PreSharedKey string `json:"pre_shared_key,omitempty"`
@ -23,10 +47,10 @@ type WireGuardOutboundOptions struct {
Network NetworkList `json:"network,omitempty"`
}
type WireGuardPeer struct {
type LegacyWireGuardPeer struct {
ServerOptions
PublicKey string `json:"public_key,omitempty"`
PreSharedKey string `json:"pre_shared_key,omitempty"`
AllowedIPs badoption.Listable[string] `json:"allowed_ips,omitempty"`
AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"`
}

View file

@ -26,7 +26,7 @@ type Outbound struct {
func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) {
return &Outbound{
Adapter: outbound.NewAdapter(C.TypeBlock, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil),
Adapter: outbound.NewAdapter(C.TypeBlock, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
logger: logger,
}, nil
}

View file

@ -68,7 +68,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (i *Inbound) Start() error {
func (i *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return i.listener.Start()
}

View file

@ -52,7 +52,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
logger: logger,
domainStrategy: dns.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay),

View file

@ -28,7 +28,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
return &Outbound{
Adapter: outbound.NewAdapter(C.TypeDNS, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil),
Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
router: router,
logger: logger,
}, nil

View file

@ -38,7 +38,7 @@ type Selector struct {
func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) {
outbound := &Selector{
Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds),
Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds),
ctx: ctx,
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger,

View file

@ -49,7 +49,7 @@ type URLTest struct {
func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) {
outbound := &URLTest{
Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds),
Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds),
ctx: ctx,
router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx),

View file

@ -61,7 +61,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {

View file

@ -39,7 +39,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, []string{N.NetworkTCP}, tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, tag, []string{N.NetworkTCP}, options.DialerOptions),
logger: logger,
client: sHTTP.NewClient(sHTTP.Options{
Dialer: detour,

View file

@ -160,7 +160,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {

View file

@ -95,7 +95,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, networkList, tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions),
logger: logger,
client: client,
}, nil

View file

@ -171,7 +171,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {

View file

@ -81,7 +81,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, networkList, tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions),
logger: logger,
client: client,
}, nil

View file

@ -54,7 +54,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}

View file

@ -78,7 +78,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (n *Inbound) Start() error {
func (n *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
var tlsConfig *tls.STDConfig
if n.tlsConfig != nil {
err := n.tlsConfig.Start()

View file

@ -42,7 +42,10 @@ func NewRedirect(ctx context.Context, router adapter.Router, logger log.ContextL
return redirect, nil
}
func (h *Redirect) Start() error {
func (h *Redirect) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}

View file

@ -61,7 +61,10 @@ func NewTProxy(ctx context.Context, router adapter.Router, logger log.ContextLog
return tproxy, nil
}
func (t *TProxy) Start() error {
func (t *TProxy) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := t.listener.Start()
if err != nil {
return err

View file

@ -93,7 +93,10 @@ func newInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, err
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}

View file

@ -101,7 +101,10 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont
return inbound, err
}
func (h *MultiInbound) Start() error {
func (h *MultiInbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}

View file

@ -86,7 +86,10 @@ func newRelayInbound(ctx context.Context, router adapter.Router, logger log.Cont
return inbound, err
}
func (h *RelayInbound) Start() error {
func (h *RelayInbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}

View file

@ -49,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, options.Network.Build(), tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, tag, options.Network.Build(), options.DialerOptions),
logger: logger,
dialer: outboundDialer,
method: method,

View file

@ -90,7 +90,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}

View file

@ -29,7 +29,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) {
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, []string{N.NetworkTCP}, tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, tag, []string{N.NetworkTCP}, options.DialerOptions),
}
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired

View file

@ -50,7 +50,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}

View file

@ -50,7 +50,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, options.Network.Build(), tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
router: router,
logger: logger,
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),

View file

@ -54,7 +54,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, []string{N.NetworkTCP}, tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, tag, []string{N.NetworkTCP}, options.DialerOptions),
ctx: ctx,
logger: logger,
dialer: outboundDialer,

View file

@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, []string{N.NetworkTCP}, tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, tag, []string{N.NetworkTCP}, options.DialerOptions),
ctx: ctx,
logger: logger,
proxy: NewProxyListener(ctx, logger, outboundDialer),

View file

@ -110,7 +110,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {

View file

@ -43,7 +43,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, options.Network.Build(), tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, tag, options.Network.Build(), options.DialerOptions),
logger: logger,
dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(),

View file

@ -142,7 +142,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {

View file

@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, options.Network.Build(), tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, tag, options.Network.Build(), options.DialerOptions),
logger: logger,
client: client,
udpStream: options.UDPOverStream,

View file

@ -300,7 +300,9 @@ func (t *Inbound) Tag() string {
return t.tag
}
func (t *Inbound) Start() error {
func (t *Inbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
if C.IsAndroid && t.platformInterface == nil {
t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager())
}
@ -348,10 +350,7 @@ func (t *Inbound) Start() error {
}
t.tunStack = tunStack
t.logger.Info("started at ", t.tunOptions.Name)
return nil
}
func (t *Inbound) PostStart() error {
case adapter.StartStatePostStart:
monitor := taskmonitor.New(t.logger, C.StartTimeout)
monitor.Start("starting tun stack")
err := t.tunStack.Start()
@ -399,6 +398,7 @@ func (t *Inbound) PostStart() error {
t.routeAddressSet = nil
t.routeExcludeAddressSet = nil
}
}
return nil
}

View file

@ -89,7 +89,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {

View file

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, options.Network.Build(), tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, tag, options.Network.Build(), options.DialerOptions),
logger: logger,
dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(),

View file

@ -99,7 +99,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return inbound, nil
}
func (h *Inbound) Start() error {
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
err := h.service.Start()
if err != nil {
return err

View file

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, options.Network.Build(), tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, tag, options.Network.Build(), options.DialerOptions),
logger: logger,
dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(),

View file

@ -0,0 +1,211 @@
package wireguard
import (
"context"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterEndpoint(registry *endpoint.Registry) {
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
}
var (
_ adapter.Endpoint = (*Endpoint)(nil)
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
)
type Endpoint struct {
endpoint.Adapter
ctx context.Context
router adapter.Router
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
}
func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) {
ep := &Endpoint{
Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
logger: logger,
localAddresses: options.Address,
}
if options.Detour == "" {
options.IsWireGuardListener = true
} else if options.GSO {
return nil, E.New("gso is conflict with detour")
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil {
return nil, err
}
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx,
Logger: logger,
System: options.System,
Handler: ep,
UDPTimeout: time.Duration(options.UDPTimeout),
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
BindInterface: interfaceName,
}))
},
Name: options.Name,
MTU: options.MTU,
GSO: options.GSO,
Address: options.Address,
PrivateKey: options.PrivateKey,
ListenPort: options.ListenPort,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
return endpointAddresses[0], nil
},
Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions {
return wireguard.PeerOptions{
Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port),
PublicKey: it.PublicKey,
PreSharedKey: it.PreSharedKey,
AllowedIPs: it.AllowedIPs,
PersistentKeepaliveInterval: it.PersistentKeepaliveInterval,
Reserved: it.Reserved,
}
}),
Workers: options.Workers,
})
if err != nil {
return nil, err
}
ep.endpoint = wgEndpoint
return ep, nil
}
func (w *Endpoint) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
return w.endpoint.Start(false)
case adapter.StartStatePostStart:
return w.endpoint.Start(true)
}
return nil
}
func (w *Endpoint) Close() error {
return w.endpoint.Close()
}
func (w *Endpoint) InterfaceUpdated() {
w.endpoint.BindUpdate()
return
}
func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error {
return w.router.PreMatch(adapter.InboundContext{
Inbound: w.Tag(),
InboundType: w.Type(),
Network: network,
Source: source,
Destination: destination,
})
}
func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext
metadata.Inbound = w.Tag()
metadata.InboundType = w.Type()
metadata.Source = source
for _, localPrefix := range w.localAddresses {
if localPrefix.Contains(destination.Addr) {
metadata.OriginDestination = destination
if destination.Addr.Is4() {
destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
} else {
destination.Addr = netip.IPv6Loopback()
}
break
}
}
metadata.Destination = destination
w.logger.InfoContext(ctx, "inbound connection from ", source)
w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
w.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}
func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
var metadata adapter.InboundContext
metadata.Inbound = w.Tag()
metadata.InboundType = w.Type()
metadata.Source = source
metadata.Destination = destination
for _, localPrefix := range w.localAddresses {
if localPrefix.Contains(destination.Addr) {
metadata.OriginDestination = destination
if destination.Addr.Is4() {
metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1})
} else {
metadata.Destination.Addr = netip.IPv6Loopback()
}
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
}
}
w.logger.InfoContext(ctx, "inbound packet connection from ", source)
w.logger.InfoContext(ctx, "inbound packet connection to ", destination)
w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return w.endpoint.DialContext(ctx, network, destination)
}
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, err
}
return w.endpoint.ListenPacket(ctx, destination)
}

View file

@ -2,231 +2,153 @@ package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
}
var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
var (
_ adapter.Endpoint = (*Endpoint)(nil)
_ adapter.InterfaceUpdateListener = (*Endpoint)(nil)
)
type Outbound struct {
outbound.Adapter
ctx context.Context
router adapter.Router
logger logger.ContextLogger
workers int
peers []wireguard.PeerConfig
useStdNetBind bool
listener N.Dialer
ipcConf string
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
bind conn.Bind
device *device.Device
tunDevice wireguard.Device
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) {
deprecated.Report(ctx, deprecated.OptionWireGuardOutbound)
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions),
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
logger: logger,
workers: options.Workers,
pauseManager: service.FromContext[pause.Manager](ctx),
localAddresses: options.LocalAddress,
}
peers, err := wireguard.ParsePeers(options)
if err != nil {
return nil, err
}
outbound.peers = peers
if len(options.LocalAddress) == 0 {
return nil, E.New("missing local address")
}
if options.GSO {
if options.GSO && options.Detour != "" {
if options.Detour == "" {
options.IsWireGuardListener = true
} else if options.GSO {
return nil, E.New("gso is conflict with detour")
}
options.IsWireGuardListener = true
outbound.useStdNetBind = true
}
listener, err := dialer.New(ctx, options.DialerOptions)
outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil {
return nil, err
}
outbound.listener = listener
var privateKey string
{
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
Context: ctx,
Logger: logger,
System: options.SystemInterface,
Dialer: outboundDialer,
CreateDialer: func(interfaceName string) N.Dialer {
return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{
BindInterface: interfaceName,
}))
},
Name: options.InterfaceName,
MTU: options.MTU,
GSO: options.GSO,
Address: options.LocalAddress,
PrivateKey: options.PrivateKey,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
return endpointAddresses[0], nil
},
Peers: common.Map(options.Peers, func(it option.LegacyWireGuardPeer) wireguard.PeerOptions {
return wireguard.PeerOptions{
Endpoint: it.ServerOptions.Build(),
PublicKey: it.PublicKey,
PreSharedKey: it.PreSharedKey,
AllowedIPs: it.AllowedIPs,
// PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval),
Reserved: it.Reserved,
}
}),
Workers: options.Workers,
})
if err != nil {
return nil, E.Cause(err, "decode private key")
return nil, err
}
privateKey = hex.EncodeToString(bytes)
}
outbound.ipcConf = "private_key=" + privateKey
mtu := options.MTU
if mtu == 0 {
mtu = 1408
}
var wireTunDevice wireguard.Device
if !options.SystemInterface && tun.WithGVisor {
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
} else {
wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
}
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
outbound.tunDevice = wireTunDevice
outbound.endpoint = wgEndpoint
return outbound, nil
}
func (w *Outbound) Start() error {
if common.Any(w.peers, func(peer wireguard.PeerConfig) bool {
return !peer.Endpoint.IsValid()
}) {
// wait for all outbounds to be started and continue in PortStart
return nil
}
return w.start()
}
func (w *Outbound) PostStart() error {
if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
return peer.Endpoint.IsValid()
}) {
return nil
}
return w.start()
}
func (w *Outbound) start() error {
err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
if err != nil {
return err
}
var bind conn.Bind
if w.useStdNetBind {
bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
peerLen := len(w.peers)
if peerLen == 1 {
isConnect = true
connectAddr = w.peers[0].Endpoint
reserved = w.peers[0].Reserved
}
bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
}
err = w.tunDevice.Start()
if err != nil {
return err
}
wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) {
w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}, w.workers)
ipcConf := w.ipcConf
for _, peer := range w.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
w.device = wgDevice
w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
return nil
}
func (w *Outbound) Close() error {
if w.device != nil {
w.device.Close()
}
if w.pauseCallback != nil {
w.pauseManager.UnregisterCallback(w.pauseCallback)
func (o *Outbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
return o.endpoint.Start(false)
case adapter.StartStatePostStart:
return o.endpoint.Start(true)
}
return nil
}
func (w *Outbound) InterfaceUpdated() {
w.device.BindUpdate()
func (o *Outbound) Close() error {
return o.endpoint.Close()
}
func (o *Outbound) InterfaceUpdated() {
o.endpoint.BindUpdate()
return
}
func (w *Outbound) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
w.device.Down()
case pause.EventDeviceWake:
w.device.Up()
}
}
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
o.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
return N.DialSerial(ctx, o.endpoint, network, destination, destinationAddresses)
} else if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return w.tunDevice.DialContext(ctx, network, destination)
return o.endpoint.DialContext(ctx, network, destination)
}
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
packetConn, _, err := N.ListenSerial(ctx, o.endpoint, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, err
}
return w.tunDevice.ListenPacket(ctx, destination)
return o.endpoint.ListenPacket(ctx, destination)
}

View file

@ -41,15 +41,15 @@ type NetworkManager struct {
autoDetectInterface bool
defaultOptions adapter.NetworkOptions
autoRedirectOutputMark uint32
networkMonitor tun.NetworkUpdateMonitor
interfaceMonitor tun.DefaultInterfaceMonitor
packageManager tun.PackageManager
powerListener winpowrprof.EventListener
pauseManager pause.Manager
platformInterface platform.Interface
inboundManager adapter.InboundManager
outboundManager adapter.OutboundManager
endpoint adapter.EndpointManager
inbound adapter.InboundManager
outbound adapter.OutboundManager
wifiState adapter.WIFIState
started bool
}
@ -69,7 +69,9 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp
},
pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx),
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
endpoint: service.FromContext[adapter.EndpointManager](ctx),
inbound: service.FromContext[adapter.InboundManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
}
if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault {
if routeOptions.DefaultInterface != "" {
@ -358,14 +360,21 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState {
func (r *NetworkManager) ResetNetwork() {
conntrack.Close()
for _, inbound := range r.inboundManager.Inbounds() {
for _, endpoint := range r.endpoint.Endpoints() {
listener, isListener := endpoint.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()
}
}
for _, inbound := range r.inbound.Inbounds() {
listener, isListener := inbound.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()
}
}
for _, outbound := range r.outboundManager.Outbounds() {
for _, outbound := range r.outbound.Outbounds() {
listener, isListener := outbound.(adapter.InterfaceUpdateListener)
if isListener {
listener.InterfaceUpdated()

View file

@ -11,7 +11,7 @@ import (
C "github.com/sagernet/sing-box/constant"
R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-dns"
tun "github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"

View file

@ -32,7 +32,7 @@ func TestMain(m *testing.M) {
var globalCtx context.Context
func init() {
globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry())
}
func startInstance(t *testing.T, options option.Options) *box.Box {

View file

@ -37,12 +37,12 @@ func _TestWireGuard(t *testing.T) {
Outbounds: []option.Outbound{
{
Type: C.TypeWireGuard,
Options: &option.WireGuardOutboundOptions{
Options: &option.WireGuardEndpointOptions{
ServerOptions: option.ServerOptions{
Server: "127.0.0.1",
ServerPort: serverPort,
},
LocalAddress: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
Address: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
},

View file

@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
select {
case <-c.done:
default:
c.logger.Error(context.Background(), E.Cause(err, "read packet"))
c.logger.Error(E.Cause(err, "read packet"))
err = nil
}
return
@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
b := packets[0]
common.ClearArray(b[1:4])
}
eps[0] = Endpoint(M.AddrPortFromNet(addr))
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
count = 1
return
}
@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
time.Sleep(time.Second)
return err
}
destination := netip.AddrPort(ep.(Endpoint))
destination := netip.AddrPort(ep.(remoteEndpoint))
for _, b := range bufs {
if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination]
@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
if err != nil {
return nil, err
}
return Endpoint(ap), nil
return remoteEndpoint(ap), nil
}
func (c *ClientBind) BatchSize() int {
@ -229,3 +229,31 @@ func (w *wireConn) Close() error {
close(w.done)
return nil
}
var _ conn.Endpoint = (*remoteEndpoint)(nil)
type remoteEndpoint netip.AddrPort
func (e remoteEndpoint) ClearSrc() {
}
func (e remoteEndpoint) SrcToString() string {
return ""
}
func (e remoteEndpoint) DstToString() string {
return (netip.AddrPort)(e).String()
}
func (e remoteEndpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
}
func (e remoteEndpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
}
func (e remoteEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}

View file

@ -1,13 +1,44 @@
package wireguard
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/tun"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
)
type Device interface {
tun.Device
wgTun.Device
N.Dialer
Start() error
// NewEndpoint() (stack.LinkEndpoint, error)
SetDevice(device *device.Device)
}
type DeviceOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
AllowedAddress []netip.Prefix
}
func NewDevice(options DeviceOptions) (Device, error) {
if !options.System {
return newStackDevice(options)
} else if options.Handler == nil {
return newSystemDevice(options)
} else {
return newSystemStackDevice(options)
}
}

View file

@ -5,7 +5,6 @@ package wireguard
import (
"context"
"net"
"net/netip"
"os"
"github.com/sagernet/gvisor/pkg/buffer"
@ -15,52 +14,41 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
)
var _ Device = (*StackDevice)(nil)
var _ Device = (*stackDevice)(nil)
const defaultNIC tcpip.NICID = 1
type StackDevice struct {
type stackDevice struct {
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
}
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
tunDevice := &stackDevice{
mtu: options.MTU,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}),
}
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
if err != nil {
return nil, E.New(err.String())
return nil, err
}
for _, prefix := range localAddresses {
for _, prefix := range options.Address {
addr := tun.AddressFromAddr(prefix.Addr())
protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{
@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
tunDevice.addr6 = addr
protoAddr.Protocol = ipv6.ProtocolNumber
}
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
if err != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
}
}
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
cOpt := tcpip.CongestionControlOption("cubic")
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
tunDevice.stack = ipStack
if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
}
return tunDevice, nil
}
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
return (*wireEndpoint)(w), nil
}
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
addr := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
Port: destination.Port,
Addr: tun.AddressFromAddr(destination.Addr),
}
bind := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati
}
}
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
bind := tcpip.FullAddress{
NIC: defaultNIC,
NIC: tun.DefaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil
}
func (w *StackDevice) Inet4Address() netip.Addr {
return tun.AddrFromAddress(w.addr4)
func (w *stackDevice) SetDevice(device *device.Device) {
}
func (w *StackDevice) Inet6Address() netip.Addr {
return tun.AddrFromAddress(w.addr6)
}
func (w *StackDevice) Start() error {
func (w *stackDevice) Start() error {
w.events <- wgTun.EventUp
return nil
}
func (w *StackDevice) File() *os.File {
func (w *stackDevice) File() *os.File {
return nil
}
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select {
case packetBuffer, ok := <-w.outbound:
if !ok {
@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e
sizes[0] = n
count = 1
return
case packet := <-w.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
count = 1
return
case <-w.done:
return 0, os.ErrClosed
}
}
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs {
b = b[offset:]
if len(b) == 0 {
@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return
}
func (w *StackDevice) Flush() error {
func (w *stackDevice) Flush() error {
return nil
}
func (w *StackDevice) MTU() (int, error) {
func (w *stackDevice) MTU() (int, error) {
return int(w.mtu), nil
}
func (w *StackDevice) Name() (string, error) {
func (w *stackDevice) Name() (string, error) {
return "sing-box", nil
}
func (w *StackDevice) Events() <-chan wgTun.Event {
func (w *stackDevice) Events() <-chan wgTun.Event {
return w.events
}
func (w *StackDevice) Close() error {
func (w *stackDevice) Close() error {
close(w.done)
close(w.events)
w.stack.Close()
@ -240,13 +213,13 @@ func (w *StackDevice) Close() error {
return nil
}
func (w *StackDevice) BatchSize() int {
func (w *stackDevice) BatchSize() int {
return 1
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint StackDevice
type wireEndpoint stackDevice
func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu

View file

@ -2,12 +2,12 @@
package wireguard
import (
"net/netip"
import "github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun"
)
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
func newStackDevice(options DeviceOptions) (Device, error) {
return nil, tun.ErrGVisorNotIncluded
}
func newSystemStackDevice(options DeviceOptions) (Device, error) {
return nil, tun.ErrGVisorNotIncluded
}

View file

@ -6,96 +6,88 @@ import (
"net"
"net/netip"
"os"
"runtime"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
"github.com/sagernet/wireguard-go/device"
wgTun "github.com/sagernet/wireguard-go/tun"
)
var _ Device = (*SystemDevice)(nil)
var _ Device = (*systemDevice)(nil)
type SystemDevice struct {
type systemDevice struct {
options DeviceOptions
dialer N.Dialer
device tun.Tun
batchDevice tun.LinuxTUN
name string
mtu uint32
inet4Addresses []netip.Prefix
inet6Addresses []netip.Prefix
gso bool
events chan wgTun.Event
closeOnce sync.Once
}
func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) {
var inet4Addresses []netip.Prefix
var inet6Addresses []netip.Prefix
for _, prefixes := range localPrefixes {
if prefixes.Addr().Is4() {
inet4Addresses = append(inet4Addresses, prefixes)
} else {
inet6Addresses = append(inet6Addresses, prefixes)
func newSystemDevice(options DeviceOptions) (*systemDevice, error) {
if options.Name == "" {
options.Name = tun.CalculateInterfaceName("wg")
}
}
if interfaceName == "" {
interfaceName = tun.CalculateInterfaceName("wg")
}
return &SystemDevice{
dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{
BindInterface: interfaceName,
})),
name: interfaceName,
mtu: mtu,
inet4Addresses: inet4Addresses,
inet6Addresses: inet6Addresses,
gso: gso,
return &systemDevice{
options: options,
dialer: options.CreateDialer(options.Name),
events: make(chan wgTun.Event, 1),
}, nil
}
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return w.dialer.DialContext(ctx, network, destination)
}
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return w.dialer.ListenPacket(ctx, destination)
}
func (w *SystemDevice) Inet4Address() netip.Addr {
if len(w.inet4Addresses) == 0 {
return netip.Addr{}
}
return w.inet4Addresses[0].Addr()
func (w *systemDevice) SetDevice(device *device.Device) {
}
func (w *SystemDevice) Inet6Address() netip.Addr {
if len(w.inet6Addresses) == 0 {
return netip.Addr{}
func (w *systemDevice) Start() error {
networkManager := service.FromContext[adapter.NetworkManager](w.options.Context)
tunOptions := tun.Options{
Name: w.options.Name,
Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool {
return it.Addr().Is6()
}),
MTU: w.options.MTU,
GSO: w.options.GSO,
InterfaceScope: true,
Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool {
return it.Addr().Is4()
}),
Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }),
InterfaceMonitor: networkManager.InterfaceMonitor(),
InterfaceFinder: networkManager.InterfaceFinder(),
}
return w.inet6Addresses[0].Addr()
}
func (w *SystemDevice) Start() error {
tunInterface, err := tun.New(tun.Options{
Name: w.name,
Inet4Address: w.inet4Addresses,
Inet6Address: w.inet6Addresses,
MTU: w.mtu,
GSO: w.gso,
})
// works with Linux, macOS with IFSCOPE routes, not tested on Windows
if runtime.GOOS == "darwin" {
tunOptions.AutoRoute = true
}
tunInterface, err := tun.New(tunOptions)
if err != nil {
return err
}
err = tunInterface.Start()
if err != nil {
return err
}
w.options.Logger.Info("started at ", w.options.Name)
w.device = tunInterface
if w.gso {
if w.options.GSO {
batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN)
if !isBatchTUN {
tunInterface.Close()
@ -107,15 +99,15 @@ func (w *SystemDevice) Start() error {
return nil
}
func (w *SystemDevice) File() *os.File {
func (w *systemDevice) File() *os.File {
return nil
}
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
if w.batchDevice != nil {
count, err = w.batchDevice.BatchRead(bufs, offset, sizes)
count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes)
} else {
sizes[0], err = w.device.Read(bufs[0][offset:])
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
if err == nil {
count = 1
} else if errors.Is(err, tun.ErrTooManySegments) {
@ -125,12 +117,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int,
return
}
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil {
return 0, w.batchDevice.BatchWrite(bufs, offset)
return w.batchDevice.BatchWrite(bufs, offset)
} else {
for _, b := range bufs {
_, err = w.device.Write(b[offset:])
for _, packet := range bufs {
if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
if err != nil {
return
}
@ -140,28 +136,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
return
}
func (w *SystemDevice) Flush() error {
func (w *systemDevice) Flush() error {
return nil
}
func (w *SystemDevice) MTU() (int, error) {
return int(w.mtu), nil
func (w *systemDevice) MTU() (int, error) {
return int(w.options.MTU), nil
}
func (w *SystemDevice) Name() (string, error) {
return w.name, nil
func (w *systemDevice) Name() (string, error) {
return w.options.Name, nil
}
func (w *SystemDevice) Events() <-chan wgTun.Event {
func (w *systemDevice) Events() <-chan wgTun.Event {
return w.events
}
func (w *SystemDevice) Close() error {
func (w *systemDevice) Close() error {
close(w.events)
return w.device.Close()
}
func (w *SystemDevice) BatchSize() int {
func (w *systemDevice) BatchSize() int {
if w.batchDevice != nil {
return w.batchDevice.BatchSize()
}

View file

@ -0,0 +1,182 @@
//go:build with_gvisor
package wireguard
import (
"net/netip"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/wireguard-go/device"
)
var _ Device = (*systemStackDevice)(nil)
type systemStackDevice struct {
*systemDevice
stack *stack.Stack
endpoint *deviceEndpoint
writeBufs [][]byte
}
func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
system, err := newSystemDevice(options)
if err != nil {
return nil, err
}
endpoint := &deviceEndpoint{
mtu: options.MTU,
done: make(chan struct{}),
}
ipStack, err := tun.NewGVisorStack(endpoint)
if err != nil {
return nil, err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
return &systemStackDevice{
systemDevice: system,
stack: ipStack,
endpoint: endpoint,
}, nil
}
func (w *systemStackDevice) SetDevice(device *device.Device) {
w.endpoint.device = device
}
func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
if w.batchDevice != nil {
w.writeBufs = w.writeBufs[:0]
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
w.writeBufs = append(w.writeBufs, packet)
}
}
if len(w.writeBufs) > 0 {
return w.batchDevice.BatchWrite(bufs, offset)
}
} else {
for _, packet := range bufs {
if !w.writeStack(packet[offset:]) {
if tun.PacketOffset > 0 {
common.ClearArray(packet[offset-tun.PacketOffset : offset])
tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
}
_, err = w.device.Write(packet[offset-tun.PacketOffset:])
}
if err != nil {
return
}
}
}
// WireGuard will not read count
return
}
func (w *systemStackDevice) Close() error {
close(w.endpoint.done)
w.stack.Close()
for _, endpoint := range w.stack.CleanupEndpoints() {
endpoint.Abort()
}
w.stack.Wait()
return w.systemDevice.Close()
}
func (w *systemStackDevice) writeStack(packet []byte) bool {
var (
networkProtocol tcpip.NetworkProtocolNumber
destination netip.Addr
)
switch header.IPVersion(packet) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
}
for _, prefix := range w.options.Address {
if prefix.Contains(destination) {
return false
}
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
packetBuffer.DecRef()
return true
}
type deviceEndpoint struct {
mtu uint32
done chan struct{}
device *device.Device
dispatcher stack.NetworkDispatcher
}
func (ep *deviceEndpoint) MTU() uint32 {
return ep.mtu
}
func (ep *deviceEndpoint) SetMTU(mtu uint32) {
}
func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityRXChecksumOffload
}
func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
ep.dispatcher = dispatcher
}
func (ep *deviceEndpoint) IsAttached() bool {
return ep.dispatcher != nil
}
func (ep *deviceEndpoint) Wait() {
}
func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
return true
}
func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
destination := packetBuffer.Network().DestinationAddress()
ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
}
return list.Len(), nil
}
func (ep *deviceEndpoint) Close() {
}
func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
}

View file

@ -1,35 +1,254 @@
package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"os"
"strings"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
"go4.org/netipx"
)
var _ conn.Endpoint = (*Endpoint)(nil)
type Endpoint netip.AddrPort
func (e Endpoint) ClearSrc() {
type Endpoint struct {
options EndpointOptions
peers []peerConfig
ipcConf string
allowedAddress []netip.Prefix
tunDevice Device
device *device.Device
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
}
func (e Endpoint) SrcToString() string {
return ""
func NewEndpoint(options EndpointOptions) (*Endpoint, error) {
if options.PrivateKey == "" {
return nil, E.New("missing private key")
}
privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey := hex.EncodeToString(privateKeyBytes)
ipcConf := "private_key=" + privateKey
if options.ListenPort != 0 {
ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort)
}
var peers []peerConfig
for peerIndex, rawPeer := range options.Peers {
peer := peerConfig{
allowedIPs: rawPeer.AllowedIPs,
keepalive: rawPeer.PersistentKeepaliveInterval,
}
if rawPeer.Endpoint.Addr.IsValid() {
peer.endpoint = rawPeer.Endpoint.AddrPort()
} else if rawPeer.Endpoint.IsFqdn() {
peer.destination = rawPeer.Endpoint
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.publicKeyHex = hex.EncodeToString(publicKeyBytes)
if rawPeer.PreSharedKey != "" {
preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved))
}
copy(peer.reserved[:], rawPeer.Reserved[:])
}
peers = append(peers, peer)
}
var allowedPrefixBuilder netipx.IPSetBuilder
for _, peer := range options.Peers {
for _, prefix := range peer.AllowedIPs {
allowedPrefixBuilder.AddPrefix(prefix)
}
}
allowedIPSet, err := allowedPrefixBuilder.IPSet()
if err != nil {
return nil, err
}
allowedAddresses := allowedIPSet.Prefixes()
if options.MTU == 0 {
options.MTU = 1408
}
deviceOptions := DeviceOptions{
Context: options.Context,
Logger: options.Logger,
System: options.System,
Handler: options.Handler,
UDPTimeout: options.UDPTimeout,
CreateDialer: options.CreateDialer,
Name: options.Name,
MTU: options.MTU,
GSO: options.GSO,
Address: options.Address,
AllowedAddress: allowedAddresses,
}
tunDevice, err := NewDevice(deviceOptions)
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
return &Endpoint{
options: options,
peers: peers,
ipcConf: ipcConf,
allowedAddress: allowedAddresses,
tunDevice: tunDevice,
}, nil
}
func (e Endpoint) DstToString() string {
return (netip.AddrPort)(e).String()
func (e *Endpoint) Start(resolve bool) error {
if common.Any(e.peers, func(peer peerConfig) bool {
return !peer.endpoint.IsValid() && peer.destination.IsFqdn()
}) {
if !resolve {
return nil
}
for peerIndex, peer := range e.peers {
if peer.endpoint.IsValid() || !peer.destination.IsFqdn() {
continue
}
destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn)
if err != nil {
return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination)
}
e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port)
}
} else if resolve {
return nil
}
var bind conn.Bind
wgListener, isWgListener := e.options.Dialer.(conn.Listener)
if isWgListener {
bind = conn.NewStdNetBind(wgListener)
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
peerLen := len(e.peers)
if peerLen == 1 {
isConnect = true
connectAddr = e.peers[0].endpoint
reserved = e.peers[0].reserved
}
bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved)
}
err := e.tunDevice.Start()
if err != nil {
return err
}
logger := &device.Logger{
Verbosef: func(format string, args ...interface{}) {
e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}
wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers)
e.tunDevice.SetDevice(wgDevice)
ipcConf := e.ipcConf
for _, peer := range e.peers {
ipcConf += peer.GenerateIpcLines()
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
}
e.device = wgDevice
e.pauseManager = service.FromContext[pause.Manager](e.options.Context)
if e.pauseManager != nil {
e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated)
}
return nil
}
func (e Endpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.DialContext(ctx, network, destination)
}
func (e Endpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if !destination.Addr.IsValid() {
return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination")
}
return e.tunDevice.ListenPacket(ctx, destination)
}
func (e Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
func (e *Endpoint) BindUpdate() error {
return e.device.BindUpdate()
}
func (e *Endpoint) Close() error {
if e.device != nil {
e.device.Close()
}
if e.pauseCallback != nil {
e.pauseManager.UnregisterCallback(e.pauseCallback)
}
return nil
}
func (e *Endpoint) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
e.device.Down()
case pause.EventDeviceWake:
e.device.Up()
}
}
type peerConfig struct {
destination M.Socksaddr
endpoint netip.AddrPort
publicKeyHex string
preSharedKeyHex string
allowedIPs []netip.Prefix
keepalive uint16
reserved [3]uint8
}
func (c peerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.publicKeyHex
if c.endpoint.IsValid() {
ipcLines += "\nendpoint=" + c.endpoint.String()
}
if c.preSharedKeyHex != "" {
ipcLines += "\npreshared_key=" + c.preSharedKeyHex
}
for _, allowedIP := range c.allowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP.String()
}
if c.keepalive > 0 {
ipcLines += "\npersistent_keepalive_interval=" + F.ToString(c.keepalive)
}
return ipcLines
}

View file

@ -0,0 +1,40 @@
package wireguard
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type EndpointOptions struct {
Context context.Context
Logger logger.ContextLogger
System bool
Handler tun.Handler
UDPTimeout time.Duration
Dialer N.Dialer
CreateDialer func(interfaceName string) N.Dialer
Name string
MTU uint32
GSO bool
Address []netip.Prefix
PrivateKey string
ListenPort uint16
ResolvePeer func(domain string) (netip.Addr, error)
Peers []PeerOptions
Workers int
}
type PeerOptions struct {
Endpoint M.Socksaddr
PublicKey string
PreSharedKey string
AllowedIPs []netip.Prefix
PersistentKeepaliveInterval uint16
Reserved []uint8
}

View file

@ -1,148 +0,0 @@
package wireguard
import (
"context"
"encoding/base64"
"encoding/hex"
"net/netip"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type PeerConfig struct {
destination M.Socksaddr
domainStrategy dns.DomainStrategy
Endpoint netip.AddrPort
PublicKey string
PreSharedKey string
AllowedIPs []string
Reserved [3]uint8
}
func (c PeerConfig) GenerateIpcLines() string {
ipcLines := "\npublic_key=" + c.PublicKey
ipcLines += "\nendpoint=" + c.Endpoint.String()
if c.PreSharedKey != "" {
ipcLines += "\npreshared_key=" + c.PreSharedKey
}
for _, allowedIP := range c.AllowedIPs {
ipcLines += "\nallowed_ip=" + allowedIP
}
return ipcLines
}
func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) {
var peers []PeerConfig
if len(options.Peers) > 0 {
for peerIndex, rawPeer := range options.Peers {
peer := PeerConfig{
AllowedIPs: rawPeer.AllowedIPs,
}
destination := rawPeer.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", peerIndex)
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if rawPeer.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex)
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(rawPeer.AllowedIPs) == 0 {
return nil, E.New("missing allowed_ips for peer ", peerIndex)
}
if len(rawPeer.Reserved) > 0 {
if len(rawPeer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
} else {
peer := PeerConfig{}
var (
addressHas4 bool
addressHas6 bool
)
for _, localAddress := range options.LocalAddress {
if localAddress.Addr().Is4() {
addressHas4 = true
} else {
addressHas6 = true
}
}
if addressHas4 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String())
}
if addressHas6 {
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String())
}
destination := options.ServerOptions.Build()
if destination.IsFqdn() {
peer.destination = destination
peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy)
} else {
peer.Endpoint = destination.AddrPort()
}
{
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
if err != nil {
return nil, E.Cause(err, "decode peer public key")
}
peer.PublicKey = hex.EncodeToString(bytes)
}
if options.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key")
}
peer.PreSharedKey = hex.EncodeToString(bytes)
}
if len(options.Reserved) > 0 {
if len(options.Reserved) != 3 {
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved))
}
copy(peer.Reserved[:], options.Reserved)
}
peers = append(peers, peer)
}
return peers, nil
}
func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error {
for peerIndex, peer := range peers {
if peer.Endpoint.IsValid() {
continue
}
destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy)
if err != nil {
if len(peers) == 1 {
return E.Cause(err, "resolve endpoint domain")
} else {
return E.Cause(err, "resolve endpoint domain for peer ", peerIndex)
}
}
if len(destinationAddresses) == 0 {
return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn)
}
peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port)
}
return nil
}