diff mbox series

[v1,2/8] dts: add locks for parallel node connections

Message ID 20220622121448.3304251-3-juraj.linkes@pantheon.tech (mailing list archive)
State New
Delegated to: Thomas Monjalon
Headers show
Series dts: ssh connection to a node | expand

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Juraj Linkeš June 22, 2022, 12:14 p.m. UTC
Each lock is held per node. The lock assures that multiple connections
to the same node don't execute anything at the same time, removing the
possibility of race conditions.

Signed-off-by: Juraj Linkeš <juraj.linkes@pantheon.tech>
---
 dts/framework/ssh_pexpect.py | 15 +++++--
 dts/framework/utils.py       | 81 ++++++++++++++++++++++++++++++++++++
 2 files changed, 92 insertions(+), 4 deletions(-)
diff mbox series

Patch

diff --git a/dts/framework/ssh_pexpect.py b/dts/framework/ssh_pexpect.py
index bccc6fae94..cbdbb91b64 100644
--- a/dts/framework/ssh_pexpect.py
+++ b/dts/framework/ssh_pexpect.py
@@ -3,7 +3,7 @@ 
 from pexpect import pxssh
 
 from .exception import SSHConnectionException, SSHSessionDeadException, TimeoutException
-from .utils import GREEN, RED
+from .utils import GREEN, RED, parallel_lock
 
 """
 Module handles ssh sessions to TG and SUT.
@@ -12,7 +12,7 @@ 
 
 
 class SSHPexpect:
-    def __init__(self, node, username, password):
+    def __init__(self, node, username, password, sut_id):
         self.magic_prompt = "MAGIC PROMPT"
         self.logger = None
 
@@ -20,11 +20,18 @@  def __init__(self, node, username, password):
         self.username = username
         self.password = password
 
-        self._connect_host()
+        self._connect_host(sut_id=sut_id)
 
-    def _connect_host(self):
+    @parallel_lock(num=8)
+    def _connect_host(self, sut_id=0):
         """
         Create connection to assigned node.
+        Parameter sut_id will be used in parallel_lock thus can assure
+        isolated locks for each node.
+        Parallel ssh connections are limited to MaxStartups option in SSHD
+        configuration file. By default concurrent number is 10, so default
+        threads number is limited to 8 which less than 10. Lock number can
+        be modified along with MaxStartups value.
         """
         retry_times = 10
         try:
diff --git a/dts/framework/utils.py b/dts/framework/utils.py
index 0ffd992952..a8e739f7b2 100644
--- a/dts/framework/utils.py
+++ b/dts/framework/utils.py
@@ -2,6 +2,87 @@ 
 # Copyright(c) 2010-2014 Intel Corporation
 #
 
+import threading
+from functools import wraps
+
+
+def parallel_lock(num=1):
+    """
+    Wrapper function for protect parallel threads, allow multiple threads
+    share one lock. Locks are created based on function name. Thread locks are
+    separated between SUTs according to argument 'sut_id'.
+    Parameter:
+        num: Number of parallel threads for the lock
+    """
+    global locks_info
+
+    def decorate(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if "sut_id" in kwargs:
+                sut_id = kwargs["sut_id"]
+            else:
+                sut_id = 0
+
+            # in case function arguments is not correct
+            if sut_id >= len(locks_info):
+                sut_id = 0
+
+            lock_info = locks_info[sut_id]
+            uplock = lock_info["update_lock"]
+
+            name = func.__name__
+            uplock.acquire()
+
+            if name not in lock_info:
+                lock_info[name] = dict()
+                lock_info[name]["lock"] = threading.RLock()
+                lock_info[name]["current_thread"] = 1
+            else:
+                lock_info[name]["current_thread"] += 1
+
+            lock = lock_info[name]["lock"]
+
+            # make sure when owned global lock, should also own update lock
+            if lock_info[name]["current_thread"] >= num:
+                if lock._is_owned():
+                    print(
+                        RED(
+                            "SUT%d %s waiting for func lock %s"
+                            % (sut_id, threading.current_thread().name, func.__name__)
+                        )
+                    )
+                lock.acquire()
+            else:
+                uplock.release()
+
+            try:
+                ret = func(*args, **kwargs)
+            except Exception as e:
+                if not uplock._is_owned():
+                    uplock.acquire()
+
+                if lock._is_owned():
+                    lock.release()
+                    lock_info[name]["current_thread"] = 0
+                uplock.release()
+                raise e
+
+            if not uplock._is_owned():
+                uplock.acquire()
+
+            if lock._is_owned():
+                lock.release()
+                lock_info[name]["current_thread"] = 0
+
+            uplock.release()
+
+            return ret
+
+        return wrapper
+
+    return decorate
+
 
 def RED(text):
     return "\x1B[" + "31;1m" + str(text) + "\x1B[" + "0m"