opnsense-src/tests/atf_python/sys/net/rtsock.py
Alexander V. Chernikov 8eb2bee6c0 testing: Add basic atf support to pytest.
Implementation consists of the pytest plugin implementing ATF format and
a simple C++ wrapper, which reorders the provided arguments from ATF format
to the format understandable by pytest. Each test has this wrapper specified
after the shebang. When kyua executes the test, wrapper calls pytest, which
loads atf plugin, does the work and returns the result. Additionally, a
separate python "package", `/usr/tests/atf_python` has been added to collect
code that may be useful across different tests.

Current limitations:
* Opaque metadata passing via X-Name properties. Require some fixtures to write
* `-s srcdir` parameter passed by the runner is ignored.
* No `atf-c-api(3)` or similar - relying on pytest framework & existing python libraries
* No support for `atf_tc_<get|has>_config_var()` & `atf_tc_set_md_var()`.
 Can be probably implemented with env variables & autoload fixtures

Differential Revision: https://reviews.freebsd.org/D31084
Reviewed by:	kp, ngie
2022-06-25 19:25:15 +00:00

604 lines
18 KiB
Python
Executable file

#!/usr/local/bin/python3
import os
import socket
import struct
import sys
from ctypes import c_byte
from ctypes import c_char
from ctypes import c_int
from ctypes import c_long
from ctypes import c_uint32
from ctypes import c_ulong
from ctypes import c_ushort
from ctypes import sizeof
from ctypes import Structure
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
def roundup2(val: int, num: int) -> int:
if val % num:
return (val | (num - 1)) + 1
else:
return val
class RtSockException(OSError):
pass
class RtConst:
RTM_VERSION = 5
ALIGN = sizeof(c_long)
AF_INET = socket.AF_INET
AF_INET6 = socket.AF_INET6
AF_LINK = socket.AF_LINK
RTA_DST = 0x1
RTA_GATEWAY = 0x2
RTA_NETMASK = 0x4
RTA_GENMASK = 0x8
RTA_IFP = 0x10
RTA_IFA = 0x20
RTA_AUTHOR = 0x40
RTA_BRD = 0x80
RTM_ADD = 1
RTM_DELETE = 2
RTM_CHANGE = 3
RTM_GET = 4
RTF_UP = 0x1
RTF_GATEWAY = 0x2
RTF_HOST = 0x4
RTF_REJECT = 0x8
RTF_DYNAMIC = 0x10
RTF_MODIFIED = 0x20
RTF_DONE = 0x40
RTF_XRESOLVE = 0x200
RTF_LLINFO = 0x400
RTF_LLDATA = 0x400
RTF_STATIC = 0x800
RTF_BLACKHOLE = 0x1000
RTF_PROTO2 = 0x4000
RTF_PROTO1 = 0x8000
RTF_PROTO3 = 0x40000
RTF_FIXEDMTU = 0x80000
RTF_PINNED = 0x100000
RTF_LOCAL = 0x200000
RTF_BROADCAST = 0x400000
RTF_MULTICAST = 0x800000
RTF_STICKY = 0x10000000
RTF_RNH_LOCKED = 0x40000000
RTF_GWFLAG_COMPAT = 0x80000000
RTV_MTU = 0x1
RTV_HOPCOUNT = 0x2
RTV_EXPIRE = 0x4
RTV_RPIPE = 0x8
RTV_SPIPE = 0x10
RTV_SSTHRESH = 0x20
RTV_RTT = 0x40
RTV_RTTVAR = 0x80
RTV_WEIGHT = 0x100
@staticmethod
def get_props(prefix: str) -> List[str]:
return [n for n in dir(RtConst) if n.startswith(prefix)]
@staticmethod
def get_name(prefix: str, value: int) -> str:
props = RtConst.get_props(prefix)
for prop in props:
if getattr(RtConst, prop) == value:
return prop
return "U:{}:{}".format(prefix, value)
@staticmethod
def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
props = RtConst.get_props(prefix)
propmap = {getattr(RtConst, prop): prop for prop in props}
v = 1
ret = {}
while value:
if v & value:
if v in propmap:
ret[v] = propmap[v]
else:
ret[v] = hex(v)
value -= v
v *= 2
return ret
@staticmethod
def get_bitmask_str(prefix: str, value: int) -> str:
bmap = RtConst.get_bitmask_map(prefix, value)
return ",".join([v for k, v in bmap.items()])
class RtMetrics(Structure):
_fields_ = [
("rmx_locks", c_ulong),
("rmx_mtu", c_ulong),
("rmx_hopcount", c_ulong),
("rmx_expire", c_ulong),
("rmx_recvpipe", c_ulong),
("rmx_sendpipe", c_ulong),
("rmx_ssthresh", c_ulong),
("rmx_rtt", c_ulong),
("rmx_rttvar", c_ulong),
("rmx_pksent", c_ulong),
("rmx_weight", c_ulong),
("rmx_nhidx", c_ulong),
("rmx_filler", c_ulong * 2),
]
class RtMsgHdr(Structure):
_fields_ = [
("rtm_msglen", c_ushort),
("rtm_version", c_byte),
("rtm_type", c_byte),
("rtm_index", c_ushort),
("_rtm_spare1", c_ushort),
("rtm_flags", c_int),
("rtm_addrs", c_int),
("rtm_pid", c_int),
("rtm_seq", c_int),
("rtm_errno", c_int),
("rtm_fmask", c_int),
("rtm_inits", c_ulong),
("rtm_rmx", RtMetrics),
]
class SockaddrIn(Structure):
_fields_ = [
("sin_len", c_byte),
("sin_family", c_byte),
("sin_port", c_ushort),
("sin_addr", c_uint32),
("sin_zero", c_char * 8),
]
class SockaddrIn6(Structure):
_fields_ = [
("sin6_len", c_byte),
("sin6_family", c_byte),
("sin6_port", c_ushort),
("sin6_flowinfo", c_uint32),
("sin6_addr", c_byte * 16),
("sin6_scope_id", c_uint32),
]
class SockaddrDl(Structure):
_fields_ = [
("sdl_len", c_byte),
("sdl_family", c_byte),
("sdl_index", c_ushort),
("sdl_type", c_byte),
("sdl_nlen", c_byte),
("sdl_alen", c_byte),
("sdl_slen", c_byte),
("sdl_data", c_byte * 8),
]
class SaHelper(object):
@staticmethod
def is_ipv6(ip: str) -> bool:
return ":" in ip
@staticmethod
def ip_sa(ip: str, scopeid: int = 0) -> bytes:
if SaHelper.is_ipv6(ip):
return SaHelper.ip6_sa(ip, scopeid)
else:
return SaHelper.ip4_sa(ip)
@staticmethod
def ip4_sa(ip: str) -> bytes:
addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
return bytes(sin)
@staticmethod
def ip6_sa(ip6: str, scopeid: int) -> bytes:
addr_bytes = (c_byte * 16)()
for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)):
addr_bytes[i] = b
sin6 = SockaddrIn6(
sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid
)
return bytes(sin6)
@staticmethod
def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes:
sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
return bytes(sa)
@staticmethod
def pxlen4_sa(pxlen: int) -> bytes:
return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen))
@staticmethod
def pxlen_to_ip4(pxlen: int) -> str:
if pxlen == 32:
return "255.255.255.255"
else:
addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1)
addr_bytes = struct.pack("!I", addr)
return socket.inet_ntop(socket.AF_INET, addr_bytes)
@staticmethod
def pxlen6_sa(pxlen: int) -> bytes:
return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen))
@staticmethod
def pxlen_to_ip6(pxlen: int) -> str:
ip6_b = [0] * 16
start = 0
while pxlen > 8:
ip6_b[start] = 0xFF
pxlen -= 8
start += 1
ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1)
return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b))
@staticmethod
def print_sa_inet(sa: bytes):
if len(sa) < 8:
raise RtSockException("IPv4 sa size too small: {}".format(len(sa)))
addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
return "{}".format(addr)
@staticmethod
def print_sa_inet6(sa: bytes):
if len(sa) < sizeof(SockaddrIn6):
raise RtSockException("IPv6 sa size too small: {}".format(len(sa)))
addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
scopeid = struct.unpack(">I", sa[24:28])[0]
return "{} scopeid {}".format(addr, scopeid)
@staticmethod
def print_sa_link(sa: bytes, hd: Optional[bool] = True):
if len(sa) < sizeof(SockaddrDl):
raise RtSockException("LINK sa size too small: {}".format(len(sa)))
sdl = SockaddrDl.from_buffer_copy(sa)
if sdl.sdl_index:
ifindex = "link#{} ".format(sdl.sdl_index)
else:
ifindex = ""
if sdl.sdl_nlen:
iface_offset = 8
if sdl.sdl_nlen + iface_offset > len(sa):
raise RtSockException(
"LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))
)
ifname = "ifname:{} ".format(
bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen])
)
else:
ifname = ""
return "{}{}".format(ifindex, ifname)
@staticmethod
def print_sa_unknown(sa: bytes):
return "unknown_type:{}".format(sa[1])
@classmethod
def print_sa(cls, sa: bytes, hd: Optional[bool] = False):
if sa[0] != len(sa):
raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
if len(sa) < 2:
raise Exception(
"sa type {} too short: {}".format(
RtConst.get_name("AF_", sa[1]), len(sa)
)
)
if sa[1] == socket.AF_INET:
text = cls.print_sa_inet(sa)
elif sa[1] == socket.AF_INET6:
text = cls.print_sa_inet6(sa)
elif sa[1] == socket.AF_LINK:
text = cls.print_sa_link(sa)
else:
text = cls.print_sa_unknown(sa)
if hd:
dump = " [{!r}]".format(sa)
else:
dump = ""
return "{}{}".format(text, dump)
class BaseRtsockMessage(object):
def __init__(self, rtm_type):
self.rtm_type = rtm_type
self.sa = SaHelper()
@staticmethod
def print_rtm_type(rtm_type):
return RtConst.get_name("RTM_", rtm_type)
@property
def rtm_type_str(self):
return self.print_rtm_type(self.rtm_type)
class RtsockRtMessage(BaseRtsockMessage):
messages = [
RtConst.RTM_ADD,
RtConst.RTM_DELETE,
RtConst.RTM_CHANGE,
RtConst.RTM_GET,
]
def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
super().__init__(rtm_type)
self.rtm_flags = 0
self.rtm_seq = rtm_seq
self._attrs = {}
self.rtm_errno = 0
self.rtm_pid = 0
self.rtm_inits = 0
self.rtm_rmx = RtMetrics()
self._orig_data = None
if dst_sa:
self.add_sa_attr(RtConst.RTA_DST, dst_sa)
if mask_sa:
self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
def add_sa_attr(self, attr_type, attr_bytes: bytes):
self._attrs[attr_type] = attr_bytes
def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0):
if ":" in ip_addr:
self.add_ip6_attr(attr_type, ip_addr, scopeid)
else:
self.add_ip4_attr(attr_type, ip_addr)
def add_ip4_attr(self, attr_type, ip: str):
self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
def add_link_attr(self, attr_type, ifindex: Optional[int] = 0):
self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
def get_sa(self, attr_type) -> bytes:
return self._attrs.get(attr_type)
def print_message(self):
# RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
if self._orig_data:
rtm_len = len(self._orig_data)
else:
rtm_len = len(bytes(self))
print(
"{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
self.rtm_type_str,
rtm_len,
self.rtm_pid,
self.rtm_seq,
self.rtm_errno,
RtConst.get_bitmask_str("RTF_", self.rtm_flags),
)
)
rtm_addrs = sum(list(self._attrs.keys()))
print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
for attr in sorted(self._attrs.keys()):
sa_data = SaHelper.print_sa(self._attrs[attr])
print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
def print_in_message(self):
print("vvvvvvvv IN vvvvvvvv")
self.print_message()
print()
def verify_sa_inet(self, sa_data):
if len(sa_data) < 8:
raise Exception("IPv4 sa size too small: {}".format(sa_data))
if sa_data[0] > len(sa_data):
raise Exception(
"IPv4 sin_len too big: {} vs sa size {}: {}".format(
sa_data[0], len(sa_data), sa_data
)
)
sin = SockaddrIn.from_buffer_copy(sa_data)
assert sin.sin_port == 0
assert sin.sin_zero == [0] * 8
def compare_sa(self, sa_type, sa_data):
if len(sa_data) < 4:
sa_type_name = RtConst.get_name("RTA_", sa_type)
raise Exception(
"sa_len for type {} too short: {}".format(sa_type_name, len(sa_data))
)
our_sa = self._attrs[sa_type]
assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa)
assert len(sa_data) == len(our_sa)
assert sa_data == our_sa
def verify(self, rtm_type: int, rtm_sa):
assert self.rtm_type_str == self.print_rtm_type(rtm_type)
assert self.rtm_errno == 0
hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
assert hdr._rtm_spare1 == 0
for sa_type, sa_data in rtm_sa.items():
if sa_type not in self._attrs:
sa_type_name = RtConst.get_name("RTA_", sa_type)
raise Exception("SA type {} not present".format(sa_type_name))
self.compare_sa(sa_type, sa_data)
@classmethod
def from_bytes(cls, data: bytes):
if len(data) < sizeof(RtMsgHdr):
raise Exception(
"messages size {} is less than expected {}".format(
len(data), sizeof(RtMsgHdr)
)
)
hdr = RtMsgHdr.from_buffer_copy(data)
self = cls(hdr.rtm_type)
self.rtm_flags = hdr.rtm_flags
self.rtm_seq = hdr.rtm_seq
self.rtm_errno = hdr.rtm_errno
self.rtm_pid = hdr.rtm_pid
self.rtm_inits = hdr.rtm_inits
self.rtm_rmx = hdr.rtm_rmx
self._orig_data = data
off = sizeof(RtMsgHdr)
v = 1
addrs_mask = hdr.rtm_addrs
while addrs_mask:
if addrs_mask & v:
addrs_mask -= v
if off + data[off] > len(data):
raise Exception(
"SA sizeof for {} > total message length: {}+{} > {}".format(
RtConst.get_name("RTA_", v), off, data[off], len(data)
)
)
self._attrs[v] = data[off : off + data[off]]
off += roundup2(data[off], RtConst.ALIGN)
v *= 2
return self
def __bytes__(self):
sz = sizeof(RtMsgHdr)
addrs_mask = 0
for k, v in self._attrs.items():
sz += roundup2(len(v), RtConst.ALIGN)
addrs_mask += k
hdr = RtMsgHdr(
rtm_msglen=sz,
rtm_version=RtConst.RTM_VERSION,
rtm_type=self.rtm_type,
rtm_flags=self.rtm_flags,
rtm_seq=self.rtm_seq,
rtm_addrs=addrs_mask,
rtm_inits=self.rtm_inits,
rtm_rmx=self.rtm_rmx,
)
buf = bytearray(sz)
buf[0 : sizeof(RtMsgHdr)] = hdr
off = sizeof(RtMsgHdr)
for attr in sorted(self._attrs.keys()):
v = self._attrs[attr]
sa_len = len(v)
buf[off : off + sa_len] = v
off += roundup2(len(v), RtConst.ALIGN)
return bytes(buf)
class Rtsock:
def __init__(self):
self.socket = self._setup_rtsock()
self.rtm_seq = 1
self.msgmap = self.build_msgmap()
def build_msgmap(self):
classes = [RtsockRtMessage]
xmap = {}
for cls in classes:
for message in cls.messages:
xmap[message] = cls
return xmap
def get_seq(self):
ret = self.rtm_seq
self.rtm_seq += 1
return ret
def get_weight(self, weight) -> int:
if weight:
return weight
else:
return 1 # RT_DEFAULT_WEIGHT
def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]):
px = prefix.split("/")
addr_sa = SaHelper.ip_sa(px[0])
if len(px) > 1:
pxlen = int(px[1])
if SaHelper.is_ipv6(px[0]):
mask_sa = SaHelper.pxlen6_sa(pxlen)
else:
mask_sa = SaHelper.pxlen4_sa(pxlen)
else:
mask_sa = None
msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa)
if isinstance(gw, bytes):
msg.add_sa_attr(RtConst.RTA_GATEWAY, gw)
else:
# String
msg.add_ip_attr(RtConst.RTA_GATEWAY, gw)
return msg
def new_rtm_add(self, prefix: str, gw: Union[str, bytes]):
return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw)
def new_rtm_del(self, prefix: str, gw: Union[str, bytes]):
return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw)
def new_rtm_change(self, prefix: str, gw: Union[str, bytes]):
return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw)
def _setup_rtsock(self) -> socket.socket:
s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
return s
def print_hd(self, data: bytes):
width = 16
print("==========================================")
for chunk in [data[i : i + width] for i in range(0, len(data), width)]:
for b in chunk:
print("0x{:02X} ".format(b), end="")
print()
print()
def write_message(self, msg):
print("vvvvvvvv OUT vvvvvvvv")
msg.print_message()
print()
msg_bytes = bytes(msg)
ret = os.write(self.socket.fileno(), msg_bytes)
if ret != -1:
assert ret == len(msg_bytes)
def parse_message(self, data: bytes):
if len(data) < 4:
raise OSError("Short read from rtsock: {} bytes".format(len(data)))
rtm_type = data[4]
if rtm_type not in self.msgmap:
return None
def write_data(self, data: bytes):
self.socket.send(data)
def read_data(self, seq: Optional[int] = None) -> bytes:
while True:
data = self.socket.recv(4096)
if seq is None:
break
if len(data) > sizeof(RtMsgHdr):
hdr = RtMsgHdr.from_buffer_copy(data)
if hdr.rtm_seq == seq:
break
return data
def read_message(self) -> bytes:
data = self.read_data()
return self.parse_message(data)