mirror of
https://github.com/isc-projects/bind9.git
synced 2026-06-19 08:28:53 -04:00
Rather than using the dnspython's facilities and defaults to create the queries, use the isctest.query.create function in all the cases that don't require special handling to have consistent defaults.
213 lines
6.1 KiB
Python
213 lines
6.1 KiB
Python
# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
|
|
#
|
|
# SPDX-License-Identifier: MPL-2.0
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, you can obtain one at https://mozilla.org/MPL/2.0/.
|
|
#
|
|
# See the COPYRIGHT file distributed with this work for additional
|
|
# information regarding copyright ownership.
|
|
|
|
import difflib
|
|
import shutil
|
|
import os
|
|
from typing import Optional
|
|
|
|
import dns.flags
|
|
import dns.message
|
|
import dns.rcode
|
|
import dns.zone
|
|
|
|
import isctest.log
|
|
from isctest.compat import dns_rcode
|
|
|
|
|
|
def rcode(message: dns.message.Message, expected_rcode) -> None:
|
|
assert message.rcode() == expected_rcode, str(message)
|
|
|
|
|
|
def noerror(message: dns.message.Message) -> None:
|
|
rcode(message, dns_rcode.NOERROR)
|
|
|
|
|
|
def notimp(message: dns.message.Message) -> None:
|
|
rcode(message, dns_rcode.NOTIMP)
|
|
|
|
|
|
def refused(message: dns.message.Message) -> None:
|
|
rcode(message, dns_rcode.REFUSED)
|
|
|
|
|
|
def servfail(message: dns.message.Message) -> None:
|
|
rcode(message, dns_rcode.SERVFAIL)
|
|
|
|
|
|
def adflag(message: dns.message.Message) -> None:
|
|
assert (message.flags & dns.flags.AD) != 0, str(message)
|
|
|
|
|
|
def noadflag(message: dns.message.Message) -> None:
|
|
assert (message.flags & dns.flags.AD) == 0, str(message)
|
|
|
|
|
|
def rdflag(message: dns.message.Message) -> None:
|
|
assert (message.flags & dns.flags.RD) != 0, str(message)
|
|
|
|
|
|
def nordflag(message: dns.message.Message) -> None:
|
|
assert (message.flags & dns.flags.RD) == 0, str(message)
|
|
|
|
|
|
def raflag(message: dns.message.Message) -> None:
|
|
assert (message.flags & dns.flags.RA) != 0, str(message)
|
|
|
|
|
|
def noraflag(message: dns.message.Message) -> None:
|
|
assert (message.flags & dns.flags.RA) == 0, str(message)
|
|
|
|
|
|
def section_equal(first_section: list, second_section: list) -> None:
|
|
for rrset in first_section:
|
|
assert (
|
|
rrset in second_section
|
|
), f"No corresponding RRset found in second section: {rrset}"
|
|
for rrset in second_section:
|
|
assert (
|
|
rrset in first_section
|
|
), f"No corresponding RRset found in first section: {rrset}"
|
|
|
|
|
|
def same_data(res1: dns.message.Message, res2: dns.message.Message):
|
|
section_equal(res1.question, res2.question)
|
|
section_equal(res1.answer, res2.answer)
|
|
section_equal(res1.authority, res2.authority)
|
|
section_equal(res1.additional, res2.additional)
|
|
assert res1.rcode() == res2.rcode()
|
|
|
|
|
|
def same_answer(res1: dns.message.Message, res2: dns.message.Message):
|
|
section_equal(res1.question, res2.question)
|
|
section_equal(res1.answer, res2.answer)
|
|
assert res1.rcode() == res2.rcode()
|
|
|
|
|
|
def rrsets_equal(
|
|
first_rrset: dns.rrset.RRset,
|
|
second_rrset: dns.rrset.RRset,
|
|
compare_ttl: Optional[bool] = False,
|
|
) -> None:
|
|
"""Compare two RRset (optionally including TTL)"""
|
|
|
|
def compare_rrs(rr1, rrset):
|
|
rr2 = next((other_rr for other_rr in rrset if rr1 == other_rr), None)
|
|
assert rr2 is not None, f"No corresponding RR found for: {rr1}"
|
|
if compare_ttl:
|
|
assert rr1.ttl == rr2.ttl
|
|
|
|
isctest.log.debug(
|
|
"%s() first RRset:\n%s",
|
|
rrsets_equal.__name__,
|
|
"\n".join([str(rr) for rr in first_rrset]),
|
|
)
|
|
isctest.log.debug(
|
|
"%s() second RRset:\n%s",
|
|
rrsets_equal.__name__,
|
|
"\n".join([str(rr) for rr in second_rrset]),
|
|
)
|
|
for rr in first_rrset:
|
|
compare_rrs(rr, second_rrset)
|
|
for rr in second_rrset:
|
|
compare_rrs(rr, first_rrset)
|
|
|
|
|
|
def zones_equal(
|
|
first_zone: dns.zone.Zone,
|
|
second_zone: dns.zone.Zone,
|
|
compare_ttl: Optional[bool] = False,
|
|
) -> None:
|
|
"""Compare two zones (optionally including TTL)"""
|
|
|
|
isctest.log.debug(
|
|
"%s() first zone:\n%s",
|
|
zones_equal.__name__,
|
|
first_zone.to_text(relativize=False),
|
|
)
|
|
isctest.log.debug(
|
|
"%s() second zone:\n%s",
|
|
zones_equal.__name__,
|
|
second_zone.to_text(relativize=False),
|
|
)
|
|
assert first_zone == second_zone
|
|
if compare_ttl:
|
|
for name, node in first_zone.nodes.items():
|
|
for rdataset in node:
|
|
found_rdataset = second_zone.find_rdataset(
|
|
name=name, rdtype=rdataset.rdtype
|
|
)
|
|
assert found_rdataset
|
|
assert found_rdataset.ttl == rdataset.ttl
|
|
|
|
|
|
def is_executable(cmd: str, errmsg: str) -> None:
|
|
executable = shutil.which(cmd)
|
|
assert executable is not None, errmsg
|
|
|
|
|
|
def named_alive(named_proc, resolver_ip):
|
|
assert named_proc.poll() is None, "named isn't running"
|
|
msg = isctest.query.create("version.bind", "TXT", "CH")
|
|
isctest.query.tcp(msg, resolver_ip, expected_rcode=dns_rcode.NOERROR)
|
|
|
|
|
|
def notauth(message: dns.message.Message) -> None:
|
|
rcode(message, dns.rcode.NOTAUTH)
|
|
|
|
|
|
def nxdomain(message: dns.message.Message) -> None:
|
|
rcode(message, dns.rcode.NXDOMAIN)
|
|
|
|
|
|
def single_question(message: dns.message.Message) -> None:
|
|
assert len(message.question) == 1, str(message)
|
|
|
|
|
|
def empty_answer(message: dns.message.Message) -> None:
|
|
assert not message.answer, str(message)
|
|
|
|
|
|
def rr_count_eq(section: list, expected: int):
|
|
# NOTE: OPT and TSIG records aren't included in the count for ADDITIONAL section
|
|
count = sum(len(rrset) for rrset in section)
|
|
assert count == expected, str(section)
|
|
|
|
|
|
def is_response_to(response: dns.message.Message, query: dns.message.Message) -> None:
|
|
single_question(response)
|
|
single_question(query)
|
|
assert query.is_response(response), str(response)
|
|
|
|
|
|
def file_contents_equal(file1, file2):
|
|
def normalize_line(line):
|
|
# remove trailing&leading whitespace and replace multiple whitespaces
|
|
return " ".join(line.split())
|
|
|
|
def read_lines(file_path):
|
|
with open(file_path, "r", encoding="utf-8") as file:
|
|
return [normalize_line(line) for line in file.readlines()]
|
|
|
|
lines1 = read_lines(file1)
|
|
lines2 = read_lines(file2)
|
|
|
|
differ = difflib.Differ()
|
|
diff = differ.compare(lines1, lines2)
|
|
|
|
for line in diff:
|
|
assert not line.startswith("+ ") and not line.startswith(
|
|
"- "
|
|
), f'file contents of "{file1}" and "{file2}" differ'
|
|
|
|
|
|
def file_empty(file):
|
|
assert os.path.getsize(file) == 0
|