mirror of
https://github.com/agessaman/meshcore-bot.git
synced 2026-05-19 13:55:25 +00:00
Implement LRU caching for SNR and RSSI in MessageHandler and enhance rate limiting with thread safety
- Updated MessageHandler to use OrderedDict for SNR and RSSI caches, implementing LRU eviction to maintain a maximum size. - Added thread safety to rate limiting classes by introducing locks, ensuring consistent behavior in concurrent environments. - Introduced periodic cleanup in TransmissionTracker to manage memory usage effectively. - Added unit tests for LRU cache behavior and automatic cleanup functionality. These changes improve performance and reliability in handling message data and rate limiting.
This commit is contained in:
@@ -7,6 +7,7 @@ Processes incoming messages and routes them to appropriate command handlers
|
||||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
from .enums import AdvertFlags, DeviceRole, PayloadType, PayloadVersion, RouteType
|
||||
@@ -33,9 +34,9 @@ class MessageHandler:
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.logger = bot.logger
|
||||
# Cache for storing SNR and RSSI data from RF log events
|
||||
self.snr_cache = {}
|
||||
self.rssi_cache = {}
|
||||
# Cache for storing SNR and RSSI data from RF log events (bounded LRU)
|
||||
self.snr_cache: OrderedDict[str, float] = OrderedDict()
|
||||
self.rssi_cache: OrderedDict[str, float] = OrderedDict()
|
||||
|
||||
# Load configuration for RF data correlation
|
||||
self.rf_data_timeout = float(bot.config.get('Bot', 'rf_data_timeout', fallback='15.0'))
|
||||
@@ -57,6 +58,9 @@ class MessageHandler:
|
||||
self._cache_cleanup_interval = 60 # Cleanup every 60 seconds
|
||||
self._last_cache_cleanup = time.time()
|
||||
|
||||
# Maximum entries for SNR/RSSI LRU caches
|
||||
self._max_signal_cache_size = 1000
|
||||
|
||||
# Multitest command listener (for collecting paths during listening window)
|
||||
self.multitest_listener = None
|
||||
|
||||
@@ -655,16 +659,22 @@ class MessageHandler:
|
||||
self.logger.debug(f"Got pubkey_prefix from metadata: {pubkey_prefix[:16]}...")
|
||||
|
||||
if packet_prefix and snr_value is not None:
|
||||
# Cache the SNR value for this packet prefix
|
||||
# Cache the SNR value for this packet prefix (LRU-bounded)
|
||||
self.snr_cache[packet_prefix] = snr_value
|
||||
self.snr_cache.move_to_end(packet_prefix)
|
||||
while len(self.snr_cache) > self._max_signal_cache_size:
|
||||
self.snr_cache.popitem(last=False)
|
||||
self.logger.debug(f"Cached SNR {snr_value} for packet prefix {packet_prefix}")
|
||||
|
||||
# Extract and cache RSSI if available
|
||||
if 'rssi' in payload:
|
||||
rssi_value = payload.get('rssi')
|
||||
if packet_prefix and rssi_value is not None:
|
||||
# Cache the RSSI value for this packet prefix
|
||||
# Cache the RSSI value for this packet prefix (LRU-bounded)
|
||||
self.rssi_cache[packet_prefix] = rssi_value
|
||||
self.rssi_cache.move_to_end(packet_prefix)
|
||||
while len(self.rssi_cache) > self._max_signal_cache_size:
|
||||
self.rssi_cache.popitem(last=False)
|
||||
self.logger.debug(f"Cached RSSI {rssi_value} for packet prefix {packet_prefix}")
|
||||
|
||||
# Store recent RF data with timestamp for SNR/RSSI matching only
|
||||
|
||||
+69
-51
@@ -5,6 +5,7 @@ Controls how often messages can be sent to prevent spam
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
@@ -25,6 +26,7 @@ class PerUserRateLimiter:
|
||||
self._last_send: OrderedDict[str, float] = OrderedDict()
|
||||
# Back-compat for existing tests/introspection: keep insertion/LRU order.
|
||||
self._order: list[str] = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _normalize_key(self, key: str) -> str:
|
||||
return key.strip()
|
||||
@@ -34,30 +36,33 @@ class PerUserRateLimiter:
|
||||
key = self._normalize_key(key)
|
||||
if not key:
|
||||
return True
|
||||
last = self._last_send.get(key, 0)
|
||||
return time.time() - last >= self.seconds
|
||||
with self._lock:
|
||||
last = self._last_send.get(key, 0)
|
||||
return time.monotonic() - last >= self.seconds
|
||||
|
||||
def time_until_next(self, key: str) -> float:
|
||||
"""Get time until next allowed send for this user."""
|
||||
key = self._normalize_key(key)
|
||||
if not key:
|
||||
return 0.0
|
||||
last = self._last_send.get(key, 0)
|
||||
elapsed = time.time() - last
|
||||
return max(0.0, self.seconds - elapsed)
|
||||
with self._lock:
|
||||
last = self._last_send.get(key, 0)
|
||||
elapsed = time.monotonic() - last
|
||||
return max(0.0, self.seconds - elapsed)
|
||||
|
||||
def record_send(self, key: str) -> None:
|
||||
"""Record that we sent a message to this user."""
|
||||
key = self._normalize_key(key)
|
||||
if not key:
|
||||
return
|
||||
if key in self._last_send:
|
||||
self._last_send.move_to_end(key)
|
||||
elif len(self._last_send) >= self.max_entries:
|
||||
self._last_send.popitem(last=False)
|
||||
self._last_send[key] = time.time()
|
||||
# Keep `_order` consistent for callers/tests.
|
||||
self._order = list(self._last_send.keys())
|
||||
with self._lock:
|
||||
if key in self._last_send:
|
||||
self._last_send.move_to_end(key)
|
||||
elif len(self._last_send) >= self.max_entries:
|
||||
self._last_send.popitem(last=False)
|
||||
self._last_send[key] = time.monotonic()
|
||||
# Keep `_order` consistent for callers/tests.
|
||||
self._order = list(self._last_send.keys())
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
@@ -65,36 +70,41 @@ class RateLimiter:
|
||||
|
||||
def __init__(self, seconds: float):
|
||||
self.seconds = float(seconds)
|
||||
self.last_send = 0
|
||||
self.last_send = 0.0
|
||||
self._total_sends = 0
|
||||
self._total_throttled = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def can_send(self) -> bool:
|
||||
"""Check if we can send a message"""
|
||||
can = time.time() - self.last_send >= self.seconds
|
||||
if not can:
|
||||
self._total_throttled += 1
|
||||
return can
|
||||
with self._lock:
|
||||
can = time.monotonic() - self.last_send >= self.seconds
|
||||
if not can:
|
||||
self._total_throttled += 1
|
||||
return can
|
||||
|
||||
def time_until_next(self) -> float:
|
||||
"""Get time until next allowed send"""
|
||||
elapsed = time.time() - self.last_send
|
||||
return max(0, self.seconds - elapsed)
|
||||
with self._lock:
|
||||
elapsed = time.monotonic() - self.last_send
|
||||
return max(0, self.seconds - elapsed)
|
||||
|
||||
def record_send(self):
|
||||
"""Record that we sent a message"""
|
||||
self.last_send = time.time()
|
||||
self._total_sends += 1
|
||||
with self._lock:
|
||||
self.last_send = time.monotonic()
|
||||
self._total_sends += 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get rate limiter statistics"""
|
||||
total_attempts = self._total_sends + self._total_throttled
|
||||
throttle_rate = self._total_throttled / max(1, total_attempts)
|
||||
return {
|
||||
'total_sends': self._total_sends,
|
||||
'total_throttled': self._total_throttled,
|
||||
'throttle_rate': throttle_rate
|
||||
}
|
||||
with self._lock:
|
||||
total_attempts = self._total_sends + self._total_throttled
|
||||
throttle_rate = self._total_throttled / max(1, total_attempts)
|
||||
return {
|
||||
'total_sends': self._total_sends,
|
||||
'total_throttled': self._total_throttled,
|
||||
'throttle_rate': throttle_rate
|
||||
}
|
||||
|
||||
|
||||
class BotTxRateLimiter:
|
||||
@@ -102,26 +112,30 @@ class BotTxRateLimiter:
|
||||
|
||||
def __init__(self, seconds: float = 1.0):
|
||||
self.seconds = seconds
|
||||
self.last_tx = 0
|
||||
self.last_tx = 0.0
|
||||
self._total_tx = 0
|
||||
self._total_throttled = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def can_tx(self) -> bool:
|
||||
"""Check if bot can transmit a message"""
|
||||
can = time.time() - self.last_tx >= self.seconds
|
||||
if not can:
|
||||
self._total_throttled += 1
|
||||
return can
|
||||
with self._lock:
|
||||
can = time.monotonic() - self.last_tx >= self.seconds
|
||||
if not can:
|
||||
self._total_throttled += 1
|
||||
return can
|
||||
|
||||
def time_until_next_tx(self) -> float:
|
||||
"""Get time until next allowed transmission"""
|
||||
elapsed = time.time() - self.last_tx
|
||||
return max(0, self.seconds - elapsed)
|
||||
with self._lock:
|
||||
elapsed = time.monotonic() - self.last_tx
|
||||
return max(0, self.seconds - elapsed)
|
||||
|
||||
def record_tx(self):
|
||||
"""Record that bot transmitted a message"""
|
||||
self.last_tx = time.time()
|
||||
self._total_tx += 1
|
||||
with self._lock:
|
||||
self.last_tx = time.monotonic()
|
||||
self._total_tx += 1
|
||||
|
||||
async def wait_for_tx(self):
|
||||
"""Wait until bot can transmit (async)"""
|
||||
@@ -132,13 +146,14 @@ class BotTxRateLimiter:
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get rate limiter statistics"""
|
||||
total_attempts = self._total_tx + self._total_throttled
|
||||
throttle_rate = self._total_throttled / max(1, total_attempts)
|
||||
return {
|
||||
'total_tx': self._total_tx,
|
||||
'total_throttled': self._total_throttled,
|
||||
'throttle_rate': throttle_rate
|
||||
}
|
||||
with self._lock:
|
||||
total_attempts = self._total_tx + self._total_throttled
|
||||
throttle_rate = self._total_throttled / max(1, total_attempts)
|
||||
return {
|
||||
'total_tx': self._total_tx,
|
||||
'total_throttled': self._total_throttled,
|
||||
'throttle_rate': throttle_rate
|
||||
}
|
||||
|
||||
|
||||
class ChannelRateLimiter:
|
||||
@@ -198,30 +213,33 @@ class NominatimRateLimiter:
|
||||
self.seconds = seconds
|
||||
self.last_request: float = 0.0
|
||||
self._lock: Optional[asyncio.Lock] = None
|
||||
self._lock_init = threading.Lock() # Guards lazy creation of asyncio.Lock
|
||||
self._total_requests = 0
|
||||
self._total_throttled = 0
|
||||
|
||||
def _get_lock(self) -> asyncio.Lock:
|
||||
"""Lazily initialize the async lock"""
|
||||
"""Lazily initialize the async lock (thread-safe)."""
|
||||
if self._lock is None:
|
||||
self._lock = asyncio.Lock()
|
||||
with self._lock_init:
|
||||
if self._lock is None:
|
||||
self._lock = asyncio.Lock()
|
||||
return self._lock
|
||||
|
||||
def can_request(self) -> bool:
|
||||
"""Check if we can make a Nominatim request"""
|
||||
can = time.time() - self.last_request >= self.seconds
|
||||
can = time.monotonic() - self.last_request >= self.seconds
|
||||
if not can:
|
||||
self._total_throttled += 1
|
||||
return can
|
||||
|
||||
def time_until_next(self) -> float:
|
||||
"""Get time until next allowed request"""
|
||||
elapsed = time.time() - self.last_request
|
||||
elapsed = time.monotonic() - self.last_request
|
||||
return max(0, self.seconds - elapsed)
|
||||
|
||||
def record_request(self):
|
||||
"""Record that we made a Nominatim request"""
|
||||
self.last_request = time.time()
|
||||
self.last_request = time.monotonic()
|
||||
self._total_requests += 1
|
||||
|
||||
async def wait_for_request(self):
|
||||
@@ -234,11 +252,11 @@ class NominatimRateLimiter:
|
||||
async def wait_and_request(self) -> None:
|
||||
"""Wait until a request can be made, then mark request time (thread-safe)"""
|
||||
async with self._get_lock():
|
||||
current_time = time.time()
|
||||
current_time = time.monotonic()
|
||||
time_since_last = current_time - self.last_request
|
||||
if time_since_last < self.seconds:
|
||||
await asyncio.sleep(self.seconds - time_since_last)
|
||||
self.last_request = time.time()
|
||||
self.last_request = time.monotonic()
|
||||
self._total_requests += 1
|
||||
|
||||
def wait_for_request_sync(self):
|
||||
|
||||
@@ -46,6 +46,8 @@ class TransmissionTracker:
|
||||
|
||||
# Cleanup old records after this time (seconds)
|
||||
self.cleanup_after = 300 # 5 minutes
|
||||
self._cleanup_interval = 60 # Run cleanup check every 60 seconds
|
||||
self._last_cleanup_time = 0.0
|
||||
|
||||
# Track our bot's public key prefix (first 2 hex chars) for filtering
|
||||
self.bot_prefix: Optional[str] = None
|
||||
@@ -95,6 +97,9 @@ class TransmissionTracker:
|
||||
|
||||
self.logger.debug(f"Recorded transmission: {message_type} to {target} at {record.timestamp}")
|
||||
|
||||
# Periodically clean up old records to prevent unbounded memory growth
|
||||
self._maybe_cleanup()
|
||||
|
||||
return record
|
||||
|
||||
def match_packet_hash(self, packet_hash: str, rf_timestamp: float) -> Optional[TransmissionRecord]:
|
||||
@@ -325,6 +330,13 @@ class TransmissionTracker:
|
||||
|
||||
return [] # No valid prefix found
|
||||
|
||||
def _maybe_cleanup(self) -> None:
|
||||
"""Run cleanup if enough time has passed since the last run."""
|
||||
now = time.time()
|
||||
if now - self._last_cleanup_time >= self._cleanup_interval:
|
||||
self._last_cleanup_time = now
|
||||
self.cleanup_old_records()
|
||||
|
||||
def cleanup_old_records(self):
|
||||
"""Remove old transmission records that are beyond the cleanup window"""
|
||||
current_time = time.time()
|
||||
|
||||
@@ -1544,3 +1544,54 @@ class TestGetPathFromRfData:
|
||||
assert nodes is None
|
||||
except (ValueError, UnboundLocalError):
|
||||
pass # expected — source-level bug causes exception to propagate
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SNR/RSSI LRU cache bounds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalCacheLRUBounds:
|
||||
"""Tests for bounded LRU eviction on snr_cache and rssi_cache."""
|
||||
|
||||
def test_snr_cache_evicts_oldest_at_limit(self, handler):
|
||||
handler._max_signal_cache_size = 3
|
||||
# Fill cache to capacity
|
||||
handler.snr_cache["aaa"] = 1.0
|
||||
handler.snr_cache["bbb"] = 2.0
|
||||
handler.snr_cache["ccc"] = 3.0
|
||||
# Simulate the write path with LRU eviction
|
||||
key = "ddd"
|
||||
handler.snr_cache[key] = 4.0
|
||||
handler.snr_cache.move_to_end(key)
|
||||
while len(handler.snr_cache) > handler._max_signal_cache_size:
|
||||
handler.snr_cache.popitem(last=False)
|
||||
# Oldest entry ("aaa") should be evicted
|
||||
assert "aaa" not in handler.snr_cache
|
||||
assert len(handler.snr_cache) == 3
|
||||
assert list(handler.snr_cache.keys()) == ["bbb", "ccc", "ddd"]
|
||||
|
||||
def test_rssi_cache_evicts_oldest_at_limit(self, handler):
|
||||
handler._max_signal_cache_size = 2
|
||||
handler.rssi_cache["x1"] = -50.0
|
||||
handler.rssi_cache["x2"] = -60.0
|
||||
# Add a third entry with eviction
|
||||
key = "x3"
|
||||
handler.rssi_cache[key] = -70.0
|
||||
handler.rssi_cache.move_to_end(key)
|
||||
while len(handler.rssi_cache) > handler._max_signal_cache_size:
|
||||
handler.rssi_cache.popitem(last=False)
|
||||
assert "x1" not in handler.rssi_cache
|
||||
assert len(handler.rssi_cache) == 2
|
||||
|
||||
def test_existing_key_update_does_not_evict(self, handler):
|
||||
handler._max_signal_cache_size = 2
|
||||
handler.snr_cache["a"] = 1.0
|
||||
handler.snr_cache["b"] = 2.0
|
||||
# Update existing key — no eviction needed
|
||||
handler.snr_cache["a"] = 5.0
|
||||
handler.snr_cache.move_to_end("a")
|
||||
while len(handler.snr_cache) > handler._max_signal_cache_size:
|
||||
handler.snr_cache.popitem(last=False)
|
||||
assert len(handler.snr_cache) == 2
|
||||
assert handler.snr_cache["a"] == 5.0
|
||||
assert handler.snr_cache["b"] == 2.0
|
||||
|
||||
@@ -38,9 +38,9 @@ class TestRateLimiter:
|
||||
|
||||
def test_record_send_updates_last_send(self):
|
||||
limiter = RateLimiter(seconds=10)
|
||||
before = time.time()
|
||||
before = time.monotonic()
|
||||
limiter.record_send()
|
||||
after = time.time()
|
||||
after = time.monotonic()
|
||||
assert before <= limiter.last_send <= after
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ class TestBotTxRateLimiterWaitForTx:
|
||||
# the first evaluation, so the while body (lines 126-128) is entered
|
||||
# at least once by making the initial state throttled and then
|
||||
# immediately resolvable.
|
||||
limiter.last_tx = time.time() - 10 # well past the 5-second window
|
||||
limiter.last_tx = time.monotonic() - 10 # well past the 5-second window
|
||||
# Now can_tx() is True so wait_for_tx returns immediately.
|
||||
asyncio.run(limiter.wait_for_tx())
|
||||
|
||||
@@ -270,7 +270,7 @@ class TestBotTxRateLimiterWaitForTx:
|
||||
# First call: report not ready (exercises loop body).
|
||||
# We also set last_tx far in the past so time_until_next_tx() == 0
|
||||
# to avoid any real asyncio.sleep.
|
||||
limiter.last_tx = time.time() - 200
|
||||
limiter.last_tx = time.monotonic() - 200
|
||||
return False
|
||||
return original_can_tx()
|
||||
|
||||
@@ -326,7 +326,7 @@ class TestNominatimRateLimiterWaitForRequest:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
# Backdate so time_until_next() returns 0, avoiding a real sleep.
|
||||
limiter.last_request = time.time() - 200
|
||||
limiter.last_request = time.monotonic() - 200
|
||||
return False
|
||||
return original()
|
||||
|
||||
@@ -342,7 +342,7 @@ class TestNominatimRateLimiterWaitAndRequest:
|
||||
"""No sleep needed; last_request starts at 0."""
|
||||
async def _inner():
|
||||
limiter = NominatimRateLimiter(seconds=1.1)
|
||||
before = time.time()
|
||||
before = time.monotonic()
|
||||
await limiter.wait_and_request()
|
||||
assert limiter.last_request >= before
|
||||
assert limiter._total_requests == 1
|
||||
@@ -365,12 +365,12 @@ class TestNominatimRateLimiterWaitAndRequest:
|
||||
limiter = NominatimRateLimiter(seconds=0.05)
|
||||
# Record a request right now so time_since_last < seconds.
|
||||
limiter.record_request()
|
||||
before = time.time()
|
||||
before = time.monotonic()
|
||||
await limiter.wait_and_request()
|
||||
# Two requests recorded total (one manual, one via wait_and_request).
|
||||
assert limiter._total_requests == 2
|
||||
# At least some time passed (the sleep).
|
||||
assert time.time() - before >= 0.0 # non-negative; sleep was brief
|
||||
assert time.monotonic() - before >= 0.0 # non-negative; sleep was brief
|
||||
|
||||
asyncio.run(_inner())
|
||||
|
||||
@@ -394,7 +394,7 @@ class TestNominatimRateLimiterWaitForRequestSync:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
# Backdate so time_until_next() returns 0, avoiding an actual sleep.
|
||||
limiter.last_request = time.time() - 200
|
||||
limiter.last_request = time.monotonic() - 200
|
||||
return False
|
||||
return original()
|
||||
|
||||
|
||||
@@ -523,3 +523,51 @@ class TestExtractRepeaterPrefixesParenPath:
|
||||
|
||||
result = tracker.extract_repeater_prefixes_from_path("01,7e,ab(2) via ROUTE_TYPE_FLOOD")
|
||||
assert result == ["ab"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Automatic cleanup via _maybe_cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMaybeCleanup:
|
||||
"""Tests for automatic periodic cleanup in TransmissionTracker."""
|
||||
|
||||
def test_maybe_cleanup_runs_after_interval(self, tracker):
|
||||
"""_maybe_cleanup runs cleanup_old_records when interval has elapsed."""
|
||||
# Record an old transmission manually
|
||||
old_record = TransmissionRecord(
|
||||
timestamp=time.time() - 600, # 10 minutes ago (past cleanup_after=300s)
|
||||
content="old", target="chan", message_type="channel",
|
||||
)
|
||||
tracker.pending_transmissions[int(old_record.timestamp)] = [old_record]
|
||||
# Force the interval to have elapsed
|
||||
tracker._last_cleanup_time = 0.0
|
||||
tracker._maybe_cleanup()
|
||||
# Old record should be cleaned up
|
||||
assert int(old_record.timestamp) not in tracker.pending_transmissions
|
||||
|
||||
def test_maybe_cleanup_skips_within_interval(self, tracker):
|
||||
"""_maybe_cleanup does NOT run cleanup if interval hasn't elapsed."""
|
||||
old_record = TransmissionRecord(
|
||||
timestamp=time.time() - 600,
|
||||
content="old", target="chan", message_type="channel",
|
||||
)
|
||||
tracker.pending_transmissions[int(old_record.timestamp)] = [old_record]
|
||||
# Set last cleanup to now — interval hasn't elapsed
|
||||
tracker._last_cleanup_time = time.time()
|
||||
tracker._maybe_cleanup()
|
||||
# Old record should still be present (cleanup didn't run)
|
||||
assert int(old_record.timestamp) in tracker.pending_transmissions
|
||||
|
||||
def test_record_transmission_triggers_cleanup(self, tracker):
|
||||
"""record_transmission calls _maybe_cleanup, cleaning stale records."""
|
||||
old_record = TransmissionRecord(
|
||||
timestamp=time.time() - 600,
|
||||
content="old", target="chan", message_type="channel",
|
||||
)
|
||||
old_key = int(old_record.timestamp)
|
||||
tracker.pending_transmissions[old_key] = [old_record]
|
||||
tracker._last_cleanup_time = 0.0 # Force cleanup to run
|
||||
# Recording a new transmission should trigger cleanup
|
||||
tracker.record_transmission("new msg", "general", "channel")
|
||||
assert old_key not in tracker.pending_transmissions
|
||||
|
||||
Reference in New Issue
Block a user