DPDK patches and discussions
 help / color / mirror / Atom feed
* [RFC PATCH 0/2] dts: add basic scope to improve shell handling
@ 2024-12-20 17:23 Luca Vizzarro
  2024-12-20 17:24 ` [RFC PATCH 1/2] dts: add scoping and shell registration to Node Luca Vizzarro
  2024-12-20 17:24 ` [RFC PATCH 2/2] dts: revert back shell split Luca Vizzarro
  0 siblings, 2 replies; 3+ messages in thread
From: Luca Vizzarro @ 2024-12-20 17:23 UTC (permalink / raw)
  To: dev; +Cc: Paul Szczepanek, Patrick Robb, Luca Vizzarro

Hi there,

To try to improve the ease of use of the framework for the test
developer, I have been trying to come up with a decent solution to
improve shell handling and consistency. At the moment we have 2 patterns
to do this, which could be confusing to the user.

It probably is that a good approach, which is what I am proposing, is to
introduce a scoping mechanism in DTS. What this would mean is to
associate any shells or modification to different scopes: global test
suite and individual test cases.

Here's an RFC for this, please have a look. Looking forward to your
feedback!

This could be massively improved, but it'd require a lot more changes.
One idea that I think is worth pursuing is turning the execution into a
FSM.

Best,
Luca

Luca Vizzarro (2):
  dts: add scoping and shell registration to Node
  dts: revert back shell split

 dts/framework/remote_session/dpdk_shell.py    |   8 +-
 .../remote_session/interactive_shell.py       | 262 +++++++++++++++--
 .../single_active_interactive_shell.py        | 266 ------------------
 dts/framework/remote_session/testpmd_shell.py |   4 +-
 dts/framework/runner.py                       |  15 +-
 dts/framework/testbed_model/capability.py     |  35 ++-
 dts/framework/testbed_model/node.py           |  65 ++++-
 dts/framework/testbed_model/sut_node.py       |  14 +-
 dts/tests/TestSuite_blocklist.py              |  16 +-
 dts/tests/TestSuite_checksum_offload.py       | 168 +++++------
 dts/tests/TestSuite_dynamic_queue_conf.py     |  52 ++--
 dts/tests/TestSuite_l2fwd.py                  |  20 +-
 dts/tests/TestSuite_mac_filter.py             | 124 ++++----
 dts/tests/TestSuite_pmd_buffer_scatter.py     |  26 +-
 dts/tests/TestSuite_smoke_tests.py            |   4 +-
 dts/tests/TestSuite_vlan.py                   |  42 +--
 16 files changed, 575 insertions(+), 546 deletions(-)
 delete mode 100644 dts/framework/remote_session/single_active_interactive_shell.py

-- 
2.43.0


^ permalink raw reply	[flat|nested] 3+ messages in thread

* [RFC PATCH 1/2] dts: add scoping and shell registration to Node
  2024-12-20 17:23 [RFC PATCH 0/2] dts: add basic scope to improve shell handling Luca Vizzarro
@ 2024-12-20 17:24 ` Luca Vizzarro
  2024-12-20 17:24 ` [RFC PATCH 2/2] dts: revert back shell split Luca Vizzarro
  1 sibling, 0 replies; 3+ messages in thread
From: Luca Vizzarro @ 2024-12-20 17:24 UTC (permalink / raw)
  To: dev; +Cc: Paul Szczepanek, Patrick Robb, Luca Vizzarro

Add a basic scoping mechanism to Nodes, to improve the control over
test suite led environmental changes. Moreover, keep a pool of active
shells based on scope, therefore allowing shells to register
themselves.

Signed-off-by: Luca Vizzarro <luca.vizzarro@arm.com>
---
 .../single_active_interactive_shell.py        |   2 +
 dts/framework/runner.py                       |  15 +-
 dts/framework/testbed_model/capability.py     |  35 ++--
 dts/framework/testbed_model/node.py           |  65 ++++++-
 dts/framework/testbed_model/sut_node.py       |  14 +-
 dts/tests/TestSuite_blocklist.py              |  16 +-
 dts/tests/TestSuite_checksum_offload.py       | 168 +++++++++---------
 dts/tests/TestSuite_dynamic_queue_conf.py     |  52 +++---
 dts/tests/TestSuite_l2fwd.py                  |  20 +--
 dts/tests/TestSuite_mac_filter.py             | 124 +++++++------
 dts/tests/TestSuite_pmd_buffer_scatter.py     |  26 ++-
 dts/tests/TestSuite_smoke_tests.py            |   4 +-
 dts/tests/TestSuite_vlan.py                   |  42 ++---
 13 files changed, 333 insertions(+), 250 deletions(-)

diff --git a/dts/framework/remote_session/single_active_interactive_shell.py b/dts/framework/remote_session/single_active_interactive_shell.py
index c43c54e457..910af8f655 100644
--- a/dts/framework/remote_session/single_active_interactive_shell.py
+++ b/dts/framework/remote_session/single_active_interactive_shell.py
@@ -112,7 +112,9 @@ def __init__(
                 the name of the underlying node which it is running on.
             **kwargs: Any additional arguments if any.
         """
+        node.register_shell(self)
         self._node = node
+
         if name is None:
             name = type(self).__name__
         self._logger = get_dts_logger(f"{node.name}.{name}")
diff --git a/dts/framework/runner.py b/dts/framework/runner.py
index 510be1a870..fd3a934a9f 100644
--- a/dts/framework/runner.py
+++ b/dts/framework/runner.py
@@ -460,6 +460,10 @@ def _run_test_suite(
             DtsStage.test_suite_setup, Path(SETTINGS.output_dir, test_suite_name)
         )
         test_suite = test_suite_with_cases.test_suite_class(sut_node, tg_node, topology)
+
+        sut_node.enter_scope("suite")
+        tg_node.enter_scope("suite")
+
         try:
             self._logger.info(f"Starting test suite setup: {test_suite_name}")
             test_suite.set_up_suite()
@@ -479,7 +483,6 @@ def _run_test_suite(
             try:
                 self._logger.set_stage(DtsStage.test_suite_teardown)
                 test_suite.tear_down_suite()
-                sut_node.kill_cleanup_dpdk_apps()
                 test_suite_result.update_teardown(Result.PASS)
             except Exception as e:
                 self._logger.exception(f"Test suite teardown ERROR: {test_suite_name}")
@@ -488,6 +491,10 @@ def _run_test_suite(
                     "the next test suite may be affected."
                 )
                 test_suite_result.update_setup(Result.ERROR, e)
+
+            sut_node.exit_scope()
+            tg_node.exit_scope()
+
             if len(test_suite_result.get_errors()) > 0 and test_suite.is_blocking:
                 raise BlockingTestSuiteError(test_suite_name)
 
@@ -511,6 +518,9 @@ def _execute_test_suite(
         """
         self._logger.set_stage(DtsStage.test_suite)
         for test_case in test_cases:
+            test_suite.sut_node.enter_scope("case")
+            test_suite.tg_node.enter_scope("case")
+
             test_case_name = test_case.__name__
             test_case_result = test_suite_result.add_test_case(test_case_name)
             all_attempts = SETTINGS.re_run + 1
@@ -531,6 +541,9 @@ def _execute_test_suite(
                 )
                 test_case_result.update_setup(Result.SKIP)
 
+            test_suite.sut_node.exit_scope()
+            test_suite.tg_node.exit_scope()
+
     def _run_test_case(
         self,
         test_suite: TestSuite,
diff --git a/dts/framework/testbed_model/capability.py b/dts/framework/testbed_model/capability.py
index 6a7a1f5b6c..e883f59d11 100644
--- a/dts/framework/testbed_model/capability.py
+++ b/dts/framework/testbed_model/capability.py
@@ -221,24 +221,23 @@ def get_supported_capabilities(
         )
         if cls.capabilities_to_check:
             capabilities_to_check_map = cls._get_decorated_capabilities_map()
-            with TestPmdShell(
-                sut_node, privileged=True, disable_device_start=True
-            ) as testpmd_shell:
-                for (
-                    conditional_capability_fn,
-                    capabilities,
-                ) in capabilities_to_check_map.items():
-                    supported_capabilities: set[NicCapability] = set()
-                    unsupported_capabilities: set[NicCapability] = set()
-                    capability_fn = cls._reduce_capabilities(
-                        capabilities, supported_capabilities, unsupported_capabilities
-                    )
-                    if conditional_capability_fn:
-                        capability_fn = conditional_capability_fn(capability_fn)
-                    capability_fn(testpmd_shell)
-                    for capability in capabilities:
-                        if capability.nic_capability in supported_capabilities:
-                            supported_conditional_capabilities.add(capability)
+            testpmd_shell = TestPmdShell(sut_node, privileged=True, disable_device_start=True)
+            for (
+                conditional_capability_fn,
+                capabilities,
+            ) in capabilities_to_check_map.items():
+                supported_capabilities: set[NicCapability] = set()
+                unsupported_capabilities: set[NicCapability] = set()
+                capability_fn = cls._reduce_capabilities(
+                    capabilities, supported_capabilities, unsupported_capabilities
+                )
+                if conditional_capability_fn:
+                    capability_fn = conditional_capability_fn(capability_fn)
+                capability_fn(testpmd_shell)
+                for capability in capabilities:
+                    if capability.nic_capability in supported_capabilities:
+                        supported_conditional_capabilities.add(capability)
+            testpmd_shell._close()
 
         logger.debug(f"Found supported capabilities {supported_conditional_capabilities}.")
         return supported_conditional_capabilities
diff --git a/dts/framework/testbed_model/node.py b/dts/framework/testbed_model/node.py
index c6f12319ca..4f06968adc 100644
--- a/dts/framework/testbed_model/node.py
+++ b/dts/framework/testbed_model/node.py
@@ -14,6 +14,7 @@
 """
 
 from abc import ABC
+from typing import TYPE_CHECKING, Literal, TypeVar
 
 from framework.config import (
     OS,
@@ -21,7 +22,7 @@
     NodeConfiguration,
     TestRunConfiguration,
 )
-from framework.exception import ConfigurationError
+from framework.exception import ConfigurationError, InternalError
 from framework.logger import DTSLogger, get_dts_logger
 
 from .cpu import (
@@ -35,6 +36,15 @@
 from .os_session import OSSession
 from .port import Port
 
+if TYPE_CHECKING:
+    from framework.remote_session.single_active_interactive_shell import (
+        SingleActiveInteractiveShell,
+    )
+
+T = TypeVar("T")
+Scope = Literal["unknown", "suite", "case"]
+ScopedShell = tuple[Scope, SingleActiveInteractiveShell]
+
 
 class Node(ABC):
     """The base class for node management.
@@ -62,6 +72,8 @@ class Node(ABC):
     _logger: DTSLogger
     _other_sessions: list[OSSession]
     _test_run_config: TestRunConfiguration
+    _active_shells: list[ScopedShell]
+    _scope_stack: list[Scope]
 
     def __init__(self, node_config: NodeConfiguration):
         """Connect to the node and gather info during initialization.
@@ -90,6 +102,8 @@ def __init__(self, node_config: NodeConfiguration):
 
         self._other_sessions = []
         self._init_ports()
+        self._active_shells = []
+        self._scope_stack = []
 
     def _init_ports(self) -> None:
         self.ports = [Port(self.name, port_config) for port_config in self.config.ports]
@@ -119,6 +133,55 @@ def tear_down_test_run(self) -> None:
         Additional steps can be added by extending the method in subclasses with the use of super().
         """
 
+    @property
+    def current_scope(self) -> Scope:
+        """The current scope of the test run."""
+        try:
+            return self._scope_stack[-1]
+        except IndexError:
+            return "unknown"
+
+    def enter_scope(self, next_scope: Scope) -> None:
+        """Prepare the node for a new testing scope."""
+        self._scope_stack.append(next_scope)
+
+    def exit_scope(self) -> Scope:
+        """Clean up the node after the current testing scope.
+
+        This method must guarantee to never fail from a Node failure during runtime.
+
+        Returns:
+            The scope before exiting.
+
+        Raises:
+            InternalError: If there was no scope to exit from.
+        """
+        try:
+            current_scope = self._scope_stack.pop()
+        except IndexError:
+            raise InternalError("Attempted to exit a scope when the node wasn't in any.")
+        else:
+            self.clean_up_shells(current_scope)
+            return current_scope
+
+    def register_shell(self, shell: SingleActiveInteractiveShell) -> None:
+        """Register a new shell to the pool of active shells."""
+        self._active_shells.append((self.current_scope, shell))
+
+    def find_active_shell(self, shell_class: type[T]) -> T | None:
+        """Retrieve an active shell of a specific class."""
+        return next(sh for _, sh in self._active_shells if type(sh) is shell_class)
+
+    def clean_up_shells(self, scope: Scope) -> None:
+        """Clean up shells from the given `scope`."""
+        zombie_shells_indices = [
+            i for i, (shell_scope, _) in enumerate(self._active_shells) if scope == shell_scope
+        ]
+
+        for i in reversed(zombie_shells_indices):
+            self._active_shells[i][1]._close()
+            del self._active_shells[i]
+
     def create_session(self, name: str) -> OSSession:
         """Create and return a new OS-aware remote session.
 
diff --git a/dts/framework/testbed_model/sut_node.py b/dts/framework/testbed_model/sut_node.py
index a9dc0a474a..3427596bd0 100644
--- a/dts/framework/testbed_model/sut_node.py
+++ b/dts/framework/testbed_model/sut_node.py
@@ -33,7 +33,7 @@
 from framework.remote_session.remote_session import CommandResult
 from framework.utils import MesonArgs, TarCompressionFormat
 
-from .node import Node
+from .node import Node, Scope
 from .os_session import OSSession, OSSessionInfo
 from .virtual_device import VirtualDevice
 
@@ -458,7 +458,7 @@ def build_dpdk_app(self, app_name: str, **meson_dpdk_args: str | bool) -> PurePa
             self.remote_dpdk_build_dir, "examples", f"dpdk-{app_name}"
         )
 
-    def kill_cleanup_dpdk_apps(self) -> None:
+    def _kill_cleanup_dpdk_apps(self) -> None:
         """Kill all dpdk applications on the SUT, then clean up hugepages."""
         if self._dpdk_kill_session and self._dpdk_kill_session.is_alive():
             # we can use the session if it exists and responds
@@ -468,6 +468,16 @@ def kill_cleanup_dpdk_apps(self) -> None:
             self._dpdk_kill_session = self.create_session("dpdk_kill")
         self.dpdk_prefix_list = []
 
+    def exit_scope(self) -> Scope:
+        """Extend :meth:`~.node.Node.exit_test_suite_scope`.
+
+        Add the DPDK apps clean up.
+        """
+        previous_scope = super().exit_scope()
+        if previous_scope == "suite":
+            self._kill_cleanup_dpdk_apps()
+        return previous_scope
+
     def run_dpdk_app(
         self, app_path: PurePath, eal_params: EalParams, timeout: float = 30
     ) -> CommandResult:
diff --git a/dts/tests/TestSuite_blocklist.py b/dts/tests/TestSuite_blocklist.py
index b9e9cd1d1a..edce042f38 100644
--- a/dts/tests/TestSuite_blocklist.py
+++ b/dts/tests/TestSuite_blocklist.py
@@ -18,16 +18,16 @@ class TestBlocklist(TestSuite):
 
     def verify_blocklisted_ports(self, ports_to_block: list[Port]):
         """Runs testpmd with the given ports blocklisted and verifies the ports."""
-        with TestPmdShell(self.sut_node, allowed_ports=[], blocked_ports=ports_to_block) as testpmd:
-            allowlisted_ports = {port.device_name for port in testpmd.show_port_info_all()}
-            blocklisted_ports = {port.pci for port in ports_to_block}
+        testpmd = TestPmdShell(self.sut_node, allowed_ports=[], blocked_ports=ports_to_block)
+        allowlisted_ports = {port.device_name for port in testpmd.show_port_info_all()}
+        blocklisted_ports = {port.pci for port in ports_to_block}
 
-            # sanity check
-            allowed_len = len(allowlisted_ports - blocklisted_ports)
-            self.verify(allowed_len > 0, "At least one port should have been allowed")
+        # sanity check
+        allowed_len = len(allowlisted_ports - blocklisted_ports)
+        self.verify(allowed_len > 0, "At least one port should have been allowed")
 
-            blocked = not allowlisted_ports & blocklisted_ports
-            self.verify(blocked, "At least one port was not blocklisted")
+        blocked = not allowlisted_ports & blocklisted_ports
+        self.verify(blocked, "At least one port was not blocklisted")
 
     @func_test
     def no_blocklisted(self):
diff --git a/dts/tests/TestSuite_checksum_offload.py b/dts/tests/TestSuite_checksum_offload.py
index c1680bd388..58b9609849 100644
--- a/dts/tests/TestSuite_checksum_offload.py
+++ b/dts/tests/TestSuite_checksum_offload.py
@@ -117,16 +117,16 @@ def test_insert_checksums(self) -> None:
             Ether(dst=mac_id) / IPv6(src="::1") / UDP() / Raw(payload),
             Ether(dst=mac_id) / IPv6(src="::1") / TCP() / Raw(payload),
         ]
-        with TestPmdShell(node=self.sut_node, enable_rx_cksum=True) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.csum)
-            testpmd.set_verbose(level=1)
-            self.setup_hw_offload(testpmd=testpmd)
-            testpmd.start()
-            self.send_packets_and_verify(packet_list=packet_list, load=payload, should_receive=True)
-            for i in range(0, len(packet_list)):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
-                )
+        testpmd = TestPmdShell(node=self.sut_node, enable_rx_cksum=True)
+        testpmd.set_forward_mode(SimpleForwardingModes.csum)
+        testpmd.set_verbose(level=1)
+        self.setup_hw_offload(testpmd=testpmd)
+        testpmd.start()
+        self.send_packets_and_verify(packet_list=packet_list, load=payload, should_receive=True)
+        for i in range(0, len(packet_list)):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
+            )
 
     @func_test
     def test_no_insert_checksums(self) -> None:
@@ -139,15 +139,15 @@ def test_no_insert_checksums(self) -> None:
             Ether(dst=mac_id) / IPv6(src="::1") / UDP() / Raw(payload),
             Ether(dst=mac_id) / IPv6(src="::1") / TCP() / Raw(payload),
         ]
-        with TestPmdShell(node=self.sut_node, enable_rx_cksum=True) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.csum)
-            testpmd.set_verbose(level=1)
-            testpmd.start()
-            self.send_packets_and_verify(packet_list=packet_list, load=payload, should_receive=True)
-            for i in range(0, len(packet_list)):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
-                )
+        testpmd = TestPmdShell(node=self.sut_node, enable_rx_cksum=True)
+        testpmd.set_forward_mode(SimpleForwardingModes.csum)
+        testpmd.set_verbose(level=1)
+        testpmd.start()
+        self.send_packets_and_verify(packet_list=packet_list, load=payload, should_receive=True)
+        for i in range(0, len(packet_list)):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
+            )
 
     @func_test
     def test_l4_rx_checksum(self) -> None:
@@ -159,18 +159,18 @@ def test_l4_rx_checksum(self) -> None:
             Ether(dst=mac_id) / IP() / UDP(chksum=0xF),
             Ether(dst=mac_id) / IP() / TCP(chksum=0xF),
         ]
-        with TestPmdShell(node=self.sut_node, enable_rx_cksum=True) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.csum)
-            testpmd.set_verbose(level=1)
-            self.setup_hw_offload(testpmd=testpmd)
-            for i in range(0, 2):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
-                )
-            for i in range(2, 4):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
-                )
+        testpmd = TestPmdShell(node=self.sut_node, enable_rx_cksum=True)
+        testpmd.set_forward_mode(SimpleForwardingModes.csum)
+        testpmd.set_verbose(level=1)
+        self.setup_hw_offload(testpmd=testpmd)
+        for i in range(0, 2):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
+            )
+        for i in range(2, 4):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
+            )
 
     @func_test
     def test_l3_rx_checksum(self) -> None:
@@ -182,18 +182,18 @@ def test_l3_rx_checksum(self) -> None:
             Ether(dst=mac_id) / IP(chksum=0xF) / UDP(),
             Ether(dst=mac_id) / IP(chksum=0xF) / TCP(),
         ]
-        with TestPmdShell(node=self.sut_node, enable_rx_cksum=True) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.csum)
-            testpmd.set_verbose(level=1)
-            self.setup_hw_offload(testpmd=testpmd)
-            for i in range(0, 2):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
-                )
-            for i in range(2, 4):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=True, goodIP=False, testpmd=testpmd, id=mac_id
-                )
+        testpmd = TestPmdShell(node=self.sut_node, enable_rx_cksum=True)
+        testpmd.set_forward_mode(SimpleForwardingModes.csum)
+        testpmd.set_verbose(level=1)
+        self.setup_hw_offload(testpmd=testpmd)
+        for i in range(0, 2):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
+            )
+        for i in range(2, 4):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=True, goodIP=False, testpmd=testpmd, id=mac_id
+            )
 
     @func_test
     def test_validate_rx_checksum(self) -> None:
@@ -209,22 +209,22 @@ def test_validate_rx_checksum(self) -> None:
             Ether(dst=mac_id) / IPv6(src="::1") / UDP(chksum=0xF),
             Ether(dst=mac_id) / IPv6(src="::1") / TCP(chksum=0xF),
         ]
-        with TestPmdShell(node=self.sut_node, enable_rx_cksum=True) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.csum)
-            testpmd.set_verbose(level=1)
-            self.setup_hw_offload(testpmd=testpmd)
-            for i in range(0, 4):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
-                )
-            for i in range(4, 6):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=False, goodIP=False, testpmd=testpmd, id=mac_id
-                )
-            for i in range(6, 8):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
-                )
+        testpmd = TestPmdShell(node=self.sut_node, enable_rx_cksum=True)
+        testpmd.set_forward_mode(SimpleForwardingModes.csum)
+        testpmd.set_verbose(level=1)
+        self.setup_hw_offload(testpmd=testpmd)
+        for i in range(0, 4):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
+            )
+        for i in range(4, 6):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=False, goodIP=False, testpmd=testpmd, id=mac_id
+            )
+        for i in range(6, 8):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
+            )
 
     @requires(NicCapability.RX_OFFLOAD_VLAN)
     @func_test
@@ -238,20 +238,20 @@ def test_vlan_checksum(self) -> None:
             Ether(dst=mac_id) / Dot1Q(vlan=1) / IPv6(src="::1") / UDP(chksum=0xF) / Raw(payload),
             Ether(dst=mac_id) / Dot1Q(vlan=1) / IPv6(src="::1") / TCP(chksum=0xF) / Raw(payload),
         ]
-        with TestPmdShell(node=self.sut_node, enable_rx_cksum=True) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.csum)
-            testpmd.set_verbose(level=1)
-            self.setup_hw_offload(testpmd=testpmd)
-            testpmd.start()
-            self.send_packets_and_verify(packet_list=packet_list, load=payload, should_receive=True)
-            for i in range(0, 2):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=False, goodIP=False, testpmd=testpmd, id=mac_id
-                )
-            for i in range(2, 4):
-                self.send_packet_and_verify_checksum(
-                    packet=packet_list[i], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
-                )
+        testpmd = TestPmdShell(node=self.sut_node, enable_rx_cksum=True)
+        testpmd.set_forward_mode(SimpleForwardingModes.csum)
+        testpmd.set_verbose(level=1)
+        self.setup_hw_offload(testpmd=testpmd)
+        testpmd.start()
+        self.send_packets_and_verify(packet_list=packet_list, load=payload, should_receive=True)
+        for i in range(0, 2):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=False, goodIP=False, testpmd=testpmd, id=mac_id
+            )
+        for i in range(2, 4):
+            self.send_packet_and_verify_checksum(
+                packet=packet_list[i], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
+            )
 
     @requires(NicCapability.RX_OFFLOAD_SCTP_CKSUM)
     @func_test
@@ -262,14 +262,14 @@ def test_validate_sctp_checksum(self) -> None:
             Ether(dst=mac_id) / IP() / SCTP(),
             Ether(dst=mac_id) / IP() / SCTP(chksum=0xF),
         ]
-        with TestPmdShell(node=self.sut_node, enable_rx_cksum=True) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.csum)
-            testpmd.set_verbose(level=1)
-            testpmd.csum_set_hw(layers=ChecksumOffloadOptions.sctp)
-            testpmd.start()
-            self.send_packet_and_verify_checksum(
-                packet=packet_list[0], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
-            )
-            self.send_packet_and_verify_checksum(
-                packet=packet_list[1], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
-            )
+        testpmd = TestPmdShell(node=self.sut_node, enable_rx_cksum=True)
+        testpmd.set_forward_mode(SimpleForwardingModes.csum)
+        testpmd.set_verbose(level=1)
+        testpmd.csum_set_hw(layers=ChecksumOffloadOptions.sctp)
+        testpmd.start()
+        self.send_packet_and_verify_checksum(
+            packet=packet_list[0], goodL4=True, goodIP=True, testpmd=testpmd, id=mac_id
+        )
+        self.send_packet_and_verify_checksum(
+            packet=packet_list[1], goodL4=False, goodIP=True, testpmd=testpmd, id=mac_id
+        )
diff --git a/dts/tests/TestSuite_dynamic_queue_conf.py b/dts/tests/TestSuite_dynamic_queue_conf.py
index e55716f545..caf820151e 100644
--- a/dts/tests/TestSuite_dynamic_queue_conf.py
+++ b/dts/tests/TestSuite_dynamic_queue_conf.py
@@ -83,37 +83,37 @@ def wrap(self: "TestDynamicQueueConf", is_rx_testing: bool) -> None:
         while len(queues_to_config) < self.num_ports_to_modify:
             queues_to_config.add(random.randint(1, self.number_of_queues - 1))
         unchanged_queues = set(range(self.number_of_queues)) - queues_to_config
-        with TestPmdShell(
+        testpmd = TestPmdShell(
             self.sut_node,
             port_topology=PortTopology.chained,
             rx_queues=self.number_of_queues,
             tx_queues=self.number_of_queues,
-        ) as testpmd:
-            for q in queues_to_config:
-                testpmd.stop_port_queue(port_id, q, is_rx_testing)
-            testpmd.set_forward_mode(SimpleForwardingModes.mac)
-
-            test_meth(
-                self,
-                port_id,
-                queues_to_config,
-                unchanged_queues,
-                testpmd,
-                is_rx_testing,
-            )
-
-            for queue_id in queues_to_config:
-                testpmd.start_port_queue(port_id, queue_id, is_rx_testing)
+        )
+        for q in queues_to_config:
+            testpmd.stop_port_queue(port_id, q, is_rx_testing)
+        testpmd.set_forward_mode(SimpleForwardingModes.mac)
+
+        test_meth(
+            self,
+            port_id,
+            queues_to_config,
+            unchanged_queues,
+            testpmd,
+            is_rx_testing,
+        )
+
+        for queue_id in queues_to_config:
+            testpmd.start_port_queue(port_id, queue_id, is_rx_testing)
 
-            testpmd.start()
-            self.send_packets_with_different_addresses(self.number_of_packets_to_send)
-            forwarding_stats = testpmd.stop()
-            for queue_id in queues_to_config:
-                self.verify(
-                    self.port_queue_in_stats(port_id, is_rx_testing, queue_id, forwarding_stats),
-                    f"Modified queue {queue_id} on port {port_id} failed to receive traffic after"
-                    "being started again.",
-                )
+        testpmd.start()
+        self.send_packets_with_different_addresses(self.number_of_packets_to_send)
+        forwarding_stats = testpmd.stop()
+        for queue_id in queues_to_config:
+            self.verify(
+                self.port_queue_in_stats(port_id, is_rx_testing, queue_id, forwarding_stats),
+                f"Modified queue {queue_id} on port {port_id} failed to receive traffic after"
+                "being started again.",
+            )
 
     return wrap
 
diff --git a/dts/tests/TestSuite_l2fwd.py b/dts/tests/TestSuite_l2fwd.py
index 0f6ff18907..9acc4365ea 100644
--- a/dts/tests/TestSuite_l2fwd.py
+++ b/dts/tests/TestSuite_l2fwd.py
@@ -44,20 +44,20 @@ def l2fwd_integrity(self) -> None:
         """
         queues = [1, 2, 4, 8]
 
-        with TestPmdShell(
+        shell = TestPmdShell(
             self.sut_node,
             lcore_filter_specifier=LogicalCoreCount(cores_per_socket=4),
             forward_mode=SimpleForwardingModes.mac,
             eth_peer=[EthPeer(1, self.tg_node.ports[1].mac_address)],
             disable_device_start=True,
-        ) as shell:
-            for queues_num in queues:
-                self._logger.info(f"Testing L2 forwarding with {queues_num} queue(s)")
-                shell.set_ports_queues(queues_num)
-                shell.start()
+        )
+        for queues_num in queues:
+            self._logger.info(f"Testing L2 forwarding with {queues_num} queue(s)")
+            shell.set_ports_queues(queues_num)
+            shell.start()
 
-                received_packets = self.send_packets_and_capture(self.packets)
-                expected_packets = self.get_expected_packets(self.packets)
-                self.match_all_packets(expected_packets, received_packets)
+            received_packets = self.send_packets_and_capture(self.packets)
+            expected_packets = self.get_expected_packets(self.packets)
+            self.match_all_packets(expected_packets, received_packets)
 
-                shell.stop()
+            shell.stop()
diff --git a/dts/tests/TestSuite_mac_filter.py b/dts/tests/TestSuite_mac_filter.py
index 11e4b595c7..ac9ceefd85 100644
--- a/dts/tests/TestSuite_mac_filter.py
+++ b/dts/tests/TestSuite_mac_filter.py
@@ -101,22 +101,22 @@ def test_add_remove_mac_addresses(self) -> None:
             Remove the fake mac address from the PMD's address pool.
             Send a packet with the fake mac address to the PMD. (Should not receive)
         """
-        with TestPmdShell(self.sut_node) as testpmd:
-            testpmd.set_promisc(0, enable=False)
-            testpmd.start()
-            mac_address = self._sut_port_ingress.mac_address
-
-            # Send a packet with NIC default mac address
-            self.send_packet_and_verify(mac_address=mac_address, should_receive=True)
-            # Send a packet with different mac address
-            fake_address = "00:00:00:00:00:01"
-            self.send_packet_and_verify(mac_address=fake_address, should_receive=False)
-
-            # Add mac address to pool and rerun tests
-            testpmd.set_mac_addr(0, mac_address=fake_address, add=True)
-            self.send_packet_and_verify(mac_address=fake_address, should_receive=True)
-            testpmd.set_mac_addr(0, mac_address=fake_address, add=False)
-            self.send_packet_and_verify(mac_address=fake_address, should_receive=False)
+        testpmd = TestPmdShell(self.sut_node)
+        testpmd.set_promisc(0, enable=False)
+        testpmd.start()
+        mac_address = self._sut_port_ingress.mac_address
+
+        # Send a packet with NIC default mac address
+        self.send_packet_and_verify(mac_address=mac_address, should_receive=True)
+        # Send a packet with different mac address
+        fake_address = "00:00:00:00:00:01"
+        self.send_packet_and_verify(mac_address=fake_address, should_receive=False)
+
+        # Add mac address to pool and rerun tests
+        testpmd.set_mac_addr(0, mac_address=fake_address, add=True)
+        self.send_packet_and_verify(mac_address=fake_address, should_receive=True)
+        testpmd.set_mac_addr(0, mac_address=fake_address, add=False)
+        self.send_packet_and_verify(mac_address=fake_address, should_receive=False)
 
     @func_test
     def test_invalid_address(self) -> None:
@@ -137,44 +137,42 @@ def test_invalid_address(self) -> None:
             Determine the device's mac address pool size, and fill the pool with fake addresses.
             Attempt to add another fake mac address, overloading the address pool. (Should fail)
         """
-        with TestPmdShell(self.sut_node) as testpmd:
-            testpmd.start()
-            mac_address = self._sut_port_ingress.mac_address
-            try:
-                testpmd.set_mac_addr(0, "00:00:00:00:00:00", add=True)
-                self.verify(False, "Invalid mac address added.")
-            except InteractiveCommandExecutionError:
-                pass
-            try:
-                testpmd.set_mac_addr(0, mac_address, add=False)
-                self.verify(False, "Default mac address removed.")
-            except InteractiveCommandExecutionError:
-                pass
-            # Should be no errors adding this twice
-            testpmd.set_mac_addr(0, "1" + mac_address[1:], add=True)
-            testpmd.set_mac_addr(0, "1" + mac_address[1:], add=True)
-            # Double check to see if default mac address can be removed
-            try:
-                testpmd.set_mac_addr(0, mac_address, add=False)
-                self.verify(False, "Default mac address removed.")
-            except InteractiveCommandExecutionError:
-                pass
-
-            for i in range(testpmd.show_port_info(0).max_mac_addresses_num - 1):
-                # A0 fake address based on the index 'i'.
-                fake_address = str(hex(i)[2:].zfill(12))
-                # Insert ':' characters every two indexes to create a fake mac address.
-                fake_address = ":".join(
-                    fake_address[x : x + 2] for x in range(0, len(fake_address), 2)
-                )
-                testpmd.set_mac_addr(0, fake_address, add=True, verify=False)
-            try:
-                testpmd.set_mac_addr(0, "E" + mac_address[1:], add=True)
-                # We add an extra address to compensate for mac address pool inconsistencies.
-                testpmd.set_mac_addr(0, "F" + mac_address[1:], add=True)
-                self.verify(False, "Mac address limit exceeded.")
-            except InteractiveCommandExecutionError:
-                pass
+        testpmd = TestPmdShell(self.sut_node)
+        testpmd.start()
+        mac_address = self._sut_port_ingress.mac_address
+        try:
+            testpmd.set_mac_addr(0, "00:00:00:00:00:00", add=True)
+            self.verify(False, "Invalid mac address added.")
+        except InteractiveCommandExecutionError:
+            pass
+        try:
+            testpmd.set_mac_addr(0, mac_address, add=False)
+            self.verify(False, "Default mac address removed.")
+        except InteractiveCommandExecutionError:
+            pass
+        # Should be no errors adding this twice
+        testpmd.set_mac_addr(0, "1" + mac_address[1:], add=True)
+        testpmd.set_mac_addr(0, "1" + mac_address[1:], add=True)
+        # Double check to see if default mac address can be removed
+        try:
+            testpmd.set_mac_addr(0, mac_address, add=False)
+            self.verify(False, "Default mac address removed.")
+        except InteractiveCommandExecutionError:
+            pass
+
+        for i in range(testpmd.show_port_info(0).max_mac_addresses_num - 1):
+            # A0 fake address based on the index 'i'.
+            fake_address = str(hex(i)[2:].zfill(12))
+            # Insert ':' characters every two indexes to create a fake mac address.
+            fake_address = ":".join(fake_address[x : x + 2] for x in range(0, len(fake_address), 2))
+            testpmd.set_mac_addr(0, fake_address, add=True, verify=False)
+        try:
+            testpmd.set_mac_addr(0, "E" + mac_address[1:], add=True)
+            # We add an extra address to compensate for mac address pool inconsistencies.
+            testpmd.set_mac_addr(0, "F" + mac_address[1:], add=True)
+            self.verify(False, "Mac address limit exceeded.")
+        except InteractiveCommandExecutionError:
+            pass
 
     @requires(NicCapability.MCAST_FILTERING)
     @func_test
@@ -191,14 +189,14 @@ def test_multicast_filter(self) -> None:
             Remove the fake multicast address from the PMDs multicast address filter.
             Send a packet with the fake multicast address to the PMD. (Should not receive)
         """
-        with TestPmdShell(self.sut_node) as testpmd:
-            testpmd.start()
-            testpmd.set_promisc(0, enable=False)
-            multicast_address = "01:00:5E:00:00:00"
+        testpmd = TestPmdShell(self.sut_node)
+        testpmd.set_promisc(0, enable=False)
+        multicast_address = "01:00:5E:00:00:00"
 
-            testpmd.set_multicast_mac_addr(0, multi_addr=multicast_address, add=True)
-            self.send_packet_and_verify(multicast_address, should_receive=True)
+        testpmd.set_multicast_mac_addr(0, multi_addr=multicast_address, add=True)
+        testpmd.start()
+        self.send_packet_and_verify(multicast_address, should_receive=True)
 
-            # Remove multicast filter and verify the packet was not received.
-            testpmd.set_multicast_mac_addr(0, multicast_address, add=False)
-            self.send_packet_and_verify(multicast_address, should_receive=False)
+        # Remove multicast filter and verify the packet was not received.
+        testpmd.set_multicast_mac_addr(0, multicast_address, add=False)
+        self.send_packet_and_verify(multicast_address, should_receive=False)
diff --git a/dts/tests/TestSuite_pmd_buffer_scatter.py b/dts/tests/TestSuite_pmd_buffer_scatter.py
index b2f42425d4..ffb345ae4e 100644
--- a/dts/tests/TestSuite_pmd_buffer_scatter.py
+++ b/dts/tests/TestSuite_pmd_buffer_scatter.py
@@ -103,26 +103,24 @@ def pmd_scatter(self, mbsize: int) -> None:
         Test:
             Start testpmd and run functional test with preset mbsize.
         """
-        with TestPmdShell(
+        testpmd = TestPmdShell(
             self.sut_node,
             forward_mode=SimpleForwardingModes.mac,
             mbcache=200,
             mbuf_size=[mbsize],
             max_pkt_len=9000,
             tx_offloads=0x00008000,
-        ) as testpmd:
-            testpmd.start()
-
-            for offset in [-1, 0, 1, 4, 5]:
-                recv_payload = self.scatter_pktgen_send_packet(mbsize + offset)
-                self._logger.debug(
-                    f"Payload of scattered packet after forwarding: \n{recv_payload}"
-                )
-                self.verify(
-                    ("58 " * 8).strip() in recv_payload,
-                    "Payload of scattered packet did not match expected payload with offset "
-                    f"{offset}.",
-                )
+        )
+        testpmd.start()
+
+        for offset in [-1, 0, 1, 4, 5]:
+            recv_payload = self.scatter_pktgen_send_packet(mbsize + offset)
+            self._logger.debug(f"Payload of scattered packet after forwarding: \n{recv_payload}")
+            self.verify(
+                ("58 " * 8).strip() in recv_payload,
+                "Payload of scattered packet did not match expected payload with offset "
+                f"{offset}.",
+            )
 
     @requires(NicCapability.SCATTERED_RX_ENABLED)
     @func_test
diff --git a/dts/tests/TestSuite_smoke_tests.py b/dts/tests/TestSuite_smoke_tests.py
index bc3a2a6bf9..0681ad4ea7 100644
--- a/dts/tests/TestSuite_smoke_tests.py
+++ b/dts/tests/TestSuite_smoke_tests.py
@@ -104,8 +104,8 @@ def test_devices_listed_in_testpmd(self) -> None:
         Test:
             List all devices found in testpmd and verify the configured devices are among them.
         """
-        with TestPmdShell(self.sut_node) as testpmd:
-            dev_list = [str(x) for x in testpmd.get_devices()]
+        testpmd = TestPmdShell(self.sut_node)
+        dev_list = [str(x) for x in testpmd.get_devices()]
         for nic in self.nics_in_node:
             self.verify(
                 nic.pci in dev_list,
diff --git a/dts/tests/TestSuite_vlan.py b/dts/tests/TestSuite_vlan.py
index c67520baef..ede1f69495 100644
--- a/dts/tests/TestSuite_vlan.py
+++ b/dts/tests/TestSuite_vlan.py
@@ -124,10 +124,10 @@ def test_vlan_receipt_no_stripping(self) -> None:
         Test:
             Create an interactive testpmd shell and verify a VLAN packet.
         """
-        with TestPmdShell(node=self.sut_node) as testpmd:
-            self.vlan_setup(testpmd=testpmd, port_id=0, filtered_id=1)
-            testpmd.start()
-            self.send_vlan_packet_and_verify(True, strip=False, vlan_id=1)
+        testpmd = TestPmdShell(node=self.sut_node)
+        self.vlan_setup(testpmd=testpmd, port_id=0, filtered_id=1)
+        testpmd.start()
+        self.send_vlan_packet_and_verify(True, strip=False, vlan_id=1)
 
     @requires(NicCapability.RX_OFFLOAD_VLAN_STRIP)
     @func_test
@@ -137,11 +137,11 @@ def test_vlan_receipt_stripping(self) -> None:
         Test:
             Create an interactive testpmd shell and verify a VLAN packet.
         """
-        with TestPmdShell(node=self.sut_node) as testpmd:
-            self.vlan_setup(testpmd=testpmd, port_id=0, filtered_id=1)
-            testpmd.set_vlan_strip(port=0, enable=True)
-            testpmd.start()
-            self.send_vlan_packet_and_verify(should_receive=True, strip=True, vlan_id=1)
+        testpmd = TestPmdShell(node=self.sut_node)
+        self.vlan_setup(testpmd=testpmd, port_id=0, filtered_id=1)
+        testpmd.set_vlan_strip(port=0, enable=True)
+        testpmd.start()
+        self.send_vlan_packet_and_verify(should_receive=True, strip=True, vlan_id=1)
 
     @func_test
     def test_vlan_no_receipt(self) -> None:
@@ -150,10 +150,10 @@ def test_vlan_no_receipt(self) -> None:
         Test:
             Create an interactive testpmd shell and verify a VLAN packet.
         """
-        with TestPmdShell(node=self.sut_node) as testpmd:
-            self.vlan_setup(testpmd=testpmd, port_id=0, filtered_id=1)
-            testpmd.start()
-            self.send_vlan_packet_and_verify(should_receive=False, strip=False, vlan_id=2)
+        testpmd = TestPmdShell(node=self.sut_node)
+        self.vlan_setup(testpmd=testpmd, port_id=0, filtered_id=1)
+        testpmd.start()
+        self.send_vlan_packet_and_verify(should_receive=False, strip=False, vlan_id=2)
 
     @func_test
     def test_vlan_header_insertion(self) -> None:
@@ -162,11 +162,11 @@ def test_vlan_header_insertion(self) -> None:
         Test:
             Create an interactive testpmd shell and verify a non-VLAN packet.
         """
-        with TestPmdShell(node=self.sut_node) as testpmd:
-            testpmd.set_forward_mode(SimpleForwardingModes.mac)
-            testpmd.set_promisc(port=0, enable=False)
-            testpmd.stop_all_ports()
-            testpmd.tx_vlan_set(port=1, enable=True, vlan=51)
-            testpmd.start_all_ports()
-            testpmd.start()
-            self.send_packet_and_verify_insertion(expected_id=51)
+        testpmd = TestPmdShell(node=self.sut_node)
+        testpmd.set_forward_mode(SimpleForwardingModes.mac)
+        testpmd.set_promisc(port=0, enable=False)
+        testpmd.stop_all_ports()
+        testpmd.tx_vlan_set(port=1, enable=True, vlan=51)
+        testpmd.start_all_ports()
+        testpmd.start()
+        self.send_packet_and_verify_insertion(expected_id=51)
-- 
2.43.0


^ permalink raw reply	[flat|nested] 3+ messages in thread

* [RFC PATCH 2/2] dts: revert back shell split
  2024-12-20 17:23 [RFC PATCH 0/2] dts: add basic scope to improve shell handling Luca Vizzarro
  2024-12-20 17:24 ` [RFC PATCH 1/2] dts: add scoping and shell registration to Node Luca Vizzarro
@ 2024-12-20 17:24 ` Luca Vizzarro
  1 sibling, 0 replies; 3+ messages in thread
From: Luca Vizzarro @ 2024-12-20 17:24 UTC (permalink / raw)
  To: dev; +Cc: Paul Szczepanek, Patrick Robb, Luca Vizzarro

The InteractiveShell was previously renamed to
SingleActiveInteractiveShell to represent a shell that can only be run
once. The mechanism used to enforce this was a context manager, which
turned out to be more constrictive on test suite development.

Shell closure is now handled by the scoping mechanism, and an attribute
is used to enforce the single active shell. Also the split has been
reverted.

Signed-off-by: Luca Vizzarro <luca.vizzarro@arm.com>
---
 dts/framework/remote_session/dpdk_shell.py    |   8 +-
 .../remote_session/interactive_shell.py       | 262 +++++++++++++++--
 .../single_active_interactive_shell.py        | 268 ------------------
 dts/framework/remote_session/testpmd_shell.py |   4 +-
 dts/framework/testbed_model/capability.py     |   2 +-
 dts/framework/testbed_model/node.py           |  10 +-
 6 files changed, 250 insertions(+), 304 deletions(-)
 delete mode 100644 dts/framework/remote_session/single_active_interactive_shell.py

diff --git a/dts/framework/remote_session/dpdk_shell.py b/dts/framework/remote_session/dpdk_shell.py
index c11d9ab81c..c37dcb2b62 100644
--- a/dts/framework/remote_session/dpdk_shell.py
+++ b/dts/framework/remote_session/dpdk_shell.py
@@ -8,10 +8,11 @@
 
 from abc import ABC
 from pathlib import PurePath
+from typing import ClassVar
 
 from framework.params.eal import EalParams
-from framework.remote_session.single_active_interactive_shell import (
-    SingleActiveInteractiveShell,
+from framework.remote_session.interactive_shell import (
+    InteractiveShell,
 )
 from framework.settings import SETTINGS
 from framework.testbed_model.cpu import LogicalCoreCount, LogicalCoreList
@@ -61,7 +62,7 @@ def compute_eal_params(
     return params
 
 
-class DPDKShell(SingleActiveInteractiveShell, ABC):
+class DPDKShell(InteractiveShell, ABC):
     """The base class for managing DPDK-based interactive shells.
 
     This class shouldn't be instantiated directly, but instead be extended.
@@ -71,6 +72,7 @@ class DPDKShell(SingleActiveInteractiveShell, ABC):
 
     _node: SutNode
     _app_params: EalParams
+    _single_active_per_node: ClassVar[bool] = True
 
     def __init__(
         self,
diff --git a/dts/framework/remote_session/interactive_shell.py b/dts/framework/remote_session/interactive_shell.py
index 9ca285b604..a136419181 100644
--- a/dts/framework/remote_session/interactive_shell.py
+++ b/dts/framework/remote_session/interactive_shell.py
@@ -1,44 +1,256 @@
 # SPDX-License-Identifier: BSD-3-Clause
-# Copyright(c) 2023 University of New Hampshire
+# Copyright(c) 2024 University of New Hampshire
 # Copyright(c) 2024 Arm Limited
 
-"""Interactive shell with manual stop/start functionality.
+"""Common functionality for interactive shell handling.
 
-Provides a class that doesn't require being started/stopped using a context manager and can instead
-be started and stopped manually, or have the stopping process be handled at the time of garbage
-collection.
+The base class, :class:`InteractiveShell`, is meant to be extended by subclasses that
+contain functionality specific to that shell type. These subclasses will often modify things like
+the prompt to expect or the arguments to pass into the application, but still utilize
+the same method for sending a command and collecting output. How this output is handled however
+is often application specific. If an application needs elevated privileges to start it is expected
+that the method for gaining those privileges is provided when initializing the class.
+
+The :option:`--timeout` command line argument and the :envvar:`DTS_TIMEOUT`
+environment variable configure the timeout of getting the output from command execution.
 """
 
-import weakref
+from abc import ABC
+from pathlib import PurePath
 from typing import ClassVar
 
-from .single_active_interactive_shell import SingleActiveInteractiveShell
+from paramiko import Channel, channel
+
+from framework.exception import (
+    InteractiveCommandExecutionError,
+    InteractiveSSHSessionDeadError,
+    InteractiveSSHTimeoutError,
+    InternalError,
+)
+from framework.logger import DTSLogger, get_dts_logger
+from framework.params import Params
+from framework.settings import SETTINGS
+from framework.testbed_model.node import Node
+from framework.utils import MultiInheritanceBaseClass
+
+
+class InteractiveShell(MultiInheritanceBaseClass, ABC):
+    """The base class for managing interactive shells.
 
+    This class shouldn't be instantiated directly, but instead be extended. It contains
+    methods for starting interactive shells as well as sending commands to these shells
+    and collecting input until reaching a certain prompt. All interactive applications
+    will use the same SSH connection, but each will create their own channel on that
+    session.
 
-class InteractiveShell(SingleActiveInteractiveShell):
-    """Adds manual start and stop functionality to interactive shells.
+    Interactive shells are started and stopped using a context manager. This allows for the start
+    and cleanup of the application to happen at predictable times regardless of exceptions or
+    interrupts.
 
-    Like its super-class, this class should not be instantiated directly and should instead be
-    extended. This class also provides an option for automated cleanup of the application using a
-    weakref and a finalize class. This finalize class allows for cleanup of the class at the time
-    of garbage collection and also ensures that cleanup only happens once. This way if a user
-    initiates the closing of the shell manually it is not repeated at the time of garbage
-    collection.
+    Attributes:
+        is_alive: :data:`True` if the application has started successfully, :data:`False`
+            otherwise.
     """
 
-    _finalizer: weakref.finalize
-    #: One attempt should be enough for shells which don't have to worry about other instances
-    #: closing before starting a new one.
-    _init_attempts: ClassVar[int] = 1
+    _node: Node
+    _stdin: channel.ChannelStdinFile
+    _stdout: channel.ChannelFile
+    _ssh_channel: Channel
+    _logger: DTSLogger
+    _timeout: float
+    _app_params: Params
+    _privileged: bool
+    _real_path: PurePath
+
+    #: The number of times to try starting the application before considering it a failure.
+    _init_attempts: ClassVar[int] = 5
+
+    #: Prompt to expect at the end of output when sending a command.
+    #: This is often overridden by subclasses.
+    _default_prompt: ClassVar[str] = ""
+
+    #: Extra characters to add to the end of every command
+    #: before sending them. This is often overridden by subclasses and is
+    #: most commonly an additional newline character. This additional newline
+    #: character is used to force the line that is currently awaiting input
+    #: into the stdout buffer so that it can be consumed and checked against
+    #: the expected prompt.
+    _command_extra_chars: ClassVar[str] = ""
+
+    #: Condition which constraints the user of the class from attempting to run more than one
+    #: shell on the same node at the same time.
+    _single_active_per_node: ClassVar[bool] = False
+
+    #: Path to the executable to start the interactive application.
+    path: ClassVar[PurePath]
+
+    is_alive: bool = False
+
+    def __init__(
+        self,
+        node: Node,
+        privileged: bool = False,
+        timeout: float = SETTINGS.timeout,
+        app_params: Params = Params(),
+        name: str | None = None,
+        **kwargs,
+    ) -> None:
+        """Create an SSH channel during initialization.
+
+        Additional keyword arguments can be passed through `kwargs` if needed for fulfilling other
+        constructors in the case of multiple inheritance.
+
+        Args:
+            node: The node on which to run start the interactive shell.
+            privileged: Enables the shell to run as superuser.
+            timeout: The timeout used for the SSH channel that is dedicated to this interactive
+                shell. This timeout is for collecting output, so if reading from the buffer
+                and no output is gathered within the timeout, an exception is thrown.
+            app_params: The command line parameters to be passed to the application on startup.
+            name: Name for the interactive shell to use for logging. This name will be appended to
+                the name of the underlying node which it is running on.
+            **kwargs: Any additional arguments if any.
+
+        Raises:
+            InternalError: If :attr:`_single_active_per_node` is :data:`True` and another shell of
+                the same class is already running.
+        """
+        if self._single_active_per_node and node.find_active_shell(type(self)) is not None:
+            raise InternalError(
+                "Attempted to run a single-active shell while another one was already open."
+            )
+
+        node.register_shell(self)
+        self._node = node
+
+        if name is None:
+            name = type(self).__name__
+        self._logger = get_dts_logger(f"{node.name}.{name}")
+        self._app_params = app_params
+        self._privileged = privileged
+        self._timeout = timeout
+        # Ensure path is properly formatted for the host
+        self._update_real_path(self.path)
+        super().__init__()
+
+        self.start_application()
+
+    def _setup_ssh_channel(self):
+        self._ssh_channel = self._node.main_session.interactive_session.session.invoke_shell()
+        self._stdin = self._ssh_channel.makefile_stdin("w")
+        self._stdout = self._ssh_channel.makefile("r")
+        self._ssh_channel.settimeout(self._timeout)
+        self._ssh_channel.set_combine_stderr(True)  # combines stdout and stderr streams
+
+    def _make_start_command(self) -> str:
+        """Makes the command that starts the interactive shell."""
+        start_command = f"{self._real_path} {self._app_params or ''}"
+        if self._privileged:
+            start_command = self._node.main_session._get_privileged_command(start_command)
+        return start_command
 
     def start_application(self) -> None:
-        """Start the application.
+        """Starts a new interactive application based on the path to the app.
+
+        This method is often overridden by subclasses as their process for starting may look
+        different. Initialization of the shell on the host can be retried up to
+        `self._init_attempts` - 1 times. This is done because some DPDK applications need slightly
+        more time after exiting their script to clean up EAL before others can start.
+
+        Raises:
+            InteractiveCommandExecutionError: If the application fails to start within the allotted
+                number of retries.
+        """
+        self._setup_ssh_channel()
+        self._ssh_channel.settimeout(5)
+        start_command = self._make_start_command()
+        self.is_alive = True
+        for attempt in range(self._init_attempts):
+            try:
+                self.send_command(start_command)
+                break
+            except InteractiveSSHTimeoutError:
+                self._logger.info(
+                    f"Interactive shell failed to start (attempt {attempt+1} out of "
+                    f"{self._init_attempts})"
+                )
+        else:
+            self._ssh_channel.settimeout(self._timeout)
+            self.is_alive = False  # update state on failure to start
+            raise InteractiveCommandExecutionError("Failed to start application.")
+        self._ssh_channel.settimeout(self._timeout)
+
+    def send_command(
+        self, command: str, prompt: str | None = None, skip_first_line: bool = False
+    ) -> str:
+        """Send `command` and get all output before the expected ending string.
+
+        Lines that expect input are not included in the stdout buffer, so they cannot
+        be used for expect.
+
+        Example:
+            If you were prompted to log into something with a username and password,
+            you cannot expect ``username:`` because it won't yet be in the stdout buffer.
+            A workaround for this could be consuming an extra newline character to force
+            the current `prompt` into the stdout buffer.
 
-        After the application has started, use :class:`weakref.finalize` to manage cleanup.
+        Args:
+            command: The command to send.
+            prompt: After sending the command, `send_command` will be expecting this string.
+                If :data:`None`, will use the class's default prompt.
+            skip_first_line: Skip the first line when capturing the output.
+
+        Returns:
+            All output in the buffer before expected string.
+
+        Raises:
+            InteractiveCommandExecutionError: If attempting to send a command to a shell that is
+                not currently running.
+            InteractiveSSHSessionDeadError: The session died while executing the command.
+            InteractiveSSHTimeoutError: If command was sent but prompt could not be found in
+                the output before the timeout.
         """
-        self._start_application()
-        self._finalizer = weakref.finalize(self, self._close)
+        if not self.is_alive:
+            raise InteractiveCommandExecutionError(
+                f"Cannot send command {command} to application because the shell is not running."
+            )
+        self._logger.info(f"Sending: '{command}'")
+        if prompt is None:
+            prompt = self._default_prompt
+        out: str = ""
+        try:
+            self._stdin.write(f"{command}{self._command_extra_chars}\n")
+            self._stdin.flush()
+            for line in self._stdout:
+                if skip_first_line:
+                    skip_first_line = False
+                    continue
+                if line.rstrip().endswith(prompt):
+                    break
+                out += line
+        except TimeoutError as e:
+            self._logger.exception(e)
+            self._logger.debug(
+                f"Prompt ({prompt}) was not found in output from command before timeout."
+            )
+            raise InteractiveSSHTimeoutError(command) from e
+        except OSError as e:
+            self._logger.exception(e)
+            raise InteractiveSSHSessionDeadError(
+                self._node.main_session.interactive_session.hostname
+            ) from e
+        finally:
+            self._logger.debug(f"Got output: {out}")
+        return out
 
     def close(self) -> None:
-        """Free all resources using :class:`weakref.finalize`."""
-        self._finalizer()
+        """Close the shell."""
+        if not self._stdin.closed:
+            self._stdin.close()
+        if not self._ssh_channel.closed:
+            self._ssh_channel.close()
+        self.is_alive = False
+
+    def _update_real_path(self, path: PurePath) -> None:
+        """Updates the interactive shell's real path used at command line."""
+        self._real_path = self._node.main_session.join_remote_path(path)
diff --git a/dts/framework/remote_session/single_active_interactive_shell.py b/dts/framework/remote_session/single_active_interactive_shell.py
deleted file mode 100644
index 910af8f655..0000000000
--- a/dts/framework/remote_session/single_active_interactive_shell.py
+++ /dev/null
@@ -1,268 +0,0 @@
-# SPDX-License-Identifier: BSD-3-Clause
-# Copyright(c) 2024 University of New Hampshire
-
-"""Common functionality for interactive shell handling.
-
-The base class, :class:`SingleActiveInteractiveShell`, is meant to be extended by subclasses that
-contain functionality specific to that shell type. These subclasses will often modify things like
-the prompt to expect or the arguments to pass into the application, but still utilize
-the same method for sending a command and collecting output. How this output is handled however
-is often application specific. If an application needs elevated privileges to start it is expected
-that the method for gaining those privileges is provided when initializing the class.
-
-This class is designed for applications like primary applications in DPDK where only one instance
-of the application can be running at a given time and, for this reason, is managed using a context
-manager. This context manager starts the application when you enter the context and cleans up the
-application when you exit. Using a context manager for this is useful since it allows us to ensure
-the application is cleaned up as soon as you leave the block regardless of the reason.
-
-The :option:`--timeout` command line argument and the :envvar:`DTS_TIMEOUT`
-environment variable configure the timeout of getting the output from command execution.
-"""
-
-from abc import ABC
-from pathlib import PurePath
-from typing import ClassVar
-
-from paramiko import Channel, channel
-from typing_extensions import Self
-
-from framework.exception import (
-    InteractiveCommandExecutionError,
-    InteractiveSSHSessionDeadError,
-    InteractiveSSHTimeoutError,
-)
-from framework.logger import DTSLogger, get_dts_logger
-from framework.params import Params
-from framework.settings import SETTINGS
-from framework.testbed_model.node import Node
-from framework.utils import MultiInheritanceBaseClass
-
-
-class SingleActiveInteractiveShell(MultiInheritanceBaseClass, ABC):
-    """The base class for managing interactive shells.
-
-    This class shouldn't be instantiated directly, but instead be extended. It contains
-    methods for starting interactive shells as well as sending commands to these shells
-    and collecting input until reaching a certain prompt. All interactive applications
-    will use the same SSH connection, but each will create their own channel on that
-    session.
-
-    Interactive shells are started and stopped using a context manager. This allows for the start
-    and cleanup of the application to happen at predictable times regardless of exceptions or
-    interrupts.
-
-    Attributes:
-        is_alive: :data:`True` if the application has started successfully, :data:`False`
-            otherwise.
-    """
-
-    _node: Node
-    _stdin: channel.ChannelStdinFile
-    _stdout: channel.ChannelFile
-    _ssh_channel: Channel
-    _logger: DTSLogger
-    _timeout: float
-    _app_params: Params
-    _privileged: bool
-    _real_path: PurePath
-
-    #: The number of times to try starting the application before considering it a failure.
-    _init_attempts: ClassVar[int] = 5
-
-    #: Prompt to expect at the end of output when sending a command.
-    #: This is often overridden by subclasses.
-    _default_prompt: ClassVar[str] = ""
-
-    #: Extra characters to add to the end of every command
-    #: before sending them. This is often overridden by subclasses and is
-    #: most commonly an additional newline character. This additional newline
-    #: character is used to force the line that is currently awaiting input
-    #: into the stdout buffer so that it can be consumed and checked against
-    #: the expected prompt.
-    _command_extra_chars: ClassVar[str] = ""
-
-    #: Path to the executable to start the interactive application.
-    path: ClassVar[PurePath]
-
-    is_alive: bool = False
-
-    def __init__(
-        self,
-        node: Node,
-        privileged: bool = False,
-        timeout: float = SETTINGS.timeout,
-        app_params: Params = Params(),
-        name: str | None = None,
-        **kwargs,
-    ) -> None:
-        """Create an SSH channel during initialization.
-
-        Additional keyword arguments can be passed through `kwargs` if needed for fulfilling other
-        constructors in the case of multiple inheritance.
-
-        Args:
-            node: The node on which to run start the interactive shell.
-            privileged: Enables the shell to run as superuser.
-            timeout: The timeout used for the SSH channel that is dedicated to this interactive
-                shell. This timeout is for collecting output, so if reading from the buffer
-                and no output is gathered within the timeout, an exception is thrown.
-            app_params: The command line parameters to be passed to the application on startup.
-            name: Name for the interactive shell to use for logging. This name will be appended to
-                the name of the underlying node which it is running on.
-            **kwargs: Any additional arguments if any.
-        """
-        node.register_shell(self)
-        self._node = node
-
-        if name is None:
-            name = type(self).__name__
-        self._logger = get_dts_logger(f"{node.name}.{name}")
-        self._app_params = app_params
-        self._privileged = privileged
-        self._timeout = timeout
-        # Ensure path is properly formatted for the host
-        self._update_real_path(self.path)
-        super().__init__()
-
-    def _setup_ssh_channel(self):
-        self._ssh_channel = self._node.main_session.interactive_session.session.invoke_shell()
-        self._stdin = self._ssh_channel.makefile_stdin("w")
-        self._stdout = self._ssh_channel.makefile("r")
-        self._ssh_channel.settimeout(self._timeout)
-        self._ssh_channel.set_combine_stderr(True)  # combines stdout and stderr streams
-
-    def _make_start_command(self) -> str:
-        """Makes the command that starts the interactive shell."""
-        start_command = f"{self._real_path} {self._app_params or ''}"
-        if self._privileged:
-            start_command = self._node.main_session._get_privileged_command(start_command)
-        return start_command
-
-    def _start_application(self) -> None:
-        """Starts a new interactive application based on the path to the app.
-
-        This method is often overridden by subclasses as their process for starting may look
-        different. Initialization of the shell on the host can be retried up to
-        `self._init_attempts` - 1 times. This is done because some DPDK applications need slightly
-        more time after exiting their script to clean up EAL before others can start.
-
-        Raises:
-            InteractiveCommandExecutionError: If the application fails to start within the allotted
-                number of retries.
-        """
-        self._setup_ssh_channel()
-        self._ssh_channel.settimeout(5)
-        start_command = self._make_start_command()
-        self.is_alive = True
-        for attempt in range(self._init_attempts):
-            try:
-                self.send_command(start_command)
-                break
-            except InteractiveSSHTimeoutError:
-                self._logger.info(
-                    f"Interactive shell failed to start (attempt {attempt+1} out of "
-                    f"{self._init_attempts})"
-                )
-        else:
-            self._ssh_channel.settimeout(self._timeout)
-            self.is_alive = False  # update state on failure to start
-            raise InteractiveCommandExecutionError("Failed to start application.")
-        self._ssh_channel.settimeout(self._timeout)
-
-    def send_command(
-        self, command: str, prompt: str | None = None, skip_first_line: bool = False
-    ) -> str:
-        """Send `command` and get all output before the expected ending string.
-
-        Lines that expect input are not included in the stdout buffer, so they cannot
-        be used for expect.
-
-        Example:
-            If you were prompted to log into something with a username and password,
-            you cannot expect ``username:`` because it won't yet be in the stdout buffer.
-            A workaround for this could be consuming an extra newline character to force
-            the current `prompt` into the stdout buffer.
-
-        Args:
-            command: The command to send.
-            prompt: After sending the command, `send_command` will be expecting this string.
-                If :data:`None`, will use the class's default prompt.
-            skip_first_line: Skip the first line when capturing the output.
-
-        Returns:
-            All output in the buffer before expected string.
-
-        Raises:
-            InteractiveCommandExecutionError: If attempting to send a command to a shell that is
-                not currently running.
-            InteractiveSSHSessionDeadError: The session died while executing the command.
-            InteractiveSSHTimeoutError: If command was sent but prompt could not be found in
-                the output before the timeout.
-        """
-        if not self.is_alive:
-            raise InteractiveCommandExecutionError(
-                f"Cannot send command {command} to application because the shell is not running."
-            )
-        self._logger.info(f"Sending: '{command}'")
-        if prompt is None:
-            prompt = self._default_prompt
-        out: str = ""
-        try:
-            self._stdin.write(f"{command}{self._command_extra_chars}\n")
-            self._stdin.flush()
-            for line in self._stdout:
-                if skip_first_line:
-                    skip_first_line = False
-                    continue
-                if line.rstrip().endswith(prompt):
-                    break
-                out += line
-        except TimeoutError as e:
-            self._logger.exception(e)
-            self._logger.debug(
-                f"Prompt ({prompt}) was not found in output from command before timeout."
-            )
-            raise InteractiveSSHTimeoutError(command) from e
-        except OSError as e:
-            self._logger.exception(e)
-            raise InteractiveSSHSessionDeadError(
-                self._node.main_session.interactive_session.hostname
-            ) from e
-        finally:
-            self._logger.debug(f"Got output: {out}")
-        return out
-
-    def _close(self) -> None:
-        self._stdin.close()
-        self._ssh_channel.close()
-
-    def _update_real_path(self, path: PurePath) -> None:
-        """Updates the interactive shell's real path used at command line."""
-        self._real_path = self._node.main_session.join_remote_path(path)
-
-    def __enter__(self) -> Self:
-        """Enter the context block.
-
-        Upon entering a context block with this class, the desired behavior is to create the
-        channel for the application to use, and then start the application.
-
-        Returns:
-            Reference to the object for the application after it has been started.
-        """
-        self._start_application()
-        return self
-
-    def __exit__(self, *_) -> None:
-        """Exit the context block.
-
-        Upon exiting a context block with this class, we want to ensure that the instance of the
-        application is explicitly closed and properly cleaned up using its close method. Note that
-        because this method returns :data:`None` if an exception was raised within the block, it is
-        not handled and will be re-raised after the application is closed.
-
-        The desired behavior is to close the application regardless of the reason for exiting the
-        context and then recreate that reason afterwards. All method arguments are ignored for
-        this reason.
-        """
-        self._close()
diff --git a/dts/framework/remote_session/testpmd_shell.py b/dts/framework/remote_session/testpmd_shell.py
index d187eaea94..cb8cd6f0ca 100644
--- a/dts/framework/remote_session/testpmd_shell.py
+++ b/dts/framework/remote_session/testpmd_shell.py
@@ -2077,11 +2077,11 @@ def set_verbose(self, level: int, verify: bool = True) -> None:
                     f"Testpmd failed to set verbose level to {level}."
                 )
 
-    def _close(self) -> None:
+    def close(self) -> None:
         """Overrides :meth:`~.interactive_shell.close`."""
         self.stop()
         self.send_command("quit", "Bye...")
-        return super()._close()
+        return super().close()
 
     """
     ====== Capability retrieval methods ======
diff --git a/dts/framework/testbed_model/capability.py b/dts/framework/testbed_model/capability.py
index e883f59d11..ad497e8ad9 100644
--- a/dts/framework/testbed_model/capability.py
+++ b/dts/framework/testbed_model/capability.py
@@ -237,7 +237,7 @@ def get_supported_capabilities(
                 for capability in capabilities:
                     if capability.nic_capability in supported_capabilities:
                         supported_conditional_capabilities.add(capability)
-            testpmd_shell._close()
+            testpmd_shell.close()
 
         logger.debug(f"Found supported capabilities {supported_conditional_capabilities}.")
         return supported_conditional_capabilities
diff --git a/dts/framework/testbed_model/node.py b/dts/framework/testbed_model/node.py
index 4f06968adc..62d0c09e78 100644
--- a/dts/framework/testbed_model/node.py
+++ b/dts/framework/testbed_model/node.py
@@ -37,13 +37,13 @@
 from .port import Port
 
 if TYPE_CHECKING:
-    from framework.remote_session.single_active_interactive_shell import (
-        SingleActiveInteractiveShell,
+    from framework.remote_session.interactive_shell import (
+        InteractiveShell,
     )
 
 T = TypeVar("T")
 Scope = Literal["unknown", "suite", "case"]
-ScopedShell = tuple[Scope, SingleActiveInteractiveShell]
+ScopedShell = tuple[Scope, InteractiveShell]
 
 
 class Node(ABC):
@@ -164,7 +164,7 @@ def exit_scope(self) -> Scope:
             self.clean_up_shells(current_scope)
             return current_scope
 
-    def register_shell(self, shell: SingleActiveInteractiveShell) -> None:
+    def register_shell(self, shell: InteractiveShell) -> None:
         """Register a new shell to the pool of active shells."""
         self._active_shells.append((self.current_scope, shell))
 
@@ -179,7 +179,7 @@ def clean_up_shells(self, scope: Scope) -> None:
         ]
 
         for i in reversed(zombie_shells_indices):
-            self._active_shells[i][1]._close()
+            self._active_shells[i][1].close()
             del self._active_shells[i]
 
     def create_session(self, name: str) -> OSSession:
-- 
2.43.0


^ permalink raw reply	[flat|nested] 3+ messages in thread

end of thread, other threads:[~2024-12-20 17:24 UTC | newest]

Thread overview: 3+ messages (download: mbox.gz / follow: Atom feed)
-- links below jump to the message on this page --
2024-12-20 17:23 [RFC PATCH 0/2] dts: add basic scope to improve shell handling Luca Vizzarro
2024-12-20 17:24 ` [RFC PATCH 1/2] dts: add scoping and shell registration to Node Luca Vizzarro
2024-12-20 17:24 ` [RFC PATCH 2/2] dts: revert back shell split Luca Vizzarro

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox;
as well as URLs for NNTP newsgroup(s).