diff --git a/docs/configuration/outbound/wireguard.md b/docs/configuration/outbound/wireguard.md index 705d16da..f4a88108 100644 --- a/docs/configuration/outbound/wireguard.md +++ b/docs/configuration/outbound/wireguard.md @@ -13,6 +13,18 @@ "10.0.0.2/32" ], "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=", + "peers": [ + { + "server": "127.0.0.1", + "server_port": 1080, + "public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=", + "pre_shared_key": "31aIhAPwktDGpH4JDhA8GNvjFXEf/a6+UaQRyOAiyfM=", + "allowed_ips": [ + "0.0.0.0/0" + ], + "reserved": [0, 0, 0] + } + ], "peer_public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=", "pre_shared_key": "31aIhAPwktDGpH4JDhA8GNvjFXEf/a6+UaQRyOAiyfM=", "reserved": [0, 0, 0], @@ -36,13 +48,13 @@ #### server -==Required== +==Required if multi-peer disabled== The server address. #### server_port -==Required== +==Required if multi-peer disabled== The server port. @@ -75,9 +87,25 @@ wg genkey echo "private key" || wg pubkey ``` +#### peers + +Multi-peer support. + +If enabled, `server, server_port, peer_public_key, pre_shared_key` will be ignored. + +#### peers.allowed_ips + +WireGuard allowed IPs. + +#### peers.reserved + +WireGuard reserved field bytes. + +`$outbound.reserved` will be used if empty. + #### peer_public_key -==Required== +==Required if multi-peer disabled== WireGuard peer public key. diff --git a/option/wireguard.go b/option/wireguard.go index ee6e1053..21b6715b 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -2,15 +2,24 @@ package option type WireGuardOutboundOptions struct { DialerOptions - ServerOptions SystemInterface bool `json:"system_interface,omitempty"` InterfaceName string `json:"interface_name,omitempty"` LocalAddress Listable[ListenPrefix] `json:"local_address"` PrivateKey string `json:"private_key"` - PeerPublicKey string `json:"peer_public_key"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` - Workers int `json:"workers,omitempty"` - MTU uint32 `json:"mtu,omitempty"` - Network NetworkList `json:"network,omitempty"` + Peers []WireGuardPeer `json:"peers,omitempty"` + ServerOptions + PeerPublicKey string `json:"peer_public_key"` + PreSharedKey string `json:"pre_shared_key,omitempty"` + Reserved []uint8 `json:"reserved,omitempty"` + Workers int `json:"workers,omitempty"` + MTU uint32 `json:"mtu,omitempty"` + Network NetworkList `json:"network,omitempty"` +} + +type WireGuardPeer struct { + ServerOptions + PublicKey string `json:"public_key,omitempty"` + PreSharedKey string `json:"pre_shared_key,omitempty"` + AllowedIPs Listable[string] `json:"allowed_ips,omitempty"` + Reserved []uint8 `json:"reserved,omitempty"` } diff --git a/outbound/wireguard.go b/outbound/wireguard.go index cdba9812..9ca3a16a 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -54,13 +54,22 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context } copy(reserved[:], options.Reserved) } - peerAddr := options.ServerOptions.Build() - outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr, reserved) + var isConnect bool + var connectAddr M.Socksaddr + if len(options.Peers) < 2 { + isConnect = true + if len(options.Peers) == 1 { + connectAddr = options.Peers[0].ServerOptions.Build() + } else { + connectAddr = options.ServerOptions.Build() + } + } + outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved) localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build) if len(localPrefixes) == 0 { return nil, E.New("missing local address") } - var privateKey, peerPublicKey, preSharedKey string + var privateKey string { bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) if err != nil { @@ -68,39 +77,79 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context } privateKey = hex.EncodeToString(bytes) } - { - bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey) - if err != nil { - return nil, E.Cause(err, "decode peer public key") - } - peerPublicKey = hex.EncodeToString(bytes) - } - if options.PreSharedKey != "" { - bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey) - if err != nil { - return nil, E.Cause(err, "decode pre shared key") - } - preSharedKey = hex.EncodeToString(bytes) - } ipcConf := "private_key=" + privateKey - ipcConf += "\npublic_key=" + peerPublicKey - ipcConf += "\nendpoint=" + peerAddr.String() - if preSharedKey != "" { - ipcConf += "\npreshared_key=" + preSharedKey - } - var has4, has6 bool - for _, address := range localPrefixes { - if address.Addr().Is4() { - has4 = true - } else { - has6 = true + if len(options.Peers) > 0 { + for i, peer := range options.Peers { + var peerPublicKey, preSharedKey string + { + bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey) + if err != nil { + return nil, E.Cause(err, "decode public key for peer ", i) + } + peerPublicKey = hex.EncodeToString(bytes) + } + if peer.PreSharedKey != "" { + bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey) + if err != nil { + return nil, E.Cause(err, "decode pre shared key for peer ", i) + } + preSharedKey = hex.EncodeToString(bytes) + } + destination := peer.ServerOptions.Build() + ipcConf += "\npublic_key=" + peerPublicKey + ipcConf += "\nendpoint=" + destination.String() + if preSharedKey != "" { + ipcConf += "\npreshared_key=" + preSharedKey + } + if len(peer.AllowedIPs) == 0 { + return nil, E.New("missing allowed_ips for peer ", i) + } + for _, allowedIP := range peer.AllowedIPs { + ipcConf += "\nallowed_ip=" + allowedIP + } + if len(peer.Reserved) > 0 { + if len(peer.Reserved) != 3 { + return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved)) + } + copy(reserved[:], options.Reserved) + outbound.bind.SetReservedForEndpoint(destination, reserved) + } + } + } else { + var peerPublicKey, preSharedKey string + { + bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey) + if err != nil { + return nil, E.Cause(err, "decode peer public key") + } + peerPublicKey = hex.EncodeToString(bytes) + } + if options.PreSharedKey != "" { + bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey) + if err != nil { + return nil, E.Cause(err, "decode pre shared key") + } + preSharedKey = hex.EncodeToString(bytes) + } + ipcConf += "\npublic_key=" + peerPublicKey + ipcConf += "\nendpoint=" + options.ServerOptions.Build().String() + if preSharedKey != "" { + ipcConf += "\npreshared_key=" + preSharedKey + } + var has4, has6 bool + for _, address := range localPrefixes { + if address.Addr().Is4() { + has4 = true + } else { + has6 = true + } + } + if has4 { + ipcConf += "\nallowed_ip=0.0.0.0/0" + } + if has6 { + ipcConf += "\nallowed_ip=::/0" } - } - if has4 { - ipcConf += "\nallowed_ip=0.0.0.0/0" - } - if has6 { - ipcConf += "\nallowed_ip=::/0" } mtu := options.MTU if mtu == 0 { diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 570b2831..26fe9967 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -3,9 +3,12 @@ package wireguard import ( "context" "net" + "net/netip" "sync" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/wireguard-go/conn" @@ -14,24 +17,32 @@ import ( var _ conn.Bind = (*ClientBind)(nil) type ClientBind struct { - ctx context.Context - dialer N.Dialer - peerAddr M.Socksaddr - reserved [3]uint8 - connAccess sync.Mutex - conn *wireConn - done chan struct{} + ctx context.Context + dialer N.Dialer + reservedForEndpoint map[M.Socksaddr][3]uint8 + connAccess sync.Mutex + conn *wireConn + done chan struct{} + isConnect bool + connectAddr M.Socksaddr + reserved [3]uint8 } -func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind { +func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind { return &ClientBind{ - ctx: ctx, - dialer: dialer, - peerAddr: peerAddr, - reserved: reserved, + ctx: ctx, + dialer: dialer, + reservedForEndpoint: make(map[M.Socksaddr][3]uint8), + isConnect: isConnect, + connectAddr: connectAddr, + reserved: reserved, } } +func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) { + c.reservedForEndpoint[destination] = reserved +} + func (c *ClientBind) connect() (*wireConn, error) { serverConn := c.conn if serverConn != nil { @@ -53,13 +64,27 @@ func (c *ClientBind) connect() (*wireConn, error) { return serverConn, nil } } - udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr) - if err != nil { - return nil, &wireError{err} - } - c.conn = &wireConn{ - Conn: udpConn, - done: make(chan struct{}), + if c.isConnect { + udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr) + if err != nil { + return nil, &wireError{err} + } + c.conn = &wireConn{ + NetPacketConn: &bufio.UnbindPacketConn{ + ExtendedConn: bufio.NewExtendedConn(udpConn), + Addr: c.connectAddr, + }, + done: make(chan struct{}), + } + } else { + udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) + if err != nil { + return nil, &wireError{err} + } + c.conn = &wireConn{ + NetPacketConn: bufio.NewPacketConn(udpConn), + done: make(chan struct{}), + } } return c.conn, nil } @@ -80,7 +105,8 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { err = &wireError{err} return } - n, err = udpConn.Read(b) + buffer := buf.With(b) + destination, err := udpConn.ReadPacket(buffer) if err != nil { udpConn.Close() select { @@ -90,12 +116,16 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { } return } + n = buffer.Len() + if buffer.Start() > 0 { + copy(b, buffer.Bytes()) + } if n > 3 { b[1] = 0 b[2] = 0 b[3] = 0 } - ep = Endpoint(c.peerAddr) + ep = Endpoint(destination) return } @@ -127,12 +157,17 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error { if err != nil { return err } + destination := M.Socksaddr(ep.(Endpoint)) if len(b) > 3 { - b[1] = c.reserved[0] - b[2] = c.reserved[1] - b[3] = c.reserved[2] + reserved, loaded := c.reservedForEndpoint[destination] + if !loaded { + reserved = c.reserved + } + b[1] = reserved[0] + b[2] = reserved[1] + b[3] = reserved[2] } - _, err = udpConn.Write(b) + err = udpConn.WritePacket(buf.As(b), destination) if err != nil { udpConn.Close() } @@ -140,15 +175,11 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error { } func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return Endpoint(c.peerAddr), nil -} - -func (c *ClientBind) Endpoint() conn.Endpoint { - return Endpoint(c.peerAddr) + return Endpoint(M.ParseSocksaddr(s)), nil } type wireConn struct { - net.Conn + N.NetPacketConn access sync.Mutex done chan struct{} } @@ -161,7 +192,7 @@ func (w *wireConn) Close() error { return net.ErrClosed default: } - w.Conn.Close() + w.NetPacketConn.Close() close(w.done) return nil }