mirror of
https://github.com/agessaman/meshcore-bot.git
synced 2026-04-26 19:05:17 +00:00
feat: Refactor multitest command to support user sessions and improve path tracking
- Introduced a new MultitestSession dataclass to manage active multitest sessions per user, enhancing concurrency handling. - Replaced global variables with session-specific attributes to track listening duration, collected paths, and target packet hashes. - Updated message handling logic to check active sessions and prevent race conditions during multitest execution. - Enhanced path collection and RF data scanning to be user-specific, improving accuracy and logging for each session. - Implemented execution locks to ensure thread safety during multitest command execution.
This commit is contained in:
@@ -6,12 +6,25 @@ Listens for a period of time and collects all unique paths from incoming message
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Set, Optional
|
||||
from typing import Set, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
from .base_command import BaseCommand
|
||||
from ..models import MeshMessage
|
||||
from ..utils import calculate_packet_hash
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultitestSession:
|
||||
"""Represents an active multitest listening session"""
|
||||
user_id: str
|
||||
target_packet_hash: str
|
||||
triggering_timestamp: float
|
||||
listening_start_time: float
|
||||
listening_duration: float
|
||||
collected_paths: Set[str]
|
||||
initial_path: Optional[str] = None
|
||||
|
||||
|
||||
class MultitestCommand(BaseCommand):
|
||||
"""Handles the multitest command - listens for multiple path variations"""
|
||||
|
||||
@@ -29,14 +42,19 @@ class MultitestCommand(BaseCommand):
|
||||
def __init__(self, bot):
|
||||
super().__init__(bot)
|
||||
self.multitest_enabled = self.get_config_value('Multitest_Command', 'enabled', fallback=True, value_type='bool')
|
||||
self.listening = False
|
||||
self.collected_paths: Set[str] = set()
|
||||
self.listening_start_time = 0
|
||||
self.listening_duration = 6.0 # 6 seconds listening window
|
||||
self.target_packet_hash: Optional[str] = None # Hash of the message we're tracking
|
||||
self.triggering_timestamp: float = 0.0 # Timestamp of the triggering message
|
||||
# Track active sessions per user to prevent race conditions
|
||||
# Key: user_id, Value: MultitestSession
|
||||
self._active_sessions: Dict[str, MultitestSession] = {}
|
||||
# Lock to prevent concurrent execution from interfering (lazily initialized)
|
||||
self._execution_lock: Optional[asyncio.Lock] = None
|
||||
self._load_config()
|
||||
|
||||
def _get_execution_lock(self) -> asyncio.Lock:
|
||||
"""Get or create the execution lock (lazy initialization)"""
|
||||
if self._execution_lock is None:
|
||||
self._execution_lock = asyncio.Lock()
|
||||
return self._execution_lock
|
||||
|
||||
def can_execute(self, message: MeshMessage) -> bool:
|
||||
"""Check if this command can be executed with the given message.
|
||||
|
||||
@@ -211,13 +229,11 @@ class MultitestCommand(BaseCommand):
|
||||
return None
|
||||
|
||||
def on_message_received(self, message: MeshMessage):
|
||||
"""Callback method called by message handler when a message is received during listening"""
|
||||
if not self.listening or not self.target_packet_hash:
|
||||
return
|
||||
"""Callback method called by message handler when a message is received during listening.
|
||||
|
||||
# Check if we're still in the listening window
|
||||
elapsed = time.time() - self.listening_start_time
|
||||
if elapsed >= self.listening_duration:
|
||||
Checks all active sessions to see if this message matches any of them.
|
||||
"""
|
||||
if not self._active_sessions:
|
||||
return
|
||||
|
||||
# Get RF data for this message (contains pre-calculated packet hash)
|
||||
@@ -247,34 +263,46 @@ class MultitestCommand(BaseCommand):
|
||||
self.logger.debug(f"Skipping message - could not determine packet hash (sender: {message.sender_id})")
|
||||
return
|
||||
|
||||
# CRITICAL: Only collect paths if this message has the same hash as the target
|
||||
# This ensures we only track variations of the same original message
|
||||
if message_hash == self.target_packet_hash:
|
||||
# Try to extract path from RF data first (more reliable)
|
||||
path = self.extract_path_from_rf_data(rf_data)
|
||||
# Check all active sessions to see if this message matches any of them
|
||||
current_time = time.time()
|
||||
for user_id, session in list(self._active_sessions.items()):
|
||||
# Check if we're still in the listening window for this session
|
||||
elapsed = current_time - session.listening_start_time
|
||||
if elapsed >= session.listening_duration:
|
||||
continue # Session expired, skip it
|
||||
|
||||
# Fallback to message path if RF data extraction failed
|
||||
if not path:
|
||||
path = self.extract_path_from_message(message)
|
||||
|
||||
if path:
|
||||
self.collected_paths.add(path)
|
||||
self.logger.info(f"✓ Collected path: {path} (hash: {message_hash[:8]}...)")
|
||||
else:
|
||||
# Log when we have a matching hash but can't extract path
|
||||
routing_info = rf_data.get('routing_info', {})
|
||||
path_length = routing_info.get('path_length', 0)
|
||||
if path_length == 0:
|
||||
self.logger.debug(f"Matched hash {message_hash[:8]}... but path is direct (0 hops)")
|
||||
# CRITICAL: Only collect paths if this message has the same hash as the target
|
||||
# This ensures we only track variations of the same original message
|
||||
if message_hash == session.target_packet_hash:
|
||||
# Try to extract path from RF data first (more reliable)
|
||||
path = self.extract_path_from_rf_data(rf_data)
|
||||
|
||||
# Fallback to message path if RF data extraction failed
|
||||
if not path:
|
||||
path = self.extract_path_from_message(message)
|
||||
|
||||
if path:
|
||||
session.collected_paths.add(path)
|
||||
self.logger.info(f"✓ Collected path for user {user_id}: {path} (hash: {message_hash[:8]}...)")
|
||||
else:
|
||||
self.logger.debug(f"Matched hash {message_hash[:8]}... but couldn't extract path from routing_info: {routing_info}")
|
||||
else:
|
||||
# Log hash mismatches for debugging (but limit to avoid spam)
|
||||
self.logger.debug(f"✗ Hash mismatch - target: {self.target_packet_hash[:8]}..., received: {message_hash[:8]}... (sender: {message.sender_id})")
|
||||
# Log when we have a matching hash but can't extract path
|
||||
routing_info = rf_data.get('routing_info', {})
|
||||
path_length = routing_info.get('path_length', 0)
|
||||
if path_length == 0:
|
||||
self.logger.debug(f"Matched hash {message_hash[:8]}... but path is direct (0 hops) for user {user_id}")
|
||||
else:
|
||||
self.logger.debug(f"Matched hash {message_hash[:8]}... but couldn't extract path from routing_info: {routing_info} for user {user_id}")
|
||||
else:
|
||||
# Log hash mismatches for debugging (but limit to avoid spam)
|
||||
self.logger.debug(f"✗ Hash mismatch for user {user_id} - target: {session.target_packet_hash[:8]}..., received: {message_hash[:8]}... (sender: {message.sender_id})")
|
||||
|
||||
def _scan_recent_rf_data(self):
|
||||
"""Scan recent RF data for packets with matching hash (for messages that haven't been processed yet)"""
|
||||
if not self.target_packet_hash:
|
||||
def _scan_recent_rf_data(self, session: MultitestSession):
|
||||
"""Scan recent RF data for packets with matching hash (for messages that haven't been processed yet)
|
||||
|
||||
Args:
|
||||
session: The multitest session to scan for
|
||||
"""
|
||||
if not session.target_packet_hash:
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -290,135 +318,165 @@ class MultitestCommand(BaseCommand):
|
||||
|
||||
# Only include RF data from the triggering message timestamp onwards
|
||||
# This prevents collecting packets from earlier messages that happen to have the same hash
|
||||
rf_timestamp = rf_data.get('timestamp', 0)
|
||||
if rf_timestamp >= self.triggering_timestamp and time_diff <= self.listening_duration:
|
||||
if rf_timestamp >= session.triggering_timestamp and time_diff <= session.listening_duration:
|
||||
packet_hash = rf_data.get('packet_hash')
|
||||
|
||||
# CRITICAL: Only process if hash matches exactly and is not None/empty
|
||||
if packet_hash and packet_hash == self.target_packet_hash:
|
||||
if packet_hash and packet_hash == session.target_packet_hash:
|
||||
matching_count += 1
|
||||
# Extract path from this RF data
|
||||
path = self.extract_path_from_rf_data(rf_data)
|
||||
if path:
|
||||
self.collected_paths.add(path)
|
||||
self.logger.info(f"✓ Collected path from RF scan: {path} (hash: {packet_hash[:8]}..., time: {time_diff:.2f}s)")
|
||||
session.collected_paths.add(path)
|
||||
self.logger.info(f"✓ Collected path from RF scan for user {session.user_id}: {path} (hash: {packet_hash[:8]}..., time: {time_diff:.2f}s)")
|
||||
else:
|
||||
self.logger.debug(f"Matched hash {packet_hash[:8]}... in RF scan but couldn't extract path")
|
||||
self.logger.debug(f"Matched hash {packet_hash[:8]}... in RF scan but couldn't extract path for user {session.user_id}")
|
||||
elif packet_hash:
|
||||
mismatching_count += 1
|
||||
# Only log first few mismatches to avoid spam
|
||||
if mismatching_count <= 3:
|
||||
self.logger.debug(f"✗ RF scan hash mismatch - target: {self.target_packet_hash[:8]}..., found: {packet_hash[:8]}... (time: {time_diff:.2f}s)")
|
||||
self.logger.debug(f"✗ RF scan hash mismatch for user {session.user_id} - target: {session.target_packet_hash[:8]}..., found: {packet_hash[:8]}... (time: {time_diff:.2f}s)")
|
||||
|
||||
if matching_count > 0 or mismatching_count > 0:
|
||||
self.logger.debug(f"RF scan complete: {matching_count} matching, {mismatching_count} mismatching packets")
|
||||
self.logger.debug(f"RF scan complete for user {session.user_id}: {matching_count} matching, {mismatching_count} mismatching packets")
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error scanning recent RF data: {e}")
|
||||
self.logger.debug(f"Error scanning recent RF data for user {session.user_id}: {e}")
|
||||
|
||||
async def execute(self, message: MeshMessage) -> bool:
|
||||
"""Execute the multitest command"""
|
||||
# Determine listening duration based on command variant
|
||||
content = message.content.strip()
|
||||
if content.startswith('!'):
|
||||
content = content[1:].strip()
|
||||
user_id = message.sender_id or "unknown"
|
||||
|
||||
content_lower = content.lower()
|
||||
# Check for variants: "mt long", "mt xlong", "multitest long", "multitest xlong"
|
||||
if content_lower.startswith('mt ') or content_lower.startswith('multitest '):
|
||||
parts = content_lower.split()
|
||||
if len(parts) >= 2 and parts[0] in ['mt', 'multitest']:
|
||||
variant = parts[1]
|
||||
if variant == 'long':
|
||||
self.listening_duration = 10.0
|
||||
self.logger.info("Multitest command (long) executed - starting 10 second listening window")
|
||||
elif variant == 'xlong':
|
||||
self.listening_duration = 14.0
|
||||
self.logger.info("Multitest command (xlong) executed - starting 14 second listening window")
|
||||
# Use lock to prevent concurrent execution from interfering
|
||||
async with self._get_execution_lock():
|
||||
# Check if user already has an active session
|
||||
if user_id in self._active_sessions:
|
||||
existing_session = self._active_sessions[user_id]
|
||||
elapsed = time.time() - existing_session.listening_start_time
|
||||
if elapsed < existing_session.listening_duration:
|
||||
# User already has an active session, reject this one
|
||||
remaining = existing_session.listening_duration - elapsed
|
||||
response = f"Multitest already in progress. Please wait {remaining:.1f} seconds."
|
||||
await self.send_response(message, response)
|
||||
return True
|
||||
|
||||
# Record execution time BEFORE starting async work to prevent race conditions
|
||||
self.record_execution(user_id)
|
||||
|
||||
# Determine listening duration based on command variant
|
||||
content = message.content.strip()
|
||||
if content.startswith('!'):
|
||||
content = content[1:].strip()
|
||||
|
||||
content_lower = content.lower()
|
||||
listening_duration = 6.0 # Default
|
||||
# Check for variants: "mt long", "mt xlong", "multitest long", "multitest xlong"
|
||||
if content_lower.startswith('mt ') or content_lower.startswith('multitest '):
|
||||
parts = content_lower.split()
|
||||
if len(parts) >= 2 and parts[0] in ['mt', 'multitest']:
|
||||
variant = parts[1]
|
||||
if variant == 'long':
|
||||
listening_duration = 10.0
|
||||
self.logger.info(f"Multitest command (long) executed by {user_id} - starting 10 second listening window")
|
||||
elif variant == 'xlong':
|
||||
listening_duration = 14.0
|
||||
self.logger.info(f"Multitest command (xlong) executed by {user_id} - starting 14 second listening window")
|
||||
else:
|
||||
self.logger.info(f"Multitest command executed by {user_id} - starting 6 second listening window")
|
||||
else:
|
||||
self.listening_duration = 6.0
|
||||
self.logger.info("Multitest command executed - starting 6 second listening window")
|
||||
self.logger.info(f"Multitest command executed by {user_id} - starting 6 second listening window")
|
||||
else:
|
||||
self.listening_duration = 6.0
|
||||
self.logger.info("Multitest command executed - starting 6 second listening window")
|
||||
else:
|
||||
self.listening_duration = 6.0
|
||||
self.logger.info("Multitest command executed - starting 6 second listening window")
|
||||
|
||||
# Get RF data for the triggering message (contains pre-calculated packet hash)
|
||||
rf_data = self.get_rf_data_for_message(message)
|
||||
if not rf_data:
|
||||
response = "Error: Could not find packet data for this message. Please try again."
|
||||
await self.send_response(message, response)
|
||||
return True
|
||||
|
||||
# Use pre-calculated packet hash if available, otherwise calculate it
|
||||
packet_hash = rf_data.get('packet_hash')
|
||||
if not packet_hash and rf_data.get('raw_hex'):
|
||||
# Fallback: calculate hash if not stored (for older RF data)
|
||||
# IMPORTANT: Must use same payload_type that was used during ingestion
|
||||
payload_type = None
|
||||
routing_info = rf_data.get('routing_info', {})
|
||||
if routing_info:
|
||||
payload_type = routing_info.get('payload_type')
|
||||
packet_hash = calculate_packet_hash(rf_data['raw_hex'], payload_type)
|
||||
|
||||
if not packet_hash:
|
||||
response = "Error: Could not calculate packet hash for this message. Please try again."
|
||||
await self.send_response(message, response)
|
||||
return True
|
||||
|
||||
# Store the packet hash to track
|
||||
self.target_packet_hash = packet_hash
|
||||
|
||||
# Store the timestamp of the triggering message to avoid collecting older packets
|
||||
triggering_rf_timestamp = rf_data.get('timestamp', time.time())
|
||||
self.triggering_timestamp = triggering_rf_timestamp
|
||||
|
||||
self.logger.info(f"Tracking packet hash: {self.target_packet_hash[:16]}... (full: {self.target_packet_hash})")
|
||||
self.logger.debug(f"Triggering message timestamp: {triggering_rf_timestamp}")
|
||||
|
||||
# Also extract path from the triggering message itself
|
||||
initial_path = self.extract_path_from_message(message)
|
||||
# Also try to extract from RF data (more reliable)
|
||||
if not initial_path and rf_data:
|
||||
initial_path = self.extract_path_from_rf_data(rf_data)
|
||||
|
||||
if initial_path:
|
||||
self.logger.debug(f"Initial path from triggering message: {initial_path}")
|
||||
|
||||
# Register this command instance as the active listener
|
||||
# Store reference in message handler so it can call on_message_received
|
||||
self.bot.message_handler.multitest_listener = self
|
||||
|
||||
# Start listening
|
||||
self.listening = True
|
||||
self.collected_paths = set()
|
||||
if initial_path:
|
||||
self.collected_paths.add(initial_path) # Include the initial path
|
||||
self.listening_start_time = time.time()
|
||||
self.logger.info(f"Multitest command executed by {user_id} - starting 6 second listening window")
|
||||
|
||||
# Get RF data for the triggering message (contains pre-calculated packet hash)
|
||||
rf_data = self.get_rf_data_for_message(message)
|
||||
if not rf_data:
|
||||
response = "Error: Could not find packet data for this message. Please try again."
|
||||
await self.send_response(message, response)
|
||||
return True
|
||||
|
||||
# Use pre-calculated packet hash if available, otherwise calculate it
|
||||
packet_hash = rf_data.get('packet_hash')
|
||||
if not packet_hash and rf_data.get('raw_hex'):
|
||||
# Fallback: calculate hash if not stored (for older RF data)
|
||||
# IMPORTANT: Must use same payload_type that was used during ingestion
|
||||
payload_type = None
|
||||
routing_info = rf_data.get('routing_info', {})
|
||||
if routing_info:
|
||||
payload_type = routing_info.get('payload_type')
|
||||
packet_hash = calculate_packet_hash(rf_data['raw_hex'], payload_type)
|
||||
|
||||
if not packet_hash:
|
||||
response = "Error: Could not calculate packet hash for this message. Please try again."
|
||||
await self.send_response(message, response)
|
||||
return True
|
||||
|
||||
# Store the timestamp of the triggering message to avoid collecting older packets
|
||||
triggering_rf_timestamp = rf_data.get('timestamp', time.time())
|
||||
|
||||
# Also extract path from the triggering message itself
|
||||
initial_path = self.extract_path_from_message(message)
|
||||
# Also try to extract from RF data (more reliable)
|
||||
if not initial_path and rf_data:
|
||||
initial_path = self.extract_path_from_rf_data(rf_data)
|
||||
|
||||
if initial_path:
|
||||
self.logger.debug(f"Initial path from triggering message for user {user_id}: {initial_path}")
|
||||
|
||||
# Create a new session for this user
|
||||
session = MultitestSession(
|
||||
user_id=user_id,
|
||||
target_packet_hash=packet_hash,
|
||||
triggering_timestamp=triggering_rf_timestamp,
|
||||
listening_start_time=time.time(),
|
||||
listening_duration=listening_duration,
|
||||
collected_paths=set(),
|
||||
initial_path=initial_path
|
||||
)
|
||||
|
||||
# Add initial path if available
|
||||
if initial_path:
|
||||
session.collected_paths.add(initial_path)
|
||||
|
||||
# Register this session
|
||||
self._active_sessions[user_id] = session
|
||||
|
||||
# Register this command instance as the active listener (if not already registered)
|
||||
# Store reference in message handler so it can call on_message_received
|
||||
if self.bot.message_handler.multitest_listener is None:
|
||||
self.bot.message_handler.multitest_listener = self
|
||||
|
||||
self.logger.info(f"Tracking packet hash for user {user_id}: {packet_hash[:16]}... (full: {packet_hash})")
|
||||
self.logger.debug(f"Triggering message timestamp for user {user_id}: {triggering_rf_timestamp}")
|
||||
|
||||
# Release lock before async sleep to allow other users to start their sessions
|
||||
# Also scan recent RF data for matching hashes (in case messages haven't been processed yet)
|
||||
# But only include packets that arrived at or after the triggering message
|
||||
self._scan_recent_rf_data()
|
||||
self._scan_recent_rf_data(session)
|
||||
|
||||
try:
|
||||
# Wait for the listening duration
|
||||
await asyncio.sleep(self.listening_duration)
|
||||
await asyncio.sleep(session.listening_duration)
|
||||
finally:
|
||||
# Stop listening and unregister (but keep target_packet_hash for error messages)
|
||||
self.listening = False
|
||||
self.bot.message_handler.multitest_listener = None
|
||||
# Re-acquire lock to clean up session
|
||||
async with self._get_execution_lock():
|
||||
# Remove this session
|
||||
if user_id in self._active_sessions:
|
||||
del self._active_sessions[user_id]
|
||||
|
||||
# Unregister listener if no more active sessions
|
||||
if not self._active_sessions and self.bot.message_handler.multitest_listener == self:
|
||||
self.bot.message_handler.multitest_listener = None
|
||||
|
||||
# Do a final scan of RF data in case any matching packets arrived
|
||||
self._scan_recent_rf_data()
|
||||
self._scan_recent_rf_data(session)
|
||||
|
||||
# Store hash for error message before clearing it
|
||||
tracking_hash = self.target_packet_hash
|
||||
tracking_hash = session.target_packet_hash
|
||||
|
||||
# Format the collected paths
|
||||
if self.collected_paths:
|
||||
if session.collected_paths:
|
||||
# Sort paths for consistent output
|
||||
sorted_paths = sorted(self.collected_paths)
|
||||
sorted_paths = sorted(session.collected_paths)
|
||||
paths_text = "\n".join(sorted_paths)
|
||||
path_count = len(sorted_paths)
|
||||
|
||||
@@ -429,7 +487,7 @@ class MultitestCommand(BaseCommand):
|
||||
sender=message.sender_id or "Unknown",
|
||||
path_count=path_count,
|
||||
paths=paths_text,
|
||||
listening_duration=int(self.listening_duration)
|
||||
listening_duration=int(session.listening_duration)
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
# If formatting fails, fall back to default
|
||||
@@ -454,12 +512,9 @@ class MultitestCommand(BaseCommand):
|
||||
f"(hash: {tracking_hash}). "
|
||||
f"Packets may be direct (0 hops) or path extraction failed.")
|
||||
else:
|
||||
response = (f"No matching packets found during {self.listening_duration}s window. "
|
||||
response = (f"No matching packets found during {session.listening_duration}s window. "
|
||||
f"Tracking hash: {tracking_hash}. ")
|
||||
|
||||
# Clear the hash after we're done with it
|
||||
self.target_packet_hash = None
|
||||
|
||||
# Wait for bot TX rate limiter cooldown to expire before sending
|
||||
# This ensures we respond even if another command put the bot on cooldown
|
||||
await self.bot.bot_tx_rate_limiter.wait_for_tx()
|
||||
|
||||
Reference in New Issue
Block a user