Refactor wireguard & add tun support

This commit is contained in:
世界 2022-09-06 00:15:09 +08:00
parent 8e7957d440
commit cb4fea0240
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
20 changed files with 792 additions and 425 deletions

View file

@ -8,7 +8,6 @@
"server": "127.0.0.1", "server": "127.0.0.1",
"server_port": 1080, "server_port": 1080,
"local_address": [ "local_address": [
"10.0.0.1",
"10.0.0.2/32" "10.0.0.2/32"
], ],
"private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=", "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
@ -43,7 +42,7 @@ The server port.
==Required== ==Required==
List of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned to the interface. List of IP (v4 or v6) address prefixes to be assigned to the interface.
#### private_key #### private_key

View file

@ -8,7 +8,6 @@
"server": "127.0.0.1", "server": "127.0.0.1",
"server_port": 1080, "server_port": 1080,
"local_address": [ "local_address": [
"10.0.0.1",
"10.0.0.2/32" "10.0.0.2/32"
], ],
"private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=", "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
@ -45,7 +44,7 @@
接口的 IPv4/IPv6 地址或地址段的列表您。 接口的 IPv4/IPv6 地址或地址段的列表您。
要分配给接口的 IPv4 或 v6地址列表(可以选择带有 CIDR 掩码) 要分配给接口的 IPv4 或 v6地址列表。
#### private_key #### private_key

View file

@ -40,7 +40,7 @@ type Tun struct {
func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunInboundOptions) (*Tun, error) { func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunInboundOptions) (*Tun, error) {
tunName := options.InterfaceName tunName := options.InterfaceName
if tunName == "" { if tunName == "" {
tunName = tun.DefaultInterfaceName() tunName = tun.CalculateInterfaceName("")
} }
tunMTU := options.MTU tunMTU := options.MTU
if tunMTU == 0 { if tunMTU == 0 {
@ -77,8 +77,8 @@ func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger
tunOptions: tun.Options{ tunOptions: tun.Options{
Name: tunName, Name: tunName,
MTU: tunMTU, MTU: tunMTU,
Inet4Address: options.Inet4Address.Build(), Inet4Address: common.Map(options.Inet4Address, option.ListenPrefix.Build),
Inet6Address: options.Inet6Address.Build(), Inet6Address: common.Map(options.Inet6Address, option.ListenPrefix.Build),
AutoRoute: options.AutoRoute, AutoRoute: options.AutoRoute,
StrictRoute: options.StrictRoute, StrictRoute: options.StrictRoute,
IncludeUID: includeUID, IncludeUID: includeUID,
@ -87,6 +87,7 @@ func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger
IncludePackage: options.IncludePackage, IncludePackage: options.IncludePackage,
ExcludePackage: options.ExcludePackage, ExcludePackage: options.ExcludePackage,
InterfaceMonitor: router.InterfaceMonitor(), InterfaceMonitor: router.InterfaceMonitor(),
TableIndex: 2022,
}, },
endpointIndependentNat: options.EndpointIndependentNat, endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: udpTimeout, udpTimeout: udpTimeout,

View file

@ -108,6 +108,7 @@ type DialerOptions struct {
Detour string `json:"detour,omitempty"` Detour string `json:"detour,omitempty"`
BindInterface string `json:"bind_interface,omitempty"` BindInterface string `json:"bind_interface,omitempty"`
BindAddress *ListenAddress `json:"bind_address,omitempty"` BindAddress *ListenAddress `json:"bind_address,omitempty"`
BindAddress6 *ListenAddress `json:"bind_address6,omitempty"`
ProtectPath string `json:"protect_path,omitempty"` ProtectPath string `json:"protect_path,omitempty"`
RoutingMark int `json:"routing_mark,omitempty"` RoutingMark int `json:"routing_mark,omitempty"`
ReuseAddr bool `json:"reuse_addr,omitempty"` ReuseAddr bool `json:"reuse_addr,omitempty"`

View file

@ -1,21 +1,21 @@
package option package option
type TunInboundOptions struct { type TunInboundOptions struct {
InterfaceName string `json:"interface_name,omitempty"` InterfaceName string `json:"interface_name,omitempty"`
MTU uint32 `json:"mtu,omitempty"` MTU uint32 `json:"mtu,omitempty"`
Inet4Address *ListenPrefix `json:"inet4_address,omitempty"` Inet4Address Listable[ListenPrefix] `json:"inet4_address,omitempty"`
Inet6Address *ListenPrefix `json:"inet6_address,omitempty"` Inet6Address Listable[ListenPrefix] `json:"inet6_address,omitempty"`
AutoRoute bool `json:"auto_route,omitempty"` AutoRoute bool `json:"auto_route,omitempty"`
StrictRoute bool `json:"strict_route,omitempty"` StrictRoute bool `json:"strict_route,omitempty"`
IncludeUID Listable[uint32] `json:"include_uid,omitempty"` IncludeUID Listable[uint32] `json:"include_uid,omitempty"`
IncludeUIDRange Listable[string] `json:"include_uid_range,omitempty"` IncludeUIDRange Listable[string] `json:"include_uid_range,omitempty"`
ExcludeUID Listable[uint32] `json:"exclude_uid,omitempty"` ExcludeUID Listable[uint32] `json:"exclude_uid,omitempty"`
ExcludeUIDRange Listable[string] `json:"exclude_uid_range,omitempty"` ExcludeUIDRange Listable[string] `json:"exclude_uid_range,omitempty"`
IncludeAndroidUser Listable[int] `json:"include_android_user,omitempty"` IncludeAndroidUser Listable[int] `json:"include_android_user,omitempty"`
IncludePackage Listable[string] `json:"include_package,omitempty"` IncludePackage Listable[string] `json:"include_package,omitempty"`
ExcludePackage Listable[string] `json:"exclude_package,omitempty"` ExcludePackage Listable[string] `json:"exclude_package,omitempty"`
EndpointIndependentNat bool `json:"endpoint_independent_nat,omitempty"` EndpointIndependentNat bool `json:"endpoint_independent_nat,omitempty"`
UDPTimeout int64 `json:"udp_timeout,omitempty"` UDPTimeout int64 `json:"udp_timeout,omitempty"`
Stack string `json:"stack,omitempty"` Stack string `json:"stack,omitempty"`
InboundOptions InboundOptions
} }

View file

@ -184,9 +184,6 @@ func (p *ListenPrefix) UnmarshalJSON(bytes []byte) error {
return nil return nil
} }
func (p *ListenPrefix) Build() netip.Prefix { func (p ListenPrefix) Build() netip.Prefix {
if p == nil { return netip.Prefix(p)
return netip.Prefix{}
}
return netip.Prefix(*p)
} }

View file

@ -3,10 +3,12 @@ package option
type WireGuardOutboundOptions struct { type WireGuardOutboundOptions struct {
DialerOptions DialerOptions
ServerOptions ServerOptions
LocalAddress Listable[string] `json:"local_address"` SystemInterface bool `json:"system_interface,omitempty"`
PrivateKey string `json:"private_key"` InterfaceName string `json:"interface_name,omitempty"`
PeerPublicKey string `json:"peer_public_key"` LocalAddress Listable[ListenPrefix] `json:"local_address"`
PreSharedKey string `json:"pre_shared_key,omitempty"` PrivateKey string `json:"private_key"`
MTU uint32 `json:"mtu,omitempty"` PeerPublicKey string `json:"peer_public_key"`
Network NetworkList `json:"network,omitempty"` PreSharedKey string `json:"pre_shared_key,omitempty"`
MTU uint32 `json:"mtu,omitempty"`
Network NetworkList `json:"network,omitempty"`
} }

View file

@ -8,49 +8,30 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"os"
"strings" "strings"
"sync"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/debug" "github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
) )
var _ adapter.Outbound = (*WireGuard)(nil) var _ adapter.Outbound = (*WireGuard)(nil)
type WireGuard struct { type WireGuard struct {
myOutboundAdapter myOutboundAdapter
ctx context.Context bind *wireguard.ClientBind
serverAddr M.Socksaddr device *device.Device
dialer N.Dialer tunDevice wireguard.Device
endpoint conn.Endpoint
device *device.Device
tunDevice *wireTunDevice
connAccess sync.Mutex
conn *wireConn
} }
func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) { func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) {
@ -62,39 +43,13 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
logger: logger, logger: logger,
tag: tag, tag: tag,
}, },
ctx: ctx,
serverAddr: options.ServerOptions.Build(),
dialer: dialer.New(router, options.DialerOptions),
} }
var endpointIp netip.Addr peerAddr := options.ServerOptions.Build()
if !outbound.serverAddr.IsFqdn() { outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr)
endpointIp = outbound.serverAddr.Addr localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
} else { if len(localPrefixes) == 0 {
endpointIp = netip.AddrFrom4([4]byte{127, 0, 0, 1})
}
outbound.endpoint = conn.StdNetEndpoint(netip.AddrPortFrom(endpointIp, outbound.serverAddr.Port))
localAddress := make([]tcpip.AddressWithPrefix, len(options.LocalAddress))
if len(localAddress) == 0 {
return nil, E.New("missing local address") return nil, E.New("missing local address")
} }
for index, address := range options.LocalAddress {
if strings.Contains(address, "/") {
prefix, err := netip.ParsePrefix(address)
if err != nil {
return nil, E.Cause(err, "parse local address prefix ", address)
}
localAddress[index] = tcpip.AddressWithPrefix{
Address: tcpip.Address(prefix.Addr().AsSlice()),
PrefixLen: prefix.Bits(),
}
} else {
addr, err := netip.ParseAddr(address)
if err != nil {
return nil, E.Cause(err, "parse local address ", address)
}
localAddress[index] = tcpip.Address(addr.AsSlice()).WithPrefix()
}
}
var privateKey, peerPublicKey, preSharedKey string var privateKey, peerPublicKey, preSharedKey string
{ {
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
@ -119,13 +74,13 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
} }
ipcConf := "private_key=" + privateKey ipcConf := "private_key=" + privateKey
ipcConf += "\npublic_key=" + peerPublicKey ipcConf += "\npublic_key=" + peerPublicKey
ipcConf += "\nendpoint=" + outbound.endpoint.DstToString() ipcConf += "\nendpoint=" + peerAddr.String()
if preSharedKey != "" { if preSharedKey != "" {
ipcConf += "\npreshared_key=" + preSharedKey ipcConf += "\npreshared_key=" + preSharedKey
} }
var has4, has6 bool var has4, has6 bool
for _, address := range localAddress { for _, address := range localPrefixes {
if address.Address.To4() != "" { if address.Addr().Is4() {
has4 = true has4 = true
} else { } else {
has6 = true has6 = true
@ -141,11 +96,17 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
if mtu == 0 { if mtu == 0 {
mtu = 1408 mtu = 1408
} }
wireDevice, err := newWireDevice(localAddress, mtu) var wireTunDevice wireguard.Device
if err != nil { var err error
return nil, err if !options.SystemInterface {
wireTunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu)
} else {
wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu)
} }
wgDevice := device.NewDevice(wireDevice, (*wireClientBind)(outbound), &device.Logger{ if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
wgDevice := device.NewDevice(wireTunDevice, outbound.bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) { Verbosef: func(format string, args ...interface{}) {
logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
}, },
@ -161,7 +122,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
return nil, E.Cause(err, "setup wireguard") return nil, E.Cause(err, "setup wireguard")
} }
outbound.device = wgDevice outbound.device = wgDevice
outbound.tunDevice = wireDevice outbound.tunDevice = wireTunDevice
return outbound, nil return outbound, nil
} }
@ -172,54 +133,19 @@ func (w *WireGuard) DialContext(ctx context.Context, network string, destination
case N.NetworkUDP: case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
} }
addr := tcpip.FullAddress{
NIC: defaultNIC,
Port: destination.Port,
}
if destination.IsFqdn() { if destination.IsFqdn() {
addrs, err := w.router.LookupDefault(ctx, destination.Fqdn) addrs, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
addr.Addr = tcpip.Address(addrs[0].AsSlice()) return N.DialSerial(ctx, w.tunDevice, network, destination, addrs)
} else {
addr.Addr = tcpip.Address(destination.Addr.AsSlice())
}
bind := tcpip.FullAddress{
NIC: defaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.tunDevice.addr4
} else {
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.tunDevice.addr6
}
switch N.NetworkName(network) {
case N.NetworkTCP:
return gonet.DialTCPWithBind(ctx, w.tunDevice.stack, bind, addr, networkProtocol)
case N.NetworkUDP:
return gonet.DialUDP(w.tunDevice.stack, &bind, &addr, networkProtocol)
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
} }
return w.tunDevice.DialContext(ctx, network, destination)
} }
func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination) w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
bind := tcpip.FullAddress{ return w.tunDevice.ListenPacket(ctx, destination)
NIC: defaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() || w.tunDevice.addr6 == "" {
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.tunDevice.addr4
} else {
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.tunDevice.addr6
}
return gonet.DialUDP(w.tunDevice.stack, &bind, nil, networkProtocol)
} }
func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
@ -231,300 +157,13 @@ func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn,
} }
func (w *WireGuard) Start() error { func (w *WireGuard) Start() error {
w.tunDevice.events <- tun.EventUp return w.tunDevice.Start()
return nil
} }
func (w *WireGuard) Close() error { func (w *WireGuard) Close() error {
return common.Close( return common.Close(
common.PtrOrNil(w.tunDevice), w.tunDevice,
common.PtrOrNil(w.device), common.PtrOrNil(w.device),
common.PtrOrNil(w.conn), common.PtrOrNil(w.bind),
) )
} }
var _ conn.Bind = (*wireClientBind)(nil)
type wireClientBind WireGuard
func (c *wireClientBind) connect() (*wireConn, error) {
c.connAccess.Lock()
defer c.connAccess.Unlock()
if c.conn != nil {
select {
case <-c.conn.done:
default:
return c.conn, nil
}
}
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
if err != nil {
return nil, &wireError{err}
}
c.conn = &wireConn{
Conn: udpConn,
done: make(chan struct{}),
}
return c.conn, nil
}
func (c *wireClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return []conn.ReceiveFunc{c.receive}, 0, nil
}
func (c *wireClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
udpConn, err := c.connect()
if err != nil {
err = &wireError{err}
return
}
n, err = udpConn.Read(b)
if err != nil {
udpConn.Close()
err = &wireError{err}
}
ep = c.endpoint
return
}
func (c *wireClientBind) Close() error {
c.connAccess.Lock()
defer c.connAccess.Unlock()
common.Close(common.PtrOrNil(c.conn))
return nil
}
func (c *wireClientBind) SetMark(mark uint32) error {
return nil
}
func (c *wireClientBind) Send(b []byte, ep conn.Endpoint) error {
udpConn, err := c.connect()
if err != nil {
return err
}
_, err = udpConn.Write(b)
if err != nil {
udpConn.Close()
}
return err
}
func (c *wireClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return c.endpoint, nil
}
type wireError struct {
cause error
}
func (w *wireError) Error() string {
return w.cause.Error()
}
func (w *wireError) Timeout() bool {
if cause, causeNet := w.cause.(net.Error); causeNet {
return cause.Timeout()
}
return false
}
func (w *wireError) Temporary() bool {
return true
}
type wireConn struct {
net.Conn
access sync.Mutex
done chan struct{}
}
func (w *wireConn) Close() error {
w.access.Lock()
defer w.access.Unlock()
select {
case <-w.done:
return net.ErrClosed
default:
}
w.Conn.Close()
close(w.done)
return nil
}
var _ tun.Device = (*wireTunDevice)(nil)
const defaultNIC tcpip.NICID = 1
type wireTunDevice struct {
stack *stack.Stack
mtu uint32
events chan tun.Event
outbound chan *stack.PacketBuffer
dispatcher stack.NetworkDispatcher
done chan struct{}
addr4 tcpip.Address
addr6 tcpip.Address
}
func newWireDevice(localAddresses []tcpip.AddressWithPrefix, mtu uint32) (*wireTunDevice, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &wireTunDevice{
stack: ipStack,
mtu: mtu,
events: make(chan tun.Event, 4),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
}
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
if err != nil {
return nil, E.New(err.String())
}
for _, addr := range localAddresses {
var protoAddr tcpip.ProtocolAddress
if len(addr.Address) == net.IPv4len {
tunDevice.addr4 = addr.Address
protoAddr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: addr,
}
} else {
tunDevice.addr6 = addr.Address
protoAddr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: addr,
}
}
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
if err != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
}
}
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
cOpt := tcpip.CongestionControlOption("cubic")
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
return tunDevice, nil
}
func (w *wireTunDevice) File() *os.File {
return nil
}
func (w *wireTunDevice) Read(p []byte, offset int) (n int, err error) {
packetBuffer, ok := <-w.outbound
if !ok {
return 0, os.ErrClosed
}
defer packetBuffer.DecRef()
p = p[offset:]
for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice)
}
return
}
func (w *wireTunDevice) Write(p []byte, offset int) (n int, err error) {
p = p[offset:]
if len(p) == 0 {
return
}
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(p) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(p),
})
defer packetBuffer.DecRef()
w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
n = len(p)
return
}
func (w *wireTunDevice) Flush() error {
return nil
}
func (w *wireTunDevice) MTU() (int, error) {
return int(w.mtu), nil
}
func (w *wireTunDevice) Name() (string, error) {
return "sing-box", nil
}
func (w *wireTunDevice) Events() chan tun.Event {
return w.events
}
func (w *wireTunDevice) Close() error {
select {
case <-w.done:
return os.ErrClosed
default:
}
close(w.done)
w.stack.Close()
for _, endpoint := range w.stack.CleanupEndpoints() {
endpoint.Abort()
}
w.stack.Wait()
close(w.outbound)
return nil
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint wireTunDevice
func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu
}
func (ep *wireEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
ep.dispatcher = dispatcher
}
func (ep *wireEndpoint) IsAttached() bool {
return ep.dispatcher != nil
}
func (ep *wireEndpoint) Wait() {
}
func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
packetBuffer.IncRef()
ep.outbound <- packetBuffer
}
return list.Len(), nil
}

View file

@ -1,6 +1,6 @@
[Interface] [Interface]
PrivateKey = gHWUGzTh5YCEV6k8dneVP537XhVtoQJPIlFNs2zsxlE= PrivateKey = gHWUGzTh5YCEV6k8dneVP537XhVtoQJPIlFNs2zsxlE=
Address = 10.0.0.1/24 Address = 10.0.0.1/32
ListenPort = 10000 ListenPort = 10000
[Peer] [Peer]

View file

@ -2,6 +2,7 @@ package main
import ( import (
"net/netip" "net/netip"
"os"
"testing" "testing"
"time" "time"
@ -40,7 +41,7 @@ func TestWireGuard(t *testing.T) {
Server: "127.0.0.1", Server: "127.0.0.1",
ServerPort: serverPort, ServerPort: serverPort,
}, },
LocalAddress: []string{"10.0.0.2/32"}, LocalAddress: []option.ListenPrefix{option.ListenPrefix(netip.MustParsePrefix("10.0.0.2/32"))},
PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=", PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=", PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
}, },
@ -49,3 +50,49 @@ func TestWireGuard(t *testing.T) {
}) })
testSuitWg(t, clientPort, testPort) testSuitWg(t, clientPort, testPort)
} }
func TestWireGuardSystem(t *testing.T) {
if os.Getuid() != 0 {
t.Skip("requires root")
}
startDockerContainer(t, DockerOptions{
Image: ImageBoringTun,
Cap: []string{"MKNOD", "NET_ADMIN", "NET_RAW"},
Ports: []uint16{serverPort, testPort},
Bind: map[string]string{
"wireguard.conf": "/etc/wireguard/wg0.conf",
},
Cmd: []string{"wg0"},
})
time.Sleep(5 * time.Second)
startInstance(t, option.Options{
Inbounds: []option.Inbound{
{
Type: C.TypeMixed,
MixedOptions: option.HTTPMixedInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.ListenAddress(netip.IPv4Unspecified()),
ListenPort: clientPort,
},
},
},
},
Outbounds: []option.Outbound{
{
Type: C.TypeWireGuard,
WireGuardOptions: option.WireGuardOutboundOptions{
InterfaceName: "wg",
ServerOptions: option.ServerOptions{
Server: "127.0.0.1",
ServerPort: serverPort,
},
LocalAddress: []option.ListenPrefix{option.ListenPrefix(netip.MustParsePrefix("10.0.0.2/32"))},
PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
},
},
},
})
time.Sleep(10 * time.Second)
testSuitWg(t, clientPort, testPort)
}

View file

@ -0,0 +1,132 @@
package wireguard
import (
"context"
"net"
"sync"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.zx2c4.com/wireguard/conn"
)
var _ conn.Bind = (*ClientBind)(nil)
type ClientBind struct {
ctx context.Context
dialer N.Dialer
peerAddr M.Socksaddr
connAccess sync.Mutex
conn *wireConn
}
func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr) *ClientBind {
return &ClientBind{
ctx: ctx,
dialer: dialer,
peerAddr: peerAddr,
}
}
func (c *ClientBind) connect() (*wireConn, error) {
serverConn := c.conn
if serverConn != nil {
select {
case <-serverConn.done:
serverConn = nil
default:
return serverConn, nil
}
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
serverConn = c.conn
if serverConn != nil {
select {
case <-serverConn.done:
serverConn = nil
default:
return serverConn, nil
}
}
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr)
if err != nil {
return nil, &wireError{err}
}
c.conn = &wireConn{
Conn: udpConn,
done: make(chan struct{}),
}
return c.conn, nil
}
func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return []conn.ReceiveFunc{c.receive}, 0, nil
}
func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
udpConn, err := c.connect()
if err != nil {
err = &wireError{err}
return
}
n, err = udpConn.Read(b)
if err != nil {
udpConn.Close()
err = &wireError{err}
}
ep = Endpoint(c.peerAddr)
return
}
func (c *ClientBind) Close() error {
c.connAccess.Lock()
defer c.connAccess.Unlock()
common.Close(common.PtrOrNil(c.conn))
return nil
}
func (c *ClientBind) SetMark(mark uint32) error {
return nil
}
func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
udpConn, err := c.connect()
if err != nil {
return err
}
_, err = udpConn.Write(b)
if err != nil {
udpConn.Close()
}
return err
}
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return Endpoint(c.peerAddr), nil
}
func (c *ClientBind) Endpoint() conn.Endpoint {
return Endpoint(c.peerAddr)
}
type wireConn struct {
net.Conn
access sync.Mutex
done chan struct{}
}
func (w *wireConn) Close() error {
w.access.Lock()
defer w.access.Unlock()
select {
case <-w.done:
return net.ErrClosed
default:
}
w.Conn.Close()
close(w.done)
return nil
}

View file

@ -0,0 +1,14 @@
package wireguard
import (
N "github.com/sagernet/sing/common/network"
"golang.zx2c4.com/wireguard/tun"
)
type Device interface {
tun.Device
N.Dialer
Start() error
// NewEndpoint() (stack.LinkEndpoint, error)
}

View file

@ -0,0 +1,254 @@
//go:build !no_gvisor
package wireguard
import (
"context"
"net"
"net/netip"
"os"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
var _ Device = (*StackDevice)(nil)
const defaultNIC tcpip.NICID = 1
type StackDevice struct {
stack *stack.Stack
mtu uint32
events chan tun.Event
outbound chan *stack.PacketBuffer
dispatcher stack.NetworkDispatcher
done chan struct{}
addr4 tcpip.Address
addr6 tcpip.Address
}
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
events: make(chan tun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
}
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
if err != nil {
return nil, E.New(err.String())
}
for _, prefix := range localAddresses {
addr := tcpip.Address(prefix.Addr().AsSlice())
protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: prefix.Bits(),
},
}
if prefix.Addr().Is4() {
tunDevice.addr4 = addr
protoAddr.Protocol = ipv4.ProtocolNumber
} else {
tunDevice.addr6 = addr
protoAddr.Protocol = ipv6.ProtocolNumber
}
err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
if err != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
}
}
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
cOpt := tcpip.CongestionControlOption("cubic")
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
return tunDevice, nil
}
func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
return (*wireEndpoint)(w), nil
}
func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
addr := tcpip.FullAddress{
NIC: defaultNIC,
Port: destination.Port,
Addr: tcpip.Address(destination.Addr.AsSlice()),
}
bind := tcpip.FullAddress{
NIC: defaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4
} else {
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6
}
switch N.NetworkName(network) {
case N.NetworkTCP:
return gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
case N.NetworkUDP:
return gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
}
func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
bind := tcpip.FullAddress{
NIC: defaultNIC,
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() || w.addr6 == "" {
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4
} else {
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6
}
return gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
}
func (w *StackDevice) Start() error {
w.events <- tun.EventUp
return nil
}
func (w *StackDevice) File() *os.File {
return nil
}
func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
packetBuffer, ok := <-w.outbound
if !ok {
return 0, os.ErrClosed
}
defer packetBuffer.DecRef()
p = p[offset:]
for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice)
}
return
}
func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
p = p[offset:]
if len(p) == 0 {
return
}
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(p) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(p),
})
defer packetBuffer.DecRef()
w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
n = len(p)
return
}
func (w *StackDevice) Flush() error {
return nil
}
func (w *StackDevice) MTU() (int, error) {
return int(w.mtu), nil
}
func (w *StackDevice) Name() (string, error) {
return "sing-box", nil
}
func (w *StackDevice) Events() chan tun.Event {
return w.events
}
func (w *StackDevice) Close() error {
select {
case <-w.done:
return os.ErrClosed
default:
}
close(w.done)
w.stack.Close()
for _, endpoint := range w.stack.CleanupEndpoints() {
endpoint.Abort()
}
w.stack.Wait()
close(w.outbound)
return nil
}
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
type wireEndpoint StackDevice
func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu
}
func (ep *wireEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
ep.dispatcher = dispatcher
}
func (ep *wireEndpoint) IsAttached() bool {
return ep.dispatcher != nil
}
func (ep *wireEndpoint) Wait() {
}
func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
packetBuffer.IncRef()
ep.outbound <- packetBuffer
}
return list.Len(), nil
}

View file

@ -0,0 +1,9 @@
//go:build no_gvisor
package wireguard
import "github.com/sagernet/sing-tun"
func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
return nil, tun.ErrGVisorNotIncluded
}

View file

@ -0,0 +1,110 @@
package wireguard
import (
"context"
"net"
"net/netip"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
wgTun "golang.zx2c4.com/wireguard/tun"
)
var _ Device = (*SystemDevice)(nil)
type SystemDevice struct {
dialer N.Dialer
device tun.Tun
name string
mtu int
events chan wgTun.Event
}
/*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) {
gTun, isGTun := w.device.(tun.GVisorTun)
if !isGTun {
return nil, tun.ErrGVisorUnsupported
}
return gTun.NewEndpoint()
}*/
func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) {
var inet4Addresses []netip.Prefix
var inet6Addresses []netip.Prefix
for _, prefixes := range localPrefixes {
if prefixes.Addr().Is4() {
inet4Addresses = append(inet4Addresses, prefixes)
} else {
inet6Addresses = append(inet6Addresses, prefixes)
}
}
if interfaceName == "" {
interfaceName = tun.CalculateInterfaceName("wg")
}
tunInterface, err := tun.Open(tun.Options{
Name: interfaceName,
Inet4Address: inet4Addresses,
Inet6Address: inet6Addresses,
MTU: mtu,
})
if err != nil {
return nil, err
}
return &SystemDevice{
dialer.NewDefault(router, option.DialerOptions{
BindInterface: interfaceName,
}),
tunInterface, interfaceName, int(mtu), make(chan wgTun.Event),
}, nil
}
func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return w.dialer.DialContext(ctx, network, destination)
}
func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return w.dialer.ListenPacket(ctx, destination)
}
func (w *SystemDevice) Start() error {
w.events <- wgTun.EventUp
return nil
}
func (w *SystemDevice) File() *os.File {
return nil
}
func (w *SystemDevice) Read(bytes []byte, index int) (int, error) {
return w.device.Read(bytes[index-tunPacketOffset:])
}
func (w *SystemDevice) Write(bytes []byte, index int) (int, error) {
return w.device.Write(bytes[index:])
}
func (w *SystemDevice) Flush() error {
return nil
}
func (w *SystemDevice) MTU() (int, error) {
return w.mtu, nil
}
func (w *SystemDevice) Name() (string, error) {
return w.name, nil
}
func (w *SystemDevice) Events() chan wgTun.Event {
return w.events
}
func (w *SystemDevice) Close() error {
return w.device.Close()
}

View file

@ -0,0 +1,37 @@
package wireguard
import (
"net/netip"
M "github.com/sagernet/sing/common/metadata"
"golang.zx2c4.com/wireguard/conn"
)
var _ conn.Endpoint = (*Endpoint)(nil)
type Endpoint M.Socksaddr
func (e Endpoint) ClearSrc() {
}
func (e Endpoint) SrcToString() string {
return ""
}
func (e Endpoint) DstToString() string {
return (M.Socksaddr)(e).String()
}
func (e Endpoint) DstToBytes() []byte {
b, _ := (M.Socksaddr)(e).AddrPort().MarshalBinary()
return b
}
func (e Endpoint) DstIP() netip.Addr {
return (M.Socksaddr)(e).Addr
}
func (e Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
}

View file

@ -0,0 +1,22 @@
package wireguard
import "net"
type wireError struct {
cause error
}
func (w *wireError) Error() string {
return w.cause.Error()
}
func (w *wireError) Timeout() bool {
if cause, causeNet := w.cause.(net.Error); causeNet {
return cause.Timeout()
}
return false
}
func (w *wireError) Temporary() bool {
return true
}

View file

@ -0,0 +1,96 @@
package wireguard
import (
"io"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.zx2c4.com/wireguard/conn"
)
var _ conn.Bind = (*ServerBind)(nil)
type ServerBind struct {
inbound chan serverPacket
done chan struct{}
writeBack N.PacketWriter
}
func NewServerBind(writeBack N.PacketWriter) *ServerBind {
return &ServerBind{
inbound: make(chan serverPacket, 256),
done: make(chan struct{}),
writeBack: writeBack,
}
}
func (s *ServerBind) Abort() error {
select {
case <-s.done:
return io.ErrClosedPipe
default:
close(s.done)
}
return nil
}
type serverPacket struct {
buffer *buf.Buffer
source M.Socksaddr
}
func (s *ServerBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
fns = []conn.ReceiveFunc{s.receive}
return
}
func (s *ServerBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
select {
case packet := <-s.inbound:
defer packet.buffer.Release()
n = copy(b, packet.buffer.Bytes())
ep = Endpoint(packet.source)
return
case <-s.done:
err = io.ErrClosedPipe
return
}
}
func (s *ServerBind) WriteIsThreadUnsafe() {
}
func (s *ServerBind) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
select {
case s.inbound <- serverPacket{
buffer: buffer,
source: destination,
}:
return nil
case <-s.done:
return io.ErrClosedPipe
}
}
func (s *ServerBind) Close() error {
return nil
}
func (s *ServerBind) SetMark(mark uint32) error {
return nil
}
func (s *ServerBind) Send(b []byte, ep conn.Endpoint) error {
return s.writeBack.WritePacket(buf.As(b), M.Socksaddr(ep.(Endpoint)))
}
func (s *ServerBind) ParseEndpoint(addr string) (conn.Endpoint, error) {
destination := M.ParseSocksaddr(addr)
if !destination.IsValid() || destination.Port == 0 {
return nil, E.New("invalid endpoint: ", addr)
}
return Endpoint(destination), nil
}

View file

@ -0,0 +1,3 @@
package wireguard
const tunPacketOffset = 4

View file

@ -0,0 +1,5 @@
//go:build !darwin
package wireguard
const tunPacketOffset = 0