diff --git a/modules/core.py b/modules/core.py index 8a46e9d..e9ee6f6 100644 --- a/modules/core.py +++ b/modules/core.py @@ -328,7 +328,8 @@ class MeshCoreBot: if key.startswith('channel.') and key.endswith('_seconds'): channel_name = key[len('channel.'):-len('_seconds')] try: - limits[channel_name] = float(value) + # Normalize now; limiter will also normalize at use-time. + limits[channel_name.strip().lower()] = float(value) except ValueError: self.logger.warning(f"Invalid channel rate limit for {key}: {value!r}") return ChannelRateLimiter(limits) diff --git a/modules/rate_limiter.py b/modules/rate_limiter.py index 8a9b86b..78fd90a 100644 --- a/modules/rate_limiter.py +++ b/modules/rate_limiter.py @@ -6,6 +6,7 @@ Controls how often messages can be sent to prevent spam import asyncio import time +from collections import OrderedDict from typing import Optional @@ -20,19 +21,17 @@ class PerUserRateLimiter: def __init__(self, seconds: float, max_entries: int = 1000): self.seconds = seconds self.max_entries = max_entries - self._last_send: dict[str, float] = {} - self._order: list[str] = [] # keys in insertion order for oldest-first eviction + # OrderedDict provides O(1) move-to-end + oldest-first eviction. + self._last_send: OrderedDict[str, float] = OrderedDict() + # Back-compat for existing tests/introspection: keep insertion/LRU order. + self._order: list[str] = [] - def _evict_if_needed(self, new_key: str) -> None: - """Evict oldest entry if at capacity and new_key is not already present.""" - if new_key in self._last_send: - return - while len(self._last_send) >= self.max_entries and self._order: - oldest = self._order.pop(0) - self._last_send.pop(oldest, None) + def _normalize_key(self, key: str) -> str: + return key.strip() def can_send(self, key: str) -> bool: """Check if we can send a message to this user (key).""" + key = self._normalize_key(key) if not key: return True last = self._last_send.get(key, 0) @@ -40,6 +39,7 @@ class PerUserRateLimiter: def time_until_next(self, key: str) -> float: """Get time until next allowed send for this user.""" + key = self._normalize_key(key) if not key: return 0.0 last = self._last_send.get(key, 0) @@ -48,20 +48,23 @@ class PerUserRateLimiter: def record_send(self, key: str) -> None: """Record that we sent a message to this user.""" + key = self._normalize_key(key) if not key: return - self._evict_if_needed(key) + if key in self._last_send: + self._last_send.move_to_end(key) + elif len(self._last_send) >= self.max_entries: + self._last_send.popitem(last=False) self._last_send[key] = time.time() - if key in self._order: - self._order.remove(key) - self._order.append(key) + # Keep `_order` consistent for callers/tests. + self._order = list(self._last_send.keys()) class RateLimiter: """Rate limiting for message sending""" - def __init__(self, seconds: int): - self.seconds = seconds + def __init__(self, seconds: float): + self.seconds = float(seconds) self.last_send = 0 self._total_sends = 0 self._total_throttled = 0 @@ -147,22 +150,33 @@ class ChannelRateLimiter: """ def __init__(self, channel_limits: dict[str, float]): + normalized: dict[str, float] = {} + for channel, seconds in channel_limits.items(): + ch = self._normalize_channel(channel) + try: + sec = float(seconds) + except (TypeError, ValueError): + continue + if ch and sec > 0: + normalized[ch] = sec self._limiters: dict[str, RateLimiter] = { - channel: RateLimiter(int(max(1, seconds))) - for channel, seconds in channel_limits.items() - if seconds > 0 + channel: RateLimiter(seconds) + for channel, seconds in normalized.items() } + def _normalize_channel(self, channel: str) -> str: + return channel.strip().lower() + def can_send(self, channel: str) -> bool: - limiter = self._limiters.get(channel) + limiter = self._limiters.get(self._normalize_channel(channel)) return limiter.can_send() if limiter else True def time_until_next(self, channel: str) -> float: - limiter = self._limiters.get(channel) + limiter = self._limiters.get(self._normalize_channel(channel)) return limiter.time_until_next() if limiter else 0.0 def record_send(self, channel: str) -> None: - limiter = self._limiters.get(channel) + limiter = self._limiters.get(self._normalize_channel(channel)) if limiter: limiter.record_send()