Enhance rate limiting functionality by normalizing channel names and keys

- Updated `MeshCoreBot` to normalize channel names when setting rate limits.
- Refactored `PerUserRateLimiter` to use `OrderedDict` for efficient key management and added normalization for keys.
- Improved `ChannelRateLimiter` to normalize channel names during initialization and when checking limits, ensuring consistent behavior.
This commit is contained in:
agessaman
2026-03-18 21:14:06 -07:00
parent afc36d4bf5
commit 579dc3ce8c
2 changed files with 37 additions and 22 deletions
+2 -1
View File
@@ -328,7 +328,8 @@ class MeshCoreBot:
if key.startswith('channel.') and key.endswith('_seconds'):
channel_name = key[len('channel.'):-len('_seconds')]
try:
limits[channel_name] = float(value)
# Normalize now; limiter will also normalize at use-time.
limits[channel_name.strip().lower()] = float(value)
except ValueError:
self.logger.warning(f"Invalid channel rate limit for {key}: {value!r}")
return ChannelRateLimiter(limits)
+35 -21
View File
@@ -6,6 +6,7 @@ Controls how often messages can be sent to prevent spam
import asyncio
import time
from collections import OrderedDict
from typing import Optional
@@ -20,19 +21,17 @@ class PerUserRateLimiter:
def __init__(self, seconds: float, max_entries: int = 1000):
self.seconds = seconds
self.max_entries = max_entries
self._last_send: dict[str, float] = {}
self._order: list[str] = [] # keys in insertion order for oldest-first eviction
# OrderedDict provides O(1) move-to-end + oldest-first eviction.
self._last_send: OrderedDict[str, float] = OrderedDict()
# Back-compat for existing tests/introspection: keep insertion/LRU order.
self._order: list[str] = []
def _evict_if_needed(self, new_key: str) -> None:
"""Evict oldest entry if at capacity and new_key is not already present."""
if new_key in self._last_send:
return
while len(self._last_send) >= self.max_entries and self._order:
oldest = self._order.pop(0)
self._last_send.pop(oldest, None)
def _normalize_key(self, key: str) -> str:
return key.strip()
def can_send(self, key: str) -> bool:
"""Check if we can send a message to this user (key)."""
key = self._normalize_key(key)
if not key:
return True
last = self._last_send.get(key, 0)
@@ -40,6 +39,7 @@ class PerUserRateLimiter:
def time_until_next(self, key: str) -> float:
"""Get time until next allowed send for this user."""
key = self._normalize_key(key)
if not key:
return 0.0
last = self._last_send.get(key, 0)
@@ -48,20 +48,23 @@ class PerUserRateLimiter:
def record_send(self, key: str) -> None:
"""Record that we sent a message to this user."""
key = self._normalize_key(key)
if not key:
return
self._evict_if_needed(key)
if key in self._last_send:
self._last_send.move_to_end(key)
elif len(self._last_send) >= self.max_entries:
self._last_send.popitem(last=False)
self._last_send[key] = time.time()
if key in self._order:
self._order.remove(key)
self._order.append(key)
# Keep `_order` consistent for callers/tests.
self._order = list(self._last_send.keys())
class RateLimiter:
"""Rate limiting for message sending"""
def __init__(self, seconds: int):
self.seconds = seconds
def __init__(self, seconds: float):
self.seconds = float(seconds)
self.last_send = 0
self._total_sends = 0
self._total_throttled = 0
@@ -147,22 +150,33 @@ class ChannelRateLimiter:
"""
def __init__(self, channel_limits: dict[str, float]):
normalized: dict[str, float] = {}
for channel, seconds in channel_limits.items():
ch = self._normalize_channel(channel)
try:
sec = float(seconds)
except (TypeError, ValueError):
continue
if ch and sec > 0:
normalized[ch] = sec
self._limiters: dict[str, RateLimiter] = {
channel: RateLimiter(int(max(1, seconds)))
for channel, seconds in channel_limits.items()
if seconds > 0
channel: RateLimiter(seconds)
for channel, seconds in normalized.items()
}
def _normalize_channel(self, channel: str) -> str:
return channel.strip().lower()
def can_send(self, channel: str) -> bool:
limiter = self._limiters.get(channel)
limiter = self._limiters.get(self._normalize_channel(channel))
return limiter.can_send() if limiter else True
def time_until_next(self, channel: str) -> float:
limiter = self._limiters.get(channel)
limiter = self._limiters.get(self._normalize_channel(channel))
return limiter.time_until_next() if limiter else 0.0
def record_send(self, channel: str) -> None:
limiter = self._limiters.get(channel)
limiter = self._limiters.get(self._normalize_channel(channel))
if limiter:
limiter.record_send()