mirror of
https://gitlab.nic.cz/knot/knot-dns.git
synced 2026-02-03 18:49:28 -05:00
472 lines
17 KiB
Python
472 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import binascii
|
|
import dns.name
|
|
import collections.abc
|
|
import itertools
|
|
from dnstest.utils import *
|
|
|
|
class Response(object):
|
|
'''Dig output context.'''
|
|
|
|
def __init__(self, server, response, query, args):
|
|
self.resp = response
|
|
self.query = query
|
|
self.args = args
|
|
self.srv = server
|
|
|
|
self.rname = dns.name.from_text(self.args["rname"])
|
|
|
|
if type(self.args["rtype"]) is str:
|
|
self.rtype = dns.rdatatype.from_text(self.args["rtype"])
|
|
else:
|
|
self.rtype = self.args["rtype"]
|
|
|
|
if type(self.args["rclass"]) is str:
|
|
self.rclass = dns.rdataclass.from_text(self.args["rclass"])
|
|
else:
|
|
self.rclass = self.args["rclass"]
|
|
|
|
def _check_question(self):
|
|
question = self.resp.question[0]
|
|
compare(question.name, self.rname, "QNAME")
|
|
compare(question.rdclass, self.rclass, "QCLASS")
|
|
compare(question.rdtype, self.rtype, "QTYPE")
|
|
|
|
def _check_flags(self, flags, noflags):
|
|
flag_names = flags.split()
|
|
for flag in flag_names:
|
|
if flag == "Z":
|
|
flag_val = 64
|
|
else:
|
|
flag_val = dns.flags.from_text(flag)
|
|
isset(self.resp.flags & flag_val, "%s FLAG" % flag)
|
|
|
|
flag_names = noflags.split()
|
|
for flag in flag_names:
|
|
if flag == "Z":
|
|
flag_val = 64
|
|
else:
|
|
flag_val = dns.flags.from_text(flag)
|
|
isset(not(self.resp.flags & flag_val), "NO %s FLAG" % flag)
|
|
|
|
def _check_eflags(self, eflags, noeflags):
|
|
eflag_names = eflags.split()
|
|
for flag in eflag_names:
|
|
flag_val = dns.flags.edns_from_text(flag)
|
|
isset(self.resp.ednsflags & flag_val, "%s FLAG" % flag)
|
|
|
|
eflag_names = noeflags.split()
|
|
for flag in eflag_names:
|
|
flag_val = dns.flags.edns_from_text(flag)
|
|
isset(not(self.resp.ednsflags & flag_val), "NO %s FLAG" % flag)
|
|
|
|
def _check_rr(self, expect, section=None, rname=None, rtype=None):
|
|
"""
|
|
Check for a presence of a RR with given name and type.
|
|
"""
|
|
if section is None:
|
|
section = "answer"
|
|
if rname is not None:
|
|
rname = dns.name.from_text(rname)
|
|
if rtype is not None:
|
|
rtype = dns.rdatatype.from_text(rtype)
|
|
|
|
assert section in ["answer", "authority", "additional"]
|
|
assert rname or rtype
|
|
|
|
section_rrsets = getattr(self.resp, section)
|
|
for rrset in section_rrsets:
|
|
if rname is not None and rname != rrset.name:
|
|
continue
|
|
if rtype is not None and rtype != rrset.rdtype:
|
|
continue
|
|
found = True
|
|
break
|
|
else:
|
|
found = False
|
|
|
|
if found != expect:
|
|
set_err("CHECK RR PRESENCE")
|
|
check_log("ERROR: CHECK RR PRESENCE")
|
|
detail_log("!%s RR name=%s type=%s section=%s" % (
|
|
"Missing" if expect else "Extra",
|
|
str(rname) if rname is not None else "",
|
|
dns.rdatatype.to_text(rtype) if rtype is not None else "",
|
|
section
|
|
))
|
|
detail_log(SEP)
|
|
|
|
def check_rr(self, section=None, rname=None, rtype=None):
|
|
self._check_rr(True, section, rname, rtype)
|
|
|
|
def check_no_rr(self, section=None, rname=None, rtype=None):
|
|
self._check_rr(False, section, rname, rtype)
|
|
|
|
def check_record(self, section="answer", name=None, rtype=None, ttl=None, rdata=None,
|
|
nordata=None):
|
|
'''Checks given section for particular record/rdata'''
|
|
|
|
sect = getattr(self.resp, section)
|
|
if not rtype:
|
|
rtype = self.rtype
|
|
elif type(rtype) is str:
|
|
rtype = dns.rdatatype.from_text(rtype)
|
|
|
|
# Check rdata presence.
|
|
if rdata:
|
|
# We work with just one rdata with TTL=0 (this TTL is not used).
|
|
rrset = dns.rdataset.from_text(self.rclass, rtype, 0, rdata)
|
|
ref = str(list(rrset)[0])
|
|
|
|
# Check answer section if contains reference rdata.
|
|
for data in sect:
|
|
if name is not None and str(data.name) != str(name):
|
|
continue
|
|
for rd in data.to_rdataset():
|
|
# Compare Rdataset instances.
|
|
if str(rd) == ref:
|
|
# Check CLASS.
|
|
compare(data.rdclass, self.rclass, "CLASS")
|
|
# Check TYPE.
|
|
compare(data.rdtype, rtype, "TYPE")
|
|
# Check TTL if specified.
|
|
if ttl != None:
|
|
compare(data.ttl, int(ttl), "TTL")
|
|
return
|
|
else:
|
|
set_err("CHECK RDATA")
|
|
check_log("ERROR: CHECK RDATA")
|
|
detail_log("!Missing data in %s section:" % section)
|
|
detail_log(" %s" % ref)
|
|
detail_log(SEP)
|
|
# Check rdata absence.
|
|
if nordata:
|
|
# We work with just one rdata with TTL=0 (this TTL is not used).
|
|
rrset = dns.rdataset.from_text(self.rclass, rtype, 0, nordata)
|
|
ref = str(list(rrset)[0])
|
|
|
|
# Check answer section if contains reference rdata.
|
|
for data in sect:
|
|
for rd in data.to_rdataset():
|
|
# Compare Rdataset instances.
|
|
if str(rd) == ref and data.rdtype == rtype:
|
|
set_err("CHECK RDATA")
|
|
check_log("ERROR: CHECK RDATA")
|
|
detail_log("!Unwanted data in %s section:" % section)
|
|
detail_log(" %s" % ref)
|
|
detail_log(SEP)
|
|
return
|
|
|
|
def check(self, rdata=None, ttl=None, rcode="NOERROR", nordata=None,
|
|
edns_version=None, flags="", noflags="", eflags="", noeflags=""):
|
|
'''Flags are text strings separated by whitespace character'''
|
|
|
|
self._check_flags(flags, noflags)
|
|
self._check_eflags(eflags, noeflags)
|
|
self._check_question()
|
|
|
|
# Check EDNS version.
|
|
edns_ver = int(edns_version) if edns_version != None else self.query.edns
|
|
compare(edns_ver, self.resp.edns, "EDNS VERSION")
|
|
|
|
# Check rcode.
|
|
if rcode is not None:
|
|
if type(rcode) is not str:
|
|
rc = dns.rcode.to_text(rcode)
|
|
else:
|
|
rc = rcode
|
|
compare(dns.rcode.to_text(self.resp.rcode()), rc, "RCODE")
|
|
|
|
# Check rdata only if NOERROR.
|
|
if rcode is None or rc == "NOERROR":
|
|
self.check_record(section="answer", rtype=self.rtype, ttl=ttl,
|
|
rdata=rdata, nordata=nordata)
|
|
|
|
def check_xfr(self, rcode="NOERROR"):
|
|
'''Checks XFR message'''
|
|
|
|
self.resp, iter_copy = itertools.tee(self.resp)
|
|
|
|
if type(rcode) is not str:
|
|
rc = dns.rcode.to_text(rcode)
|
|
else:
|
|
rc = rcode
|
|
|
|
# Get the first message.
|
|
try:
|
|
for msg in iter_copy:
|
|
question = msg.question[0]
|
|
compare(question.rdclass, self.rclass, "QCLASS")
|
|
compare(question.rdtype, self.rtype, "QTYPE")
|
|
|
|
# Check rcode.
|
|
compare(dns.rcode.to_text(msg.rcode()), rc, "RCODE")
|
|
|
|
# Check the first message only.
|
|
break
|
|
except dns.query.TransferError as e:
|
|
compare(dns.rcode.to_text(e.rcode), rc, "RCODE")
|
|
|
|
# Checks whether the transfer is an AXFR-style IXFR
|
|
def check_axfr_style_ixfr(self, axfr=None):
|
|
# 1) QTYPE == IXFR && RCODE == NOERROR
|
|
self.check_xfr()
|
|
|
|
# 2) Check if Answer contains AXFR data (first SOA, second non-SOA)
|
|
rr_count = 0
|
|
|
|
self.resp, iter_copy = itertools.tee(self.resp)
|
|
for msg in iter_copy:
|
|
for rrset in msg.answer:
|
|
for rr in rrset:
|
|
if rr_count == 0:
|
|
if rr.rdtype != dns.rdatatype.SOA:
|
|
set_err("First RR is not SOA")
|
|
return
|
|
elif rr_count == 1:
|
|
if rr.rdtype == dns.rdatatype.SOA:
|
|
set_err("Second RR is SOA")
|
|
return
|
|
|
|
rr_count += 1
|
|
|
|
# 3) Check that number of records in IXFR and AXFR is the same
|
|
if axfr:
|
|
compare(self.count("ANY"), axfr.count("ANY"),
|
|
"Count of RRs in Answer")
|
|
|
|
def check_nsid(self, nsid=None):
|
|
compare(self.resp.edns, 0, "EDNS VERSION")
|
|
|
|
nsid_count = 0
|
|
nsid_opt = None
|
|
for opt in self.resp.options:
|
|
if opt.otype == dns.edns.NSID:
|
|
nsid_count += 1
|
|
nsid_opt = opt
|
|
|
|
compare(nsid_count, 1 if nsid else 0, "NUMBER OF NSID OPTIONS")
|
|
if nsid and nsid_opt:
|
|
val = nsid_opt.to_wire()
|
|
if nsid[:2] == "0x":
|
|
compare(binascii.hexlify(val).decode('ascii'),
|
|
nsid[2:], "HEX NSID")
|
|
else:
|
|
compare(val.decode('ascii'), nsid, "TXT NSID")
|
|
|
|
def diff(self, resp, flags=True, answer=True, authority=True,
|
|
additional=False, rcode=True):
|
|
'''Compares specified response sections against another response'''
|
|
|
|
if rcode:
|
|
compare(self.resp.rcode(), resp.resp.rcode(), "RCODE")
|
|
|
|
if flags:
|
|
compare(dns.flags.to_text(self.resp.flags),
|
|
dns.flags.to_text(resp.resp.flags), "FLAGS")
|
|
compare(dns.flags.edns_to_text(self.resp.ednsflags),
|
|
dns.flags.edns_to_text(resp.resp.ednsflags), "EDNS FLAGS")
|
|
if answer:
|
|
compare_sections(self.resp.answer, self.srv.name,
|
|
resp.resp.answer, resp.srv.name,
|
|
"ANSWER")
|
|
if authority:
|
|
compare_sections(self.resp.authority, self.srv.name,
|
|
resp.resp.authority, resp.srv.name,
|
|
"AUTHORITY")
|
|
if additional:
|
|
compare_sections(self.resp.additional, self.srv.name,
|
|
resp.resp.additional, resp.srv.name,
|
|
"ADDITIONAL")
|
|
|
|
def cmp(self, server, flags=True, answer=True, authority=True,
|
|
additional=False, rcode=True):
|
|
'''
|
|
Asks server for the same question an compares specified sections
|
|
|
|
The Additional section is not compared by default.
|
|
'''
|
|
|
|
resp = server.dig(**self.args)
|
|
self.diff(resp, flags, answer, authority, additional, rcode)
|
|
|
|
def count(self, rtype=None, section="answer"):
|
|
'''Returns number of records of given type in specified section'''
|
|
|
|
if not rtype:
|
|
rtype = self.rtype
|
|
elif type(rtype) is str:
|
|
rtype = dns.rdatatype.from_text(rtype)
|
|
|
|
cnt = 0
|
|
if isinstance(self.resp, collections.abc.Iterable):
|
|
self.resp, iter_copy = itertools.tee(self.resp)
|
|
for msg in iter_copy:
|
|
if not section or section == "answer":
|
|
sect = msg.answer
|
|
elif section == "additional":
|
|
sect = msg.additional
|
|
elif section == "authority":
|
|
sect = msg.authority
|
|
|
|
for rrset in sect:
|
|
if rrset.rdtype == rtype or rtype == dns.rdatatype.ANY:
|
|
cnt += len(rrset)
|
|
else:
|
|
if not section or section == "answer":
|
|
sect = self.resp.answer
|
|
elif section == "additional":
|
|
sect = self.resp.additional
|
|
elif section == "authority":
|
|
sect = self.resp.authority
|
|
|
|
for rrset in sect:
|
|
if rrset.rdtype == rtype or rtype == dns.rdatatype.ANY:
|
|
cnt += len(rrset)
|
|
|
|
return cnt
|
|
|
|
def check_count(self, expected, rtype=None, section="answer"):
|
|
found = self.count(rtype, section)
|
|
if found != expected:
|
|
set_err("CHECK RR COUNT")
|
|
check_log("ERROR: CHECK RR COUNT")
|
|
detail_log("!Invalid RR count type=%s section=%s %d!=%d" % (
|
|
rtype if rtype is not None else "",
|
|
section, found, expected
|
|
))
|
|
detail_log(SEP)
|
|
|
|
def check_counts(self, answer=None, authority=None, additional=None):
|
|
for section in ["answer", "authority", "additional"]:
|
|
expected = locals()[section]
|
|
if expected is not None:
|
|
section_count = self.count(rtype="ANY", section=section)
|
|
if section_count != expected:
|
|
set_err("CHECK RR COUNT")
|
|
check_log("ERROR: CHECK RR COUNT")
|
|
detail_log("!RR count %i != %i in section=%s" %
|
|
(section_count, expected, section))
|
|
detail_log(SEP)
|
|
|
|
def check_empty(self, section="answer"):
|
|
self.check_count(0, None, section)
|
|
|
|
def msg_count(self):
|
|
'''Returns number of response messages'''
|
|
|
|
cnt = 0
|
|
self.resp, iter_copy = itertools.tee(self.resp)
|
|
for msg in iter_copy:
|
|
cnt += 1
|
|
|
|
return cnt
|
|
|
|
def soa_serial(self, section="answer"):
|
|
if self.count("SOA", section) != 1:
|
|
set_err("CHECK SOA PRESENCE")
|
|
detail_log("SOA not present in response section " + section)
|
|
return 0
|
|
|
|
if not section or section == "answer":
|
|
sect = self.resp.answer
|
|
elif section == "additional":
|
|
sect = self.resp.additional
|
|
elif section == "authority":
|
|
sect = self.resp.authority
|
|
|
|
soa = str(sect[0].to_rdataset())
|
|
return int(soa.split()[5])
|
|
|
|
def check_soa_serial(self, expect, section="answer"):
|
|
found = self.soa_serial(section)
|
|
if found != expect:
|
|
set_err("CHECK SOA SERIAL")
|
|
detail_log("SOA serial different than expected: %d != %d" % (found, expect))
|
|
|
|
def check_auth_soa_ttl(self, dnssec=False):
|
|
if self.count("SOA", "authority") != 1:
|
|
set_err("CHECK SOA PRESENCE")
|
|
detail_log("SOA not present in response section authority")
|
|
|
|
soa = str(self.resp.authority[0].to_rdataset()).split()
|
|
if int(soa[0]) > int(soa[9]):
|
|
set_err("AUTHORITY SOA TTL")
|
|
detail_log("SOA TTL %d is higher that its minimum-ttl %d" % (int(soa[0]), int(soa[9])))
|
|
|
|
if dnssec:
|
|
rrsig = None
|
|
for record in self.resp.authority:
|
|
candidate = str(record.to_rdataset()).split()
|
|
if candidate[2] == "RRSIG" and candidate[3] == "SOA":
|
|
rrsig = candidate
|
|
|
|
if rrsig is None:
|
|
set_err("CHECK RRSIG PRESENCE")
|
|
detail_log("SOA not signed in response section authority")
|
|
|
|
elif int(rrsig[0]) != int(soa[0]):
|
|
set_err("AUTHORITY SOA RRSIG TTL")
|
|
detail_log("RRSIG TTL %d differs from SOA TTL %d" % (int(rrsig[0]), int(soa[0])))
|
|
|
|
def check_nsec(self, nsec3=False, nonsec=False):
|
|
'''Checks if the response contains NSEC(3) records.'''
|
|
|
|
nsec_rrs = list()
|
|
nsec3_rrs = list()
|
|
for data in self.resp.authority:
|
|
rrset = data.to_rdataset()
|
|
records = data.to_text().split("\n")
|
|
if rrset.rdtype == dns.rdatatype.NSEC:
|
|
nsec_rrs.extend(records)
|
|
elif rrset.rdtype == dns.rdatatype.NSEC3:
|
|
nsec3_rrs.extend(records)
|
|
|
|
if nonsec:
|
|
if nsec_rrs or nsec3_rrs:
|
|
set_err("CHECK NSEC(3) ABSENCE")
|
|
check_log("ERROR: CHECK NSEC(3) ABSENCE")
|
|
detail_log("!Unexpected records:")
|
|
for rr in nsec_rrs + nsec3_rrs:
|
|
detail_log(" %s" % rr)
|
|
detail_log(SEP)
|
|
return
|
|
|
|
if nsec3:
|
|
if not nsec3_rrs:
|
|
set_err("CHECK NSEC3 PRESENCE")
|
|
check_log("ERROR: CHECK NSEC3 PRESENCE")
|
|
detail_log(SEP)
|
|
if nsec_rrs:
|
|
set_err("CHECK NSEC3")
|
|
check_log("ERROR: CHECK NSEC3")
|
|
detail_log("!Unexpected records:")
|
|
for rr in nsec_rrs:
|
|
detail_log(" %s" % rr)
|
|
detail_log(SEP)
|
|
else:
|
|
if not nsec_rrs:
|
|
set_err("CHECK NSEC PRESENCE")
|
|
check_log("ERROR: CHECK NSEC PRESENCE")
|
|
detail_log(SEP)
|
|
if nsec3_rrs:
|
|
set_err("CHECK NSEC")
|
|
check_log("ERROR: CHECK NSEC")
|
|
detail_log("!Unexpected records:")
|
|
for rr in nsec3_rrs:
|
|
detail_log(" %s" % rr)
|
|
detail_log(SEP)
|
|
|
|
def query_size(self):
|
|
'''Return query size.'''
|
|
|
|
return len(self.query.to_wire())
|
|
|
|
def response_size(self):
|
|
'''Return response size.'''
|
|
|
|
return len(self.resp.to_wire())
|
|
|
|
def rcode(self):
|
|
return dns.rcode.to_text(self.resp.rcode())
|