@@ -2,13 +2,29 @@
# Copyright(c) 2023 PANTHEON.tech s.r.o.
# Copyright(c) 2023 University of New Hampshire
+import json
+from typing import TypedDict
+from typing_extensions import NotRequired
+
+
from framework.exception import RemoteCommandExecutionError
from framework.testbed_model import LogicalCore
+from framework.testbed_model.hw.port import PortIdentifier
from framework.utils import expand_range
from .posix_session import PosixSession
+class LshwOutputConfigurationDict(TypedDict):
+ link: str
+
+
+class LshwOutputDict(TypedDict):
+ businfo: str
+ logicalname: NotRequired[str]
+ configuration: LshwOutputConfigurationDict
+
+
class LinuxSession(PosixSession):
"""
The implementation of non-Posix compliant parts of Linux remote sessions.
@@ -105,3 +121,42 @@ def _configure_huge_pages(
self.remote_session.send_command(
f"echo {amount} | sudo tee {hugepage_config_path}"
)
+
+ def get_lshw_info(self) -> list[LshwOutputDict]:
+ output = self.remote_session.send_expect("lshw -quiet -json -C network", "#")
+ assert not isinstance(
+ output, int
+ ), "send_expect returned an int when it should have been a string"
+ return json.loads(output)
+
+ def get_logical_name_of_port(self, id: PortIdentifier) -> str | None:
+ self._logger.debug(f"Searching for logical name of {id.pci}")
+ assert (
+ id.node == self.name
+ ), "Attempted to get the logical port name on the wrong node"
+ port_info_list: list[LshwOutputDict] = self.get_lshw_info()
+ for port_info in port_info_list:
+ if f"pci@{id.pci}" == port_info.get("businfo"):
+ if "logicalname" in port_info:
+ self._logger.debug(
+ f"Found logical name for port {id.pci}, {port_info.get('logicalname')}"
+ )
+ return port_info.get("logicalname")
+ else:
+ self._logger.warning(
+ f"Attempted to get the logical name of {id.pci}, but none existed"
+ )
+ return None
+ self._logger.warning(f"No port at pci address {id.pci} found.")
+ return None
+
+ def check_link_is_up(self, id: PortIdentifier) -> bool | None:
+ self._logger.debug(f"Checking link status for {id.pci}")
+ port_info_list: list[LshwOutputDict] = self.get_lshw_info()
+ for port_info in port_info_list:
+ if f"pci@{id.pci}" == port_info.get("businfo"):
+ status = port_info["configuration"]["link"]
+ self._logger.debug(f"Found link status for port {id.pci}, {status}")
+ return status == "up"
+ self._logger.warning(f"No port at pci address {id.pci} found.")
+ return None
@@ -84,6 +84,13 @@ def _connect(self) -> None:
Create connection to assigned node.
"""
+ @abstractmethod
+ def send_expect(
+ self, command: str, prompt: str, timeout: float = 15,
+ verify: bool = False
+ ) -> str | int:
+ """"""
+
def send_command(
self,
command: str,
new file mode 100644
@@ -0,0 +1,348 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright(c) 2022 University of New Hampshire
+#
+
+import inspect
+import json
+import marshal
+import types
+import xmlrpc.client
+from typing import TypedDict
+from xmlrpc.server import SimpleXMLRPCServer
+
+import scapy.all
+from scapy.packet import Packet
+from typing_extensions import NotRequired
+
+from framework.config import OS
+from framework.logger import getLogger
+from .tg_node import TGNode
+from .hw.port import Port, PortIdentifier
+from .capturing_traffic_generator import (
+ CapturingTrafficGenerator,
+ _get_default_capture_name,
+)
+from framework.settings import SETTINGS
+from framework.remote_session import OSSession
+
+"""
+========= BEGIN RPC FUNCTIONS =========
+
+All of the functions in this section are intended to be exported to a python
+shell which runs a scapy RPC server. These functions are made available via that
+RPC server to the packet generator. To add a new function to the RPC server,
+first write the function in this section. Then, if you need any imports, make sure to
+add them to SCAPY_RPC_SERVER_IMPORTS as well. After that, add the function to the list
+in EXPORTED_FUNCTIONS. Note that kwargs (keyword arguments) do not work via xmlrpc,
+so you may need to construct wrapper functions around many scapy types.
+"""
+
+"""
+Add the line needed to import something in a normal python environment
+as an entry to this array. It will be imported before any functions are
+sent to the server.
+"""
+SCAPY_RPC_SERVER_IMPORTS = [
+ "from scapy.all import *",
+ "import xmlrpc",
+ "import sys",
+ "from xmlrpc.server import SimpleXMLRPCServer",
+ "import marshal",
+ "import pickle",
+ "import types",
+]
+
+
+def scapy_sr1_different_interfaces(
+ packets: list[Packet], send_iface: str, recv_iface: str, timeout_s: int
+) -> bytes:
+ packets = [scapy.all.Packet(packet.data) for packet in packets]
+ sniffer = scapy.all.AsyncSniffer(
+ iface=recv_iface,
+ store=True,
+ timeout=timeout_s,
+ started_callback=lambda _: scapy.all.sendp(packets, iface=send_iface),
+ stop_filter=lambda _: True,
+ )
+ sniffer.start()
+ packets = sniffer.stop(join=True)
+ assert len(packets) != 0, "Not enough packets were sniffed"
+ assert len(packets) == 1, "More packets than expected were sniffed"
+ return packets[0].build()
+
+
+def scapy_send_packets_and_capture(
+ packets: list[Packet], send_iface: str, recv_iface: str, duration_s: int
+) -> list[bytes]:
+ packets = [scapy.all.Packet(packet.data) for packet in packets]
+ sniffer = scapy.all.AsyncSniffer(
+ iface=recv_iface,
+ store=True,
+ timeout=duration_s,
+ started_callback=lambda _: scapy.all.sendp(packets, iface=send_iface),
+ )
+ sniffer.start()
+ return [packet.build() for packet in sniffer.stop(join=True)]
+
+
+def scapy_send_packets(packets: list[xmlrpc.client.Binary], send_iface: str) -> None:
+ packets = [scapy.all.Packet(packet.data) for packet in packets]
+ scapy.all.sendp(packets, iface=send_iface, realtime=True, verbose=True)
+
+
+"""
+Functions to be exposed by the scapy RPC server.
+"""
+RPC_FUNCTIONS = [
+ scapy_send_packets,
+ scapy_send_packets_and_capture,
+ scapy_sr1_different_interfaces,
+]
+
+
+class QuittableXMLRPCServer(SimpleXMLRPCServer):
+ def __init__(self, *args, **kwargs):
+ kwargs["allow_none"] = True
+ super().__init__(*args, **kwargs)
+ self.register_introspection_functions()
+ self.register_function(self.quit)
+ self.register_function(self.add_rpc_function)
+
+ def quit(self) -> None:
+ self._BaseServer__shutdown_request = True
+ return None
+
+ def add_rpc_function(self, name: str, function_bytes: xmlrpc.client.Binary):
+ function_code = marshal.loads(function_bytes.data)
+ function = types.FunctionType(function_code, globals(), name)
+ self.register_function(function)
+
+ def serve_forever(self, poll_interval: float = 0.5) -> None:
+ print("XMLRPC OK")
+ super().serve_forever(poll_interval)
+
+
+"""
+========= END RPC FUNCTIONS =========
+"""
+
+
+class NetworkInfoDict(TypedDict):
+ businfo: str
+ logicalname: NotRequired[str]
+
+
+class ScapyTrafficGenerator(CapturingTrafficGenerator):
+ """
+ Provides access to scapy functions via an RPC interface
+ """
+
+ tg_node: TGNode
+ ports: list[Port]
+ session: OSSession
+ scapy: xmlrpc.client.ServerProxy
+ iface_names: dict[PortIdentifier, str]
+
+ def __init__(self, tg_node: TGNode, ports: list[Port]):
+ self.tg_node = tg_node
+
+ assert tg_node.config.os == OS.linux, (
+ "Linux is the only supported OS for scapy traffic generation"
+ )
+
+ self.session = tg_node.create_session("scapy")
+ self.logger = getLogger("scapy-pktgen-messages", node=tg_node.name)
+ self.ports = ports
+
+ # No fancy colors
+
+ prompt_str = "<PROMPT>"
+ self.session.remote_session.send_expect(f'export PS1="{prompt_str}"', prompt_str)
+
+ network_info_str: str = self.session.remote_session.send_expect(
+ "lshw -quiet -json -C network", prompt_str, timeout=10
+ )
+
+ network_info_list: list[NetworkInfoDict] = json.loads(network_info_str)
+ network_info_lookup: dict[str, str] = {
+ network_info["businfo"]: network_info.get("logicalname")
+ for network_info in network_info_list
+ }
+
+ self.iface_names = dict()
+ for port in self.ports:
+ businfo_str = f"pci@{port.pci}"
+ assert businfo_str in network_info_lookup, (
+ f"Expected '{businfo_str}' in lshw output for {self.tg_node.name}, but "
+ f"it was not present."
+ )
+
+ self.iface_names[port.identifier] = network_info_lookup[businfo_str]
+
+ assert (
+ self.iface_names[port.identifier] is not None
+ ), f"No interface was present for {port.pci} on {self.tg_node.name}"
+
+ self._run_command("python3")
+
+ self._add_helper_functions_to_scapy()
+ self.session.remote_session.send_expect(
+ 'server = QuittableXMLRPCServer(("0.0.0.0", 8000)); server.serve_forever()',
+ "XMLRPC OK",
+ timeout=5,
+ )
+
+ server_url: str = f"http://{self.tg_node.config.hostname}:8000"
+
+ self.scapy = xmlrpc.client.ServerProxy(
+ server_url, allow_none=True, verbose=SETTINGS.verbose
+ )
+
+ for function in RPC_FUNCTIONS:
+ # A slightly hacky way to move a function to the remote server.
+ # It is constructed from the name and code on the other side.
+ # Pickle cannot handle functions, nor can any of the other serialization
+ # frameworks aside from the libraries used to generate pyc files, which
+ # are even more messy to work with.
+ function_bytes = marshal.dumps(function.__code__)
+ self.scapy.add_rpc_function(function.__name__, function_bytes)
+
+ def _add_helper_functions_to_scapy(self):
+ for import_statement in SCAPY_RPC_SERVER_IMPORTS:
+ self._run_command(import_statement + "\r\n")
+
+ for helper_function in {QuittableXMLRPCServer}:
+ # load the source of the function
+ src = inspect.getsource(helper_function)
+ # Lines with only whitespace break the repl if in the middle of a function
+ # or class, so strip all lines containing only whitespace
+ src = "\n".join(
+ [line for line in src.splitlines() if not line.isspace() and line != ""]
+ )
+
+ spacing = "\n" * 4
+
+ # execute it in the python terminal
+ self._run_command(spacing + src + spacing)
+
+ def _run_command(self, command: str) -> str:
+ return self.session.remote_session.send_expect(command, ">>>")
+
+ def _get_port_interface_or_error(self, port: PortIdentifier) -> str:
+ match self.iface_names.get(port):
+ case None:
+ assert (
+ False
+ ), f"{port} is not a valid port on this packet generator on {self.tg_node.name}."
+ case iface:
+ return iface
+
+ def send_packet(self, port: PortIdentifier, packet: Packet) -> None:
+ iface = self._get_port_interface_or_error(port)
+ self.logger.info("Sending packet")
+ self.logger.debug("Packet contents: \n" + packet._do_summary()[1])
+ self.scapy.scapy_send_packets([packet.build()], iface)
+
+ def send_packets(self, port: PortIdentifier, packets: list[Packet]) -> None:
+ iface = self._get_port_interface_or_error(port)
+ self.logger.info("Sending packets")
+ packet_summaries = json.dumps(
+ list(map(lambda pkt: pkt._do_summary()[1], packets)), indent=4
+ )
+ packets = [packet.build() for packet in packets]
+ self.logger.debug("Packet contents: \n" + packet_summaries)
+ self.scapy.scapy_send_packets(packets, iface)
+
+ def send_packet_and_capture(
+ self,
+ send_port_id: PortIdentifier,
+ packet: Packet,
+ receive_port_id: PortIdentifier,
+ duration_s: int,
+ capture_name: str = _get_default_capture_name(),
+ ) -> list[Packet]:
+ packets = self.scapy.scapy_send_packets_and_capture(
+ [packet.build()], send_port_id, receive_port_id, duration_s
+ )
+ self._write_capture_from_packets(capture_name, packets)
+ return packets
+
+ def send_packets_and_capture(
+ self,
+ send_port_id: PortIdentifier,
+ packets: Packet,
+ receive_port_id: PortIdentifier,
+ duration_s: int,
+ capture_name: str = _get_default_capture_name(),
+ ) -> list[Packet]:
+ packets: list[bytes] = [packet.build() for packet in packets]
+ packets: list[bytes] = self.scapy.scapy_send_packets_and_capture(
+ packets, send_port_id, receive_port_id, duration_s
+ )
+ packets: list[Packet] = [scapy.all.Packet(packet) for packet in packets]
+ self._write_capture_from_packets(capture_name, packets)
+ return packets
+
+ def send_packet_and_expect_packet(
+ self,
+ send_port_id: PortIdentifier,
+ packet: Packet,
+ receive_port_id: PortIdentifier,
+ expected_packet: Packet,
+ timeout: int = SETTINGS.timeout,
+ capture_name: str = _get_default_capture_name(),
+ ) -> None:
+ self.send_packets_and_expect_packets(
+ send_port_id,
+ [packet],
+ receive_port_id,
+ [expected_packet],
+ timeout,
+ capture_name,
+ )
+
+ def send_packets_and_expect_packets(
+ self,
+ send_port_id: PortIdentifier,
+ packets: list[Packet],
+ receive_port_id: PortIdentifier,
+ expected_packets: list[Packet],
+ timeout: int = SETTINGS.timeout,
+ capture_name: str = _get_default_capture_name(),
+ ) -> None:
+ send_iface = self._get_port_interface_or_error(send_port_id)
+ recv_iface = self._get_port_interface_or_error(receive_port_id)
+
+ packets = [packet.build() for packet in packets]
+
+ received_packets = self.scapy.scapy_sr1_different_interfaces(
+ packets, send_iface, recv_iface, timeout
+ )
+
+ received_packets = [scapy.all.Packet(packet) for packet in received_packets]
+
+ self._write_capture_from_packets(capture_name, received_packets)
+
+ assert len(received_packets) == len(
+ expected_packets
+ ), "Incorrect number of packets received"
+ for i, expected_packet in enumerate(expected_packets):
+ assert (
+ received_packets[i] == expected_packet
+ ), f"Received packet {i} differed from expected packet"
+
+ def close(self):
+ try:
+ self.scapy.quit()
+ except ConnectionRefusedError:
+ # Because the python instance closes, we get no RPC response.
+ # Thus, this error is expected
+ pass
+ try:
+ self.session.close(force=True)
+ except TimeoutError:
+ # Pexpect does not like being in a python prompt when it closes
+ pass
+
+ def assert_port_is_connected(self, id: PortIdentifier) -> None:
+ self.tg_node.main_session.check_link_is_up(id)