Add multi-peer support for wireguard outbound

This commit is contained in:
世界 2023-03-31 12:31:26 +08:00
parent b484d9bca6
commit 6f2cc9761d
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 194 additions and 77 deletions

View file

@ -13,6 +13,18 @@
"10.0.0.2/32" "10.0.0.2/32"
], ],
"private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=", "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
"peers": [
{
"server": "127.0.0.1",
"server_port": 1080,
"public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=",
"pre_shared_key": "31aIhAPwktDGpH4JDhA8GNvjFXEf/a6+UaQRyOAiyfM=",
"allowed_ips": [
"0.0.0.0/0"
],
"reserved": [0, 0, 0]
}
],
"peer_public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=", "peer_public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=",
"pre_shared_key": "31aIhAPwktDGpH4JDhA8GNvjFXEf/a6+UaQRyOAiyfM=", "pre_shared_key": "31aIhAPwktDGpH4JDhA8GNvjFXEf/a6+UaQRyOAiyfM=",
"reserved": [0, 0, 0], "reserved": [0, 0, 0],
@ -36,13 +48,13 @@
#### server #### server
==Required== ==Required if multi-peer disabled==
The server address. The server address.
#### server_port #### server_port
==Required== ==Required if multi-peer disabled==
The server port. The server port.
@ -75,9 +87,25 @@ wg genkey
echo "private key" || wg pubkey echo "private key" || wg pubkey
``` ```
#### peers
Multi-peer support.
If enabled, `server, server_port, peer_public_key, pre_shared_key` will be ignored.
#### peers.allowed_ips
WireGuard allowed IPs.
#### peers.reserved
WireGuard reserved field bytes.
`$outbound.reserved` will be used if empty.
#### peer_public_key #### peer_public_key
==Required== ==Required if multi-peer disabled==
WireGuard peer public key. WireGuard peer public key.

View file

@ -2,11 +2,12 @@ package option
type WireGuardOutboundOptions struct { type WireGuardOutboundOptions struct {
DialerOptions DialerOptions
ServerOptions
SystemInterface bool `json:"system_interface,omitempty"` SystemInterface bool `json:"system_interface,omitempty"`
InterfaceName string `json:"interface_name,omitempty"` InterfaceName string `json:"interface_name,omitempty"`
LocalAddress Listable[ListenPrefix] `json:"local_address"` LocalAddress Listable[ListenPrefix] `json:"local_address"`
PrivateKey string `json:"private_key"` PrivateKey string `json:"private_key"`
Peers []WireGuardPeer `json:"peers,omitempty"`
ServerOptions
PeerPublicKey string `json:"peer_public_key"` PeerPublicKey string `json:"peer_public_key"`
PreSharedKey string `json:"pre_shared_key,omitempty"` PreSharedKey string `json:"pre_shared_key,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"` Reserved []uint8 `json:"reserved,omitempty"`
@ -15,3 +16,11 @@ type WireGuardOutboundOptions struct {
Network NetworkList `json:"network,omitempty"` Network NetworkList `json:"network,omitempty"`
IPRewrite bool `json:"ip_rewrite,omitempty"` IPRewrite bool `json:"ip_rewrite,omitempty"`
} }
type WireGuardPeer struct {
ServerOptions
PublicKey string `json:"public_key,omitempty"`
PreSharedKey string `json:"pre_shared_key,omitempty"`
AllowedIPs Listable[string] `json:"allowed_ips,omitempty"`
Reserved []uint8 `json:"reserved,omitempty"`
}

View file

@ -57,13 +57,22 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
} }
copy(reserved[:], options.Reserved) copy(reserved[:], options.Reserved)
} }
peerAddr := options.ServerOptions.Build() var isConnect bool
outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr, reserved) var connectAddr M.Socksaddr
if len(options.Peers) < 2 {
isConnect = true
if len(options.Peers) == 1 {
connectAddr = options.Peers[0].ServerOptions.Build()
} else {
connectAddr = options.ServerOptions.Build()
}
}
outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved)
localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build) localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
if len(localPrefixes) == 0 { if len(localPrefixes) == 0 {
return nil, E.New("missing local address") return nil, E.New("missing local address")
} }
var privateKey, peerPublicKey, preSharedKey string var privateKey string
{ {
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil { if err != nil {
@ -71,6 +80,46 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
} }
privateKey = hex.EncodeToString(bytes) privateKey = hex.EncodeToString(bytes)
} }
ipcConf := "private_key=" + privateKey
if len(options.Peers) > 0 {
for i, peer := range options.Peers {
var peerPublicKey, preSharedKey string
{
bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", i)
}
peerPublicKey = hex.EncodeToString(bytes)
}
if peer.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", i)
}
preSharedKey = hex.EncodeToString(bytes)
}
destination := peer.ServerOptions.Build()
ipcConf += "\npublic_key=" + peerPublicKey
ipcConf += "\nendpoint=" + destination.String()
if preSharedKey != "" {
ipcConf += "\npreshared_key=" + preSharedKey
}
if len(peer.AllowedIPs) == 0 {
return nil, E.New("missing allowed_ips for peer ", i)
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "\nallowed_ip=" + allowedIP
}
if len(peer.Reserved) > 0 {
if len(peer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved))
}
copy(reserved[:], options.Reserved)
outbound.bind.SetReservedForEndpoint(destination, reserved)
}
}
} else {
var peerPublicKey, preSharedKey string
{ {
bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey) bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
if err != nil { if err != nil {
@ -85,9 +134,8 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
} }
preSharedKey = hex.EncodeToString(bytes) preSharedKey = hex.EncodeToString(bytes)
} }
ipcConf := "private_key=" + privateKey
ipcConf += "\npublic_key=" + peerPublicKey ipcConf += "\npublic_key=" + peerPublicKey
ipcConf += "\nendpoint=" + peerAddr.String() ipcConf += "\nendpoint=" + options.ServerOptions.Build().String()
if preSharedKey != "" { if preSharedKey != "" {
ipcConf += "\npreshared_key=" + preSharedKey ipcConf += "\npreshared_key=" + preSharedKey
} }
@ -105,6 +153,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
if has6 { if has6 {
ipcConf += "\nallowed_ip=::/0" ipcConf += "\nallowed_ip=::/0"
} }
}
mtu := options.MTU mtu := options.MTU
if mtu == 0 { if mtu == 0 {
mtu = 1408 mtu = 1408

View file

@ -3,9 +3,12 @@ package wireguard
import ( import (
"context" "context"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/conn" "github.com/sagernet/wireguard-go/conn"
@ -16,22 +19,30 @@ var _ conn.Bind = (*ClientBind)(nil)
type ClientBind struct { type ClientBind struct {
ctx context.Context ctx context.Context
dialer N.Dialer dialer N.Dialer
peerAddr M.Socksaddr reservedForEndpoint map[M.Socksaddr][3]uint8
reserved [3]uint8
connAccess sync.Mutex connAccess sync.Mutex
conn *wireConn conn *wireConn
done chan struct{} done chan struct{}
isConnect bool
connectAddr M.Socksaddr
reserved [3]uint8
} }
func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind { func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
return &ClientBind{ return &ClientBind{
ctx: ctx, ctx: ctx,
dialer: dialer, dialer: dialer,
peerAddr: peerAddr, reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
isConnect: isConnect,
connectAddr: connectAddr,
reserved: reserved, reserved: reserved,
} }
} }
func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
c.reservedForEndpoint[destination] = reserved
}
func (c *ClientBind) connect() (*wireConn, error) { func (c *ClientBind) connect() (*wireConn, error) {
serverConn := c.conn serverConn := c.conn
if serverConn != nil { if serverConn != nil {
@ -53,14 +64,28 @@ func (c *ClientBind) connect() (*wireConn, error) {
return serverConn, nil return serverConn, nil
} }
} }
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr) if c.isConnect {
udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
if err != nil { if err != nil {
return nil, &wireError{err} return nil, &wireError{err}
} }
c.conn = &wireConn{ c.conn = &wireConn{
Conn: udpConn, NetPacketConn: &bufio.UnbindPacketConn{
ExtendedConn: bufio.NewExtendedConn(udpConn),
Addr: c.connectAddr,
},
done: make(chan struct{}), done: make(chan struct{}),
} }
} else {
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
if err != nil {
return nil, &wireError{err}
}
c.conn = &wireConn{
NetPacketConn: bufio.NewPacketConn(udpConn),
done: make(chan struct{}),
}
}
return c.conn, nil return c.conn, nil
} }
@ -80,7 +105,8 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
err = &wireError{err} err = &wireError{err}
return return
} }
n, err = udpConn.Read(b) buffer := buf.With(b)
destination, err := udpConn.ReadPacket(buffer)
if err != nil { if err != nil {
udpConn.Close() udpConn.Close()
select { select {
@ -90,12 +116,16 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
} }
return return
} }
n = buffer.Len()
if buffer.Start() > 0 {
copy(b, buffer.Bytes())
}
if n > 3 { if n > 3 {
b[1] = 0 b[1] = 0
b[2] = 0 b[2] = 0
b[3] = 0 b[3] = 0
} }
ep = Endpoint(c.peerAddr) ep = Endpoint(destination)
return return
} }
@ -127,12 +157,17 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
if err != nil { if err != nil {
return err return err
} }
destination := M.Socksaddr(ep.(Endpoint))
if len(b) > 3 { if len(b) > 3 {
b[1] = c.reserved[0] reserved, loaded := c.reservedForEndpoint[destination]
b[2] = c.reserved[1] if !loaded {
b[3] = c.reserved[2] reserved = c.reserved
} }
_, err = udpConn.Write(b) b[1] = reserved[0]
b[2] = reserved[1]
b[3] = reserved[2]
}
err = udpConn.WritePacket(buf.As(b), destination)
if err != nil { if err != nil {
udpConn.Close() udpConn.Close()
} }
@ -140,15 +175,11 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
} }
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return Endpoint(c.peerAddr), nil return Endpoint(M.ParseSocksaddr(s)), nil
}
func (c *ClientBind) Endpoint() conn.Endpoint {
return Endpoint(c.peerAddr)
} }
type wireConn struct { type wireConn struct {
net.Conn N.NetPacketConn
access sync.Mutex access sync.Mutex
done chan struct{} done chan struct{}
} }
@ -161,7 +192,7 @@ func (w *wireConn) Close() error {
return net.ErrClosed return net.ErrClosed
default: default:
} }
w.Conn.Close() w.NetPacketConn.Close()
close(w.done) close(w.done)
return nil return nil
} }