Implement LRU caching for SNR and RSSI in MessageHandler and enhance rate limiting with thread safety

- Updated MessageHandler to use OrderedDict for SNR and RSSI caches, implementing LRU eviction to maintain a maximum size.
- Added thread safety to rate limiting classes by introducing locks, ensuring consistent behavior in concurrent environments.
- Introduced periodic cleanup in TransmissionTracker to manage memory usage effectively.
- Added unit tests for LRU cache behavior and automatic cleanup functionality.

These changes improve performance and reliability in handling message data and rate limiting.
This commit is contained in:
agessaman
2026-03-29 14:11:12 -07:00
parent 12c8c0b787
commit ea0e25d746
6 changed files with 204 additions and 65 deletions
+15 -5
View File
@@ -7,6 +7,7 @@ Processes incoming messages and routes them to appropriate command handlers
import asyncio
import copy
import time
from collections import OrderedDict
from typing import Any, Optional
from .enums import AdvertFlags, DeviceRole, PayloadType, PayloadVersion, RouteType
@@ -33,9 +34,9 @@ class MessageHandler:
def __init__(self, bot):
self.bot = bot
self.logger = bot.logger
# Cache for storing SNR and RSSI data from RF log events
self.snr_cache = {}
self.rssi_cache = {}
# Cache for storing SNR and RSSI data from RF log events (bounded LRU)
self.snr_cache: OrderedDict[str, float] = OrderedDict()
self.rssi_cache: OrderedDict[str, float] = OrderedDict()
# Load configuration for RF data correlation
self.rf_data_timeout = float(bot.config.get('Bot', 'rf_data_timeout', fallback='15.0'))
@@ -57,6 +58,9 @@ class MessageHandler:
self._cache_cleanup_interval = 60 # Cleanup every 60 seconds
self._last_cache_cleanup = time.time()
# Maximum entries for SNR/RSSI LRU caches
self._max_signal_cache_size = 1000
# Multitest command listener (for collecting paths during listening window)
self.multitest_listener = None
@@ -655,16 +659,22 @@ class MessageHandler:
self.logger.debug(f"Got pubkey_prefix from metadata: {pubkey_prefix[:16]}...")
if packet_prefix and snr_value is not None:
# Cache the SNR value for this packet prefix
# Cache the SNR value for this packet prefix (LRU-bounded)
self.snr_cache[packet_prefix] = snr_value
self.snr_cache.move_to_end(packet_prefix)
while len(self.snr_cache) > self._max_signal_cache_size:
self.snr_cache.popitem(last=False)
self.logger.debug(f"Cached SNR {snr_value} for packet prefix {packet_prefix}")
# Extract and cache RSSI if available
if 'rssi' in payload:
rssi_value = payload.get('rssi')
if packet_prefix and rssi_value is not None:
# Cache the RSSI value for this packet prefix
# Cache the RSSI value for this packet prefix (LRU-bounded)
self.rssi_cache[packet_prefix] = rssi_value
self.rssi_cache.move_to_end(packet_prefix)
while len(self.rssi_cache) > self._max_signal_cache_size:
self.rssi_cache.popitem(last=False)
self.logger.debug(f"Cached RSSI {rssi_value} for packet prefix {packet_prefix}")
# Store recent RF data with timestamp for SNR/RSSI matching only
+69 -51
View File
@@ -5,6 +5,7 @@ Controls how often messages can be sent to prevent spam
"""
import asyncio
import threading
import time
from collections import OrderedDict
from typing import Optional
@@ -25,6 +26,7 @@ class PerUserRateLimiter:
self._last_send: OrderedDict[str, float] = OrderedDict()
# Back-compat for existing tests/introspection: keep insertion/LRU order.
self._order: list[str] = []
self._lock = threading.Lock()
def _normalize_key(self, key: str) -> str:
return key.strip()
@@ -34,30 +36,33 @@ class PerUserRateLimiter:
key = self._normalize_key(key)
if not key:
return True
last = self._last_send.get(key, 0)
return time.time() - last >= self.seconds
with self._lock:
last = self._last_send.get(key, 0)
return time.monotonic() - last >= self.seconds
def time_until_next(self, key: str) -> float:
"""Get time until next allowed send for this user."""
key = self._normalize_key(key)
if not key:
return 0.0
last = self._last_send.get(key, 0)
elapsed = time.time() - last
return max(0.0, self.seconds - elapsed)
with self._lock:
last = self._last_send.get(key, 0)
elapsed = time.monotonic() - last
return max(0.0, self.seconds - elapsed)
def record_send(self, key: str) -> None:
"""Record that we sent a message to this user."""
key = self._normalize_key(key)
if not key:
return
if key in self._last_send:
self._last_send.move_to_end(key)
elif len(self._last_send) >= self.max_entries:
self._last_send.popitem(last=False)
self._last_send[key] = time.time()
# Keep `_order` consistent for callers/tests.
self._order = list(self._last_send.keys())
with self._lock:
if key in self._last_send:
self._last_send.move_to_end(key)
elif len(self._last_send) >= self.max_entries:
self._last_send.popitem(last=False)
self._last_send[key] = time.monotonic()
# Keep `_order` consistent for callers/tests.
self._order = list(self._last_send.keys())
class RateLimiter:
@@ -65,36 +70,41 @@ class RateLimiter:
def __init__(self, seconds: float):
self.seconds = float(seconds)
self.last_send = 0
self.last_send = 0.0
self._total_sends = 0
self._total_throttled = 0
self._lock = threading.Lock()
def can_send(self) -> bool:
"""Check if we can send a message"""
can = time.time() - self.last_send >= self.seconds
if not can:
self._total_throttled += 1
return can
with self._lock:
can = time.monotonic() - self.last_send >= self.seconds
if not can:
self._total_throttled += 1
return can
def time_until_next(self) -> float:
"""Get time until next allowed send"""
elapsed = time.time() - self.last_send
return max(0, self.seconds - elapsed)
with self._lock:
elapsed = time.monotonic() - self.last_send
return max(0, self.seconds - elapsed)
def record_send(self):
"""Record that we sent a message"""
self.last_send = time.time()
self._total_sends += 1
with self._lock:
self.last_send = time.monotonic()
self._total_sends += 1
def get_stats(self) -> dict:
"""Get rate limiter statistics"""
total_attempts = self._total_sends + self._total_throttled
throttle_rate = self._total_throttled / max(1, total_attempts)
return {
'total_sends': self._total_sends,
'total_throttled': self._total_throttled,
'throttle_rate': throttle_rate
}
with self._lock:
total_attempts = self._total_sends + self._total_throttled
throttle_rate = self._total_throttled / max(1, total_attempts)
return {
'total_sends': self._total_sends,
'total_throttled': self._total_throttled,
'throttle_rate': throttle_rate
}
class BotTxRateLimiter:
@@ -102,26 +112,30 @@ class BotTxRateLimiter:
def __init__(self, seconds: float = 1.0):
self.seconds = seconds
self.last_tx = 0
self.last_tx = 0.0
self._total_tx = 0
self._total_throttled = 0
self._lock = threading.Lock()
def can_tx(self) -> bool:
"""Check if bot can transmit a message"""
can = time.time() - self.last_tx >= self.seconds
if not can:
self._total_throttled += 1
return can
with self._lock:
can = time.monotonic() - self.last_tx >= self.seconds
if not can:
self._total_throttled += 1
return can
def time_until_next_tx(self) -> float:
"""Get time until next allowed transmission"""
elapsed = time.time() - self.last_tx
return max(0, self.seconds - elapsed)
with self._lock:
elapsed = time.monotonic() - self.last_tx
return max(0, self.seconds - elapsed)
def record_tx(self):
"""Record that bot transmitted a message"""
self.last_tx = time.time()
self._total_tx += 1
with self._lock:
self.last_tx = time.monotonic()
self._total_tx += 1
async def wait_for_tx(self):
"""Wait until bot can transmit (async)"""
@@ -132,13 +146,14 @@ class BotTxRateLimiter:
def get_stats(self) -> dict:
"""Get rate limiter statistics"""
total_attempts = self._total_tx + self._total_throttled
throttle_rate = self._total_throttled / max(1, total_attempts)
return {
'total_tx': self._total_tx,
'total_throttled': self._total_throttled,
'throttle_rate': throttle_rate
}
with self._lock:
total_attempts = self._total_tx + self._total_throttled
throttle_rate = self._total_throttled / max(1, total_attempts)
return {
'total_tx': self._total_tx,
'total_throttled': self._total_throttled,
'throttle_rate': throttle_rate
}
class ChannelRateLimiter:
@@ -198,30 +213,33 @@ class NominatimRateLimiter:
self.seconds = seconds
self.last_request: float = 0.0
self._lock: Optional[asyncio.Lock] = None
self._lock_init = threading.Lock() # Guards lazy creation of asyncio.Lock
self._total_requests = 0
self._total_throttled = 0
def _get_lock(self) -> asyncio.Lock:
"""Lazily initialize the async lock"""
"""Lazily initialize the async lock (thread-safe)."""
if self._lock is None:
self._lock = asyncio.Lock()
with self._lock_init:
if self._lock is None:
self._lock = asyncio.Lock()
return self._lock
def can_request(self) -> bool:
"""Check if we can make a Nominatim request"""
can = time.time() - self.last_request >= self.seconds
can = time.monotonic() - self.last_request >= self.seconds
if not can:
self._total_throttled += 1
return can
def time_until_next(self) -> float:
"""Get time until next allowed request"""
elapsed = time.time() - self.last_request
elapsed = time.monotonic() - self.last_request
return max(0, self.seconds - elapsed)
def record_request(self):
"""Record that we made a Nominatim request"""
self.last_request = time.time()
self.last_request = time.monotonic()
self._total_requests += 1
async def wait_for_request(self):
@@ -234,11 +252,11 @@ class NominatimRateLimiter:
async def wait_and_request(self) -> None:
"""Wait until a request can be made, then mark request time (thread-safe)"""
async with self._get_lock():
current_time = time.time()
current_time = time.monotonic()
time_since_last = current_time - self.last_request
if time_since_last < self.seconds:
await asyncio.sleep(self.seconds - time_since_last)
self.last_request = time.time()
self.last_request = time.monotonic()
self._total_requests += 1
def wait_for_request_sync(self):
+12
View File
@@ -46,6 +46,8 @@ class TransmissionTracker:
# Cleanup old records after this time (seconds)
self.cleanup_after = 300 # 5 minutes
self._cleanup_interval = 60 # Run cleanup check every 60 seconds
self._last_cleanup_time = 0.0
# Track our bot's public key prefix (first 2 hex chars) for filtering
self.bot_prefix: Optional[str] = None
@@ -95,6 +97,9 @@ class TransmissionTracker:
self.logger.debug(f"Recorded transmission: {message_type} to {target} at {record.timestamp}")
# Periodically clean up old records to prevent unbounded memory growth
self._maybe_cleanup()
return record
def match_packet_hash(self, packet_hash: str, rf_timestamp: float) -> Optional[TransmissionRecord]:
@@ -325,6 +330,13 @@ class TransmissionTracker:
return [] # No valid prefix found
def _maybe_cleanup(self) -> None:
"""Run cleanup if enough time has passed since the last run."""
now = time.time()
if now - self._last_cleanup_time >= self._cleanup_interval:
self._last_cleanup_time = now
self.cleanup_old_records()
def cleanup_old_records(self):
"""Remove old transmission records that are beyond the cleanup window"""
current_time = time.time()
+51
View File
@@ -1544,3 +1544,54 @@ class TestGetPathFromRfData:
assert nodes is None
except (ValueError, UnboundLocalError):
pass # expected — source-level bug causes exception to propagate
# ---------------------------------------------------------------------------
# SNR/RSSI LRU cache bounds
# ---------------------------------------------------------------------------
class TestSignalCacheLRUBounds:
"""Tests for bounded LRU eviction on snr_cache and rssi_cache."""
def test_snr_cache_evicts_oldest_at_limit(self, handler):
handler._max_signal_cache_size = 3
# Fill cache to capacity
handler.snr_cache["aaa"] = 1.0
handler.snr_cache["bbb"] = 2.0
handler.snr_cache["ccc"] = 3.0
# Simulate the write path with LRU eviction
key = "ddd"
handler.snr_cache[key] = 4.0
handler.snr_cache.move_to_end(key)
while len(handler.snr_cache) > handler._max_signal_cache_size:
handler.snr_cache.popitem(last=False)
# Oldest entry ("aaa") should be evicted
assert "aaa" not in handler.snr_cache
assert len(handler.snr_cache) == 3
assert list(handler.snr_cache.keys()) == ["bbb", "ccc", "ddd"]
def test_rssi_cache_evicts_oldest_at_limit(self, handler):
handler._max_signal_cache_size = 2
handler.rssi_cache["x1"] = -50.0
handler.rssi_cache["x2"] = -60.0
# Add a third entry with eviction
key = "x3"
handler.rssi_cache[key] = -70.0
handler.rssi_cache.move_to_end(key)
while len(handler.rssi_cache) > handler._max_signal_cache_size:
handler.rssi_cache.popitem(last=False)
assert "x1" not in handler.rssi_cache
assert len(handler.rssi_cache) == 2
def test_existing_key_update_does_not_evict(self, handler):
handler._max_signal_cache_size = 2
handler.snr_cache["a"] = 1.0
handler.snr_cache["b"] = 2.0
# Update existing key — no eviction needed
handler.snr_cache["a"] = 5.0
handler.snr_cache.move_to_end("a")
while len(handler.snr_cache) > handler._max_signal_cache_size:
handler.snr_cache.popitem(last=False)
assert len(handler.snr_cache) == 2
assert handler.snr_cache["a"] == 5.0
assert handler.snr_cache["b"] == 2.0
+9 -9
View File
@@ -38,9 +38,9 @@ class TestRateLimiter:
def test_record_send_updates_last_send(self):
limiter = RateLimiter(seconds=10)
before = time.time()
before = time.monotonic()
limiter.record_send()
after = time.time()
after = time.monotonic()
assert before <= limiter.last_send <= after
@@ -251,7 +251,7 @@ class TestBotTxRateLimiterWaitForTx:
# the first evaluation, so the while body (lines 126-128) is entered
# at least once by making the initial state throttled and then
# immediately resolvable.
limiter.last_tx = time.time() - 10 # well past the 5-second window
limiter.last_tx = time.monotonic() - 10 # well past the 5-second window
# Now can_tx() is True so wait_for_tx returns immediately.
asyncio.run(limiter.wait_for_tx())
@@ -270,7 +270,7 @@ class TestBotTxRateLimiterWaitForTx:
# First call: report not ready (exercises loop body).
# We also set last_tx far in the past so time_until_next_tx() == 0
# to avoid any real asyncio.sleep.
limiter.last_tx = time.time() - 200
limiter.last_tx = time.monotonic() - 200
return False
return original_can_tx()
@@ -326,7 +326,7 @@ class TestNominatimRateLimiterWaitForRequest:
call_count[0] += 1
if call_count[0] == 1:
# Backdate so time_until_next() returns 0, avoiding a real sleep.
limiter.last_request = time.time() - 200
limiter.last_request = time.monotonic() - 200
return False
return original()
@@ -342,7 +342,7 @@ class TestNominatimRateLimiterWaitAndRequest:
"""No sleep needed; last_request starts at 0."""
async def _inner():
limiter = NominatimRateLimiter(seconds=1.1)
before = time.time()
before = time.monotonic()
await limiter.wait_and_request()
assert limiter.last_request >= before
assert limiter._total_requests == 1
@@ -365,12 +365,12 @@ class TestNominatimRateLimiterWaitAndRequest:
limiter = NominatimRateLimiter(seconds=0.05)
# Record a request right now so time_since_last < seconds.
limiter.record_request()
before = time.time()
before = time.monotonic()
await limiter.wait_and_request()
# Two requests recorded total (one manual, one via wait_and_request).
assert limiter._total_requests == 2
# At least some time passed (the sleep).
assert time.time() - before >= 0.0 # non-negative; sleep was brief
assert time.monotonic() - before >= 0.0 # non-negative; sleep was brief
asyncio.run(_inner())
@@ -394,7 +394,7 @@ class TestNominatimRateLimiterWaitForRequestSync:
call_count[0] += 1
if call_count[0] == 1:
# Backdate so time_until_next() returns 0, avoiding an actual sleep.
limiter.last_request = time.time() - 200
limiter.last_request = time.monotonic() - 200
return False
return original()
+48
View File
@@ -523,3 +523,51 @@ class TestExtractRepeaterPrefixesParenPath:
result = tracker.extract_repeater_prefixes_from_path("01,7e,ab(2) via ROUTE_TYPE_FLOOD")
assert result == ["ab"]
# ---------------------------------------------------------------------------
# Automatic cleanup via _maybe_cleanup
# ---------------------------------------------------------------------------
class TestMaybeCleanup:
"""Tests for automatic periodic cleanup in TransmissionTracker."""
def test_maybe_cleanup_runs_after_interval(self, tracker):
"""_maybe_cleanup runs cleanup_old_records when interval has elapsed."""
# Record an old transmission manually
old_record = TransmissionRecord(
timestamp=time.time() - 600, # 10 minutes ago (past cleanup_after=300s)
content="old", target="chan", message_type="channel",
)
tracker.pending_transmissions[int(old_record.timestamp)] = [old_record]
# Force the interval to have elapsed
tracker._last_cleanup_time = 0.0
tracker._maybe_cleanup()
# Old record should be cleaned up
assert int(old_record.timestamp) not in tracker.pending_transmissions
def test_maybe_cleanup_skips_within_interval(self, tracker):
"""_maybe_cleanup does NOT run cleanup if interval hasn't elapsed."""
old_record = TransmissionRecord(
timestamp=time.time() - 600,
content="old", target="chan", message_type="channel",
)
tracker.pending_transmissions[int(old_record.timestamp)] = [old_record]
# Set last cleanup to now — interval hasn't elapsed
tracker._last_cleanup_time = time.time()
tracker._maybe_cleanup()
# Old record should still be present (cleanup didn't run)
assert int(old_record.timestamp) in tracker.pending_transmissions
def test_record_transmission_triggers_cleanup(self, tracker):
"""record_transmission calls _maybe_cleanup, cleaning stale records."""
old_record = TransmissionRecord(
timestamp=time.time() - 600,
content="old", target="chan", message_type="channel",
)
old_key = int(old_record.timestamp)
tracker.pending_transmissions[old_key] = [old_record]
tracker._last_cleanup_time = 0.0 # Force cleanup to run
# Recording a new transmission should trigger cleanup
tracker.record_transmission("new msg", "general", "channel")
assert old_key not in tracker.pending_transmissions