mirror of
https://github.com/simplex-chat/simplex-chat.git
synced 2026-07-02 17:52:00 +00:00
simplex-chat-python: split Client from Bot, add request/response API (#6976)
* simplex-chat-python: split Client from Bot, add request/response API
Client is now the base class for SimpleX participants that talk TO
services (monitors, probes, automated participants). Bot extends Client
with server features (address, auto-accept, welcome, commands).
New methods on Client (inherited by Bot):
connect_to(link) idempotent contact handshake
send_and_wait(id, text) send a message and await the reply
events() async iterator over chat events
@on_message(contact_id=N) filter by sender in decorators
BotProfile renamed to Profile (alias kept). New ContactAlreadyExistsError
subclass for cleaner error handling.
* simplex-chat-python: narrow event payload type per @on_event tag
@client.on_event("contactConnected") now types the handler's event
parameter as CEvt.ContactConnected instead of the unnarrowed
CEvt.ChatEvent union — mirroring how @on_message narrows by
content_type.
The 50 overloads are generated by the Haskell codegen into _events.py
(as a Protocol class), so new events stay in sync automatically.
Client.on_event is exposed as a property typed as that Protocol; the
runtime implementation is unchanged.
This commit is contained in:
@@ -1,12 +1,21 @@
|
||||
"""SimpleX Chat — Python client library for chat bots."""
|
||||
|
||||
from ._version import __version__
|
||||
from .api import ChatApi, ChatCommandError, ConnReqType, Db, PostgresDb, SqliteDb
|
||||
from .api import (
|
||||
ChatApi,
|
||||
ChatCommandError,
|
||||
ConnReqType,
|
||||
ContactAlreadyExistsError,
|
||||
Db,
|
||||
PostgresDb,
|
||||
SqliteDb,
|
||||
)
|
||||
from .bot import (
|
||||
Bot,
|
||||
BotCommand,
|
||||
BotProfile,
|
||||
ChatMessage,
|
||||
Client,
|
||||
CommandHandler,
|
||||
EventHandler,
|
||||
FileMessage,
|
||||
@@ -16,6 +25,7 @@ from .bot import (
|
||||
MessageHandler,
|
||||
Middleware,
|
||||
ParsedCommand,
|
||||
Profile,
|
||||
ReportMessage,
|
||||
TextMessage,
|
||||
UnknownMessage,
|
||||
@@ -35,8 +45,10 @@ __all__ = [
|
||||
"ChatCommandError",
|
||||
"ChatInitError",
|
||||
"ChatMessage",
|
||||
"Client",
|
||||
"CommandHandler",
|
||||
"ConnReqType",
|
||||
"ContactAlreadyExistsError",
|
||||
"CryptoArgs",
|
||||
"Db",
|
||||
"EventHandler",
|
||||
@@ -49,6 +61,7 @@ __all__ = [
|
||||
"MigrationConfirmation",
|
||||
"ParsedCommand",
|
||||
"PostgresDb",
|
||||
"Profile",
|
||||
"ReportMessage",
|
||||
"SqliteDb",
|
||||
"TextMessage",
|
||||
|
||||
@@ -40,10 +40,26 @@ def _db_to_migrate_args(db: Db) -> tuple[str, str, _native.Backend]:
|
||||
|
||||
|
||||
class ChatCommandError(Exception):
|
||||
"""A chat command returned an unexpected response type.
|
||||
|
||||
`response` is the raw wire response; `response_type` exposes its `type`
|
||||
discriminator for quick checks. Subclasses cover known recoverable cases
|
||||
so callers can `except ContactAlreadyExistsError` instead of inspecting
|
||||
`response.get("type")` themselves.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, response: CR.ChatResponse):
|
||||
super().__init__(message)
|
||||
self.response = response
|
||||
|
||||
@property
|
||||
def response_type(self) -> str:
|
||||
return self.response.get("type", "") # type: ignore[return-value]
|
||||
|
||||
|
||||
class ContactAlreadyExistsError(ChatCommandError):
|
||||
"""`api_connect`/`api_connect_active_user` was called for an existing contact."""
|
||||
|
||||
|
||||
class ChatApi:
|
||||
def __init__(self, ctrl: int):
|
||||
@@ -481,7 +497,7 @@ class ChatApi:
|
||||
if r["type"] == "sentInvitation":
|
||||
return "contact"
|
||||
if r["type"] == "contactAlreadyExists":
|
||||
raise ChatCommandError("contact already exists", r)
|
||||
raise ContactAlreadyExistsError("contact already exists", r)
|
||||
raise ChatCommandError("connection error", r)
|
||||
|
||||
async def api_accept_contact_request(self, contact_req_id: int) -> T.Contact:
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
"""User-facing `Bot` API: decorators, filters, Message wrapper, lifecycle."""
|
||||
"""`Bot` — Client extended with server-side features (address, auto-accept, commands)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal as _signal
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from . import util
|
||||
from .api import ChatApi, Db
|
||||
from .core import ChatAPIError, MigrationConfirmation
|
||||
from .filters import compile_message_filter
|
||||
from .types import CEvt, T
|
||||
|
||||
log = logging.getLogger("simplex_chat")
|
||||
|
||||
C = TypeVar("C", bound="T.MsgContent")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BotProfile:
|
||||
display_name: str
|
||||
full_name: str = ""
|
||||
short_descr: str | None = None
|
||||
image: str | None = None
|
||||
from .api import Db
|
||||
from .client import (
|
||||
BotProfile,
|
||||
ChatMessage,
|
||||
Client,
|
||||
CommandHandler,
|
||||
EventHandler,
|
||||
FileMessage,
|
||||
ImageMessage,
|
||||
LinkMessage,
|
||||
Message,
|
||||
MessageHandler,
|
||||
Middleware,
|
||||
ParsedCommand,
|
||||
Profile,
|
||||
ReportMessage,
|
||||
TextMessage,
|
||||
UnknownMessage,
|
||||
VideoMessage,
|
||||
VoiceMessage,
|
||||
log,
|
||||
)
|
||||
from .core import MigrationConfirmation
|
||||
from .types import T
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -35,85 +37,25 @@ class BotCommand:
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class ParsedCommand:
|
||||
keyword: str
|
||||
args: str
|
||||
class Bot(Client):
|
||||
"""SimpleX bot — Client extended with server-side features.
|
||||
|
||||
On top of `Client` (identity + messaging + connect_to/send_and_wait/events),
|
||||
a Bot:
|
||||
- creates and announces its own contact address
|
||||
- auto-accepts incoming contact requests (configurable)
|
||||
- advertises a list of slash-commands in its profile preferences
|
||||
- sets `peerType=bot` and disables calls/voice in profile prefs
|
||||
- sends a `welcome` message to new contacts via the auto-reply address setting
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class Message(Generic[C]):
|
||||
chat_item: T.AChatItem
|
||||
content: C
|
||||
bot: "Bot"
|
||||
|
||||
@property
|
||||
def chat_info(self) -> T.ChatInfo:
|
||||
return self.chat_item["chatInfo"]
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
c = self.content
|
||||
if isinstance(c, dict):
|
||||
return c.get("text") # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
async def reply(self, text: str) -> "Message[T.MsgContent]":
|
||||
items = await self.bot.api.api_send_text_reply(self.chat_item, text)
|
||||
ci = items[0]
|
||||
content = ci["chatItem"]["content"]
|
||||
# content is CIContent — snd variant has msgContent; cast for type safety.
|
||||
msg_content: T.MsgContent = content["msgContent"] # type: ignore[index]
|
||||
return Message(chat_item=ci, content=msg_content, bot=self.bot)
|
||||
|
||||
async def reply_content(self, content: T.MsgContent) -> "Message[T.MsgContent]":
|
||||
items = await self.bot.api.api_send_messages(
|
||||
self.chat_info, [{"msgContent": content, "mentions": {}}]
|
||||
)
|
||||
ci = items[0]
|
||||
ci_content = ci["chatItem"]["content"]
|
||||
msg_content: T.MsgContent = ci_content["msgContent"] # type: ignore[index]
|
||||
return Message(chat_item=ci, content=msg_content, bot=self.bot)
|
||||
|
||||
|
||||
# Concrete narrowed aliases — one per MsgContent_<tag> variant in _types.py.
|
||||
TextMessage = Message[T.MsgContent_text]
|
||||
LinkMessage = Message[T.MsgContent_link]
|
||||
ImageMessage = Message[T.MsgContent_image]
|
||||
VideoMessage = Message[T.MsgContent_video]
|
||||
VoiceMessage = Message[T.MsgContent_voice]
|
||||
FileMessage = Message[T.MsgContent_file]
|
||||
ReportMessage = Message[T.MsgContent_report]
|
||||
ChatMessage = Message[T.MsgContent_chat]
|
||||
UnknownMessage = Message[T.MsgContent_unknown]
|
||||
|
||||
MessageHandler = Callable[[Message[Any]], Awaitable[None]]
|
||||
CommandHandler = Callable[[Message[Any], ParsedCommand], Awaitable[None]]
|
||||
EventHandler = Callable[[CEvt.ChatEvent], Awaitable[None]]
|
||||
|
||||
|
||||
class Middleware:
|
||||
"""Override `__call__` to wrap message handlers with cross-cutting logic.
|
||||
|
||||
`handler` is the next stage in the chain — call it with `(message, data)`
|
||||
to continue, or skip the call to short-circuit. `data` is a per-dispatch
|
||||
dict that middleware can use to pass values down the chain.
|
||||
If you want just identity + messaging without any of that, use `Client`
|
||||
directly.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[[Message[Any], dict[str, object]], Awaitable[None]],
|
||||
message: Message[Any],
|
||||
data: dict[str, object],
|
||||
) -> None:
|
||||
await handler(message, data)
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
profile: BotProfile,
|
||||
profile: Profile,
|
||||
db: Db,
|
||||
welcome: str | T.MsgContent | None = None,
|
||||
commands: list[BotCommand] | None = None,
|
||||
@@ -124,423 +66,42 @@ class Bot:
|
||||
auto_accept: bool = True,
|
||||
business_address: bool = False,
|
||||
allow_files: bool = False,
|
||||
use_bot_profile: bool = True,
|
||||
log_contacts: bool = True,
|
||||
log_network: bool = False,
|
||||
) -> None:
|
||||
self._profile = profile
|
||||
self._db = db
|
||||
super().__init__(
|
||||
profile=profile,
|
||||
db=db,
|
||||
confirm_migrations=confirm_migrations,
|
||||
update_profile=update_profile,
|
||||
log_contacts=log_contacts,
|
||||
log_network=log_network,
|
||||
)
|
||||
self._welcome = welcome
|
||||
self._commands = commands or []
|
||||
self._confirm_migrations = confirm_migrations
|
||||
self._opts = {
|
||||
"create_address": create_address,
|
||||
"update_address": update_address,
|
||||
"update_profile": update_profile,
|
||||
"auto_accept": auto_accept,
|
||||
"business_address": business_address,
|
||||
"allow_files": allow_files,
|
||||
"use_bot_profile": use_bot_profile,
|
||||
"log_contacts": log_contacts,
|
||||
"log_network": log_network,
|
||||
}
|
||||
self._api: ChatApi | None = None
|
||||
self._serving = False
|
||||
self._stop_event = asyncio.Event()
|
||||
self._message_handlers: list[tuple[Callable[[Message[Any]], bool], MessageHandler]] = []
|
||||
self._command_handlers: list[
|
||||
tuple[tuple[str, ...], Callable[[Message[Any]], bool], CommandHandler]
|
||||
] = []
|
||||
self._event_handlers: dict[str, list[EventHandler]] = {}
|
||||
self._middleware: list[Middleware] = []
|
||||
# Track default-handler registration so __aenter__ on a re-used bot
|
||||
# doesn't accumulate duplicate log/error handlers.
|
||||
self._defaults_registered = False
|
||||
|
||||
@property
|
||||
def api(self) -> ChatApi:
|
||||
if self._api is None:
|
||||
raise RuntimeError("Bot not initialized — call bot.run() or use `async with bot:`")
|
||||
return self._api
|
||||
self._create_address = create_address
|
||||
self._update_address = update_address
|
||||
self._auto_accept = auto_accept
|
||||
self._business_address = business_address
|
||||
self._allow_files = allow_files
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Decorators
|
||||
# Profile + address sync (overrides hooks in Client)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["text"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[TextMessage], Awaitable[None]]],
|
||||
Callable[[TextMessage], Awaitable[None]],
|
||||
]: ...
|
||||
async def _post_start(self, user: T.User) -> None:
|
||||
"""Bots sync address first, then embed the link in the profile."""
|
||||
link = await self._sync_address(user)
|
||||
await self._maybe_sync_profile(user, contact_link=link)
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["link"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[LinkMessage], Awaitable[None]]],
|
||||
Callable[[LinkMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["image"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[ImageMessage], Awaitable[None]]],
|
||||
Callable[[ImageMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["video"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[VideoMessage], Awaitable[None]]],
|
||||
Callable[[VideoMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["voice"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[VoiceMessage], Awaitable[None]]],
|
||||
Callable[[VoiceMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["file"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[FileMessage], Awaitable[None]]],
|
||||
Callable[[FileMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["report"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[ReportMessage], Awaitable[None]]],
|
||||
Callable[[ReportMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["chat"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[ChatMessage], Awaitable[None]]],
|
||||
Callable[[ChatMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["unknown"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[UnknownMessage], Awaitable[None]]],
|
||||
Callable[[UnknownMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(self, **rest: Any) -> Callable[[MessageHandler], MessageHandler]: ...
|
||||
|
||||
def on_message(self, **filter_kw: Any) -> Callable[[MessageHandler], MessageHandler]:
|
||||
predicate = compile_message_filter(filter_kw)
|
||||
|
||||
def deco(fn: MessageHandler) -> MessageHandler:
|
||||
self._message_handlers.append((predicate, fn))
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
def on_command(
|
||||
self, name: str | tuple[str, ...], **filter_kw: Any
|
||||
) -> Callable[[CommandHandler], CommandHandler]:
|
||||
names = (name,) if isinstance(name, str) else tuple(name)
|
||||
predicate = compile_message_filter(filter_kw)
|
||||
|
||||
def deco(fn: CommandHandler) -> CommandHandler:
|
||||
self._command_handlers.append((names, predicate, fn))
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
def on_event(self, event: CEvt.ChatEvent_Tag, /) -> Callable[[EventHandler], EventHandler]:
|
||||
def deco(fn: EventHandler) -> EventHandler:
|
||||
self._event_handlers.setdefault(event, []).append(fn)
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
def use(self, middleware: Middleware) -> None:
|
||||
self._middleware.append(middleware)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def __aenter__(self) -> "Bot":
|
||||
# Order matters: libsimplex `/_start` requires an active user, so
|
||||
# ensure (or create) the user first, THEN start the chat, THEN
|
||||
# do address + profile sync. Mirrors Node bot.ts:48-64.
|
||||
self._api = await ChatApi.init(self._db, self._confirm_migrations)
|
||||
user = await self._ensure_active_user()
|
||||
await self._api.start_chat()
|
||||
await self._sync_address_and_profile(user)
|
||||
self._register_log_handlers()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_info: object) -> None:
|
||||
self.stop()
|
||||
if self._api is not None:
|
||||
try:
|
||||
await self._api.stop_chat()
|
||||
finally:
|
||||
await self._api.close()
|
||||
self._api = None
|
||||
|
||||
def run(self) -> None:
|
||||
"""Blocking entry: runs serve_forever() with SIGINT/SIGTERM handlers installed.
|
||||
|
||||
Configures `logging.basicConfig(level=INFO)` if the root logger has no
|
||||
handlers yet, so the bot's startup messages and the announced address
|
||||
are visible without callers having to set up logging. Embedders that
|
||||
manage logging themselves are unaffected (basicConfig is a no-op when
|
||||
handlers already exist).
|
||||
"""
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
async def _main() -> None:
|
||||
async with self:
|
||||
loop = asyncio.get_running_loop()
|
||||
# First Ctrl+C → graceful stop (~500ms, bounded by the
|
||||
# receive-loop poll interval). Second Ctrl+C → force-exit
|
||||
# immediately (in case stop_chat / close hang on a wedged
|
||||
# FFI call). Standard CLI UX (jupyter, ipython, …).
|
||||
sigint_count = 0
|
||||
|
||||
def on_interrupt() -> None:
|
||||
nonlocal sigint_count
|
||||
sigint_count += 1
|
||||
if sigint_count == 1:
|
||||
log.info("stopping bot... (press Ctrl+C again to force exit)")
|
||||
self.stop()
|
||||
else:
|
||||
os._exit(130) # 128 + SIGINT
|
||||
|
||||
if hasattr(_signal, "SIGINT"):
|
||||
try:
|
||||
loop.add_signal_handler(_signal.SIGINT, on_interrupt)
|
||||
loop.add_signal_handler(_signal.SIGTERM, self.stop)
|
||||
except NotImplementedError: # Windows
|
||||
_signal.signal(_signal.SIGINT, lambda *_: on_interrupt())
|
||||
await self.serve_forever()
|
||||
|
||||
asyncio.run(_main())
|
||||
|
||||
async def serve_forever(self) -> None:
|
||||
if self._serving:
|
||||
raise RuntimeError("already serving")
|
||||
self._serving = True
|
||||
self._stop_event.clear()
|
||||
try:
|
||||
await self._receive_loop()
|
||||
finally:
|
||||
self._serving = False
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
# Catch broad Exception so a single malformed event or transient
|
||||
# native error doesn't crash the whole bot. CancelledError must
|
||||
# always re-raise so `bot.stop()` and asyncio cancellation work.
|
||||
# `wait_us=500_000` (500ms) bounds the worst-case Ctrl+C latency:
|
||||
# the C call blocks the worker thread until timeout, and the loop
|
||||
# only checks `_stop_event` between polls.
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
event = await self.api.recv_chat_event(wait_us=500_000)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ChatAPIError as e:
|
||||
# Async chat errors emitted via the Haskell `eToView` path —
|
||||
# routine soft errors (stale connections after a peer deletes
|
||||
# a chat, file cleanup failures, etc.) intermixed with
|
||||
# CRITICAL agent failures the operator must see. Mirror the
|
||||
# desktop policy in SimpleXAPI.kt:3332-3340: escalate
|
||||
# CRITICAL agent errors, keep everything else at debug.
|
||||
chat_err: Any = e.chat_error or {}
|
||||
agent_err: Any = (
|
||||
chat_err.get("agentError", {}) if chat_err.get("type") == "errorAgent" else {}
|
||||
)
|
||||
if agent_err.get("type") == "CRITICAL":
|
||||
log.error(
|
||||
"chat agent CRITICAL: %s (offerRestart=%s)",
|
||||
agent_err.get("criticalErr"),
|
||||
agent_err.get("offerRestart"),
|
||||
)
|
||||
else:
|
||||
log.debug("chat event error: %s", chat_err.get("type"))
|
||||
continue
|
||||
except Exception:
|
||||
log.exception("recv_chat_event failed")
|
||||
# Bound the spin rate when the FFI is wedged on a persistent
|
||||
# error (vs the timeout path, which already paces itself).
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
if event is None:
|
||||
continue
|
||||
try:
|
||||
await self._dispatch_event(event)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
log.exception("dispatch_event failed for tag=%s", event.get("type"))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Dispatch
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _dispatch_event(self, event: CEvt.ChatEvent) -> None:
|
||||
tag = event["type"]
|
||||
for h in self._event_handlers.get(tag, []):
|
||||
try:
|
||||
await h(event)
|
||||
except Exception:
|
||||
log.exception("on_event handler failed")
|
||||
if tag == "newChatItems":
|
||||
evt: CEvt.NewChatItems = event # type: ignore[assignment]
|
||||
for ci in evt["chatItems"]:
|
||||
content = ci["chatItem"]["content"]
|
||||
if content["type"] != "rcvMsgContent":
|
||||
continue
|
||||
msg_content = content["msgContent"] # type: ignore[index]
|
||||
msg: Message[T.MsgContent] = Message(chat_item=ci, content=msg_content, bot=self)
|
||||
await self._dispatch_message(msg)
|
||||
|
||||
async def _dispatch_message(self, msg: Message[Any]) -> None:
|
||||
# First-match-wins. The squaring bot's `@on_message(text=NUMBER_RE)`
|
||||
# and catch-all `@on_message(content_type="text")` both match a number
|
||||
# like "1"; we want only the first to fire. Registration order is the
|
||||
# priority order — register the most-specific filters first.
|
||||
#
|
||||
# Slash-commands are tried first against command handlers; if no
|
||||
# command handler matches, fall through to message handlers (so
|
||||
# `@on_message` can still catch unknown slash-commands).
|
||||
cmd = self._parse_command(msg)
|
||||
if cmd is not None:
|
||||
for names, predicate, handler in self._command_handlers:
|
||||
if cmd.keyword in names and predicate(msg):
|
||||
await self._invoke_command_with_middleware(handler, msg, cmd)
|
||||
return
|
||||
for predicate, handler in self._message_handlers:
|
||||
if predicate(msg):
|
||||
await self._invoke_with_middleware(handler, msg)
|
||||
return
|
||||
|
||||
async def _invoke_with_middleware(self, handler: MessageHandler, message: Message[Any]) -> None:
|
||||
# Fast path: most bots register no middleware. Skip the closure-chain
|
||||
# construction and the empty-data dict on every dispatch.
|
||||
if not self._middleware:
|
||||
try:
|
||||
await handler(message)
|
||||
except Exception:
|
||||
log.exception("message handler failed")
|
||||
return
|
||||
|
||||
async def call(m: Message[Any], _data: dict[str, object]) -> None:
|
||||
await handler(m)
|
||||
|
||||
chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call
|
||||
for mw in reversed(self._middleware):
|
||||
inner = chain
|
||||
|
||||
async def _wrapped(
|
||||
m: Message[Any],
|
||||
d: dict[str, object],
|
||||
mw: Middleware = mw,
|
||||
inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner,
|
||||
) -> None:
|
||||
await mw(inner, m, d)
|
||||
|
||||
chain = _wrapped
|
||||
|
||||
try:
|
||||
await chain(message, {})
|
||||
except Exception:
|
||||
log.exception("message handler failed")
|
||||
|
||||
async def _invoke_command_with_middleware(
|
||||
self, handler: CommandHandler, message: Message[Any], cmd: ParsedCommand
|
||||
) -> None:
|
||||
if not self._middleware:
|
||||
try:
|
||||
await handler(message, cmd)
|
||||
except Exception:
|
||||
log.exception("command handler failed")
|
||||
return
|
||||
|
||||
async def call(m: Message[Any], _data: dict[str, object]) -> None:
|
||||
await handler(m, cmd)
|
||||
|
||||
chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call
|
||||
for mw in reversed(self._middleware):
|
||||
inner = chain
|
||||
|
||||
async def _wrapped(
|
||||
m: Message[Any],
|
||||
d: dict[str, object],
|
||||
mw: Middleware = mw,
|
||||
inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner,
|
||||
) -> None:
|
||||
await mw(inner, m, d)
|
||||
|
||||
chain = _wrapped
|
||||
|
||||
try:
|
||||
await chain(message, {})
|
||||
except Exception:
|
||||
log.exception("command handler failed")
|
||||
|
||||
@staticmethod
|
||||
def _parse_command(msg: Message[Any]) -> ParsedCommand | None:
|
||||
parsed = util.ci_bot_command(msg.chat_item["chatItem"])
|
||||
if parsed is None:
|
||||
return None
|
||||
keyword, args = parsed
|
||||
return ParsedCommand(keyword=keyword, args=args)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Profile + address sync
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _ensure_active_user(self) -> T.User:
|
||||
"""Get or create the active user. Must run before `start_chat`.
|
||||
|
||||
Mirrors Node `createBotUser` (bot.ts:158-166). The chat controller
|
||||
won't accept `/_start` without a user, so this phase has to land
|
||||
before lifecycle proceeds.
|
||||
"""
|
||||
api = self.api
|
||||
user = await api.api_get_active_user()
|
||||
if user is None:
|
||||
log.info("No active user in database, creating...")
|
||||
user = await api.api_create_active_user(self._bot_profile_to_wire())
|
||||
log.info("Bot user: %s", user["profile"]["displayName"])
|
||||
return user
|
||||
|
||||
async def _sync_address_and_profile(self, user: T.User) -> None:
|
||||
"""Address + profile sync. Runs after `start_chat` (mirrors bot.ts:57-63)."""
|
||||
async def _sync_address(self, user: T.User) -> str | None:
|
||||
"""Address sync. Returns the public link if any, for embedding in the profile."""
|
||||
api = self.api
|
||||
user_id = user["userId"]
|
||||
|
||||
# 2. Address (numbered to match bot.ts comments — phase 1 was user creation).
|
||||
address = await api.api_get_user_address(user_id)
|
||||
if address is None:
|
||||
if self._opts["create_address"]:
|
||||
if self._create_address:
|
||||
log.info("Bot has no address, creating...")
|
||||
await api.api_create_user_address(user_id)
|
||||
address = await api.api_get_user_address(user_id)
|
||||
@@ -549,17 +110,16 @@ class Bot:
|
||||
else:
|
||||
log.warning("Bot has no address")
|
||||
|
||||
# Always announce the address — matches Node bot.ts:60.
|
||||
link: str | None = None
|
||||
if address is not None:
|
||||
link = util.contact_address_str(address["connLinkContact"])
|
||||
log.info("Bot address: %s", link)
|
||||
|
||||
# 3. Address settings (auto-accept + welcome message). Mirrors bot.ts:185-194.
|
||||
# Address settings (auto-accept + welcome message). Mirrors bot.ts:185-194.
|
||||
# autoAccept present → accept; absent → no auto-accept (mirrors Node bot.ts).
|
||||
if address is not None and self._opts["update_address"]:
|
||||
desired: T.AddressSettings = {"businessAddress": self._opts["business_address"]}
|
||||
if self._opts["auto_accept"]:
|
||||
if address is not None and self._update_address:
|
||||
desired: T.AddressSettings = {"businessAddress": self._business_address}
|
||||
if self._auto_accept:
|
||||
desired["autoAccept"] = {"acceptIncognito": False}
|
||||
if self._welcome is not None:
|
||||
desired["autoReply"] = (
|
||||
@@ -571,154 +131,45 @@ class Bot:
|
||||
log.info("Bot address settings changed, updating...")
|
||||
await api.api_set_address_settings(user_id, desired)
|
||||
|
||||
# 4. Profile update. Mirrors Node `updateBotUserProfile` (bot.ts:199-214).
|
||||
# Field-by-field comparison: user["profile"] is LocalProfile (has extra
|
||||
# fields profileId, localAlias, preferences, peerType) so a full-dict
|
||||
# equality would always differ.
|
||||
new_profile = self._bot_profile_to_wire()
|
||||
if link is not None and self._opts["use_bot_profile"]:
|
||||
# Mirrors bot.ts:62 — embed the connection link in the bot's profile
|
||||
# so contacts that resolve the bot via stored profile data see the
|
||||
# current address.
|
||||
new_profile["contactLink"] = link
|
||||
cur = user["profile"]
|
||||
changed = (
|
||||
cur["displayName"] != new_profile["displayName"]
|
||||
or cur.get("fullName", "") != new_profile.get("fullName", "")
|
||||
or cur.get("shortDescr") != new_profile.get("shortDescr")
|
||||
or cur.get("image") != new_profile.get("image")
|
||||
or cur.get("preferences") != new_profile.get("preferences")
|
||||
or cur.get("peerType") != new_profile.get("peerType")
|
||||
or cur.get("contactLink") != new_profile.get("contactLink")
|
||||
)
|
||||
if changed and self._opts["update_profile"]:
|
||||
log.info("Bot profile changed, updating...")
|
||||
await api.api_update_profile(user_id, new_profile)
|
||||
return link
|
||||
|
||||
def _bot_profile_to_wire(self) -> T.Profile:
|
||||
"""Construct wire-format Profile, applying bot conventions when use_bot_profile=True.
|
||||
def _profile_to_wire(self) -> T.Profile:
|
||||
"""Bot profile: base profile + peerType=bot, command list, calls/voice prefs disabled.
|
||||
|
||||
Mirrors Node mkBotProfile (bot.ts:88-102): bots get peerType="bot",
|
||||
calls/voice prefs disabled, files gated on `allow_files`, and any
|
||||
registered `commands` embedded in the profile preferences.
|
||||
Mirrors Node `mkBotProfile` (bot.ts:88-102).
|
||||
"""
|
||||
p: T.Profile = {
|
||||
"displayName": self._profile.display_name,
|
||||
"fullName": self._profile.full_name,
|
||||
p = super()._profile_to_wire()
|
||||
prefs: T.Preferences = {
|
||||
"calls": {"allow": "no"},
|
||||
"voice": {"allow": "no"},
|
||||
"files": {"allow": "yes" if self._allow_files else "no"},
|
||||
}
|
||||
if self._profile.short_descr is not None:
|
||||
p["shortDescr"] = self._profile.short_descr
|
||||
if self._profile.image is not None:
|
||||
p["image"] = self._profile.image
|
||||
if self._opts["use_bot_profile"]:
|
||||
prefs: T.Preferences = {
|
||||
"calls": {"allow": "no"},
|
||||
"voice": {"allow": "no"},
|
||||
"files": {"allow": "yes" if self._opts["allow_files"] else "no"},
|
||||
}
|
||||
if self._commands:
|
||||
prefs["commands"] = [
|
||||
{"type": "command", "keyword": c.keyword, "label": c.label}
|
||||
for c in self._commands
|
||||
]
|
||||
p["preferences"] = prefs
|
||||
p["peerType"] = "bot"
|
||||
elif self._commands:
|
||||
raise ValueError(
|
||||
"use_bot_profile=False but commands were passed; commands are "
|
||||
"only sent when use_bot_profile=True (they're embedded in the "
|
||||
"user profile preferences)."
|
||||
)
|
||||
if self._commands:
|
||||
prefs["commands"] = [
|
||||
{"type": "command", "keyword": c.keyword, "label": c.label}
|
||||
for c in self._commands
|
||||
]
|
||||
p["preferences"] = prefs
|
||||
p["peerType"] = "bot"
|
||||
return p
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Log subscriptions (mirror Node subscribeLogEvents bot.ts:142-156)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _register_log_handlers(self) -> None:
|
||||
# Idempotent: a Bot reused across multiple `__aenter__` cycles must
|
||||
# not stack duplicate log handlers. Always-on error handlers run
|
||||
# regardless of log_contacts/log_network so messageError/chatError/
|
||||
# chatErrors don't disappear into the void.
|
||||
if self._defaults_registered:
|
||||
return
|
||||
self._defaults_registered = True
|
||||
self._event_handlers.setdefault("messageError", []).append(self._log_message_error)
|
||||
self._event_handlers.setdefault("chatError", []).append(self._log_chat_error)
|
||||
self._event_handlers.setdefault("chatErrors", []).append(self._log_chat_errors)
|
||||
if self._opts["log_contacts"]:
|
||||
self._event_handlers.setdefault("contactConnected", []).append(
|
||||
self._log_contact_connected
|
||||
)
|
||||
self._event_handlers.setdefault("contactDeletedByContact", []).append(
|
||||
self._log_contact_deleted
|
||||
)
|
||||
if self._opts["log_network"]:
|
||||
self._event_handlers.setdefault("hostConnected", []).append(self._log_host_connected)
|
||||
self._event_handlers.setdefault("hostDisconnected", []).append(
|
||||
self._log_host_disconnected
|
||||
)
|
||||
self._event_handlers.setdefault("subscriptionStatus", []).append(
|
||||
self._log_subscription_status
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _log_contact_connected(evt: CEvt.ChatEvent) -> None:
|
||||
log.info("%s connected", evt["contact"]["profile"]["displayName"]) # type: ignore[index]
|
||||
|
||||
@staticmethod
|
||||
async def _log_contact_deleted(evt: CEvt.ChatEvent) -> None:
|
||||
log.info(
|
||||
"%s deleted connection with bot",
|
||||
evt["contact"]["profile"]["displayName"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _log_host_connected(evt: CEvt.ChatEvent) -> None:
|
||||
log.info("connected server %s", evt["transportHost"]) # type: ignore[index]
|
||||
|
||||
@staticmethod
|
||||
async def _log_host_disconnected(evt: CEvt.ChatEvent) -> None:
|
||||
log.info("disconnected server %s", evt["transportHost"]) # type: ignore[index]
|
||||
|
||||
@staticmethod
|
||||
async def _log_subscription_status(evt: CEvt.ChatEvent) -> None:
|
||||
log.info(
|
||||
"%d subscription(s) %s",
|
||||
len(evt["connections"]), # type: ignore[index]
|
||||
evt["subscriptionStatus"]["type"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _log_message_error(evt: CEvt.ChatEvent) -> None:
|
||||
log.warning("messageError: %s", evt.get("severity", "?")) # type: ignore[union-attr]
|
||||
|
||||
@staticmethod
|
||||
async def _log_chat_error(evt: CEvt.ChatEvent) -> None:
|
||||
err = evt.get("chatError") # type: ignore[union-attr]
|
||||
log.error("chatError: %s", err.get("type") if isinstance(err, dict) else err)
|
||||
|
||||
@staticmethod
|
||||
async def _log_chat_errors(evt: CEvt.ChatEvent) -> None:
|
||||
errs = evt.get("chatErrors") or [] # type: ignore[union-attr]
|
||||
log.error("chatErrors: %d errors", len(errs))
|
||||
|
||||
|
||||
# Suppress unused-import warnings for re-exported names used only at type-check time.
|
||||
__all__ = [
|
||||
"Bot",
|
||||
"BotCommand",
|
||||
"BotProfile",
|
||||
"ChatMessage",
|
||||
"Client",
|
||||
"CommandHandler",
|
||||
"EventHandler",
|
||||
"FileMessage",
|
||||
"ImageMessage",
|
||||
"LinkMessage",
|
||||
"Message",
|
||||
"MessageHandler",
|
||||
"CommandHandler",
|
||||
"EventHandler",
|
||||
"Middleware",
|
||||
"ParsedCommand",
|
||||
"Profile",
|
||||
"ReportMessage",
|
||||
"TextMessage",
|
||||
"UnknownMessage",
|
||||
|
||||
@@ -0,0 +1,955 @@
|
||||
"""Base `Client` API: lifecycle, dispatch, decorators, connect_to / send_and_wait / events.
|
||||
|
||||
Bot extends Client to add server-side features (address, auto-accept, welcome,
|
||||
commands). Client by itself is suitable for monitors, probes, automated
|
||||
participants — anything that talks TO services rather than accepting incoming
|
||||
connections.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal as _signal
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from . import util
|
||||
from .api import ChatApi, ChatCommandError, ContactAlreadyExistsError, Db
|
||||
from .core import ChatAPIError, MigrationConfirmation
|
||||
from .filters import compile_message_filter
|
||||
from .types import CEvt, T
|
||||
|
||||
log = logging.getLogger("simplex_chat")
|
||||
|
||||
C = TypeVar("C", bound="T.MsgContent")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Profile:
|
||||
"""SimpleX user profile fields: display name, optional full name, descr, avatar.
|
||||
|
||||
Universal — used by both `Client` and `Bot`. The bot-specific extensions
|
||||
(peerType=bot, command list, calls/voice preferences) are added at
|
||||
wire-conversion time by `Bot`, not stored here.
|
||||
"""
|
||||
|
||||
display_name: str
|
||||
full_name: str = ""
|
||||
short_descr: str | None = None
|
||||
image: str | None = None
|
||||
|
||||
|
||||
# Backwards-compatibility alias — the dataclass was named `BotProfile` before
|
||||
# the Client/Bot hierarchy was introduced. Keep the old name working so
|
||||
# `from simplex_chat import BotProfile` doesn't break existing code.
|
||||
BotProfile = Profile
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class ParsedCommand:
|
||||
keyword: str
|
||||
args: str
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class Message(Generic[C]):
|
||||
chat_item: T.AChatItem
|
||||
content: C
|
||||
client: "Client"
|
||||
|
||||
@property
|
||||
def chat_info(self) -> T.ChatInfo:
|
||||
return self.chat_item["chatInfo"]
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
c = self.content
|
||||
if isinstance(c, dict):
|
||||
return c.get("text") # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
async def reply(self, text: str) -> "Message[T.MsgContent]":
|
||||
items = await self.client.api.api_send_text_reply(self.chat_item, text)
|
||||
ci = items[0]
|
||||
content = ci["chatItem"]["content"]
|
||||
# content is CIContent — snd variant has msgContent; cast for type safety.
|
||||
msg_content: T.MsgContent = content["msgContent"] # type: ignore[index]
|
||||
return Message(chat_item=ci, content=msg_content, client=self.client)
|
||||
|
||||
async def reply_content(self, content: T.MsgContent) -> "Message[T.MsgContent]":
|
||||
items = await self.client.api.api_send_messages(
|
||||
self.chat_info, [{"msgContent": content, "mentions": {}}]
|
||||
)
|
||||
ci = items[0]
|
||||
ci_content = ci["chatItem"]["content"]
|
||||
msg_content: T.MsgContent = ci_content["msgContent"] # type: ignore[index]
|
||||
return Message(chat_item=ci, content=msg_content, client=self.client)
|
||||
|
||||
|
||||
# Concrete narrowed aliases — one per MsgContent_<tag> variant in _types.py.
|
||||
TextMessage = Message[T.MsgContent_text]
|
||||
LinkMessage = Message[T.MsgContent_link]
|
||||
ImageMessage = Message[T.MsgContent_image]
|
||||
VideoMessage = Message[T.MsgContent_video]
|
||||
VoiceMessage = Message[T.MsgContent_voice]
|
||||
FileMessage = Message[T.MsgContent_file]
|
||||
ReportMessage = Message[T.MsgContent_report]
|
||||
ChatMessage = Message[T.MsgContent_chat]
|
||||
UnknownMessage = Message[T.MsgContent_unknown]
|
||||
|
||||
MessageHandler = Callable[[Message[Any]], Awaitable[None]]
|
||||
CommandHandler = Callable[[Message[Any], ParsedCommand], Awaitable[None]]
|
||||
EventHandler = Callable[[CEvt.ChatEvent], Awaitable[None]]
|
||||
|
||||
|
||||
class Middleware:
|
||||
"""Override `__call__` to wrap message handlers with cross-cutting logic.
|
||||
|
||||
`handler` is the next stage in the chain — call it with `(message, data)`
|
||||
to continue, or skip the call to short-circuit. `data` is a per-dispatch
|
||||
dict that middleware can use to pass values down the chain.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[[Message[Any], dict[str, object]], Awaitable[None]],
|
||||
message: Message[Any],
|
||||
data: dict[str, object],
|
||||
) -> None:
|
||||
await handler(message, data)
|
||||
|
||||
|
||||
class Client:
|
||||
"""SimpleX participant — has an identity, sends and receives messages.
|
||||
|
||||
No address, no auto-accept of incoming requests, no bot profile prefs. Use
|
||||
this for monitors, probes, automated participants — anything that talks
|
||||
TO services rather than accepting incoming connections. Use `Bot` for the
|
||||
server-side flavour.
|
||||
|
||||
Typical pattern:
|
||||
|
||||
async with Client(profile=Profile(display_name="m"), db=...) as c:
|
||||
serve = asyncio.create_task(c.serve_forever())
|
||||
contact = await c.connect_to(link)
|
||||
reply = await c.send_and_wait(contact["contactId"], "/help")
|
||||
c.stop()
|
||||
await serve
|
||||
|
||||
The decorator-style handlers (`@on_message`, `@on_command`, `@on_event`)
|
||||
work too if you want callback-style dispatch instead of async-await.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
profile: Profile,
|
||||
db: Db,
|
||||
confirm_migrations: MigrationConfirmation = MigrationConfirmation.YES_UP,
|
||||
update_profile: bool = True,
|
||||
log_contacts: bool = False,
|
||||
log_network: bool = False,
|
||||
) -> None:
|
||||
self._profile = profile
|
||||
self._db = db
|
||||
self._confirm_migrations = confirm_migrations
|
||||
self._update_profile = update_profile
|
||||
self._log_contacts = log_contacts
|
||||
self._log_network = log_network
|
||||
self._api: ChatApi | None = None
|
||||
self._serving = False
|
||||
self._stop_event = asyncio.Event()
|
||||
self._message_handlers: list[tuple[Callable[[Message[Any]], bool], MessageHandler]] = []
|
||||
self._command_handlers: list[
|
||||
tuple[tuple[str, ...], Callable[[Message[Any]], bool], CommandHandler]
|
||||
] = []
|
||||
self._event_handlers: dict[str, list[EventHandler]] = {}
|
||||
self._middleware: list[Middleware] = []
|
||||
# Track default-handler registration so __aenter__ on a re-used client
|
||||
# doesn't accumulate duplicate log/error handlers.
|
||||
self._defaults_registered = False
|
||||
# Internal waiters used by `send_and_wait` (keyed by contact_id, FIFO
|
||||
# within a contact) and `connect_to` (one-shot, resolved on the next
|
||||
# contactConnected event). Populated by user-async-callers, drained
|
||||
# in `_dispatch_event` before user handlers run.
|
||||
self._reply_waiters: dict[int, list[asyncio.Future[Message[Any]]]] = {}
|
||||
self._connect_waiters: list[asyncio.Future[T.Contact]] = []
|
||||
|
||||
@property
|
||||
def api(self) -> ChatApi:
|
||||
if self._api is None:
|
||||
raise RuntimeError("Client not initialized — call run() or use `async with client:`")
|
||||
return self._api
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Decorators
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["text"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[TextMessage], Awaitable[None]]],
|
||||
Callable[[TextMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["link"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[LinkMessage], Awaitable[None]]],
|
||||
Callable[[LinkMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["image"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[ImageMessage], Awaitable[None]]],
|
||||
Callable[[ImageMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["video"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[VideoMessage], Awaitable[None]]],
|
||||
Callable[[VideoMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["voice"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[VoiceMessage], Awaitable[None]]],
|
||||
Callable[[VoiceMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["file"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[FileMessage], Awaitable[None]]],
|
||||
Callable[[FileMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["report"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[ReportMessage], Awaitable[None]]],
|
||||
Callable[[ReportMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["chat"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[ChatMessage], Awaitable[None]]],
|
||||
Callable[[ChatMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(
|
||||
self, *, content_type: Literal["unknown"], **rest: Any
|
||||
) -> Callable[
|
||||
[Callable[[UnknownMessage], Awaitable[None]]],
|
||||
Callable[[UnknownMessage], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def on_message(self, **rest: Any) -> Callable[[MessageHandler], MessageHandler]: ...
|
||||
|
||||
def on_message(self, **filter_kw: Any) -> Callable[[MessageHandler], MessageHandler]:
|
||||
predicate = compile_message_filter(filter_kw)
|
||||
|
||||
def deco(fn: MessageHandler) -> MessageHandler:
|
||||
self._message_handlers.append((predicate, fn))
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
def on_command(
|
||||
self, name: str | tuple[str, ...], **filter_kw: Any
|
||||
) -> Callable[[CommandHandler], CommandHandler]:
|
||||
names = (name,) if isinstance(name, str) else tuple(name)
|
||||
predicate = compile_message_filter(filter_kw)
|
||||
|
||||
def deco(fn: CommandHandler) -> CommandHandler:
|
||||
self._command_handlers.append((names, predicate, fn))
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
# `on_event` is exposed as a property typed as the generated
|
||||
# `OnEventDecorator` Protocol so per-tag narrowing applies — e.g.
|
||||
# `@client.on_event("contactConnected")` types the handler's event
|
||||
# parameter as `CEvt.ContactConnected`, not the unnarrowed
|
||||
# `CEvt.ChatEvent` union. The Protocol's overload chain lives in
|
||||
# generated code (one entry per event tag) so it stays in sync with
|
||||
# the wire schema automatically. The runtime implementation is the
|
||||
# plain handler-registration below.
|
||||
@property
|
||||
def on_event(self) -> CEvt.OnEventDecorator:
|
||||
return self._register_event_handler # type: ignore[return-value]
|
||||
|
||||
def _register_event_handler(
|
||||
self, event: str, /
|
||||
) -> Callable[[EventHandler], EventHandler]:
|
||||
def deco(fn: EventHandler) -> EventHandler:
|
||||
self._event_handlers.setdefault(event, []).append(fn)
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
def use(self, middleware: Middleware) -> None:
|
||||
self._middleware.append(middleware)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def __aenter__(self) -> "Client":
|
||||
# Order matters: libsimplex `/_start` requires an active user, so
|
||||
# ensure (or create) the user first, THEN start the chat, THEN
|
||||
# do post-start setup (profile sync; Bot adds address sync).
|
||||
# Clear `_stop_event` here (not in `serve_forever`/`events`) so that
|
||||
# a `stop()` call landing between `__aenter__` and the receive loop
|
||||
# — e.g. a signal handler firing while signal handlers are being
|
||||
# wired up — is preserved and causes the loop to exit immediately
|
||||
# on entry.
|
||||
self._stop_event.clear()
|
||||
self._api = await ChatApi.init(self._db, self._confirm_migrations)
|
||||
try:
|
||||
user = await self._ensure_active_user()
|
||||
await self._api.start_chat()
|
||||
await self._post_start(user)
|
||||
self._register_log_handlers()
|
||||
return self
|
||||
except BaseException:
|
||||
# __aexit__ is only called when __aenter__ returns successfully —
|
||||
# roll back the open chat controller here so a failure during
|
||||
# init doesn't leak the FFI resource.
|
||||
await self._shutdown_partial_init()
|
||||
raise
|
||||
|
||||
async def _shutdown_partial_init(self) -> None:
|
||||
"""Best-effort teardown for an `__aenter__` that didn't reach return."""
|
||||
api = self._api
|
||||
if api is None:
|
||||
return
|
||||
if api.started:
|
||||
try:
|
||||
await api.stop_chat()
|
||||
except Exception:
|
||||
log.exception("stop_chat failed during init rollback")
|
||||
try:
|
||||
await api.close()
|
||||
except Exception:
|
||||
log.exception("close failed during init rollback")
|
||||
self._api = None
|
||||
|
||||
async def __aexit__(self, *exc_info: object) -> None:
|
||||
self.stop()
|
||||
api = self._api
|
||||
if api is None:
|
||||
return
|
||||
# Null out the reference up-front so the Client appears closed even
|
||||
# if stop_chat / close raise — otherwise `client.api` would still
|
||||
# hand back a half-shutdown controller after `async with` exits.
|
||||
self._api = None
|
||||
try:
|
||||
await api.stop_chat()
|
||||
finally:
|
||||
await api.close()
|
||||
|
||||
async def _post_start(self, user: T.User) -> None:
|
||||
"""Hook for subclasses to add work between `start_chat` and serving.
|
||||
|
||||
Default (Client): sync profile only. Bot overrides to also sync its
|
||||
address and embed the connection link in the profile.
|
||||
"""
|
||||
await self._maybe_sync_profile(user, contact_link=None)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Blocking entry: runs serve_forever() with SIGINT/SIGTERM handlers installed.
|
||||
|
||||
Configures `logging.basicConfig(level=INFO)` if the root logger has no
|
||||
handlers yet, so startup messages and the announced address are
|
||||
visible without callers having to set up logging. Embedders that
|
||||
manage logging themselves are unaffected (basicConfig is a no-op when
|
||||
handlers already exist).
|
||||
"""
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
async def _main() -> None:
|
||||
async with self:
|
||||
loop = asyncio.get_running_loop()
|
||||
# First Ctrl+C → graceful stop (~500ms, bounded by the
|
||||
# receive-loop poll interval). Second Ctrl+C → force-exit
|
||||
# immediately (in case stop_chat / close hang on a wedged
|
||||
# FFI call). Standard CLI UX (jupyter, ipython, …).
|
||||
sigint_count = 0
|
||||
|
||||
def on_interrupt() -> None:
|
||||
nonlocal sigint_count
|
||||
sigint_count += 1
|
||||
if sigint_count == 1:
|
||||
log.info("stopping... (press Ctrl+C again to force exit)")
|
||||
self.stop()
|
||||
else:
|
||||
os._exit(130) # 128 + SIGINT
|
||||
|
||||
if hasattr(_signal, "SIGINT"):
|
||||
try:
|
||||
loop.add_signal_handler(_signal.SIGINT, on_interrupt)
|
||||
loop.add_signal_handler(_signal.SIGTERM, self.stop)
|
||||
except NotImplementedError: # Windows
|
||||
_signal.signal(_signal.SIGINT, lambda *_: on_interrupt())
|
||||
await self.serve_forever()
|
||||
|
||||
asyncio.run(_main())
|
||||
|
||||
async def serve_forever(self) -> None:
|
||||
if self._serving:
|
||||
raise RuntimeError("already serving")
|
||||
self._serving = True
|
||||
try:
|
||||
await self._receive_loop()
|
||||
finally:
|
||||
self._serving = False
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
async def events(self) -> AsyncIterator[CEvt.ChatEvent]:
|
||||
"""Yield chat events one at a time — alternative to `serve_forever`.
|
||||
|
||||
Runs the full dispatch pipeline on each event (internal waiters,
|
||||
user `@on_event`/`@on_message`/`@on_command` handlers), then yields
|
||||
the raw event for inspection. Use this when you want direct control
|
||||
over the receive loop, e.g. to surface errors that `serve_forever`
|
||||
would otherwise swallow, or to compose with other async iterators.
|
||||
|
||||
Mutually exclusive with `serve_forever`. Stops when `stop()` is
|
||||
called or when the consumer exits the `async for` loop (which
|
||||
triggers the generator's `aclose`). Async-generator GC alone is
|
||||
not reliable for cleanup — exit the loop explicitly.
|
||||
"""
|
||||
if self._serving:
|
||||
raise RuntimeError(
|
||||
"already serving — events() and serve_forever() are mutually exclusive"
|
||||
)
|
||||
self._serving = True
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
event = await self.api.recv_chat_event(wait_us=500_000)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
if event is None:
|
||||
continue
|
||||
try:
|
||||
await self._dispatch_event(event)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
log.exception("dispatch_event failed for tag=%s", event.get("type"))
|
||||
yield event
|
||||
finally:
|
||||
self._serving = False
|
||||
|
||||
async def connect_to(self, link: str, *, timeout: float = 120.0) -> T.Contact:
|
||||
"""Connect to a SimpleX contact link, returning the resulting Contact.
|
||||
|
||||
Idempotent: if the link is already known (via `api_connect_plan`)
|
||||
the existing Contact is returned without re-handshaking. Otherwise
|
||||
initiates the handshake and waits for the `contactConnected` event.
|
||||
|
||||
Requires the receive loop to be running (`serve_forever` or
|
||||
`events()`), since the handshake completes asynchronously.
|
||||
|
||||
Concurrency caveat: pending `connect_to` waiters are a single FIFO
|
||||
with no link↔waiter correlation. If you call `connect_to` for two
|
||||
different links concurrently, or if a third party connects to your
|
||||
address (Bot subclass with `auto_accept=True`) while a `connect_to`
|
||||
is in flight, the returned Contact may not be the one you asked
|
||||
for. Sequence concurrent connects, or call them one at a time.
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: handshake didn't complete within `timeout`
|
||||
ValueError: timeout is not positive
|
||||
RuntimeError: no active user, or receive loop not running
|
||||
"""
|
||||
if timeout <= 0:
|
||||
# Reject upfront — otherwise wait_for raises TimeoutError after
|
||||
# the handshake side-effect (api_connect_active_user) has
|
||||
# already gone over the wire, leaving the caller with no
|
||||
# Contact reference and a half-initiated connection.
|
||||
raise ValueError(f"timeout must be positive, got {timeout!r}")
|
||||
if not self._serving:
|
||||
raise RuntimeError(
|
||||
"connect_to requires the receive loop to be running — "
|
||||
"call serve_forever() (in a task) or iterate events() first"
|
||||
)
|
||||
api = self.api
|
||||
user = await api.api_get_active_user()
|
||||
if user is None:
|
||||
raise RuntimeError("no active user")
|
||||
|
||||
existing = await self._lookup_known_contact(user["userId"], link)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
waiter: asyncio.Future[T.Contact] = loop.create_future()
|
||||
self._connect_waiters.append(waiter)
|
||||
try:
|
||||
try:
|
||||
await api.api_connect_active_user(link)
|
||||
except ContactAlreadyExistsError:
|
||||
# Handshake mid-flight, or a previous incomplete attempt
|
||||
# left the connection in a known-but-not-connected state.
|
||||
# Either way: wait for the contactConnected event.
|
||||
pass
|
||||
return await asyncio.wait_for(waiter, timeout=timeout)
|
||||
finally:
|
||||
if waiter in self._connect_waiters:
|
||||
self._connect_waiters.remove(waiter)
|
||||
|
||||
async def _lookup_known_contact(self, user_id: int, link: str) -> T.Contact | None:
|
||||
"""Resolve a link to an existing Contact via api_connect_plan, or None.
|
||||
|
||||
Only ChatCommandError is swallowed (malformed link, etc.) — the
|
||||
connect_to caller will fall back to the full handshake path.
|
||||
Transport/FFI errors propagate so the caller sees the real cause.
|
||||
"""
|
||||
try:
|
||||
plan, _ = await self.api.api_connect_plan(user_id, link)
|
||||
except ChatCommandError:
|
||||
return None
|
||||
if plan["type"] == "contactAddress":
|
||||
cap = plan["contactAddressPlan"]
|
||||
if cap["type"] == "known":
|
||||
return cap["contact"]
|
||||
if plan["type"] == "invitationLink":
|
||||
ilp = plan["invitationLinkPlan"]
|
||||
if ilp["type"] == "known":
|
||||
return ilp["contact"]
|
||||
return None
|
||||
|
||||
async def send_and_wait(
|
||||
self,
|
||||
contact_id: int,
|
||||
text: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> "Message[T.MsgContent]":
|
||||
"""Send text to a direct contact and wait for the next reply from them.
|
||||
|
||||
Waiters are FIFO per contact_id: two concurrent calls to the same
|
||||
contact get two replies in send order. Concurrent calls to *different*
|
||||
contacts run in parallel. Once a reply matches a waiter, user
|
||||
message handlers do NOT fire for that message — the awaiter owns it.
|
||||
|
||||
Correlation caveat: matching is by sender contact_id only — there
|
||||
is no message-id correlation. ANY direct message from `contact_id`
|
||||
arriving while a waiter is pending will resolve that waiter, even
|
||||
if it was an unsolicited message (e.g. an auto-reply from a bot's
|
||||
address settings, a delayed reply from a previous send, a push
|
||||
notification). For strict request/response semantics, ensure the
|
||||
peer is otherwise quiet, or use the `@on_message` callback model.
|
||||
|
||||
Requires the receive loop to be running. Raises asyncio.TimeoutError
|
||||
on timeout, ValueError if timeout is not positive.
|
||||
"""
|
||||
if timeout <= 0:
|
||||
# Reject upfront — otherwise wait_for raises TimeoutError after
|
||||
# api_send_text_message already went over the wire, surprising
|
||||
# the caller with a sent message and no Future to await.
|
||||
raise ValueError(f"timeout must be positive, got {timeout!r}")
|
||||
if not self._serving:
|
||||
raise RuntimeError(
|
||||
"send_and_wait requires the receive loop to be running — "
|
||||
"call serve_forever() (in a task) or iterate events() first"
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
waiter: asyncio.Future[Message[Any]] = loop.create_future()
|
||||
waiters = self._reply_waiters.setdefault(contact_id, [])
|
||||
waiters.append(waiter)
|
||||
try:
|
||||
await self.api.api_send_text_message(["direct", contact_id], text)
|
||||
return await asyncio.wait_for(waiter, timeout=timeout)
|
||||
finally:
|
||||
# Always clean up our slot, even on send error or timeout. Leaving
|
||||
# an unresolved Future in the dict would make the next incoming
|
||||
# message resolve a future no one is waiting on.
|
||||
if waiter in waiters:
|
||||
waiters.remove(waiter)
|
||||
if not waiters:
|
||||
self._reply_waiters.pop(contact_id, None)
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
# Catch broad Exception so a single malformed event or transient
|
||||
# native error doesn't crash the whole client. CancelledError must
|
||||
# always re-raise so `stop()` and asyncio cancellation work.
|
||||
# `wait_us=500_000` (500ms) bounds the worst-case Ctrl+C latency:
|
||||
# the C call blocks the worker thread until timeout, and the loop
|
||||
# only checks `_stop_event` between polls.
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
event = await self.api.recv_chat_event(wait_us=500_000)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ChatAPIError as e:
|
||||
# Async chat errors emitted via the Haskell `eToView` path —
|
||||
# routine soft errors (stale connections after a peer deletes
|
||||
# a chat, file cleanup failures, etc.) intermixed with
|
||||
# CRITICAL agent failures the operator must see. Mirror the
|
||||
# desktop policy in SimpleXAPI.kt:3332-3340: escalate
|
||||
# CRITICAL agent errors, keep everything else at debug.
|
||||
chat_err: Any = e.chat_error or {}
|
||||
agent_err: Any = (
|
||||
chat_err.get("agentError", {}) if chat_err.get("type") == "errorAgent" else {}
|
||||
)
|
||||
if agent_err.get("type") == "CRITICAL":
|
||||
log.error(
|
||||
"chat agent CRITICAL: %s (offerRestart=%s)",
|
||||
agent_err.get("criticalErr"),
|
||||
agent_err.get("offerRestart"),
|
||||
)
|
||||
else:
|
||||
log.debug("chat event error: %s", chat_err.get("type"))
|
||||
continue
|
||||
except Exception:
|
||||
log.exception("recv_chat_event failed")
|
||||
# Bound the spin rate when the FFI is wedged on a persistent
|
||||
# error (vs the timeout path, which already paces itself).
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
if event is None:
|
||||
continue
|
||||
try:
|
||||
await self._dispatch_event(event)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
log.exception("dispatch_event failed for tag=%s", event.get("type"))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Dispatch
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _dispatch_event(self, event: CEvt.ChatEvent) -> None:
|
||||
tag = event["type"]
|
||||
# Resolve internal waiters BEFORE user handlers. A pending
|
||||
# `connect_to` consumes the contactConnected; a pending
|
||||
# `send_and_wait` consumes the matching incoming message — user
|
||||
# handlers don't fire for that message. This matches the mental
|
||||
# model: the awaiter explicitly asked for this event.
|
||||
if tag == "contactConnected" and self._connect_waiters:
|
||||
contact: T.Contact = event["contact"] # type: ignore[typeddict-item]
|
||||
waiter = self._connect_waiters.pop(0)
|
||||
if not waiter.done():
|
||||
waiter.set_result(contact)
|
||||
for h in self._event_handlers.get(tag, []):
|
||||
try:
|
||||
await h(event)
|
||||
except Exception:
|
||||
log.exception("on_event handler failed")
|
||||
if tag == "newChatItems":
|
||||
evt: CEvt.NewChatItems = event # type: ignore[assignment]
|
||||
for ci in evt["chatItems"]:
|
||||
content = ci["chatItem"]["content"]
|
||||
if content["type"] != "rcvMsgContent":
|
||||
continue
|
||||
msg_content = content["msgContent"] # type: ignore[index]
|
||||
msg: Message[T.MsgContent] = Message(chat_item=ci, content=msg_content, client=self)
|
||||
# If a send_and_wait is pending for this sender, fulfil it
|
||||
# and skip the user dispatch chain — the awaiter "owns" this
|
||||
# reply. FIFO within a contact_id.
|
||||
if self._maybe_resolve_reply_waiter(msg):
|
||||
continue
|
||||
await self._dispatch_message(msg)
|
||||
|
||||
def _maybe_resolve_reply_waiter(self, msg: Message[T.MsgContent]) -> bool:
|
||||
chat_info = msg.chat_info
|
||||
if chat_info.get("type") != "direct":
|
||||
return False
|
||||
contact_id = chat_info.get("contact", {}).get("contactId") # type: ignore[union-attr]
|
||||
if contact_id is None:
|
||||
return False
|
||||
waiters = self._reply_waiters.get(contact_id)
|
||||
if not waiters:
|
||||
return False
|
||||
# Skip waiters whose callers have already given up (cancelled by
|
||||
# wait_for timing out at the same loop tick). Without this skip,
|
||||
# a reply arriving in the narrow timeout-race window would be
|
||||
# silently dropped because the FIFO would pop a done waiter and
|
||||
# neither resolve it nor dispatch to user handlers.
|
||||
while waiters:
|
||||
waiter = waiters.pop(0)
|
||||
if not waiter.done():
|
||||
if not waiters:
|
||||
self._reply_waiters.pop(contact_id, None)
|
||||
waiter.set_result(msg)
|
||||
return True
|
||||
self._reply_waiters.pop(contact_id, None)
|
||||
return False
|
||||
|
||||
async def _dispatch_message(self, msg: Message[Any]) -> None:
|
||||
# First-match-wins. The squaring bot's `@on_message(text=NUMBER_RE)`
|
||||
# and catch-all `@on_message(content_type="text")` both match a number
|
||||
# like "1"; we want only the first to fire. Registration order is the
|
||||
# priority order — register the most-specific filters first.
|
||||
#
|
||||
# Slash-commands are tried first against command handlers; if no
|
||||
# command handler matches, fall through to message handlers (so
|
||||
# `@on_message` can still catch unknown slash-commands).
|
||||
cmd = self._parse_command(msg)
|
||||
if cmd is not None:
|
||||
for names, predicate, handler in self._command_handlers:
|
||||
if cmd.keyword in names and predicate(msg):
|
||||
await self._invoke_command_with_middleware(handler, msg, cmd)
|
||||
return
|
||||
for predicate, handler in self._message_handlers:
|
||||
if predicate(msg):
|
||||
await self._invoke_with_middleware(handler, msg)
|
||||
return
|
||||
|
||||
async def _invoke_with_middleware(self, handler: MessageHandler, message: Message[Any]) -> None:
|
||||
# Fast path: most clients register no middleware. Skip the closure-chain
|
||||
# construction and the empty-data dict on every dispatch.
|
||||
if not self._middleware:
|
||||
try:
|
||||
await handler(message)
|
||||
except Exception:
|
||||
log.exception("message handler failed")
|
||||
return
|
||||
|
||||
async def call(m: Message[Any], _data: dict[str, object]) -> None:
|
||||
await handler(m)
|
||||
|
||||
chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call
|
||||
for mw in reversed(self._middleware):
|
||||
inner = chain
|
||||
|
||||
async def _wrapped(
|
||||
m: Message[Any],
|
||||
d: dict[str, object],
|
||||
mw: Middleware = mw,
|
||||
inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner,
|
||||
) -> None:
|
||||
await mw(inner, m, d)
|
||||
|
||||
chain = _wrapped
|
||||
|
||||
try:
|
||||
await chain(message, {})
|
||||
except Exception:
|
||||
log.exception("message handler failed")
|
||||
|
||||
async def _invoke_command_with_middleware(
|
||||
self, handler: CommandHandler, message: Message[Any], cmd: ParsedCommand
|
||||
) -> None:
|
||||
if not self._middleware:
|
||||
try:
|
||||
await handler(message, cmd)
|
||||
except Exception:
|
||||
log.exception("command handler failed")
|
||||
return
|
||||
|
||||
async def call(m: Message[Any], _data: dict[str, object]) -> None:
|
||||
await handler(m, cmd)
|
||||
|
||||
chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call
|
||||
for mw in reversed(self._middleware):
|
||||
inner = chain
|
||||
|
||||
async def _wrapped(
|
||||
m: Message[Any],
|
||||
d: dict[str, object],
|
||||
mw: Middleware = mw,
|
||||
inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner,
|
||||
) -> None:
|
||||
await mw(inner, m, d)
|
||||
|
||||
chain = _wrapped
|
||||
|
||||
try:
|
||||
await chain(message, {})
|
||||
except Exception:
|
||||
log.exception("command handler failed")
|
||||
|
||||
@staticmethod
|
||||
def _parse_command(msg: Message[Any]) -> ParsedCommand | None:
|
||||
parsed = util.ci_bot_command(msg.chat_item["chatItem"])
|
||||
if parsed is None:
|
||||
return None
|
||||
keyword, args = parsed
|
||||
return ParsedCommand(keyword=keyword, args=args)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Profile sync
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _ensure_active_user(self) -> T.User:
|
||||
"""Get or create the active user. Must run before `start_chat`.
|
||||
|
||||
Mirrors Node `createBotUser` (bot.ts:158-166). The chat controller
|
||||
won't accept `/_start` without a user, so this phase has to land
|
||||
before lifecycle proceeds.
|
||||
"""
|
||||
api = self.api
|
||||
user = await api.api_get_active_user()
|
||||
if user is None:
|
||||
log.info("No active user in database, creating...")
|
||||
user = await api.api_create_active_user(self._profile_to_wire())
|
||||
log.info("user: %s", user["profile"]["displayName"])
|
||||
return user
|
||||
|
||||
async def _maybe_sync_profile(self, user: T.User, *, contact_link: str | None) -> None:
|
||||
"""Update the user profile on the wire if its fields changed.
|
||||
|
||||
`contact_link` is only set by Bot (to embed its address). Mirrors
|
||||
Node `updateBotUserProfile` (bot.ts:199-214). Field-by-field
|
||||
comparison because user["profile"] is LocalProfile (has extra
|
||||
fields profileId, localAlias, preferences, peerType) so a full
|
||||
dict equality would always differ.
|
||||
"""
|
||||
if not self._update_profile:
|
||||
return
|
||||
new_profile = self._profile_to_wire()
|
||||
if contact_link is not None:
|
||||
new_profile["contactLink"] = contact_link
|
||||
cur = user["profile"]
|
||||
changed = (
|
||||
cur["displayName"] != new_profile["displayName"]
|
||||
or cur.get("fullName", "") != new_profile.get("fullName", "")
|
||||
or cur.get("shortDescr") != new_profile.get("shortDescr")
|
||||
or cur.get("image") != new_profile.get("image")
|
||||
or cur.get("preferences") != new_profile.get("preferences")
|
||||
or cur.get("peerType") != new_profile.get("peerType")
|
||||
or cur.get("contactLink") != new_profile.get("contactLink")
|
||||
)
|
||||
if changed:
|
||||
log.info("profile changed, updating...")
|
||||
await self.api.api_update_profile(user["userId"], new_profile)
|
||||
|
||||
def _profile_to_wire(self) -> T.Profile:
|
||||
"""Convert the user-facing Profile dataclass to wire format.
|
||||
|
||||
Base version produces a plain user profile. Bot overrides this to
|
||||
add the bot-specific extensions (peerType=bot, command list,
|
||||
calls/voice/files prefs).
|
||||
"""
|
||||
p: T.Profile = {
|
||||
"displayName": self._profile.display_name,
|
||||
"fullName": self._profile.full_name,
|
||||
}
|
||||
if self._profile.short_descr is not None:
|
||||
p["shortDescr"] = self._profile.short_descr
|
||||
if self._profile.image is not None:
|
||||
p["image"] = self._profile.image
|
||||
return p
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Log subscriptions (mirror Node subscribeLogEvents bot.ts:142-156)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _register_log_handlers(self) -> None:
|
||||
# Idempotent: re-entering the async context must not stack duplicate
|
||||
# log handlers. Always-on error handlers run regardless of
|
||||
# log_contacts/log_network so messageError/chatError/chatErrors
|
||||
# don't disappear into the void.
|
||||
if self._defaults_registered:
|
||||
return
|
||||
self._defaults_registered = True
|
||||
self._event_handlers.setdefault("messageError", []).append(self._log_message_error)
|
||||
self._event_handlers.setdefault("chatError", []).append(self._log_chat_error)
|
||||
self._event_handlers.setdefault("chatErrors", []).append(self._log_chat_errors)
|
||||
if self._log_contacts:
|
||||
self._event_handlers.setdefault("contactConnected", []).append(
|
||||
self._log_contact_connected
|
||||
)
|
||||
self._event_handlers.setdefault("contactDeletedByContact", []).append(
|
||||
self._log_contact_deleted
|
||||
)
|
||||
if self._log_network:
|
||||
self._event_handlers.setdefault("hostConnected", []).append(self._log_host_connected)
|
||||
self._event_handlers.setdefault("hostDisconnected", []).append(
|
||||
self._log_host_disconnected
|
||||
)
|
||||
self._event_handlers.setdefault("subscriptionStatus", []).append(
|
||||
self._log_subscription_status
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _log_contact_connected(evt: CEvt.ChatEvent) -> None:
|
||||
log.info("%s connected", evt["contact"]["profile"]["displayName"]) # type: ignore[index]
|
||||
|
||||
@staticmethod
|
||||
async def _log_contact_deleted(evt: CEvt.ChatEvent) -> None:
|
||||
log.info(
|
||||
"%s deleted connection",
|
||||
evt["contact"]["profile"]["displayName"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _log_host_connected(evt: CEvt.ChatEvent) -> None:
|
||||
log.info("connected server %s", evt["transportHost"]) # type: ignore[index]
|
||||
|
||||
@staticmethod
|
||||
async def _log_host_disconnected(evt: CEvt.ChatEvent) -> None:
|
||||
log.info("disconnected server %s", evt["transportHost"]) # type: ignore[index]
|
||||
|
||||
@staticmethod
|
||||
async def _log_subscription_status(evt: CEvt.ChatEvent) -> None:
|
||||
log.info(
|
||||
"%d subscription(s) %s",
|
||||
len(evt["connections"]), # type: ignore[index]
|
||||
evt["subscriptionStatus"]["type"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _log_message_error(evt: CEvt.ChatEvent) -> None:
|
||||
log.warning("messageError: %s", evt.get("severity", "?")) # type: ignore[union-attr]
|
||||
|
||||
@staticmethod
|
||||
async def _log_chat_error(evt: CEvt.ChatEvent) -> None:
|
||||
err = evt.get("chatError") # type: ignore[union-attr]
|
||||
log.error("chatError: %s", err.get("type") if isinstance(err, dict) else err)
|
||||
|
||||
@staticmethod
|
||||
async def _log_chat_errors(evt: CEvt.ChatEvent) -> None:
|
||||
errs = evt.get("chatErrors") or [] # type: ignore[union-attr]
|
||||
log.error("chatErrors: %d errors", len(errs))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BotProfile", # backwards-compat alias for Profile
|
||||
"ChatMessage",
|
||||
"Client",
|
||||
"CommandHandler",
|
||||
"EventHandler",
|
||||
"FileMessage",
|
||||
"ImageMessage",
|
||||
"LinkMessage",
|
||||
"Message",
|
||||
"MessageHandler",
|
||||
"Middleware",
|
||||
"ParsedCommand",
|
||||
"Profile",
|
||||
"ReportMessage",
|
||||
"TextMessage",
|
||||
"UnknownMessage",
|
||||
"VideoMessage",
|
||||
"VoiceMessage",
|
||||
]
|
||||
@@ -37,6 +37,15 @@ def compile_message_filter(kw: dict[str, Any]) -> Callable[[Any], bool]:
|
||||
|
||||
predicates.append(gid_match)
|
||||
|
||||
if (cid := kw.get("contact_id")) is not None:
|
||||
cid_set: tuple[int, ...] = (cid,) if isinstance(cid, int) else tuple(cid)
|
||||
|
||||
def cid_match(m: Any) -> bool:
|
||||
ci = m.chat_item["chatInfo"]
|
||||
return ci["type"] == "direct" and ci["contact"]["contactId"] in cid_set
|
||||
|
||||
predicates.append(cid_match)
|
||||
|
||||
if (when := kw.get("when")) is not None:
|
||||
predicates.append(when)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# API Events
|
||||
# This file is generated automatically.
|
||||
from __future__ import annotations
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Literal, NotRequired, Protocol, TypedDict, overload
|
||||
from . import _types as T
|
||||
|
||||
class ContactConnected(TypedDict):
|
||||
@@ -377,3 +378,318 @@ ChatEvent = (
|
||||
)
|
||||
|
||||
ChatEvent_Tag = Literal["contactConnected", "contactUpdated", "contactDeletedByContact", "receivedContactRequest", "newMemberContactReceivedInv", "contactSndReady", "newChatItems", "chatItemReaction", "chatItemsDeleted", "chatItemUpdated", "groupChatItemsDeleted", "chatItemsStatusesUpdated", "receivedGroupInvitation", "userJoinedGroup", "groupUpdated", "joinedGroupMember", "memberRole", "deletedMember", "leftMember", "deletedMemberUser", "groupDeleted", "connectedToGroupMember", "memberAcceptedByOther", "memberBlockedForAll", "groupMemberUpdated", "groupLinkDataUpdated", "groupRelayUpdated", "rcvFileDescrReady", "rcvFileComplete", "sndFileCompleteXFTP", "rcvFileStart", "rcvFileSndCancelled", "rcvFileAccepted", "rcvFileError", "rcvFileWarning", "sndFileError", "sndFileWarning", "acceptingContactRequest", "acceptingBusinessRequest", "contactConnecting", "businessLinkConnecting", "joinedGroupMemberConnecting", "sentGroupInvitation", "groupLinkConnecting", "hostConnected", "hostDisconnected", "subscriptionStatus", "messageError", "chatError", "chatErrors"]
|
||||
|
||||
|
||||
class OnEventDecorator(Protocol):
|
||||
"""Per-tag narrowing protocol for ``Client.on_event``.
|
||||
|
||||
``@client.on_event("contactConnected")`` types the handler's
|
||||
``evt`` parameter as :class:`ContactConnected` rather than the
|
||||
unnarrowed :data:`ChatEvent` union.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["contactConnected"], /) -> Callable[
|
||||
[Callable[["ContactConnected"], Awaitable[None]]],
|
||||
Callable[["ContactConnected"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["contactUpdated"], /) -> Callable[
|
||||
[Callable[["ContactUpdated"], Awaitable[None]]],
|
||||
Callable[["ContactUpdated"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["contactDeletedByContact"], /) -> Callable[
|
||||
[Callable[["ContactDeletedByContact"], Awaitable[None]]],
|
||||
Callable[["ContactDeletedByContact"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["receivedContactRequest"], /) -> Callable[
|
||||
[Callable[["ReceivedContactRequest"], Awaitable[None]]],
|
||||
Callable[["ReceivedContactRequest"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["newMemberContactReceivedInv"], /) -> Callable[
|
||||
[Callable[["NewMemberContactReceivedInv"], Awaitable[None]]],
|
||||
Callable[["NewMemberContactReceivedInv"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["contactSndReady"], /) -> Callable[
|
||||
[Callable[["ContactSndReady"], Awaitable[None]]],
|
||||
Callable[["ContactSndReady"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["newChatItems"], /) -> Callable[
|
||||
[Callable[["NewChatItems"], Awaitable[None]]],
|
||||
Callable[["NewChatItems"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["chatItemReaction"], /) -> Callable[
|
||||
[Callable[["ChatItemReaction"], Awaitable[None]]],
|
||||
Callable[["ChatItemReaction"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["chatItemsDeleted"], /) -> Callable[
|
||||
[Callable[["ChatItemsDeleted"], Awaitable[None]]],
|
||||
Callable[["ChatItemsDeleted"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["chatItemUpdated"], /) -> Callable[
|
||||
[Callable[["ChatItemUpdated"], Awaitable[None]]],
|
||||
Callable[["ChatItemUpdated"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["groupChatItemsDeleted"], /) -> Callable[
|
||||
[Callable[["GroupChatItemsDeleted"], Awaitable[None]]],
|
||||
Callable[["GroupChatItemsDeleted"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["chatItemsStatusesUpdated"], /) -> Callable[
|
||||
[Callable[["ChatItemsStatusesUpdated"], Awaitable[None]]],
|
||||
Callable[["ChatItemsStatusesUpdated"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["receivedGroupInvitation"], /) -> Callable[
|
||||
[Callable[["ReceivedGroupInvitation"], Awaitable[None]]],
|
||||
Callable[["ReceivedGroupInvitation"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["userJoinedGroup"], /) -> Callable[
|
||||
[Callable[["UserJoinedGroup"], Awaitable[None]]],
|
||||
Callable[["UserJoinedGroup"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["groupUpdated"], /) -> Callable[
|
||||
[Callable[["GroupUpdated"], Awaitable[None]]],
|
||||
Callable[["GroupUpdated"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["joinedGroupMember"], /) -> Callable[
|
||||
[Callable[["JoinedGroupMember"], Awaitable[None]]],
|
||||
Callable[["JoinedGroupMember"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["memberRole"], /) -> Callable[
|
||||
[Callable[["MemberRole"], Awaitable[None]]],
|
||||
Callable[["MemberRole"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["deletedMember"], /) -> Callable[
|
||||
[Callable[["DeletedMember"], Awaitable[None]]],
|
||||
Callable[["DeletedMember"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["leftMember"], /) -> Callable[
|
||||
[Callable[["LeftMember"], Awaitable[None]]],
|
||||
Callable[["LeftMember"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["deletedMemberUser"], /) -> Callable[
|
||||
[Callable[["DeletedMemberUser"], Awaitable[None]]],
|
||||
Callable[["DeletedMemberUser"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["groupDeleted"], /) -> Callable[
|
||||
[Callable[["GroupDeleted"], Awaitable[None]]],
|
||||
Callable[["GroupDeleted"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["connectedToGroupMember"], /) -> Callable[
|
||||
[Callable[["ConnectedToGroupMember"], Awaitable[None]]],
|
||||
Callable[["ConnectedToGroupMember"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["memberAcceptedByOther"], /) -> Callable[
|
||||
[Callable[["MemberAcceptedByOther"], Awaitable[None]]],
|
||||
Callable[["MemberAcceptedByOther"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["memberBlockedForAll"], /) -> Callable[
|
||||
[Callable[["MemberBlockedForAll"], Awaitable[None]]],
|
||||
Callable[["MemberBlockedForAll"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["groupMemberUpdated"], /) -> Callable[
|
||||
[Callable[["GroupMemberUpdated"], Awaitable[None]]],
|
||||
Callable[["GroupMemberUpdated"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["groupLinkDataUpdated"], /) -> Callable[
|
||||
[Callable[["GroupLinkDataUpdated"], Awaitable[None]]],
|
||||
Callable[["GroupLinkDataUpdated"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["groupRelayUpdated"], /) -> Callable[
|
||||
[Callable[["GroupRelayUpdated"], Awaitable[None]]],
|
||||
Callable[["GroupRelayUpdated"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["rcvFileDescrReady"], /) -> Callable[
|
||||
[Callable[["RcvFileDescrReady"], Awaitable[None]]],
|
||||
Callable[["RcvFileDescrReady"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["rcvFileComplete"], /) -> Callable[
|
||||
[Callable[["RcvFileComplete"], Awaitable[None]]],
|
||||
Callable[["RcvFileComplete"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["sndFileCompleteXFTP"], /) -> Callable[
|
||||
[Callable[["SndFileCompleteXFTP"], Awaitable[None]]],
|
||||
Callable[["SndFileCompleteXFTP"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["rcvFileStart"], /) -> Callable[
|
||||
[Callable[["RcvFileStart"], Awaitable[None]]],
|
||||
Callable[["RcvFileStart"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["rcvFileSndCancelled"], /) -> Callable[
|
||||
[Callable[["RcvFileSndCancelled"], Awaitable[None]]],
|
||||
Callable[["RcvFileSndCancelled"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["rcvFileAccepted"], /) -> Callable[
|
||||
[Callable[["RcvFileAccepted"], Awaitable[None]]],
|
||||
Callable[["RcvFileAccepted"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["rcvFileError"], /) -> Callable[
|
||||
[Callable[["RcvFileError"], Awaitable[None]]],
|
||||
Callable[["RcvFileError"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["rcvFileWarning"], /) -> Callable[
|
||||
[Callable[["RcvFileWarning"], Awaitable[None]]],
|
||||
Callable[["RcvFileWarning"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["sndFileError"], /) -> Callable[
|
||||
[Callable[["SndFileError"], Awaitable[None]]],
|
||||
Callable[["SndFileError"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["sndFileWarning"], /) -> Callable[
|
||||
[Callable[["SndFileWarning"], Awaitable[None]]],
|
||||
Callable[["SndFileWarning"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["acceptingContactRequest"], /) -> Callable[
|
||||
[Callable[["AcceptingContactRequest"], Awaitable[None]]],
|
||||
Callable[["AcceptingContactRequest"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["acceptingBusinessRequest"], /) -> Callable[
|
||||
[Callable[["AcceptingBusinessRequest"], Awaitable[None]]],
|
||||
Callable[["AcceptingBusinessRequest"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["contactConnecting"], /) -> Callable[
|
||||
[Callable[["ContactConnecting"], Awaitable[None]]],
|
||||
Callable[["ContactConnecting"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["businessLinkConnecting"], /) -> Callable[
|
||||
[Callable[["BusinessLinkConnecting"], Awaitable[None]]],
|
||||
Callable[["BusinessLinkConnecting"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["joinedGroupMemberConnecting"], /) -> Callable[
|
||||
[Callable[["JoinedGroupMemberConnecting"], Awaitable[None]]],
|
||||
Callable[["JoinedGroupMemberConnecting"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["sentGroupInvitation"], /) -> Callable[
|
||||
[Callable[["SentGroupInvitation"], Awaitable[None]]],
|
||||
Callable[["SentGroupInvitation"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["groupLinkConnecting"], /) -> Callable[
|
||||
[Callable[["GroupLinkConnecting"], Awaitable[None]]],
|
||||
Callable[["GroupLinkConnecting"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["hostConnected"], /) -> Callable[
|
||||
[Callable[["HostConnected"], Awaitable[None]]],
|
||||
Callable[["HostConnected"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["hostDisconnected"], /) -> Callable[
|
||||
[Callable[["HostDisconnected"], Awaitable[None]]],
|
||||
Callable[["HostDisconnected"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["subscriptionStatus"], /) -> Callable[
|
||||
[Callable[["SubscriptionStatus"], Awaitable[None]]],
|
||||
Callable[["SubscriptionStatus"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["messageError"], /) -> Callable[
|
||||
[Callable[["MessageError"], Awaitable[None]]],
|
||||
Callable[["MessageError"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["chatError"], /) -> Callable[
|
||||
[Callable[["ChatError"], Awaitable[None]]],
|
||||
Callable[["ChatError"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: Literal["chatErrors"], /) -> Callable[
|
||||
[Callable[["ChatErrors"], Awaitable[None]]],
|
||||
Callable[["ChatErrors"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __call__(self, event: str, /) -> Callable[
|
||||
[Callable[["ChatEvent"], Awaitable[None]]],
|
||||
Callable[["ChatEvent"], Awaitable[None]],
|
||||
]: ...
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from simplex_chat import Bot, BotCommand, BotProfile, Middleware, SqliteDb
|
||||
from simplex_chat import Bot, BotCommand, BotProfile, Client, Middleware, Profile, SqliteDb
|
||||
from simplex_chat.api import ChatApi
|
||||
|
||||
|
||||
@@ -57,9 +57,9 @@ def test_command_keyword_tuple():
|
||||
|
||||
|
||||
def test_bot_profile_to_wire_default():
|
||||
"""use_bot_profile=True (default) sets peerType=bot and disables calls/voice."""
|
||||
"""Bot's profile wire-form sets peerType=bot and disables calls/voice."""
|
||||
bot = _bot()
|
||||
p = bot._bot_profile_to_wire()
|
||||
p = bot._profile_to_wire()
|
||||
assert p["displayName"] == "x"
|
||||
assert p.get("peerType") == "bot"
|
||||
prefs = p.get("preferences") or {}
|
||||
@@ -74,7 +74,7 @@ def test_bot_profile_to_wire_allow_files():
|
||||
db=SqliteDb(file_prefix="/tmp/test"),
|
||||
allow_files=True,
|
||||
)
|
||||
prefs = bot._bot_profile_to_wire().get("preferences") or {}
|
||||
prefs = bot._profile_to_wire().get("preferences") or {}
|
||||
assert prefs.get("files", {}).get("allow") == "yes"
|
||||
|
||||
|
||||
@@ -84,32 +84,26 @@ def test_bot_profile_to_wire_with_commands():
|
||||
db=SqliteDb(file_prefix="/tmp/test"),
|
||||
commands=[BotCommand(keyword="ping", label="Ping bot"), BotCommand("help", "Show help")],
|
||||
)
|
||||
cmds = bot._bot_profile_to_wire().get("preferences", {}).get("commands") or []
|
||||
cmds = bot._profile_to_wire().get("preferences", {}).get("commands") or []
|
||||
assert len(cmds) == 2
|
||||
assert cmds[0] == {"type": "command", "keyword": "ping", "label": "Ping bot"}
|
||||
assert cmds[1] == {"type": "command", "keyword": "help", "label": "Show help"}
|
||||
|
||||
|
||||
def test_bot_profile_to_wire_no_bot_profile():
|
||||
bot = Bot(
|
||||
profile=BotProfile(display_name="x"),
|
||||
db=SqliteDb(file_prefix="/tmp/test"),
|
||||
use_bot_profile=False,
|
||||
)
|
||||
p = bot._bot_profile_to_wire()
|
||||
def test_client_profile_to_wire_has_no_bot_extras():
|
||||
"""Client's wire profile has no peerType=bot, no command list, no calls/voice prefs.
|
||||
That's the whole point of having Client as a separate class."""
|
||||
c = Client(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
p = c._profile_to_wire()
|
||||
assert p["displayName"] == "x"
|
||||
assert "peerType" not in p
|
||||
assert "preferences" not in p
|
||||
|
||||
|
||||
def test_commands_without_bot_profile_raises():
|
||||
bot = Bot(
|
||||
profile=BotProfile(display_name="x"),
|
||||
db=SqliteDb(file_prefix="/tmp/test"),
|
||||
use_bot_profile=False,
|
||||
commands=[BotCommand("ping", "Ping bot")],
|
||||
)
|
||||
with pytest.raises(ValueError, match="use_bot_profile=False"):
|
||||
bot._bot_profile_to_wire()
|
||||
def test_bot_profile_alias_is_profile():
|
||||
"""`BotProfile` is kept as an alias for backwards compatibility."""
|
||||
assert BotProfile is Profile
|
||||
assert BotProfile(display_name="x") == Profile(display_name="x")
|
||||
|
||||
|
||||
def test_dispatch_message_first_match_wins():
|
||||
|
||||
@@ -0,0 +1,616 @@
|
||||
"""Tests for Client class + connect_to / send_and_wait / events plumbing.
|
||||
|
||||
Stubs out ChatApi so we exercise the dispatch and waiter logic without
|
||||
spinning up the native libsimplex controller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from simplex_chat import (
|
||||
Bot,
|
||||
BotProfile,
|
||||
Client,
|
||||
ContactAlreadyExistsError,
|
||||
Profile,
|
||||
SqliteDb,
|
||||
)
|
||||
|
||||
|
||||
class FakeApi:
|
||||
"""Drop-in replacement for ChatApi for tests that don't need the FFI.
|
||||
|
||||
Records api_send_text_message calls; supports scripting api_connect_plan
|
||||
and api_connect_active_user behaviour.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[tuple[Any, str]] = []
|
||||
self.connect_plan_result: Any = ("error", None) # default: no known contact
|
||||
self.connect_should_raise: Exception | None = None
|
||||
self.active_user: dict[str, Any] = {"userId": 1, "profile": {"displayName": "x"}}
|
||||
|
||||
async def api_send_text_message(self, chat, text, in_reply_to=None):
|
||||
self.sent.append((chat, text))
|
||||
return []
|
||||
|
||||
async def api_connect_plan(self, _user_id, _link):
|
||||
kind = self.connect_plan_result[0]
|
||||
if kind == "known_contact_address":
|
||||
return (
|
||||
{
|
||||
"type": "contactAddress",
|
||||
"contactAddressPlan": {"type": "known", "contact": self.connect_plan_result[1]},
|
||||
},
|
||||
{},
|
||||
)
|
||||
if kind == "known_invitation":
|
||||
return (
|
||||
{
|
||||
"type": "invitationLink",
|
||||
"invitationLinkPlan": {"type": "known", "contact": self.connect_plan_result[1]},
|
||||
},
|
||||
{},
|
||||
)
|
||||
if kind == "ok":
|
||||
return (
|
||||
{
|
||||
"type": "contactAddress",
|
||||
"contactAddressPlan": {"type": "ok"},
|
||||
},
|
||||
{},
|
||||
)
|
||||
# default "error"
|
||||
return ({"type": "error", "chatError": {}}, {})
|
||||
|
||||
async def api_connect_active_user(self, _link):
|
||||
if self.connect_should_raise is not None:
|
||||
raise self.connect_should_raise
|
||||
return "contact"
|
||||
|
||||
async def api_get_active_user(self):
|
||||
return self.active_user
|
||||
|
||||
|
||||
def _bot_with_fake_api() -> tuple[Bot, FakeApi]:
|
||||
bot = Bot(profile=BotProfile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
api = FakeApi()
|
||||
bot._api = api # type: ignore[assignment]
|
||||
bot._serving = True # pretend receive loop is up
|
||||
return bot, api
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_client_has_no_address_or_bot_profile_attributes():
|
||||
"""Client should not carry bot-side state (address creation, auto-accept,
|
||||
welcome, commands). That's the whole point of separating Client from Bot."""
|
||||
c = Client(profile=Profile(display_name="monitor"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
for attr in ("_create_address", "_update_address", "_auto_accept", "_welcome", "_commands"):
|
||||
assert not hasattr(c, attr), f"Client unexpectedly has Bot-only attribute {attr}"
|
||||
# And the wire profile has no bot peerType
|
||||
p = c._profile_to_wire()
|
||||
assert "peerType" not in p
|
||||
assert "preferences" not in p
|
||||
|
||||
|
||||
def test_bot_is_a_client_subclass():
|
||||
"""Bot should extend Client, so anywhere a Client is accepted, a Bot fits too."""
|
||||
assert issubclass(Bot, Client)
|
||||
|
||||
|
||||
def test_client_exposes_messaging_methods():
|
||||
c = Client(profile=Profile(display_name="m"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
assert hasattr(c, "connect_to")
|
||||
assert hasattr(c, "send_and_wait")
|
||||
assert hasattr(c, "events")
|
||||
assert hasattr(c, "on_message") # decorators available on Client too
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_and_wait
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_send_and_wait_requires_serving():
|
||||
"""Without the receive loop running, send_and_wait must raise — otherwise
|
||||
callers would silently hang waiting for a reply that's never dispatched."""
|
||||
bot = Bot(profile=BotProfile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
bot._api = FakeApi() # type: ignore[assignment]
|
||||
# _serving is False by default
|
||||
with pytest.raises(RuntimeError, match="receive loop"):
|
||||
asyncio.run(bot.send_and_wait(1, "hi"))
|
||||
|
||||
|
||||
def test_send_and_wait_resolves_on_matching_reply():
|
||||
"""A reply from the awaited contact should resolve the Future and skip
|
||||
regular message dispatch."""
|
||||
bot, api = _bot_with_fake_api()
|
||||
fallback_calls: list[str] = []
|
||||
|
||||
@bot.on_message(content_type="text")
|
||||
async def fallback(_msg):
|
||||
fallback_calls.append("fallback")
|
||||
|
||||
async def go() -> str:
|
||||
send_task = asyncio.create_task(bot.send_and_wait(42, "ping", timeout=2.0))
|
||||
# Yield so the task gets to register its waiter.
|
||||
await asyncio.sleep(0)
|
||||
evt = {"type": "newChatItems", "chatItems": [
|
||||
{
|
||||
"chatInfo": {"type": "direct", "contact": {"contactId": 42}},
|
||||
"chatItem": {
|
||||
"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "pong"}},
|
||||
},
|
||||
}
|
||||
]}
|
||||
await bot._dispatch_event(evt) # type: ignore[arg-type]
|
||||
reply = await send_task
|
||||
return reply.text or ""
|
||||
|
||||
result = asyncio.run(go())
|
||||
assert result == "pong"
|
||||
assert api.sent == [(["direct", 42], "ping")]
|
||||
assert fallback_calls == [], "fallback handler should NOT fire when a waiter consumed the reply"
|
||||
|
||||
|
||||
def test_send_and_wait_ignores_other_contacts():
|
||||
"""Replies from a different contact must not resolve the waiter — that
|
||||
would mis-correlate responses and is the bug send_and_wait exists to
|
||||
prevent users from writing themselves."""
|
||||
bot, _api = _bot_with_fake_api()
|
||||
|
||||
async def go():
|
||||
send_task = asyncio.create_task(bot.send_and_wait(42, "ping", timeout=0.5))
|
||||
await asyncio.sleep(0)
|
||||
evt = {"type": "newChatItems", "chatItems": [
|
||||
{
|
||||
"chatInfo": {"type": "direct", "contact": {"contactId": 99}},
|
||||
"chatItem": {
|
||||
"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "not for you"}},
|
||||
},
|
||||
}
|
||||
]}
|
||||
await bot._dispatch_event(evt) # type: ignore[arg-type]
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await send_task
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
def test_send_and_wait_fifo_within_contact():
|
||||
"""Two concurrent waiters on the same contact should resolve in send order."""
|
||||
bot, _api = _bot_with_fake_api()
|
||||
|
||||
async def go() -> tuple[str, str]:
|
||||
first = asyncio.create_task(bot.send_and_wait(42, "first", timeout=2.0))
|
||||
await asyncio.sleep(0)
|
||||
second = asyncio.create_task(bot.send_and_wait(42, "second", timeout=2.0))
|
||||
await asyncio.sleep(0)
|
||||
for text in ("reply1", "reply2"):
|
||||
evt = {"type": "newChatItems", "chatItems": [
|
||||
{
|
||||
"chatInfo": {"type": "direct", "contact": {"contactId": 42}},
|
||||
"chatItem": {
|
||||
"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": text}},
|
||||
},
|
||||
}
|
||||
]}
|
||||
await bot._dispatch_event(evt) # type: ignore[arg-type]
|
||||
return (await first).text or "", (await second).text or ""
|
||||
|
||||
a, b = asyncio.run(go())
|
||||
assert (a, b) == ("reply1", "reply2")
|
||||
|
||||
|
||||
def test_send_and_wait_cleans_up_state_on_timeout():
|
||||
"""Timed-out waiters must be removed so they don't accidentally consume
|
||||
later replies."""
|
||||
bot, _api = _bot_with_fake_api()
|
||||
|
||||
async def go():
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await bot.send_and_wait(42, "ping", timeout=0.05)
|
||||
assert 42 not in bot._reply_waiters, f"leaked waiters: {bot._reply_waiters}"
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
def test_dispatch_skips_cancelled_waiters_and_falls_through_to_handlers():
|
||||
"""Race fix: if a waiter is cancelled (wait_for timed out) but still in
|
||||
the FIFO when a reply arrives, the dispatcher must skip it and either
|
||||
resolve a live waiter OR fall through to user message handlers — not
|
||||
silently drop the message."""
|
||||
bot, _api = _bot_with_fake_api()
|
||||
fallback_calls: list[str] = []
|
||||
|
||||
@bot.on_message(content_type="text")
|
||||
async def fallback(msg):
|
||||
fallback_calls.append(msg.text or "")
|
||||
|
||||
async def go():
|
||||
# Manually inject a cancelled waiter (simulating wait_for timeout
|
||||
# cleanup losing the race with the inbound message).
|
||||
loop = asyncio.get_running_loop()
|
||||
stale: asyncio.Future = loop.create_future()
|
||||
stale.cancel()
|
||||
bot._reply_waiters[42] = [stale]
|
||||
|
||||
evt = {"type": "newChatItems", "chatItems": [
|
||||
{
|
||||
"chatInfo": {"type": "direct", "contact": {"contactId": 42}},
|
||||
"chatItem": {
|
||||
"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "racing reply"}},
|
||||
},
|
||||
}
|
||||
]}
|
||||
await bot._dispatch_event(evt) # type: ignore[arg-type]
|
||||
|
||||
asyncio.run(go())
|
||||
assert fallback_calls == ["racing reply"], (
|
||||
"dispatcher dropped the message instead of falling through to user handlers; "
|
||||
f"got {fallback_calls}"
|
||||
)
|
||||
assert 42 not in bot._reply_waiters, "cancelled waiter wasn't cleaned up"
|
||||
|
||||
|
||||
def test_send_and_wait_parallel_different_contacts():
|
||||
"""Concurrent send_and_wait to different contacts must not block each other.
|
||||
|
||||
The library docstring promises this; this test pins the behaviour so a
|
||||
future refactor (e.g., adding a single lock) can't quietly break it."""
|
||||
bot, _api = _bot_with_fake_api()
|
||||
|
||||
async def go() -> tuple[str, str]:
|
||||
t_a = asyncio.create_task(bot.send_and_wait(10, "a", timeout=2.0))
|
||||
await asyncio.sleep(0)
|
||||
t_b = asyncio.create_task(bot.send_and_wait(20, "b", timeout=2.0))
|
||||
await asyncio.sleep(0)
|
||||
# Deliver reply for B first — order shouldn't matter.
|
||||
await bot._dispatch_event({"type": "newChatItems", "chatItems": [ # type: ignore[arg-type]
|
||||
{
|
||||
"chatInfo": {"type": "direct", "contact": {"contactId": 20}},
|
||||
"chatItem": {"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "B"}}},
|
||||
}
|
||||
]})
|
||||
await bot._dispatch_event({"type": "newChatItems", "chatItems": [ # type: ignore[arg-type]
|
||||
{
|
||||
"chatInfo": {"type": "direct", "contact": {"contactId": 10}},
|
||||
"chatItem": {"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "A"}}},
|
||||
}
|
||||
]})
|
||||
return (await t_a).text or "", (await t_b).text or ""
|
||||
|
||||
a, b = asyncio.run(go())
|
||||
assert (a, b) == ("A", "B")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# connect_to
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_connect_to_returns_known_contact_without_handshake():
|
||||
"""If the link is already known, connect_to skips api_connect entirely."""
|
||||
bot, api = _bot_with_fake_api()
|
||||
existing = {"contactId": 7, "profile": {"displayName": "SimpleX Directory"}}
|
||||
api.connect_plan_result = ("known_contact_address", existing)
|
||||
|
||||
contact = asyncio.run(bot.connect_to("link", timeout=2.0))
|
||||
assert contact["contactId"] == 7
|
||||
# No connect issued: send buffer untouched.
|
||||
assert api.sent == []
|
||||
|
||||
|
||||
def test_connect_to_waits_for_contactConnected():
|
||||
"""For unknown links, connect_to issues the handshake and waits for the
|
||||
contactConnected event before returning."""
|
||||
bot, api = _bot_with_fake_api()
|
||||
api.connect_plan_result = ("ok", None)
|
||||
new_contact = {"contactId": 11, "profile": {"displayName": "Friend"}}
|
||||
|
||||
async def go():
|
||||
connect_task = asyncio.create_task(bot.connect_to("link", timeout=2.0))
|
||||
await asyncio.sleep(0)
|
||||
await bot._dispatch_event({"type": "contactConnected", "contact": new_contact}) # type: ignore[arg-type]
|
||||
return await connect_task
|
||||
|
||||
contact = asyncio.run(go())
|
||||
assert contact["contactId"] == 11
|
||||
|
||||
|
||||
def test_connect_to_tolerates_contact_already_exists():
|
||||
"""ContactAlreadyExistsError must NOT abort connect_to — a previous
|
||||
incomplete attempt may have left the connection mid-handshake; the
|
||||
contactConnected event will still arrive."""
|
||||
bot, api = _bot_with_fake_api()
|
||||
api.connect_plan_result = ("ok", None)
|
||||
api.connect_should_raise = ContactAlreadyExistsError(
|
||||
"exists", {"type": "contactAlreadyExists"} # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
async def go():
|
||||
connect_task = asyncio.create_task(bot.connect_to("link", timeout=2.0))
|
||||
await asyncio.sleep(0)
|
||||
await bot._dispatch_event({"type": "contactConnected", "contact": {"contactId": 5, "profile": {"displayName": "Friend"}}}) # type: ignore[arg-type]
|
||||
return await connect_task
|
||||
|
||||
contact = asyncio.run(go())
|
||||
assert contact["contactId"] == 5
|
||||
|
||||
|
||||
def test_connect_to_requires_serving():
|
||||
bot = Bot(profile=BotProfile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
bot._api = FakeApi() # type: ignore[assignment]
|
||||
with pytest.raises(RuntimeError, match="receive loop"):
|
||||
asyncio.run(bot.connect_to("link"))
|
||||
|
||||
|
||||
def test_connect_to_timeout_cleans_up_waiter():
|
||||
bot, api = _bot_with_fake_api()
|
||||
api.connect_plan_result = ("ok", None)
|
||||
|
||||
async def go():
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await bot.connect_to("link", timeout=0.05)
|
||||
assert bot._connect_waiters == [], "leaked connect waiter"
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
def test_connect_to_rejects_non_positive_timeout():
|
||||
"""timeout<=0 must fail upfront — otherwise wait_for raises after the
|
||||
handshake side-effect has already gone over the wire."""
|
||||
bot, _api = _bot_with_fake_api()
|
||||
|
||||
async def go():
|
||||
for bad in (0, -1, -0.001):
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
await bot.connect_to("link", timeout=bad)
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
def test_send_and_wait_rejects_non_positive_timeout():
|
||||
"""Same as connect_to: timeout<=0 would surprise the caller with a sent
|
||||
message and no Future to await."""
|
||||
bot, api = _bot_with_fake_api()
|
||||
|
||||
async def go():
|
||||
for bad in (0, -1, -0.5):
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
await bot.send_and_wait(42, "ping", timeout=bad)
|
||||
# And nothing was sent.
|
||||
assert api.sent == []
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
def test_stop_before_serve_forever_is_preserved(monkeypatch):
|
||||
"""If stop() is called between __aenter__ and serve_forever (e.g. a
|
||||
signal handler fires during the window where run() wires SIGINT), the
|
||||
pre-set _stop_event must NOT be cleared by serve_forever — otherwise
|
||||
the signal is silently lost and the loop runs indefinitely."""
|
||||
import simplex_chat.client as client_mod
|
||||
|
||||
class _FakeApi:
|
||||
@classmethod
|
||||
async def init(cls, *_a, **_kw):
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def started(self):
|
||||
return False
|
||||
|
||||
async def start_chat(self):
|
||||
pass
|
||||
|
||||
async def stop_chat(self):
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def api_get_active_user(self):
|
||||
return {"userId": 1, "profile": {"displayName": "x"}}
|
||||
|
||||
async def recv_chat_event(self, wait_us=0):
|
||||
# Should NOT be reached — the loop should exit on the pre-set
|
||||
# stop event before it ever polls for an event.
|
||||
raise AssertionError("receive loop should have exited immediately")
|
||||
|
||||
# _ensure_active_user / _maybe_sync_profile pokes
|
||||
async def send_chat_cmd(self, _cmd):
|
||||
return {"type": "cmdOk"}
|
||||
|
||||
monkeypatch.setattr(client_mod, "ChatApi", _FakeApi)
|
||||
|
||||
c = Client(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
|
||||
async def go():
|
||||
async with c:
|
||||
c.stop() # signal fires before serve_forever
|
||||
await c.serve_forever() # must not block
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
def test_aexit_nulls_api_even_if_close_raises(monkeypatch):
|
||||
"""If `close()` raises inside __aexit__, the Client must still appear
|
||||
closed — `client.api` should refuse to hand back the half-shutdown
|
||||
controller, and re-entering the context manager should re-init cleanly."""
|
||||
import simplex_chat.client as client_mod
|
||||
|
||||
init_count = [0]
|
||||
|
||||
class _BoomCloseApi:
|
||||
@classmethod
|
||||
async def init(cls, *_a, **_kw):
|
||||
init_count[0] += 1
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def started(self):
|
||||
return False
|
||||
|
||||
async def start_chat(self):
|
||||
pass
|
||||
|
||||
async def stop_chat(self):
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
raise RuntimeError("close failed")
|
||||
|
||||
async def api_get_active_user(self):
|
||||
return {"userId": 1, "profile": {"displayName": "x"}}
|
||||
|
||||
async def send_chat_cmd(self, _cmd):
|
||||
return {"type": "cmdOk"}
|
||||
|
||||
monkeypatch.setattr(client_mod, "ChatApi", _BoomCloseApi)
|
||||
|
||||
c = Client(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
|
||||
async def go():
|
||||
with pytest.raises(RuntimeError, match="close failed"):
|
||||
async with c:
|
||||
pass
|
||||
# _api must be None despite close() raising
|
||||
assert c._api is None, "Client._api leaked after __aexit__ close() raised"
|
||||
with pytest.raises(RuntimeError, match="not initialized"):
|
||||
_ = c.api
|
||||
# Re-enter must work
|
||||
try:
|
||||
async with c:
|
||||
pass
|
||||
except RuntimeError:
|
||||
pass # close raises again, fine
|
||||
assert init_count[0] == 2, "re-entry didn't re-init the controller"
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
def test_aenter_rolls_back_partial_init_on_post_start_failure(monkeypatch):
|
||||
"""If anything in __aenter__ raises after ChatApi.init succeeded — including
|
||||
_post_start — the controller must be closed. Otherwise the with-block isn't
|
||||
entered, __aexit__ never runs, and the FFI handle leaks."""
|
||||
import simplex_chat.client as client_mod
|
||||
|
||||
closed: list[str] = []
|
||||
started: list[bool] = [False]
|
||||
|
||||
class FakeChatApi:
|
||||
@classmethod
|
||||
async def init(cls, *_args, **_kwargs):
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def started(self) -> bool:
|
||||
return started[0]
|
||||
|
||||
async def start_chat(self):
|
||||
started[0] = True
|
||||
|
||||
async def stop_chat(self):
|
||||
started[0] = False
|
||||
closed.append("stop")
|
||||
|
||||
async def close(self):
|
||||
closed.append("close")
|
||||
|
||||
# Stub the bits _ensure_active_user / _maybe_sync_profile reach for.
|
||||
async def api_get_active_user(self):
|
||||
return {"userId": 1, "profile": {"displayName": "x"}}
|
||||
|
||||
async def send_chat_cmd(self, _cmd):
|
||||
return {"type": "cmdOk"}
|
||||
|
||||
monkeypatch.setattr(client_mod, "ChatApi", FakeChatApi)
|
||||
|
||||
class Boom(RuntimeError):
|
||||
pass
|
||||
|
||||
class BoomClient(Client):
|
||||
async def _post_start(self, user):
|
||||
raise Boom("kaboom")
|
||||
|
||||
c = BoomClient(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test"))
|
||||
|
||||
async def go():
|
||||
with pytest.raises(Boom):
|
||||
async with c:
|
||||
pytest.fail("should not enter the with-block")
|
||||
|
||||
asyncio.run(go())
|
||||
assert closed == ["stop", "close"], f"controller not cleaned up: {closed}"
|
||||
assert c._api is None, "Client._api should be reset to None after rollback"
|
||||
|
||||
|
||||
def test_lookup_known_contact_propagates_non_command_errors():
|
||||
"""_lookup_known_contact must NOT mask transport / FFI errors as 'unknown
|
||||
link' — only ChatCommandError (malformed link, etc.) should fall through
|
||||
to the handshake path. Bare Exception catch would hide real bugs."""
|
||||
bot, api = _bot_with_fake_api()
|
||||
|
||||
class BoomError(RuntimeError):
|
||||
pass
|
||||
|
||||
async def boom(_user_id, _link):
|
||||
raise BoomError("FFI wedged")
|
||||
|
||||
api.api_connect_plan = boom # type: ignore[assignment]
|
||||
|
||||
async def go():
|
||||
with pytest.raises(BoomError):
|
||||
await bot._lookup_known_contact(1, "link")
|
||||
|
||||
asyncio.run(go())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception subclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_contact_already_exists_is_chat_command_error_subclass():
|
||||
"""Callers should be able to catch the base class to handle all command
|
||||
errors uniformly, and the specific subclass for targeted handling."""
|
||||
from simplex_chat import ChatCommandError, ContactAlreadyExistsError
|
||||
|
||||
assert issubclass(ContactAlreadyExistsError, ChatCommandError)
|
||||
|
||||
e = ContactAlreadyExistsError("x", {"type": "contactAlreadyExists"}) # type: ignore[arg-type]
|
||||
assert isinstance(e, ChatCommandError)
|
||||
assert e.response_type == "contactAlreadyExists"
|
||||
|
||||
|
||||
def test_chat_command_error_response_type_property():
|
||||
from simplex_chat import ChatCommandError
|
||||
|
||||
e = ChatCommandError("x", {"type": "someError"}) # type: ignore[arg-type]
|
||||
assert e.response_type == "someError"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# events() mutual exclusion with serve_forever
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_events_raises_if_already_serving():
|
||||
bot, _api = _bot_with_fake_api()
|
||||
# _serving=True is set by _bot_with_fake_api
|
||||
|
||||
async def go():
|
||||
with pytest.raises(RuntimeError, match="mutually exclusive"):
|
||||
async for _ in bot.events():
|
||||
pass
|
||||
|
||||
asyncio.run(go())
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
from simplex_chat.filters import compile_message_filter
|
||||
|
||||
|
||||
def _msg(content_type="text", text=None, chat_type="direct", group_id=None):
|
||||
def _msg(content_type="text", text=None, chat_type="direct", group_id=None, contact_id=None):
|
||||
"""Build a minimal mock Message-like object for filter testing."""
|
||||
|
||||
class M:
|
||||
@@ -11,12 +11,12 @@ def _msg(content_type="text", text=None, chat_type="direct", group_id=None):
|
||||
|
||||
m = M()
|
||||
m.content = {"type": content_type, "text": text} if text is not None else {"type": content_type}
|
||||
m.chat_item = {
|
||||
"chatInfo": {
|
||||
"type": chat_type,
|
||||
**({"groupInfo": {"groupId": group_id}} if chat_type == "group" else {}),
|
||||
}
|
||||
}
|
||||
chat_info: dict = {"type": chat_type}
|
||||
if chat_type == "group":
|
||||
chat_info["groupInfo"] = {"groupId": group_id}
|
||||
elif chat_type == "direct" and contact_id is not None:
|
||||
chat_info["contact"] = {"contactId": contact_id}
|
||||
m.chat_item = {"chatInfo": chat_info}
|
||||
return m
|
||||
|
||||
|
||||
@@ -81,3 +81,23 @@ def test_group_id_tuple_or():
|
||||
f = compile_message_filter({"group_id": (1, 2, 3)})
|
||||
assert f(_msg(chat_type="group", group_id=2))
|
||||
assert not f(_msg(chat_type="group", group_id=99))
|
||||
|
||||
|
||||
def test_contact_id_filter():
|
||||
f = compile_message_filter({"contact_id": 7})
|
||||
assert f(_msg(chat_type="direct", contact_id=7))
|
||||
assert not f(_msg(chat_type="direct", contact_id=99))
|
||||
assert not f(_msg(chat_type="group", group_id=7))
|
||||
|
||||
|
||||
def test_contact_id_tuple_or():
|
||||
f = compile_message_filter({"contact_id": (1, 2, 3)})
|
||||
assert f(_msg(chat_type="direct", contact_id=2))
|
||||
assert not f(_msg(chat_type="direct", contact_id=99))
|
||||
|
||||
|
||||
def test_contact_id_combined_with_content_type():
|
||||
f = compile_message_filter({"content_type": "text", "contact_id": 5})
|
||||
assert f(_msg(content_type="text", chat_type="direct", contact_id=5))
|
||||
assert not f(_msg(content_type="image", chat_type="direct", contact_id=5))
|
||||
assert not f(_msg(content_type="text", chat_type="direct", contact_id=99))
|
||||
|
||||
Reference in New Issue
Block a user