Fix a concurrency issue in fakedns

In rare cases different domains asking for dns will return the same IP. Add a mutex.
This commit is contained in:
yuhan6665 2022-03-12 10:35:06 -05:00
parent 03ade23022
commit c1a54ae58e
2 changed files with 35 additions and 1 deletions

View file

@ -5,6 +5,7 @@ import (
"math" "math"
"math/big" "math/big"
gonet "net" gonet "net"
"sync"
"time" "time"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
@ -16,6 +17,7 @@ import (
type Holder struct { type Holder struct {
domainToIP cache.Lru domainToIP cache.Lru
ipRange *gonet.IPNet ipRange *gonet.IPNet
mu *sync.Mutex
config *FakeDnsPool config *FakeDnsPool
} }
@ -49,6 +51,7 @@ func (fkdns *Holder) Start() error {
func (fkdns *Holder) Close() error { func (fkdns *Holder) Close() error {
fkdns.domainToIP = nil fkdns.domainToIP = nil
fkdns.ipRange = nil fkdns.ipRange = nil
fkdns.mu = nil
return nil return nil
} }
@ -67,7 +70,7 @@ func NewFakeDNSHolder() (*Holder, error) {
} }
func NewFakeDNSHolderConfigOnly(conf *FakeDnsPool) (*Holder, error) { func NewFakeDNSHolderConfigOnly(conf *FakeDnsPool) (*Holder, error) {
return &Holder{nil, nil, conf}, nil return &Holder{nil, nil, nil, conf}, nil
} }
func (fkdns *Holder) initializeFromConfig() error { func (fkdns *Holder) initializeFromConfig() error {
@ -89,11 +92,14 @@ func (fkdns *Holder) initialize(ipPoolCidr string, lruSize int) error {
} }
fkdns.domainToIP = cache.NewLru(lruSize) fkdns.domainToIP = cache.NewLru(lruSize)
fkdns.ipRange = ipRange fkdns.ipRange = ipRange
fkdns.mu = new(sync.Mutex)
return nil return nil
} }
// GetFakeIPForDomain checks and generates a fake IP for a domain name // GetFakeIPForDomain checks and generates a fake IP for a domain name
func (fkdns *Holder) GetFakeIPForDomain(domain string) []net.Address { func (fkdns *Holder) GetFakeIPForDomain(domain string) []net.Address {
fkdns.mu.Lock()
defer fkdns.mu.Unlock()
if v, ok := fkdns.domainToIP.Get(domain); ok { if v, ok := fkdns.domainToIP.Get(domain); ok {
return []net.Address{v.(net.Address)} return []net.Address{v.(net.Address)}
} }

View file

@ -2,10 +2,13 @@ package fakedns
import ( import (
gonet "net" gonet "net"
"strconv"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/uuid" "github.com/xtls/xray-core/common/uuid"
@ -66,6 +69,31 @@ func TestFakeDnsHolderCreateMappingManySingleDomain(t *testing.T) {
assert.Equal(t, addr[0].IP().String(), addr2[0].IP().String()) assert.Equal(t, addr[0].IP().String(), addr2[0].IP().String())
} }
func TestGetFakeIPForDomainConcurrently(t *testing.T) {
fkdns, err := NewFakeDNSHolder()
common.Must(err)
total := 200
addr := make([][]net.Address, total)
var errg errgroup.Group
for i := 0; i < total; i++ {
errg.Go(testGetFakeIP(i, addr, fkdns))
}
errg.Wait()
for i := 0; i < total; i++ {
for j := i + 1; j < total; j++ {
assert.NotEqual(t, addr[i][0].IP().String(), addr[j][0].IP().String())
}
}
}
func testGetFakeIP(index int, addr [][]net.Address, fkdns *Holder) func() error {
return func() error {
addr[index] = fkdns.GetFakeIPForDomain("fakednstest" + strconv.Itoa(index) + ".example.com")
return nil
}
}
func TestFakeDnsHolderCreateMappingAndRollOver(t *testing.T) { func TestFakeDnsHolderCreateMappingAndRollOver(t *testing.T) {
fkdns, err := NewFakeDNSHolderConfigOnly(&FakeDnsPool{ fkdns, err := NewFakeDNSHolderConfigOnly(&FakeDnsPool{
IpPool: dns.FakeIPv4Pool, IpPool: dns.FakeIPv4Pool,