diff --git a/modules/message_handler.py b/modules/message_handler.py index b026071..f36527b 100644 --- a/modules/message_handler.py +++ b/modules/message_handler.py @@ -7,6 +7,7 @@ Processes incoming messages and routes them to appropriate command handlers import asyncio import copy import time +from collections import OrderedDict from typing import Any, Optional from .enums import AdvertFlags, DeviceRole, PayloadType, PayloadVersion, RouteType @@ -33,9 +34,9 @@ class MessageHandler: def __init__(self, bot): self.bot = bot self.logger = bot.logger - # Cache for storing SNR and RSSI data from RF log events - self.snr_cache = {} - self.rssi_cache = {} + # Cache for storing SNR and RSSI data from RF log events (bounded LRU) + self.snr_cache: OrderedDict[str, float] = OrderedDict() + self.rssi_cache: OrderedDict[str, float] = OrderedDict() # Load configuration for RF data correlation self.rf_data_timeout = float(bot.config.get('Bot', 'rf_data_timeout', fallback='15.0')) @@ -57,6 +58,9 @@ class MessageHandler: self._cache_cleanup_interval = 60 # Cleanup every 60 seconds self._last_cache_cleanup = time.time() + # Maximum entries for SNR/RSSI LRU caches + self._max_signal_cache_size = 1000 + # Multitest command listener (for collecting paths during listening window) self.multitest_listener = None @@ -655,16 +659,22 @@ class MessageHandler: self.logger.debug(f"Got pubkey_prefix from metadata: {pubkey_prefix[:16]}...") if packet_prefix and snr_value is not None: - # Cache the SNR value for this packet prefix + # Cache the SNR value for this packet prefix (LRU-bounded) self.snr_cache[packet_prefix] = snr_value + self.snr_cache.move_to_end(packet_prefix) + while len(self.snr_cache) > self._max_signal_cache_size: + self.snr_cache.popitem(last=False) self.logger.debug(f"Cached SNR {snr_value} for packet prefix {packet_prefix}") # Extract and cache RSSI if available if 'rssi' in payload: rssi_value = payload.get('rssi') if packet_prefix and rssi_value is not None: - # Cache the RSSI value for this packet prefix + # Cache the RSSI value for this packet prefix (LRU-bounded) self.rssi_cache[packet_prefix] = rssi_value + self.rssi_cache.move_to_end(packet_prefix) + while len(self.rssi_cache) > self._max_signal_cache_size: + self.rssi_cache.popitem(last=False) self.logger.debug(f"Cached RSSI {rssi_value} for packet prefix {packet_prefix}") # Store recent RF data with timestamp for SNR/RSSI matching only diff --git a/modules/rate_limiter.py b/modules/rate_limiter.py index 78fd90a..a9ab342 100644 --- a/modules/rate_limiter.py +++ b/modules/rate_limiter.py @@ -5,6 +5,7 @@ Controls how often messages can be sent to prevent spam """ import asyncio +import threading import time from collections import OrderedDict from typing import Optional @@ -25,6 +26,7 @@ class PerUserRateLimiter: self._last_send: OrderedDict[str, float] = OrderedDict() # Back-compat for existing tests/introspection: keep insertion/LRU order. self._order: list[str] = [] + self._lock = threading.Lock() def _normalize_key(self, key: str) -> str: return key.strip() @@ -34,30 +36,33 @@ class PerUserRateLimiter: key = self._normalize_key(key) if not key: return True - last = self._last_send.get(key, 0) - return time.time() - last >= self.seconds + with self._lock: + last = self._last_send.get(key, 0) + return time.monotonic() - last >= self.seconds def time_until_next(self, key: str) -> float: """Get time until next allowed send for this user.""" key = self._normalize_key(key) if not key: return 0.0 - last = self._last_send.get(key, 0) - elapsed = time.time() - last - return max(0.0, self.seconds - elapsed) + with self._lock: + last = self._last_send.get(key, 0) + elapsed = time.monotonic() - last + return max(0.0, self.seconds - elapsed) def record_send(self, key: str) -> None: """Record that we sent a message to this user.""" key = self._normalize_key(key) if not key: return - if key in self._last_send: - self._last_send.move_to_end(key) - elif len(self._last_send) >= self.max_entries: - self._last_send.popitem(last=False) - self._last_send[key] = time.time() - # Keep `_order` consistent for callers/tests. - self._order = list(self._last_send.keys()) + with self._lock: + if key in self._last_send: + self._last_send.move_to_end(key) + elif len(self._last_send) >= self.max_entries: + self._last_send.popitem(last=False) + self._last_send[key] = time.monotonic() + # Keep `_order` consistent for callers/tests. + self._order = list(self._last_send.keys()) class RateLimiter: @@ -65,36 +70,41 @@ class RateLimiter: def __init__(self, seconds: float): self.seconds = float(seconds) - self.last_send = 0 + self.last_send = 0.0 self._total_sends = 0 self._total_throttled = 0 + self._lock = threading.Lock() def can_send(self) -> bool: """Check if we can send a message""" - can = time.time() - self.last_send >= self.seconds - if not can: - self._total_throttled += 1 - return can + with self._lock: + can = time.monotonic() - self.last_send >= self.seconds + if not can: + self._total_throttled += 1 + return can def time_until_next(self) -> float: """Get time until next allowed send""" - elapsed = time.time() - self.last_send - return max(0, self.seconds - elapsed) + with self._lock: + elapsed = time.monotonic() - self.last_send + return max(0, self.seconds - elapsed) def record_send(self): """Record that we sent a message""" - self.last_send = time.time() - self._total_sends += 1 + with self._lock: + self.last_send = time.monotonic() + self._total_sends += 1 def get_stats(self) -> dict: """Get rate limiter statistics""" - total_attempts = self._total_sends + self._total_throttled - throttle_rate = self._total_throttled / max(1, total_attempts) - return { - 'total_sends': self._total_sends, - 'total_throttled': self._total_throttled, - 'throttle_rate': throttle_rate - } + with self._lock: + total_attempts = self._total_sends + self._total_throttled + throttle_rate = self._total_throttled / max(1, total_attempts) + return { + 'total_sends': self._total_sends, + 'total_throttled': self._total_throttled, + 'throttle_rate': throttle_rate + } class BotTxRateLimiter: @@ -102,26 +112,30 @@ class BotTxRateLimiter: def __init__(self, seconds: float = 1.0): self.seconds = seconds - self.last_tx = 0 + self.last_tx = 0.0 self._total_tx = 0 self._total_throttled = 0 + self._lock = threading.Lock() def can_tx(self) -> bool: """Check if bot can transmit a message""" - can = time.time() - self.last_tx >= self.seconds - if not can: - self._total_throttled += 1 - return can + with self._lock: + can = time.monotonic() - self.last_tx >= self.seconds + if not can: + self._total_throttled += 1 + return can def time_until_next_tx(self) -> float: """Get time until next allowed transmission""" - elapsed = time.time() - self.last_tx - return max(0, self.seconds - elapsed) + with self._lock: + elapsed = time.monotonic() - self.last_tx + return max(0, self.seconds - elapsed) def record_tx(self): """Record that bot transmitted a message""" - self.last_tx = time.time() - self._total_tx += 1 + with self._lock: + self.last_tx = time.monotonic() + self._total_tx += 1 async def wait_for_tx(self): """Wait until bot can transmit (async)""" @@ -132,13 +146,14 @@ class BotTxRateLimiter: def get_stats(self) -> dict: """Get rate limiter statistics""" - total_attempts = self._total_tx + self._total_throttled - throttle_rate = self._total_throttled / max(1, total_attempts) - return { - 'total_tx': self._total_tx, - 'total_throttled': self._total_throttled, - 'throttle_rate': throttle_rate - } + with self._lock: + total_attempts = self._total_tx + self._total_throttled + throttle_rate = self._total_throttled / max(1, total_attempts) + return { + 'total_tx': self._total_tx, + 'total_throttled': self._total_throttled, + 'throttle_rate': throttle_rate + } class ChannelRateLimiter: @@ -198,30 +213,33 @@ class NominatimRateLimiter: self.seconds = seconds self.last_request: float = 0.0 self._lock: Optional[asyncio.Lock] = None + self._lock_init = threading.Lock() # Guards lazy creation of asyncio.Lock self._total_requests = 0 self._total_throttled = 0 def _get_lock(self) -> asyncio.Lock: - """Lazily initialize the async lock""" + """Lazily initialize the async lock (thread-safe).""" if self._lock is None: - self._lock = asyncio.Lock() + with self._lock_init: + if self._lock is None: + self._lock = asyncio.Lock() return self._lock def can_request(self) -> bool: """Check if we can make a Nominatim request""" - can = time.time() - self.last_request >= self.seconds + can = time.monotonic() - self.last_request >= self.seconds if not can: self._total_throttled += 1 return can def time_until_next(self) -> float: """Get time until next allowed request""" - elapsed = time.time() - self.last_request + elapsed = time.monotonic() - self.last_request return max(0, self.seconds - elapsed) def record_request(self): """Record that we made a Nominatim request""" - self.last_request = time.time() + self.last_request = time.monotonic() self._total_requests += 1 async def wait_for_request(self): @@ -234,11 +252,11 @@ class NominatimRateLimiter: async def wait_and_request(self) -> None: """Wait until a request can be made, then mark request time (thread-safe)""" async with self._get_lock(): - current_time = time.time() + current_time = time.monotonic() time_since_last = current_time - self.last_request if time_since_last < self.seconds: await asyncio.sleep(self.seconds - time_since_last) - self.last_request = time.time() + self.last_request = time.monotonic() self._total_requests += 1 def wait_for_request_sync(self): diff --git a/modules/transmission_tracker.py b/modules/transmission_tracker.py index ba9b0ac..0e5b074 100644 --- a/modules/transmission_tracker.py +++ b/modules/transmission_tracker.py @@ -46,6 +46,8 @@ class TransmissionTracker: # Cleanup old records after this time (seconds) self.cleanup_after = 300 # 5 minutes + self._cleanup_interval = 60 # Run cleanup check every 60 seconds + self._last_cleanup_time = 0.0 # Track our bot's public key prefix (first 2 hex chars) for filtering self.bot_prefix: Optional[str] = None @@ -95,6 +97,9 @@ class TransmissionTracker: self.logger.debug(f"Recorded transmission: {message_type} to {target} at {record.timestamp}") + # Periodically clean up old records to prevent unbounded memory growth + self._maybe_cleanup() + return record def match_packet_hash(self, packet_hash: str, rf_timestamp: float) -> Optional[TransmissionRecord]: @@ -325,6 +330,13 @@ class TransmissionTracker: return [] # No valid prefix found + def _maybe_cleanup(self) -> None: + """Run cleanup if enough time has passed since the last run.""" + now = time.time() + if now - self._last_cleanup_time >= self._cleanup_interval: + self._last_cleanup_time = now + self.cleanup_old_records() + def cleanup_old_records(self): """Remove old transmission records that are beyond the cleanup window""" current_time = time.time() diff --git a/tests/test_message_handler.py b/tests/test_message_handler.py index 7727f60..7f17e55 100644 --- a/tests/test_message_handler.py +++ b/tests/test_message_handler.py @@ -1544,3 +1544,54 @@ class TestGetPathFromRfData: assert nodes is None except (ValueError, UnboundLocalError): pass # expected — source-level bug causes exception to propagate + + +# --------------------------------------------------------------------------- +# SNR/RSSI LRU cache bounds +# --------------------------------------------------------------------------- + +class TestSignalCacheLRUBounds: + """Tests for bounded LRU eviction on snr_cache and rssi_cache.""" + + def test_snr_cache_evicts_oldest_at_limit(self, handler): + handler._max_signal_cache_size = 3 + # Fill cache to capacity + handler.snr_cache["aaa"] = 1.0 + handler.snr_cache["bbb"] = 2.0 + handler.snr_cache["ccc"] = 3.0 + # Simulate the write path with LRU eviction + key = "ddd" + handler.snr_cache[key] = 4.0 + handler.snr_cache.move_to_end(key) + while len(handler.snr_cache) > handler._max_signal_cache_size: + handler.snr_cache.popitem(last=False) + # Oldest entry ("aaa") should be evicted + assert "aaa" not in handler.snr_cache + assert len(handler.snr_cache) == 3 + assert list(handler.snr_cache.keys()) == ["bbb", "ccc", "ddd"] + + def test_rssi_cache_evicts_oldest_at_limit(self, handler): + handler._max_signal_cache_size = 2 + handler.rssi_cache["x1"] = -50.0 + handler.rssi_cache["x2"] = -60.0 + # Add a third entry with eviction + key = "x3" + handler.rssi_cache[key] = -70.0 + handler.rssi_cache.move_to_end(key) + while len(handler.rssi_cache) > handler._max_signal_cache_size: + handler.rssi_cache.popitem(last=False) + assert "x1" not in handler.rssi_cache + assert len(handler.rssi_cache) == 2 + + def test_existing_key_update_does_not_evict(self, handler): + handler._max_signal_cache_size = 2 + handler.snr_cache["a"] = 1.0 + handler.snr_cache["b"] = 2.0 + # Update existing key — no eviction needed + handler.snr_cache["a"] = 5.0 + handler.snr_cache.move_to_end("a") + while len(handler.snr_cache) > handler._max_signal_cache_size: + handler.snr_cache.popitem(last=False) + assert len(handler.snr_cache) == 2 + assert handler.snr_cache["a"] == 5.0 + assert handler.snr_cache["b"] == 2.0 diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py index 1a0c264..1bad475 100644 --- a/tests/test_rate_limiter.py +++ b/tests/test_rate_limiter.py @@ -38,9 +38,9 @@ class TestRateLimiter: def test_record_send_updates_last_send(self): limiter = RateLimiter(seconds=10) - before = time.time() + before = time.monotonic() limiter.record_send() - after = time.time() + after = time.monotonic() assert before <= limiter.last_send <= after @@ -251,7 +251,7 @@ class TestBotTxRateLimiterWaitForTx: # the first evaluation, so the while body (lines 126-128) is entered # at least once by making the initial state throttled and then # immediately resolvable. - limiter.last_tx = time.time() - 10 # well past the 5-second window + limiter.last_tx = time.monotonic() - 10 # well past the 5-second window # Now can_tx() is True so wait_for_tx returns immediately. asyncio.run(limiter.wait_for_tx()) @@ -270,7 +270,7 @@ class TestBotTxRateLimiterWaitForTx: # First call: report not ready (exercises loop body). # We also set last_tx far in the past so time_until_next_tx() == 0 # to avoid any real asyncio.sleep. - limiter.last_tx = time.time() - 200 + limiter.last_tx = time.monotonic() - 200 return False return original_can_tx() @@ -326,7 +326,7 @@ class TestNominatimRateLimiterWaitForRequest: call_count[0] += 1 if call_count[0] == 1: # Backdate so time_until_next() returns 0, avoiding a real sleep. - limiter.last_request = time.time() - 200 + limiter.last_request = time.monotonic() - 200 return False return original() @@ -342,7 +342,7 @@ class TestNominatimRateLimiterWaitAndRequest: """No sleep needed; last_request starts at 0.""" async def _inner(): limiter = NominatimRateLimiter(seconds=1.1) - before = time.time() + before = time.monotonic() await limiter.wait_and_request() assert limiter.last_request >= before assert limiter._total_requests == 1 @@ -365,12 +365,12 @@ class TestNominatimRateLimiterWaitAndRequest: limiter = NominatimRateLimiter(seconds=0.05) # Record a request right now so time_since_last < seconds. limiter.record_request() - before = time.time() + before = time.monotonic() await limiter.wait_and_request() # Two requests recorded total (one manual, one via wait_and_request). assert limiter._total_requests == 2 # At least some time passed (the sleep). - assert time.time() - before >= 0.0 # non-negative; sleep was brief + assert time.monotonic() - before >= 0.0 # non-negative; sleep was brief asyncio.run(_inner()) @@ -394,7 +394,7 @@ class TestNominatimRateLimiterWaitForRequestSync: call_count[0] += 1 if call_count[0] == 1: # Backdate so time_until_next() returns 0, avoiding an actual sleep. - limiter.last_request = time.time() - 200 + limiter.last_request = time.monotonic() - 200 return False return original() diff --git a/tests/test_transmission_tracker.py b/tests/test_transmission_tracker.py index d0a9f94..ae34bd9 100644 --- a/tests/test_transmission_tracker.py +++ b/tests/test_transmission_tracker.py @@ -523,3 +523,51 @@ class TestExtractRepeaterPrefixesParenPath: result = tracker.extract_repeater_prefixes_from_path("01,7e,ab(2) via ROUTE_TYPE_FLOOD") assert result == ["ab"] + + +# --------------------------------------------------------------------------- +# Automatic cleanup via _maybe_cleanup +# --------------------------------------------------------------------------- + +class TestMaybeCleanup: + """Tests for automatic periodic cleanup in TransmissionTracker.""" + + def test_maybe_cleanup_runs_after_interval(self, tracker): + """_maybe_cleanup runs cleanup_old_records when interval has elapsed.""" + # Record an old transmission manually + old_record = TransmissionRecord( + timestamp=time.time() - 600, # 10 minutes ago (past cleanup_after=300s) + content="old", target="chan", message_type="channel", + ) + tracker.pending_transmissions[int(old_record.timestamp)] = [old_record] + # Force the interval to have elapsed + tracker._last_cleanup_time = 0.0 + tracker._maybe_cleanup() + # Old record should be cleaned up + assert int(old_record.timestamp) not in tracker.pending_transmissions + + def test_maybe_cleanup_skips_within_interval(self, tracker): + """_maybe_cleanup does NOT run cleanup if interval hasn't elapsed.""" + old_record = TransmissionRecord( + timestamp=time.time() - 600, + content="old", target="chan", message_type="channel", + ) + tracker.pending_transmissions[int(old_record.timestamp)] = [old_record] + # Set last cleanup to now — interval hasn't elapsed + tracker._last_cleanup_time = time.time() + tracker._maybe_cleanup() + # Old record should still be present (cleanup didn't run) + assert int(old_record.timestamp) in tracker.pending_transmissions + + def test_record_transmission_triggers_cleanup(self, tracker): + """record_transmission calls _maybe_cleanup, cleaning stale records.""" + old_record = TransmissionRecord( + timestamp=time.time() - 600, + content="old", target="chan", message_type="channel", + ) + old_key = int(old_record.timestamp) + tracker.pending_transmissions[old_key] = [old_record] + tracker._last_cleanup_time = 0.0 # Force cleanup to run + # Recording a new transmission should trigger cleanup + tracker.record_transmission("new msg", "general", "channel") + assert old_key not in tracker.pending_transmissions