[v2,3/8] dts: add locks for parallel node connections

Message ID 20220711145126.295427-4-juraj.linkes@pantheon.tech (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series ssh connection to a node |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Juraj Linkeš July 11, 2022, 2:51 p.m. UTC
  Each lock is held per node. The lock assures that multiple connections
to the same node don't execute anything at the same time, removing the
possibility of race conditions.

Signed-off-by: Juraj Linkeš <juraj.linkes@pantheon.tech>
---
 dts/framework/ssh_pexpect.py | 14 ++++--
 dts/framework/utils.py       | 88 ++++++++++++++++++++++++++++++++++++
 2 files changed, 99 insertions(+), 3 deletions(-)
  

Patch

diff --git a/dts/framework/ssh_pexpect.py b/dts/framework/ssh_pexpect.py
index c73c1048a4..01ebd1c010 100644
--- a/dts/framework/ssh_pexpect.py
+++ b/dts/framework/ssh_pexpect.py
@@ -12,7 +12,7 @@ 
 from .exception import (SSHConnectionException, SSHSessionDeadException,
                         TimeoutException)
 from .logger import DTSLOG
-from .utils import GREEN, RED
+from .utils import GREEN, RED, parallel_lock
 
 """
 Module handles ssh sessions to TG and SUT.
@@ -33,6 +33,7 @@  def __init__(
         username: str,
         password: Optional[str],
         logger: DTSLOG,
+        sut_id: int,
     ):
         self.magic_prompt = "MAGIC PROMPT"
         self.logger = logger
@@ -42,11 +43,18 @@  def __init__(
         self.password = password or ""
         self.logger.info(f"ssh {self.username}@{self.node}")
 
-        self._connect_host()
+        self._connect_host(sut_id=sut_id)
 
-    def _connect_host(self) -> None:
+    @parallel_lock(num=8)
+    def _connect_host(self, sut_id: int = 0) -> None:
         """
         Create connection to assigned node.
+        Parameter sut_id will be used in parallel_lock thus can assure
+        isolated locks for each node.
+        Parallel ssh connections are limited to MaxStartups option in SSHD
+        configuration file. By default concurrent number is 10, so default
+        threads number is limited to 8 which less than 10. Lock number can
+        be modified along with MaxStartups value.
         """
         retry_times = 10
         try:
diff --git a/dts/framework/utils.py b/dts/framework/utils.py
index 7036843dd7..a637c4641e 100644
--- a/dts/framework/utils.py
+++ b/dts/framework/utils.py
@@ -1,7 +1,95 @@ 
 # SPDX-License-Identifier: BSD-3-Clause
 # Copyright(c) 2010-2014 Intel Corporation
+# Copyright(c) 2022 PANTHEON.tech s.r.o.
+# Copyright(c) 2022 University of New Hampshire
 #
 
+import threading
+from functools import wraps
+from typing import Any, Callable, TypeVar
+
+locks_info: list[dict[str, Any]] = list()
+
+T = TypeVar("T")
+
+
+def parallel_lock(num: int = 1) -> Callable[[Callable[..., T]], Callable[..., T]]:
+    """
+    Wrapper function for protect parallel threads, allow multiple threads
+    share one lock. Locks are created based on function name. Thread locks are
+    separated between SUTs according to argument 'sut_id'.
+    Parameter:
+        num: Number of parallel threads for the lock
+    """
+    global locks_info
+
+    def decorate(func: Callable[..., T]) -> Callable[..., T]:
+        # mypy does not know how to handle the types of this function, so Any is required
+        @wraps(func)
+        def wrapper(*args: Any, **kwargs: Any) -> T:
+            if "sut_id" in kwargs:
+                sut_id = kwargs["sut_id"]
+            else:
+                sut_id = 0
+
+            # in case function arguments is not correct
+            if sut_id >= len(locks_info):
+                sut_id = 0
+
+            lock_info = locks_info[sut_id]
+            uplock = lock_info["update_lock"]
+
+            name = func.__name__
+            uplock.acquire()
+
+            if name not in lock_info:
+                lock_info[name] = dict()
+                lock_info[name]["lock"] = threading.RLock()
+                lock_info[name]["current_thread"] = 1
+            else:
+                lock_info[name]["current_thread"] += 1
+
+            lock = lock_info[name]["lock"]
+
+            # make sure when owned global lock, should also own update lock
+            if lock_info[name]["current_thread"] >= num:
+                if lock._is_owned():
+                    print(
+                        RED(
+                            f"SUT{sut_id:d} {threading.current_thread().name} waiting for func lock {func.__name__}"
+                        )
+                    )
+                lock.acquire()
+            else:
+                uplock.release()
+
+            try:
+                ret = func(*args, **kwargs)
+            except Exception as e:
+                if not uplock._is_owned():
+                    uplock.acquire()
+
+                if lock._is_owned():
+                    lock.release()
+                    lock_info[name]["current_thread"] = 0
+                uplock.release()
+                raise e
+
+            if not uplock._is_owned():
+                uplock.acquire()
+
+            if lock._is_owned():
+                lock.release()
+                lock_info[name]["current_thread"] = 0
+
+            uplock.release()
+
+            return ret
+
+        return wrapper
+
+    return decorate
+
 
 def RED(text: str) -> str:
     return f"\u001B[31;1m{str(text)}\u001B[0m"