diff --git a/modules/commands/multitest_command.py b/modules/commands/multitest_command.py index 09731bd..dffa312 100644 --- a/modules/commands/multitest_command.py +++ b/modules/commands/multitest_command.py @@ -6,12 +6,25 @@ Listens for a period of time and collects all unique paths from incoming message import asyncio import time -from typing import Set, Optional +from typing import Set, Optional, Dict +from dataclasses import dataclass from .base_command import BaseCommand from ..models import MeshMessage from ..utils import calculate_packet_hash +@dataclass +class MultitestSession: + """Represents an active multitest listening session""" + user_id: str + target_packet_hash: str + triggering_timestamp: float + listening_start_time: float + listening_duration: float + collected_paths: Set[str] + initial_path: Optional[str] = None + + class MultitestCommand(BaseCommand): """Handles the multitest command - listens for multiple path variations""" @@ -29,14 +42,19 @@ class MultitestCommand(BaseCommand): def __init__(self, bot): super().__init__(bot) self.multitest_enabled = self.get_config_value('Multitest_Command', 'enabled', fallback=True, value_type='bool') - self.listening = False - self.collected_paths: Set[str] = set() - self.listening_start_time = 0 - self.listening_duration = 6.0 # 6 seconds listening window - self.target_packet_hash: Optional[str] = None # Hash of the message we're tracking - self.triggering_timestamp: float = 0.0 # Timestamp of the triggering message + # Track active sessions per user to prevent race conditions + # Key: user_id, Value: MultitestSession + self._active_sessions: Dict[str, MultitestSession] = {} + # Lock to prevent concurrent execution from interfering (lazily initialized) + self._execution_lock: Optional[asyncio.Lock] = None self._load_config() + def _get_execution_lock(self) -> asyncio.Lock: + """Get or create the execution lock (lazy initialization)""" + if self._execution_lock is None: + self._execution_lock = asyncio.Lock() + return self._execution_lock + def can_execute(self, message: MeshMessage) -> bool: """Check if this command can be executed with the given message. @@ -211,13 +229,11 @@ class MultitestCommand(BaseCommand): return None def on_message_received(self, message: MeshMessage): - """Callback method called by message handler when a message is received during listening""" - if not self.listening or not self.target_packet_hash: - return + """Callback method called by message handler when a message is received during listening. - # Check if we're still in the listening window - elapsed = time.time() - self.listening_start_time - if elapsed >= self.listening_duration: + Checks all active sessions to see if this message matches any of them. + """ + if not self._active_sessions: return # Get RF data for this message (contains pre-calculated packet hash) @@ -247,34 +263,46 @@ class MultitestCommand(BaseCommand): self.logger.debug(f"Skipping message - could not determine packet hash (sender: {message.sender_id})") return - # CRITICAL: Only collect paths if this message has the same hash as the target - # This ensures we only track variations of the same original message - if message_hash == self.target_packet_hash: - # Try to extract path from RF data first (more reliable) - path = self.extract_path_from_rf_data(rf_data) + # Check all active sessions to see if this message matches any of them + current_time = time.time() + for user_id, session in list(self._active_sessions.items()): + # Check if we're still in the listening window for this session + elapsed = current_time - session.listening_start_time + if elapsed >= session.listening_duration: + continue # Session expired, skip it - # Fallback to message path if RF data extraction failed - if not path: - path = self.extract_path_from_message(message) - - if path: - self.collected_paths.add(path) - self.logger.info(f"✓ Collected path: {path} (hash: {message_hash[:8]}...)") - else: - # Log when we have a matching hash but can't extract path - routing_info = rf_data.get('routing_info', {}) - path_length = routing_info.get('path_length', 0) - if path_length == 0: - self.logger.debug(f"Matched hash {message_hash[:8]}... but path is direct (0 hops)") + # CRITICAL: Only collect paths if this message has the same hash as the target + # This ensures we only track variations of the same original message + if message_hash == session.target_packet_hash: + # Try to extract path from RF data first (more reliable) + path = self.extract_path_from_rf_data(rf_data) + + # Fallback to message path if RF data extraction failed + if not path: + path = self.extract_path_from_message(message) + + if path: + session.collected_paths.add(path) + self.logger.info(f"✓ Collected path for user {user_id}: {path} (hash: {message_hash[:8]}...)") else: - self.logger.debug(f"Matched hash {message_hash[:8]}... but couldn't extract path from routing_info: {routing_info}") - else: - # Log hash mismatches for debugging (but limit to avoid spam) - self.logger.debug(f"✗ Hash mismatch - target: {self.target_packet_hash[:8]}..., received: {message_hash[:8]}... (sender: {message.sender_id})") + # Log when we have a matching hash but can't extract path + routing_info = rf_data.get('routing_info', {}) + path_length = routing_info.get('path_length', 0) + if path_length == 0: + self.logger.debug(f"Matched hash {message_hash[:8]}... but path is direct (0 hops) for user {user_id}") + else: + self.logger.debug(f"Matched hash {message_hash[:8]}... but couldn't extract path from routing_info: {routing_info} for user {user_id}") + else: + # Log hash mismatches for debugging (but limit to avoid spam) + self.logger.debug(f"✗ Hash mismatch for user {user_id} - target: {session.target_packet_hash[:8]}..., received: {message_hash[:8]}... (sender: {message.sender_id})") - def _scan_recent_rf_data(self): - """Scan recent RF data for packets with matching hash (for messages that haven't been processed yet)""" - if not self.target_packet_hash: + def _scan_recent_rf_data(self, session: MultitestSession): + """Scan recent RF data for packets with matching hash (for messages that haven't been processed yet) + + Args: + session: The multitest session to scan for + """ + if not session.target_packet_hash: return try: @@ -290,135 +318,165 @@ class MultitestCommand(BaseCommand): # Only include RF data from the triggering message timestamp onwards # This prevents collecting packets from earlier messages that happen to have the same hash - rf_timestamp = rf_data.get('timestamp', 0) - if rf_timestamp >= self.triggering_timestamp and time_diff <= self.listening_duration: + if rf_timestamp >= session.triggering_timestamp and time_diff <= session.listening_duration: packet_hash = rf_data.get('packet_hash') # CRITICAL: Only process if hash matches exactly and is not None/empty - if packet_hash and packet_hash == self.target_packet_hash: + if packet_hash and packet_hash == session.target_packet_hash: matching_count += 1 # Extract path from this RF data path = self.extract_path_from_rf_data(rf_data) if path: - self.collected_paths.add(path) - self.logger.info(f"✓ Collected path from RF scan: {path} (hash: {packet_hash[:8]}..., time: {time_diff:.2f}s)") + session.collected_paths.add(path) + self.logger.info(f"✓ Collected path from RF scan for user {session.user_id}: {path} (hash: {packet_hash[:8]}..., time: {time_diff:.2f}s)") else: - self.logger.debug(f"Matched hash {packet_hash[:8]}... in RF scan but couldn't extract path") + self.logger.debug(f"Matched hash {packet_hash[:8]}... in RF scan but couldn't extract path for user {session.user_id}") elif packet_hash: mismatching_count += 1 # Only log first few mismatches to avoid spam if mismatching_count <= 3: - self.logger.debug(f"✗ RF scan hash mismatch - target: {self.target_packet_hash[:8]}..., found: {packet_hash[:8]}... (time: {time_diff:.2f}s)") + self.logger.debug(f"✗ RF scan hash mismatch for user {session.user_id} - target: {session.target_packet_hash[:8]}..., found: {packet_hash[:8]}... (time: {time_diff:.2f}s)") if matching_count > 0 or mismatching_count > 0: - self.logger.debug(f"RF scan complete: {matching_count} matching, {mismatching_count} mismatching packets") + self.logger.debug(f"RF scan complete for user {session.user_id}: {matching_count} matching, {mismatching_count} mismatching packets") except Exception as e: - self.logger.debug(f"Error scanning recent RF data: {e}") + self.logger.debug(f"Error scanning recent RF data for user {session.user_id}: {e}") async def execute(self, message: MeshMessage) -> bool: """Execute the multitest command""" - # Determine listening duration based on command variant - content = message.content.strip() - if content.startswith('!'): - content = content[1:].strip() + user_id = message.sender_id or "unknown" - content_lower = content.lower() - # Check for variants: "mt long", "mt xlong", "multitest long", "multitest xlong" - if content_lower.startswith('mt ') or content_lower.startswith('multitest '): - parts = content_lower.split() - if len(parts) >= 2 and parts[0] in ['mt', 'multitest']: - variant = parts[1] - if variant == 'long': - self.listening_duration = 10.0 - self.logger.info("Multitest command (long) executed - starting 10 second listening window") - elif variant == 'xlong': - self.listening_duration = 14.0 - self.logger.info("Multitest command (xlong) executed - starting 14 second listening window") + # Use lock to prevent concurrent execution from interfering + async with self._get_execution_lock(): + # Check if user already has an active session + if user_id in self._active_sessions: + existing_session = self._active_sessions[user_id] + elapsed = time.time() - existing_session.listening_start_time + if elapsed < existing_session.listening_duration: + # User already has an active session, reject this one + remaining = existing_session.listening_duration - elapsed + response = f"Multitest already in progress. Please wait {remaining:.1f} seconds." + await self.send_response(message, response) + return True + + # Record execution time BEFORE starting async work to prevent race conditions + self.record_execution(user_id) + + # Determine listening duration based on command variant + content = message.content.strip() + if content.startswith('!'): + content = content[1:].strip() + + content_lower = content.lower() + listening_duration = 6.0 # Default + # Check for variants: "mt long", "mt xlong", "multitest long", "multitest xlong" + if content_lower.startswith('mt ') or content_lower.startswith('multitest '): + parts = content_lower.split() + if len(parts) >= 2 and parts[0] in ['mt', 'multitest']: + variant = parts[1] + if variant == 'long': + listening_duration = 10.0 + self.logger.info(f"Multitest command (long) executed by {user_id} - starting 10 second listening window") + elif variant == 'xlong': + listening_duration = 14.0 + self.logger.info(f"Multitest command (xlong) executed by {user_id} - starting 14 second listening window") + else: + self.logger.info(f"Multitest command executed by {user_id} - starting 6 second listening window") else: - self.listening_duration = 6.0 - self.logger.info("Multitest command executed - starting 6 second listening window") + self.logger.info(f"Multitest command executed by {user_id} - starting 6 second listening window") else: - self.listening_duration = 6.0 - self.logger.info("Multitest command executed - starting 6 second listening window") - else: - self.listening_duration = 6.0 - self.logger.info("Multitest command executed - starting 6 second listening window") - - # Get RF data for the triggering message (contains pre-calculated packet hash) - rf_data = self.get_rf_data_for_message(message) - if not rf_data: - response = "Error: Could not find packet data for this message. Please try again." - await self.send_response(message, response) - return True - - # Use pre-calculated packet hash if available, otherwise calculate it - packet_hash = rf_data.get('packet_hash') - if not packet_hash and rf_data.get('raw_hex'): - # Fallback: calculate hash if not stored (for older RF data) - # IMPORTANT: Must use same payload_type that was used during ingestion - payload_type = None - routing_info = rf_data.get('routing_info', {}) - if routing_info: - payload_type = routing_info.get('payload_type') - packet_hash = calculate_packet_hash(rf_data['raw_hex'], payload_type) - - if not packet_hash: - response = "Error: Could not calculate packet hash for this message. Please try again." - await self.send_response(message, response) - return True - - # Store the packet hash to track - self.target_packet_hash = packet_hash - - # Store the timestamp of the triggering message to avoid collecting older packets - triggering_rf_timestamp = rf_data.get('timestamp', time.time()) - self.triggering_timestamp = triggering_rf_timestamp - - self.logger.info(f"Tracking packet hash: {self.target_packet_hash[:16]}... (full: {self.target_packet_hash})") - self.logger.debug(f"Triggering message timestamp: {triggering_rf_timestamp}") - - # Also extract path from the triggering message itself - initial_path = self.extract_path_from_message(message) - # Also try to extract from RF data (more reliable) - if not initial_path and rf_data: - initial_path = self.extract_path_from_rf_data(rf_data) - - if initial_path: - self.logger.debug(f"Initial path from triggering message: {initial_path}") - - # Register this command instance as the active listener - # Store reference in message handler so it can call on_message_received - self.bot.message_handler.multitest_listener = self - - # Start listening - self.listening = True - self.collected_paths = set() - if initial_path: - self.collected_paths.add(initial_path) # Include the initial path - self.listening_start_time = time.time() + self.logger.info(f"Multitest command executed by {user_id} - starting 6 second listening window") + + # Get RF data for the triggering message (contains pre-calculated packet hash) + rf_data = self.get_rf_data_for_message(message) + if not rf_data: + response = "Error: Could not find packet data for this message. Please try again." + await self.send_response(message, response) + return True + + # Use pre-calculated packet hash if available, otherwise calculate it + packet_hash = rf_data.get('packet_hash') + if not packet_hash and rf_data.get('raw_hex'): + # Fallback: calculate hash if not stored (for older RF data) + # IMPORTANT: Must use same payload_type that was used during ingestion + payload_type = None + routing_info = rf_data.get('routing_info', {}) + if routing_info: + payload_type = routing_info.get('payload_type') + packet_hash = calculate_packet_hash(rf_data['raw_hex'], payload_type) + + if not packet_hash: + response = "Error: Could not calculate packet hash for this message. Please try again." + await self.send_response(message, response) + return True + + # Store the timestamp of the triggering message to avoid collecting older packets + triggering_rf_timestamp = rf_data.get('timestamp', time.time()) + + # Also extract path from the triggering message itself + initial_path = self.extract_path_from_message(message) + # Also try to extract from RF data (more reliable) + if not initial_path and rf_data: + initial_path = self.extract_path_from_rf_data(rf_data) + + if initial_path: + self.logger.debug(f"Initial path from triggering message for user {user_id}: {initial_path}") + + # Create a new session for this user + session = MultitestSession( + user_id=user_id, + target_packet_hash=packet_hash, + triggering_timestamp=triggering_rf_timestamp, + listening_start_time=time.time(), + listening_duration=listening_duration, + collected_paths=set(), + initial_path=initial_path + ) + + # Add initial path if available + if initial_path: + session.collected_paths.add(initial_path) + + # Register this session + self._active_sessions[user_id] = session + + # Register this command instance as the active listener (if not already registered) + # Store reference in message handler so it can call on_message_received + if self.bot.message_handler.multitest_listener is None: + self.bot.message_handler.multitest_listener = self + + self.logger.info(f"Tracking packet hash for user {user_id}: {packet_hash[:16]}... (full: {packet_hash})") + self.logger.debug(f"Triggering message timestamp for user {user_id}: {triggering_rf_timestamp}") + # Release lock before async sleep to allow other users to start their sessions # Also scan recent RF data for matching hashes (in case messages haven't been processed yet) # But only include packets that arrived at or after the triggering message - self._scan_recent_rf_data() + self._scan_recent_rf_data(session) try: # Wait for the listening duration - await asyncio.sleep(self.listening_duration) + await asyncio.sleep(session.listening_duration) finally: - # Stop listening and unregister (but keep target_packet_hash for error messages) - self.listening = False - self.bot.message_handler.multitest_listener = None + # Re-acquire lock to clean up session + async with self._get_execution_lock(): + # Remove this session + if user_id in self._active_sessions: + del self._active_sessions[user_id] + + # Unregister listener if no more active sessions + if not self._active_sessions and self.bot.message_handler.multitest_listener == self: + self.bot.message_handler.multitest_listener = None # Do a final scan of RF data in case any matching packets arrived - self._scan_recent_rf_data() + self._scan_recent_rf_data(session) # Store hash for error message before clearing it - tracking_hash = self.target_packet_hash + tracking_hash = session.target_packet_hash # Format the collected paths - if self.collected_paths: + if session.collected_paths: # Sort paths for consistent output - sorted_paths = sorted(self.collected_paths) + sorted_paths = sorted(session.collected_paths) paths_text = "\n".join(sorted_paths) path_count = len(sorted_paths) @@ -429,7 +487,7 @@ class MultitestCommand(BaseCommand): sender=message.sender_id or "Unknown", path_count=path_count, paths=paths_text, - listening_duration=int(self.listening_duration) + listening_duration=int(session.listening_duration) ) except (KeyError, ValueError) as e: # If formatting fails, fall back to default @@ -454,12 +512,9 @@ class MultitestCommand(BaseCommand): f"(hash: {tracking_hash}). " f"Packets may be direct (0 hops) or path extraction failed.") else: - response = (f"No matching packets found during {self.listening_duration}s window. " + response = (f"No matching packets found during {session.listening_duration}s window. " f"Tracking hash: {tracking_hash}. ") - # Clear the hash after we're done with it - self.target_packet_hash = None - # Wait for bot TX rate limiter cooldown to expire before sending # This ensures we respond even if another command put the bot on cooldown await self.bot.bot_tx_rate_limiter.wait_for_tx()