[v5,02/10] dts: add ssh command verification

Message ID 20230223152840.634183-3-juraj.linkes@pantheon.tech (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series dts: add hello world testcase |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Juraj Linkeš Feb. 23, 2023, 3:28 p.m. UTC
This is a basic capability needed to check whether the command execution
was successful or not. If not, raise a RemoteCommandExecutionError. When
a failure is expected, the caller is supposed to catch the exception.

Signed-off-by: Juraj Linkeš <juraj.linkes@pantheon.tech>
---
 dts/framework/exception.py                    | 23 +++++++-
 .../remote_session/remote/remote_session.py   | 55 +++++++++++++------
 .../remote_session/remote/ssh_session.py      | 12 +++-
 3 files changed, 69 insertions(+), 21 deletions(-)
  

Patch

diff --git a/dts/framework/exception.py b/dts/framework/exception.py
index 121a0f7296..e776b42bd9 100644
--- a/dts/framework/exception.py
+++ b/dts/framework/exception.py
@@ -21,7 +21,8 @@  class ErrorSeverity(IntEnum):
     NO_ERR = 0
     GENERIC_ERR = 1
     CONFIG_ERR = 2
-    SSH_ERR = 3
+    REMOTE_CMD_EXEC_ERR = 3
+    SSH_ERR = 4
 
 
 class DTSError(Exception):
@@ -90,3 +91,23 @@  class ConfigurationError(DTSError):
     """
 
     severity: ClassVar[ErrorSeverity] = ErrorSeverity.CONFIG_ERR
+
+
+class RemoteCommandExecutionError(DTSError):
+    """
+    Raised when a command executed on a Node returns a non-zero exit status.
+    """
+
+    command: str
+    command_return_code: int
+    severity: ClassVar[ErrorSeverity] = ErrorSeverity.REMOTE_CMD_EXEC_ERR
+
+    def __init__(self, command: str, command_return_code: int):
+        self.command = command
+        self.command_return_code = command_return_code
+
+    def __str__(self) -> str:
+        return (
+            f"Command {self.command} returned a non-zero exit code: "
+            f"{self.command_return_code}"
+        )
diff --git a/dts/framework/remote_session/remote/remote_session.py b/dts/framework/remote_session/remote/remote_session.py
index 7c7b30225f..5ac395ec79 100644
--- a/dts/framework/remote_session/remote/remote_session.py
+++ b/dts/framework/remote_session/remote/remote_session.py
@@ -7,15 +7,29 @@ 
 from abc import ABC, abstractmethod
 
 from framework.config import NodeConfiguration
+from framework.exception import RemoteCommandExecutionError
 from framework.logger import DTSLOG
 from framework.settings import SETTINGS
 
 
 @dataclasses.dataclass(slots=True, frozen=True)
-class HistoryRecord:
+class CommandResult:
+    """
+    The result of remote execution of a command.
+    """
+
     name: str
     command: str
-    output: str | int
+    stdout: str
+    stderr: str
+    return_code: int
+
+    def __str__(self) -> str:
+        return (
+            f"stdout: '{self.stdout}'\n"
+            f"stderr: '{self.stderr}'\n"
+            f"return_code: '{self.return_code}'"
+        )
 
 
 class RemoteSession(ABC):
@@ -34,7 +48,7 @@  class RemoteSession(ABC):
     port: int | None
     username: str
     password: str
-    history: list[HistoryRecord]
+    history: list[CommandResult]
     _logger: DTSLOG
     _node_config: NodeConfiguration
 
@@ -68,28 +82,33 @@  def _connect(self) -> None:
         Create connection to assigned node.
         """
 
-    def send_command(self, command: str, timeout: float = SETTINGS.timeout) -> str:
+    def send_command(
+        self, command: str, timeout: float = SETTINGS.timeout, verify: bool = False
+    ) -> CommandResult:
         """
-        Send a command and return the output.
+        Send a command to the connected node and return CommandResult.
+        If verify is True, check the return code of the executed command
+        and raise a RemoteCommandExecutionError if the command failed.
         """
-        self._logger.info(f"Sending: {command}")
-        out = self._send_command(command, timeout)
-        self._logger.debug(f"Received from {command}: {out}")
-        self._history_add(command=command, output=out)
-        return out
+        self._logger.info(f"Sending: '{command}'")
+        result = self._send_command(command, timeout)
+        if verify and result.return_code:
+            self._logger.debug(
+                f"Command '{command}' failed with return code '{result.return_code}'"
+            )
+            self._logger.debug(f"stdout: '{result.stdout}'")
+            self._logger.debug(f"stderr: '{result.stderr}'")
+            raise RemoteCommandExecutionError(command, result.return_code)
+        self._logger.debug(f"Received from '{command}':\n{result}")
+        self.history.append(result)
+        return result
 
     @abstractmethod
-    def _send_command(self, command: str, timeout: float) -> str:
+    def _send_command(self, command: str, timeout: float) -> CommandResult:
         """
-        Use the underlying protocol to execute the command and return the output
-        of the command.
+        Use the underlying protocol to execute the command and return CommandResult.
         """
 
-    def _history_add(self, command: str, output: str) -> None:
-        self.history.append(
-            HistoryRecord(name=self.name, command=command, output=output)
-        )
-
     def close(self, force: bool = False) -> None:
         """
         Close the remote session and free all used resources.
diff --git a/dts/framework/remote_session/remote/ssh_session.py b/dts/framework/remote_session/remote/ssh_session.py
index 96175f5284..c2362e2fdf 100644
--- a/dts/framework/remote_session/remote/ssh_session.py
+++ b/dts/framework/remote_session/remote/ssh_session.py
@@ -12,7 +12,7 @@ 
 from framework.logger import DTSLOG
 from framework.utils import GREEN, RED
 
-from .remote_session import RemoteSession
+from .remote_session import CommandResult, RemoteSession
 
 
 class SSHSession(RemoteSession):
@@ -66,6 +66,7 @@  def _connect(self) -> None:
 
             self.send_expect("stty -echo", "#")
             self.send_expect("stty columns 1000", "#")
+            self.send_expect("bind 'set enable-bracketed-paste off'", "#")
         except Exception as e:
             self._logger.error(RED(str(e)))
             if getattr(self, "port", None):
@@ -163,7 +164,14 @@  def _flush(self) -> None:
     def is_alive(self) -> bool:
         return self.session.isalive()
 
-    def _send_command(self, command: str, timeout: float) -> str:
+    def _send_command(self, command: str, timeout: float) -> CommandResult:
+        output = self._send_command_get_output(command, timeout)
+        return_code = int(self._send_command_get_output("echo $?", timeout))
+
+        # we're capturing only stdout
+        return CommandResult(self.name, command, output, "", return_code)
+
+    def _send_command_get_output(self, command: str, timeout: float) -> str:
         try:
             self._clean_session()
             self._send_line(command)