Support matching ip interface id for ip_cidr rule

This commit is contained in:
blip 2024-10-21 22:14:09 +00:00
parent 65c424f7eb
commit 0586dc4317

View file

@ -1,6 +1,8 @@
package rule package rule
import ( import (
"errors"
"net"
"net/netip" "net/netip"
"strings" "strings"
@ -14,18 +16,28 @@ var _ RuleItem = (*IPCIDRItem)(nil)
type IPCIDRItem struct { type IPCIDRItem struct {
ipSet *netipx.IPSet ipSet *netipx.IPSet
ipifSet ipInterfaceSet
isSource bool isSource bool
description string description string
} }
func NewIPCIDRItem(isSource bool, prefixStrings []string) (*IPCIDRItem, error) { func NewIPCIDRItem(isSource bool, prefixStrings []string) (*IPCIDRItem, error) {
var builder netipx.IPSetBuilder var builder netipx.IPSetBuilder
ipifs := make([]ipInterface, 0)
for i, prefixString := range prefixStrings { for i, prefixString := range prefixStrings {
prefix, err := netip.ParsePrefix(prefixString) prefix, err := netip.ParsePrefix(prefixString)
if err == nil { if err == nil {
builder.AddPrefix(prefix) builder.AddPrefix(prefix)
continue continue
} }
ipif, addrErr := parseIPInterface(prefixString)
if addrErr == nil {
ipifs = append(ipifs, ipif)
continue
}
if addrErr != errNotIPInterface {
return nil, E.Cause(addrErr, "parse [", i, "]")
}
addr, addrErr := netip.ParseAddr(prefixString) addr, addrErr := netip.ParseAddr(prefixString)
if addrErr == nil { if addrErr == nil {
builder.Add(addr) builder.Add(addr)
@ -52,6 +64,7 @@ func NewIPCIDRItem(isSource bool, prefixStrings []string) (*IPCIDRItem, error) {
} }
return &IPCIDRItem{ return &IPCIDRItem{
ipSet: ipSet, ipSet: ipSet,
ipifSet: ipInterfaceSet(ipifs),
isSource: isSource, isSource: isSource,
description: description, description: description,
}, nil }, nil
@ -74,16 +87,25 @@ func NewRawIPCIDRItem(isSource bool, ipSet *netipx.IPSet) *IPCIDRItem {
func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool {
if r.isSource || metadata.IPCIDRMatchSource { if r.isSource || metadata.IPCIDRMatchSource {
return r.ipSet.Contains(metadata.Source.Addr) if r.ipSet.Contains(metadata.Source.Addr) {
return true
}
return r.ipifSet.Contains(metadata.Source.Addr)
} }
if metadata.Destination.IsIP() { if metadata.Destination.IsIP() {
return r.ipSet.Contains(metadata.Destination.Addr) if r.ipSet.Contains(metadata.Destination.Addr) {
return true
}
return r.ipifSet.Contains(metadata.Destination.Addr)
} }
if len(metadata.DestinationAddresses) > 0 { if len(metadata.DestinationAddresses) > 0 {
for _, address := range metadata.DestinationAddresses { for _, address := range metadata.DestinationAddresses {
if r.ipSet.Contains(address) { if r.ipSet.Contains(address) {
return true return true
} }
if r.ipifSet.Contains(address) {
return true
}
} }
return false return false
} }
@ -93,3 +115,69 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool {
func (r *IPCIDRItem) String() string { func (r *IPCIDRItem) String() string {
return r.description return r.description
} }
type ipInterfaceSet []ipInterface
func (ipifs ipInterfaceSet) Contains(ip netip.Addr) bool {
for _, ipif := range ipifs {
if ipif.EqualInterfaceID(ip) {
return true
}
}
return false
}
type ipInterface struct {
id netip.Addr
bits int
}
var errNotIPInterface = errors.New("not in ::1/::ffff form")
func parseIPInterface(s string) (ipInterface, error) {
var ipif ipInterface
parts := strings.Split(s, "/")
if len(parts) != 2 || !strings.ContainsRune(parts[0], ':') || !strings.ContainsRune(parts[1], ':') {
return ipif, errNotIPInterface
}
idip, err := netip.ParseAddr(parts[0])
if err != nil {
return ipif, err
}
maskip, err := netip.ParseAddr(parts[1])
if err != nil {
return ipif, err
}
ms := maskip.AsSlice()
for i, b := range ms {
ms[i] = ^b
}
mask := net.IPMask(ms)
ones, bits := mask.Size()
if ones == 0 && bits == 0 || ones == idip.BitLen() {
return ipif, errors.New("invalid mask: " + parts[1])
}
ipif.id = maskNetwork(idip, ones)
ipif.bits = ones
return ipif, nil
}
func (ipif ipInterface) EqualInterfaceID(ip netip.Addr) bool {
idip := maskNetwork(ip, ipif.bits)
return ipif.id == idip
}
func maskNetwork(ip netip.Addr, bits int) netip.Addr {
n := bits / 8
m := bits % 8
s := ip.AsSlice()
for i := 0; i < n; i++ {
s[i] = 0
}
if m != 0 {
mask := byte((1 << (8 - m)) - 1)
s[n] &= mask
}
masked, _ := netip.AddrFromSlice(s)
return masked
}