mirror of
https://gitlab.nic.cz/knot/knot-dns.git
synced 2026-02-03 18:49:28 -05:00
630 lines
22 KiB
Python
630 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import ipaddress
|
|
import os
|
|
import random
|
|
import shutil
|
|
import socket
|
|
import errno
|
|
import psutil
|
|
import time
|
|
import dns.name
|
|
import dns.zone
|
|
import zone_generate
|
|
from dnstest.utils import *
|
|
from dnstest.context import Context
|
|
import dnstest.params as params
|
|
import dnstest.server
|
|
import dnstest.keys
|
|
import dnstest.zonefile
|
|
|
|
class Test(object):
|
|
'''Specification of DNS test topology'''
|
|
|
|
MAX_START_TRIES = 10
|
|
XDP_LOCK_FILE = "/tmp/knottest-xdp-lock"
|
|
LOCAL_ADDR_COMMON = {4: "127.0.0.1", 6: "::1"}
|
|
LOCAL_ADDR_MULTI = LOCAL_ADDR_COMMON
|
|
if params.addresses > 1:
|
|
idx = 1 + Context().job_id % params.addresses
|
|
LOCAL_ADDR_MULTI = {4: "127.0.1.%i" % idx, 6: "::1%i" % idx}
|
|
|
|
# Value of the last generated port.
|
|
last_port = None
|
|
|
|
# Number of unsuccessful starts of servers. Recursion protection.
|
|
start_tries = 0
|
|
|
|
rel_time = time.time()
|
|
start_time = 0
|
|
|
|
def __init__(self, address=None, tsig=None, stress=True, quic=False, tls=False):
|
|
if not os.path.exists(Context().out_dir):
|
|
raise Exception("Output directory doesn't exist")
|
|
|
|
self.out_dir = Context().out_dir
|
|
self.data_dir = Context().test_dir + "/data/"
|
|
self.zones_dir = self.out_dir + "/zones/"
|
|
self.quic = quic
|
|
self.tls = tls
|
|
|
|
if address == 4 or address == 6:
|
|
self.addr = Test.LOCAL_ADDR_COMMON[address]
|
|
elif address:
|
|
self.addr = address
|
|
else:
|
|
self.addr = Test.LOCAL_ADDR_MULTI[random.choice([4, 6])]
|
|
|
|
self.tsig = None
|
|
if tsig != None:
|
|
if type(tsig) is dnstest.keys.Tsig:
|
|
self.tsig = tsig
|
|
elif tsig:
|
|
self.tsig = dnstest.keys.Tsig()
|
|
elif random.choice([True, False]):
|
|
self.tsig = dnstest.keys.Tsig()
|
|
|
|
self.stress = stress
|
|
|
|
self.servers = set()
|
|
|
|
dnstest.server.Knot.count = 0
|
|
dnstest.server.Bind.count = 0
|
|
dnstest.server.Dummy.count = 0
|
|
|
|
Context().test = self
|
|
|
|
def _check_port(self, port):
|
|
if not port:
|
|
return False
|
|
|
|
family = socket.AF_INET
|
|
if ipaddress.ip_address(self.addr).version == 6:
|
|
family = socket.AF_INET6
|
|
|
|
try:
|
|
for stype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]:
|
|
s = socket.socket(family, stype)
|
|
s.bind((self.addr, port))
|
|
s.close()
|
|
except:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _gen_port(self):
|
|
min_port = 1500
|
|
max_port = 65000
|
|
|
|
port = Test.last_port
|
|
if port:
|
|
port = port + 1 if port < max_port else min_port
|
|
|
|
while not self._check_port(port):
|
|
port = random.randint(min_port, max_port)
|
|
|
|
Test.last_port = port
|
|
return port
|
|
|
|
def _gen_lock_file(self, srvname):
|
|
try:
|
|
fd = os.open(self.XDP_LOCK_FILE, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
|
with os.fdopen(fd, "w") as f:
|
|
f.write(str(os.getpid()) + "\n" + self.out_dir + "\n" + srvname + "\n")
|
|
return True
|
|
except OSError as e:
|
|
if e.errno == errno.EEXIST:
|
|
return False
|
|
|
|
def _clean_lock_file(self):
|
|
try:
|
|
with open(self.XDP_LOCK_FILE, "r") as f:
|
|
lines = f.read().splitlines()
|
|
if lines[0] != str(os.getpid()) or lines[1] != self.out_dir:
|
|
if psutil.pid_exists(int(lines[0])):
|
|
return
|
|
os.unlink(self.XDP_LOCK_FILE)
|
|
except OSError as e:
|
|
if e.errno != errno.ENOENT:
|
|
raise e
|
|
|
|
@property
|
|
def hostname(self):
|
|
hostname = socket.gethostname()
|
|
addrinfo = socket.getaddrinfo(hostname, 0, socket.AF_UNSPEC,
|
|
socket.SOCK_DGRAM, 0, socket.AI_CANONNAME)
|
|
return addrinfo[0][3] if addrinfo else hostname
|
|
|
|
def server(self, server, nsid=None, ident=None, version=None, \
|
|
valgrind=None, address=None, port=None, ctlport=None, \
|
|
xdp_enable=True, external=False, tsig=None, via=None):
|
|
if server == "knot":
|
|
srv = dnstest.server.Knot()
|
|
elif server == "bind":
|
|
srv = dnstest.server.Bind()
|
|
elif server == "dummy":
|
|
srv = dnstest.server.Dummy()
|
|
else:
|
|
raise Failed("Unsupported server '%s'" % server)
|
|
|
|
type(srv).count += 1
|
|
|
|
srv.data_dir = self.data_dir
|
|
|
|
srv.nsid = nsid
|
|
srv.ident = ident
|
|
srv.version = version
|
|
|
|
if address == 4 or address == 6:
|
|
srv.addr = Test.LOCAL_ADDR_COMMON[address]
|
|
elif address:
|
|
srv.addr = address
|
|
else:
|
|
srv.addr = self.addr
|
|
|
|
if via:
|
|
srv.via = srv.addr if via == True else via
|
|
|
|
if port:
|
|
srv.port = int(port)
|
|
srv.fixed_port = True
|
|
if not ctlport and server == "bind":
|
|
raise Failed("Missing remote control port '%s'" % server)
|
|
|
|
if ctlport:
|
|
srv.ctlport = int(ctlport)
|
|
|
|
if external:
|
|
srv.external = True
|
|
|
|
if self.tsig and not tsig:
|
|
srv.tsig = dnstest.keys.Tsig()
|
|
else:
|
|
srv.tsig = tsig
|
|
|
|
srv.tsig_test = self.tsig
|
|
|
|
srv.name = "%s%s" % (server, srv.count)
|
|
srv.dir = self.out_dir + "/" + srv.name
|
|
srv.fout = srv.dir + "/stdout"
|
|
srv.ferr = srv.dir + "/stderr"
|
|
srv.valgrind_log = srv.dir + "/valgrind"
|
|
srv.session_log = srv.dir + "/secrets.log"
|
|
srv.quic_log = srv.dir + "/quic.log"
|
|
srv.confile = srv.dir + "/%s.conf" % srv.name
|
|
|
|
xdp_enable = (params.xdp and xdp_enable and \
|
|
server == "knot" and \
|
|
srv.addr.startswith("::") and \
|
|
random.choice([False, True]) and \
|
|
self._gen_lock_file(srv.name))
|
|
if xdp_enable:
|
|
srv.xdp_port = 0
|
|
|
|
prepare_dir(srv.dir)
|
|
|
|
if srv.ctlkey:
|
|
srv.ctlkeyfile = srv.dir + "/%s.ctlkey" % srv.name
|
|
srv.ctlkey.dump(srv.ctlkeyfile)
|
|
|
|
if params.valgrind_bin and \
|
|
(valgrind or (valgrind == None and server == "knot")):
|
|
srv.valgrind = [params.valgrind_bin] + \
|
|
params.valgrind_flags.split() + \
|
|
["--log-file=%s" % srv.valgrind_log] + \
|
|
(["--undef-value-errors=no"] if xdp_enable else [])
|
|
suppressions_file = "%s/%s.supp" % (params.common_data_dir, server)
|
|
if os.path.isfile(suppressions_file):
|
|
srv.valgrind.append("--suppressions=%s" % suppressions_file)
|
|
|
|
self.servers.add(srv)
|
|
return srv
|
|
|
|
def server_remove(self, server=None):
|
|
# Remove server/servers from the test.
|
|
|
|
if server:
|
|
if server.listening():
|
|
server.stop()
|
|
self.servers.discard(server)
|
|
return
|
|
|
|
servers = [srv for srv in self.servers]
|
|
for server in servers:
|
|
self.server_remove(server)
|
|
|
|
def generate_conf(self):
|
|
# Next two loops can't be merged!
|
|
for server in self.servers:
|
|
if server.fixed_port:
|
|
continue
|
|
|
|
server.port = self._gen_port()
|
|
server.ctlport = self._gen_port()
|
|
if self.tls:
|
|
server.tls_port = self._gen_port()
|
|
server.quic_port = server.tls_port if self.quic else None
|
|
else:
|
|
server.quic_port = self._gen_port() if self.quic else None
|
|
server.xdp_port = self._gen_port() if server.xdp_port is not None else None
|
|
if server.xdp_port:
|
|
server.xdp_cover_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
|
server.xdp_cover_sock.bind((server.addr, server.xdp_port))
|
|
|
|
for server in self.servers:
|
|
server.gen_confile()
|
|
|
|
def start(self):
|
|
'''Start all test servers'''
|
|
|
|
if self.start_tries > Test.MAX_START_TRIES:
|
|
raise Failed("Can't start all servers")
|
|
|
|
self.start_time = time.monotonic()
|
|
self.start_tries += 1
|
|
|
|
self.generate_conf()
|
|
|
|
def srv_sort(server):
|
|
masters = 0
|
|
for z in server.zones:
|
|
if server.zones[z].masters: masters += 1
|
|
return masters
|
|
|
|
# Sort server list by number of masters. I.e. masters are preferred.
|
|
for server in sorted(self.servers, key=srv_sort):
|
|
if server.external:
|
|
continue
|
|
|
|
server.start(clean=True)
|
|
|
|
if not server.running():
|
|
raise Failed("Server '%s' not running" % server.name)
|
|
|
|
if not server.listening():
|
|
self.stop(kill=True)
|
|
self.start()
|
|
|
|
self.start_tries = 0
|
|
|
|
def stop(self, check=True, kill=False):
|
|
'''Stop all servers'''
|
|
|
|
for server in self.servers:
|
|
if server.external:
|
|
continue
|
|
|
|
if kill:
|
|
server.kill()
|
|
else:
|
|
server.stop(check=check)
|
|
|
|
def end(self):
|
|
'''Finish testing'''
|
|
|
|
self.stop(check=True)
|
|
self._clean_lock_file()
|
|
Context().test = None
|
|
|
|
def sleep(self, seconds):
|
|
time.sleep(seconds)
|
|
|
|
def pause(self, msg=None):
|
|
input(test_info() + ((", " + msg) if msg else ""))
|
|
|
|
def rel_sleep(self, seconds):
|
|
timenow = time.time()
|
|
res = timenow - self.rel_time
|
|
if seconds == 0:
|
|
self.rel_time = timenow
|
|
else:
|
|
self.rel_time += seconds
|
|
to_wait = self.rel_time - timenow
|
|
if to_wait > 0:
|
|
self.sleep(to_wait)
|
|
return res
|
|
|
|
def uptime(self):
|
|
return time.monotonic() - self.start_time
|
|
|
|
def zone(self, name, file_name=None, storage=None, version=None, exists=True):
|
|
|
|
zone = dnstest.zonefile.ZoneFile(self.zones_dir)
|
|
zone.set_name(name)
|
|
|
|
if storage == ".":
|
|
src_dir = self.data_dir
|
|
elif storage:
|
|
src_dir = storage
|
|
else:
|
|
src_dir = params.common_data_dir
|
|
|
|
zone.set_file(file_name=file_name, storage=src_dir, version=version,
|
|
exists=exists)
|
|
|
|
return [zone]
|
|
|
|
def zone_rnd(self, number, dnssec=None, nsec3=None, records=None, serial=None,
|
|
ttl=None, exists=True):
|
|
zones = list()
|
|
|
|
# Generate unique zone names.
|
|
names = zone_generate.main(["-n", number]).split()
|
|
for name in names:
|
|
zone = dnstest.zonefile.ZoneFile(self.zones_dir)
|
|
zone.set_name(name)
|
|
if exists:
|
|
zone.gen_file(dnssec=dnssec, nsec3=nsec3, records=records,
|
|
serial=serial, ttl=ttl)
|
|
else:
|
|
zone.file_name = zone.name + "zone"
|
|
zones.append(zone)
|
|
|
|
return zones
|
|
|
|
def link(self, zones, master, slave=None, ddns=False, ixfr=False, journal_content="changes"):
|
|
for zone in zones:
|
|
if master not in self.servers:
|
|
raise Failed("Server is out of testing scope")
|
|
master.set_master(zone, slave, ddns, ixfr, journal_content)
|
|
|
|
if slave:
|
|
if slave not in self.servers:
|
|
raise Failed("Server is out of testing scope")
|
|
slave.set_slave(zone, master, ddns, ixfr, journal_content)
|
|
|
|
def _canonize_record(self, rtype, record):
|
|
''':-(('''
|
|
item_owner_split = record.strip().split(" ", 1)
|
|
|
|
if rtype in [dns.rdatatype.SOA, dns.rdatatype.NS, dns.rdatatype.CNAME, \
|
|
dns.rdatatype.PTR, dns.rdatatype.DNAME, dns.rdatatype.SOA, \
|
|
dns.rdatatype.MINFO, dns.rdatatype.RP, dns.rdatatype.MX, \
|
|
dns.rdatatype.AFSDB, dns.rdatatype.RT, dns.rdatatype.KX, \
|
|
dns.rdatatype.SRV, dns.rdatatype.NSEC]:
|
|
item_data = item_owner_split[1].lower()
|
|
# pythondns prints signer's as @
|
|
#elif rtype in [dns.rdatatype.RRSIG]:
|
|
# item_dname_split = item_owner_split[1].rsplit(" ", 1)
|
|
# item_data = item_dname_split[0].lower() + " " + item_dname_split[1]
|
|
elif rtype in [dns.rdatatype.NAPTR]:
|
|
item_dname_split = item_owner_split[1].rsplit(" ", 1)
|
|
item_data = item_dname_split[0] + " " + item_dname_split[1].lower()
|
|
else:
|
|
item_data = item_owner_split[1]
|
|
|
|
return item_owner_split[0].lower() + " " + item_data
|
|
|
|
def _axfr_records(self, resp, zone, no_rrsig_rdata):
|
|
unique = set()
|
|
records = list()
|
|
|
|
for msg in resp.resp:
|
|
for rrset in msg.answer:
|
|
rrs = rrset.to_text(origin=dns.name.from_text(zone.name),
|
|
relativize=False).split("\n")
|
|
|
|
for rr in rrs:
|
|
item_lower = self._canonize_record(rrset.rdtype, rr.strip())
|
|
|
|
if no_rrsig_rdata and rrset.rdtype == dns.rdatatype.SOA:
|
|
# Reset SOA serial.
|
|
soa_split = item_lower.split()
|
|
soa_split[6] = "0"
|
|
item_lower = " ".join(soa_split)
|
|
|
|
if no_rrsig_rdata and rrset.rdtype == dns.rdatatype.RRSIG:
|
|
# Trim RRSIG signature part.
|
|
rrsig_split = item_lower.split(None, 5)
|
|
item_lower = " ".join(rrsig_split[:5])
|
|
|
|
elif item_lower in unique and rrset.rdtype != dns.rdatatype.SOA:
|
|
detail_log("!Duplicate record server='%s':" % server.name)
|
|
detail_log(" %s" % item_lower)
|
|
continue
|
|
|
|
unique.add(item_lower)
|
|
records.append(item_lower)
|
|
|
|
return unique, records
|
|
|
|
def _axfr_diff_resp(self, unique1, rrset1s, unique2, rrsets2, server1, server2):
|
|
diff1 = sorted(list(unique1 - unique2))
|
|
if diff1:
|
|
set_err("AXFR DIFF")
|
|
detail_log("!Extra records server='%s':" % server1.name)
|
|
for record in diff1:
|
|
detail_log(" %s" % record)
|
|
|
|
diff2 = sorted(list(unique2 - unique1))
|
|
if diff2:
|
|
set_err("AXFR DIFF")
|
|
detail_log("!Extra records server='%s':" % server2.name)
|
|
for record in diff2:
|
|
detail_log(" %s" % record)
|
|
|
|
def _axfr_diff(self, server1, server2, zone, no_rrsig_rdata):
|
|
unique1, rrsets1 = self._axfr_records(server1.dig(zone.name, "AXFR", log_no_sep=True), zone, no_rrsig_rdata)
|
|
unique2, rrsets2 = self._axfr_records(server2.dig(zone.name, "AXFR", log_no_sep=True), zone, no_rrsig_rdata)
|
|
|
|
self._axfr_diff_resp(unique1, rrsets1, unique2, rrsets2, server1, server2)
|
|
|
|
class IxfrChange():
|
|
def __init__(self):
|
|
self.soa_old = None
|
|
self.soa_new = None
|
|
self.removed = list()
|
|
self.added = list()
|
|
|
|
def rem(self, record):
|
|
self.removed.append(record)
|
|
|
|
def add(self, record):
|
|
self.added.append(record)
|
|
|
|
def sort(self):
|
|
self.removed.sort()
|
|
self.added.sort()
|
|
|
|
def cmp(self, other):
|
|
if self.soa_old != other.soa_old:
|
|
set_err("IXFR CHANGE DIFF")
|
|
detail_log("!Different remove SOA:")
|
|
detail_log(" %s" % self.soa_old)
|
|
detail_log(" %s" % other.soa_old)
|
|
|
|
if len(self.removed) != len(other.removed):
|
|
set_err("IXFR CHANGE DIFF")
|
|
detail_log("!Number of remove records:")
|
|
detail_log(" (%i) != (%i)" %
|
|
(len(self.removed), len(other.removed)))
|
|
rec1 = ""
|
|
rec2 = ""
|
|
for rem1 in self.removed:
|
|
if rem1 not in other.removed:
|
|
rec1 += '\n\t' + rem1
|
|
for rem2 in other.removed:
|
|
if rem2 not in self.removed:
|
|
rec2 += '\n\t' + rem2
|
|
if rec1 != "" or rec2 != "":
|
|
detail_log("!Extra remove records self:%s" % rec1)
|
|
detail_log("!Extra remove records other:%s" % rec2)
|
|
|
|
if self.soa_new != other.soa_new:
|
|
set_err("IXFR CHANGE DIFF")
|
|
detail_log("!Different add SOA:")
|
|
detail_log(" %s" % self.soa_new)
|
|
detail_log(" %s" % other.soa_new)
|
|
|
|
if len(self.added) != len(other.added):
|
|
set_err("IXFR CHANGE DIFF")
|
|
detail_log("!Number of add records:")
|
|
detail_log(" (%i) != (%i)" %
|
|
(len(self.added), len(other.added)))
|
|
|
|
rec1 = ""
|
|
rec2 = ""
|
|
for add1 in self.added:
|
|
if add1 not in other.added:
|
|
rec1 += '\n\t' + add1
|
|
for add2 in other.added:
|
|
if add2 not in self.added:
|
|
rec2 += '\n\t' + add2
|
|
if rec1 != "" or rec2 != "":
|
|
detail_log("!Extra add records self:%s" % rec1)
|
|
detail_log("!Extra add records other:%s" % rec2)
|
|
|
|
def _ixfr_changes(self, server, zone, serial, udp):
|
|
soa = None
|
|
changes = list()
|
|
|
|
resp = server.dig(zone.name, "IXFR", log_no_sep=True, serial=serial,
|
|
udp=udp)
|
|
|
|
change = Test.IxfrChange()
|
|
for msg in resp.resp:
|
|
for rrset in msg.answer:
|
|
records = rrset.to_text(origin=dns.name.from_text(zone.name),
|
|
relativize=False).split("\n")
|
|
for record in records:
|
|
item_lower = self._canonize_record(rrset.rdtype, record.strip())
|
|
|
|
if rrset.rdtype == dns.rdatatype.SOA:
|
|
if not soa: # IXFR leading SOA.
|
|
soa = item_lower
|
|
continue
|
|
|
|
if not change.soa_old: # Remove change section.
|
|
change.soa_old = item_lower
|
|
continue
|
|
|
|
if not change.soa_new: # Add change section.
|
|
change.soa_new = item_lower
|
|
continue
|
|
|
|
# Next change -> store the actual one.
|
|
change.sort()
|
|
changes.append(change)
|
|
change = Test.IxfrChange()
|
|
change.soa_old = item_lower
|
|
else:
|
|
if not soa:
|
|
set_err("IXFR FORMAT")
|
|
detail_log("!Missing leading SOA zone='%s', " \
|
|
"server='%s' before:" %
|
|
(zone.name, server.name))
|
|
detail_log(" %s" % item_lower)
|
|
|
|
if not change.soa_old:
|
|
set_err("IXFR FORMAT")
|
|
detail_log("!Expected SOA zone='%s', server='%s' " \
|
|
"before:" %
|
|
(zone.name, server.name))
|
|
detail_log(" %s" % item_lower)
|
|
|
|
if not change.soa_new:
|
|
change.rem(item_lower)
|
|
else:
|
|
change.add(item_lower)
|
|
|
|
if not soa:
|
|
set_err("IXFR FORMAT")
|
|
detail_log("!Missing leading SOA zone='%s', server='%s'" %
|
|
(zone.name, server.name))
|
|
elif change.removed or change.added:
|
|
set_err("IXFR FORMAT")
|
|
detail_log("!Missing trailing SOA zone='%s', server='%s'" %
|
|
(zone.name, server.name))
|
|
elif change.soa_old and change.soa_old != soa:
|
|
set_err("IXFR FORMAT")
|
|
detail_log("!Trailing SOA differs from the leading one " \
|
|
"zone='%s', server='%s'" %
|
|
(zone.name, server.name))
|
|
|
|
return soa, changes
|
|
|
|
def _ixfr_diff(self, server1, server2, zone, serial, udp):
|
|
soa1, changes1 = self._ixfr_changes(server1, zone, serial, udp)
|
|
soa2, changes2 = self._ixfr_changes(server2, zone, serial, udp)
|
|
|
|
if soa1 != soa2:
|
|
set_err("IXFR DIFF")
|
|
detail_log("!Different leading SOA records:")
|
|
detail_log(" %s" % soa1)
|
|
detail_log(" %s" % soa2)
|
|
|
|
if len(changes1) != len(changes2):
|
|
set_err("IXFR DIFF")
|
|
detail_log("!Number of changes:")
|
|
detail_log(" (server='%s', num='%i') != (server='%s', num='%i')" %
|
|
(server1.name, len(changes1),
|
|
server2.name, len(changes2)))
|
|
|
|
for change1, change2 in zip(changes1, changes2):
|
|
change1.cmp(change2)
|
|
|
|
def xfr_diff(self, server1, server2, zones, serials=None, udp=False, no_rrsig_rdata=False):
|
|
for zone in zones:
|
|
check_log("CHECK %sXFR DIFF %s %s<->%s" % ("I" if serials else "A",
|
|
zone.name, server1.name, server2.name))
|
|
if serials:
|
|
if no_rrsig_rdata:
|
|
set_err("RRSIG rdata and SOA serial skipping not implemented for IXFR diff")
|
|
self._ixfr_diff(server1, server2, zone, serials[zone.name], udp)
|
|
else:
|
|
self._axfr_diff(server1, server2, zone, no_rrsig_rdata)
|
|
|
|
detail_log(SEP)
|
|
|
|
def axfr_diff_resp(self, resp1, resp2, server1, server2, zone, no_rrsig_rdata=False):
|
|
unique1, rrsets1 = self._axfr_records(resp1, zone, no_rrsig_rdata)
|
|
unique2, rrsets2 = self._axfr_records(resp2, zone, no_rrsig_rdata)
|
|
|
|
self._axfr_diff_resp(unique1, rrsets1, unique2, rrsets2, server1, server2)
|
|
|
|
def check_axfr_style_ixfr(self, server, zone_name, serial):
|
|
resp_ixfr = server.dig(zone_name, "IXFR", serial=serial)
|
|
resp_axfr = server.dig(zone_name, "AXFR")
|
|
|
|
resp_ixfr.check_axfr_style_ixfr(resp_axfr)
|
|
|