From 9584992c8331b3198f5d6d370eaaceb4ab8c037a Mon Sep 17 00:00:00 2001 From: sh <37271604+shumvgolove@users.noreply.github.com> Date: Wed, 13 May 2026 15:51:00 +0000 Subject: [PATCH] simplex-chat-python: split Client from Bot, add request/response API (#6976) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * simplex-chat-python: split Client from Bot, add request/response API Client is now the base class for SimpleX participants that talk TO services (monitors, probes, automated participants). Bot extends Client with server features (address, auto-accept, welcome, commands). New methods on Client (inherited by Bot): connect_to(link) idempotent contact handshake send_and_wait(id, text) send a message and await the reply events() async iterator over chat events @on_message(contact_id=N) filter by sender in decorators BotProfile renamed to Profile (alias kept). New ContactAlreadyExistsError subclass for cleaner error handling. * simplex-chat-python: narrow event payload type per @on_event tag @client.on_event("contactConnected") now types the handler's event parameter as CEvt.ContactConnected instead of the unnarrowed CEvt.ChatEvent union — mirroring how @on_message narrows by content_type. The 50 overloads are generated by the Haskell codegen into _events.py (as a Protocol class), so new events stay in sync automatically. Client.on_event is exposed as a property typed as that Protocol; the runtime implementation is unchanged. --- bots/src/API/Docs/Generate/Python.hs | 38 +- .../src/simplex_chat/__init__.py | 15 +- .../src/simplex_chat/api.py | 18 +- .../src/simplex_chat/bot.py | 713 ++----------- .../src/simplex_chat/client.py | 955 ++++++++++++++++++ .../src/simplex_chat/filters.py | 9 + .../src/simplex_chat/types/_events.py | 318 +++++- .../tests/test_bot_registration.py | 36 +- .../tests/test_client_and_waiters.py | 616 +++++++++++ .../simplex-chat-python/tests/test_filters.py | 34 +- 10 files changed, 2089 insertions(+), 663 deletions(-) create mode 100644 packages/simplex-chat-python/src/simplex_chat/client.py create mode 100644 packages/simplex-chat-python/tests/test_client_and_waiters.py diff --git a/bots/src/API/Docs/Generate/Python.hs b/bots/src/API/Docs/Generate/Python.hs index a144aa4376..64aa1d1062 100644 --- a/bots/src/API/Docs/Generate/Python.hs +++ b/bots/src/API/Docs/Generate/Python.hs @@ -83,12 +83,48 @@ responsesCodeText = eventsCodeText :: Text eventsCodeText = ("# API Events\n# " <> autoGenerated <> "\n") - <> pythonImports + <> "from __future__ import annotations\n" + <> "from collections.abc import Awaitable, Callable\n" + <> "from typing import Literal, NotRequired, Protocol, TypedDict, overload\n" + <> "from . import _types as T\n" <> unionTypeCodePy moduleMember "T." "ChatEvent" chatEventConstrs + <> onEventProtocolCode chatEventConstrs where chatEventConstrs = L.fromList $ concatMap catEvents chatEventsDocs catEvents CECategory {mainEvents, otherEvents} = map eventType $ mainEvents ++ otherEvents +-- | Render the `OnEventDecorator` Protocol — one `__call__` overload per +-- event tag, narrowing the handler's event parameter from the unnarrowed +-- `ChatEvent` union to the specific tagged TypedDict. Plus a fallback +-- overload for `event: str` that keeps the unnarrowed shape so non-literal +-- tags don't trigger a type error. +-- +-- `Client.on_event` is typed as a `OnEventDecorator` (via a property) so +-- callers get per-tag narrowing without per-tag handwritten overloads +-- in client.py. +onEventProtocolCode :: L.NonEmpty ATUnionMember -> Text +onEventProtocolCode members = + "\n\nclass OnEventDecorator(Protocol):\n" + <> " \"\"\"Per-tag narrowing protocol for ``Client.on_event``.\n" + <> "\n" + <> " ``@client.on_event(\"contactConnected\")`` types the handler's\n" + <> " ``evt`` parameter as :class:`ContactConnected` rather than the\n" + <> " unnarrowed :data:`ChatEvent` union.\n" + <> " \"\"\"\n" + <> foldMap overloadCode (L.toList members) + <> "\n @overload\n" + <> " def __call__(self, event: str, /) -> Callable[\n" + <> " [Callable[[\"ChatEvent\"], Awaitable[None]]],\n" + <> " Callable[[\"ChatEvent\"], Awaitable[None]],\n" + <> " ]: ...\n" + where + overloadCode (ATUnionMember tag _) = + "\n @overload\n" + <> " def __call__(self, event: Literal[\"" <> T.pack tag <> "\"], /) -> Callable[\n" + <> " [Callable[[\"" <> pyConstrName tag <> "\"], Awaitable[None]]],\n" + <> " Callable[[\"" <> pyConstrName tag <> "\"], Awaitable[None]],\n" + <> " ]: ...\n" + typesCodeText :: Text typesCodeText = ("# API Types\n# " <> autoGenerated <> "\n") diff --git a/packages/simplex-chat-python/src/simplex_chat/__init__.py b/packages/simplex-chat-python/src/simplex_chat/__init__.py index dfafef123a..c353b74935 100644 --- a/packages/simplex-chat-python/src/simplex_chat/__init__.py +++ b/packages/simplex-chat-python/src/simplex_chat/__init__.py @@ -1,12 +1,21 @@ """SimpleX Chat — Python client library for chat bots.""" from ._version import __version__ -from .api import ChatApi, ChatCommandError, ConnReqType, Db, PostgresDb, SqliteDb +from .api import ( + ChatApi, + ChatCommandError, + ConnReqType, + ContactAlreadyExistsError, + Db, + PostgresDb, + SqliteDb, +) from .bot import ( Bot, BotCommand, BotProfile, ChatMessage, + Client, CommandHandler, EventHandler, FileMessage, @@ -16,6 +25,7 @@ from .bot import ( MessageHandler, Middleware, ParsedCommand, + Profile, ReportMessage, TextMessage, UnknownMessage, @@ -35,8 +45,10 @@ __all__ = [ "ChatCommandError", "ChatInitError", "ChatMessage", + "Client", "CommandHandler", "ConnReqType", + "ContactAlreadyExistsError", "CryptoArgs", "Db", "EventHandler", @@ -49,6 +61,7 @@ __all__ = [ "MigrationConfirmation", "ParsedCommand", "PostgresDb", + "Profile", "ReportMessage", "SqliteDb", "TextMessage", diff --git a/packages/simplex-chat-python/src/simplex_chat/api.py b/packages/simplex-chat-python/src/simplex_chat/api.py index 8f116c903f..ef37e28384 100644 --- a/packages/simplex-chat-python/src/simplex_chat/api.py +++ b/packages/simplex-chat-python/src/simplex_chat/api.py @@ -40,10 +40,26 @@ def _db_to_migrate_args(db: Db) -> tuple[str, str, _native.Backend]: class ChatCommandError(Exception): + """A chat command returned an unexpected response type. + + `response` is the raw wire response; `response_type` exposes its `type` + discriminator for quick checks. Subclasses cover known recoverable cases + so callers can `except ContactAlreadyExistsError` instead of inspecting + `response.get("type")` themselves. + """ + def __init__(self, message: str, response: CR.ChatResponse): super().__init__(message) self.response = response + @property + def response_type(self) -> str: + return self.response.get("type", "") # type: ignore[return-value] + + +class ContactAlreadyExistsError(ChatCommandError): + """`api_connect`/`api_connect_active_user` was called for an existing contact.""" + class ChatApi: def __init__(self, ctrl: int): @@ -481,7 +497,7 @@ class ChatApi: if r["type"] == "sentInvitation": return "contact" if r["type"] == "contactAlreadyExists": - raise ChatCommandError("contact already exists", r) + raise ContactAlreadyExistsError("contact already exists", r) raise ChatCommandError("connection error", r) async def api_accept_contact_request(self, contact_req_id: int) -> T.Contact: diff --git a/packages/simplex-chat-python/src/simplex_chat/bot.py b/packages/simplex-chat-python/src/simplex_chat/bot.py index 7414f28b87..fb511e2818 100644 --- a/packages/simplex-chat-python/src/simplex_chat/bot.py +++ b/packages/simplex-chat-python/src/simplex_chat/bot.py @@ -1,32 +1,34 @@ -"""User-facing `Bot` API: decorators, filters, Message wrapper, lifecycle.""" +"""`Bot` — Client extended with server-side features (address, auto-accept, commands).""" from __future__ import annotations -import asyncio -import logging -import os -import signal as _signal -from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, overload from . import util -from .api import ChatApi, Db -from .core import ChatAPIError, MigrationConfirmation -from .filters import compile_message_filter -from .types import CEvt, T - -log = logging.getLogger("simplex_chat") - -C = TypeVar("C", bound="T.MsgContent") - - -@dataclass(slots=True) -class BotProfile: - display_name: str - full_name: str = "" - short_descr: str | None = None - image: str | None = None +from .api import Db +from .client import ( + BotProfile, + ChatMessage, + Client, + CommandHandler, + EventHandler, + FileMessage, + ImageMessage, + LinkMessage, + Message, + MessageHandler, + Middleware, + ParsedCommand, + Profile, + ReportMessage, + TextMessage, + UnknownMessage, + VideoMessage, + VoiceMessage, + log, +) +from .core import MigrationConfirmation +from .types import T @dataclass(slots=True) @@ -35,85 +37,25 @@ class BotCommand: label: str -@dataclass(slots=True, frozen=True) -class ParsedCommand: - keyword: str - args: str +class Bot(Client): + """SimpleX bot — Client extended with server-side features. + On top of `Client` (identity + messaging + connect_to/send_and_wait/events), + a Bot: + - creates and announces its own contact address + - auto-accepts incoming contact requests (configurable) + - advertises a list of slash-commands in its profile preferences + - sets `peerType=bot` and disables calls/voice in profile prefs + - sends a `welcome` message to new contacts via the auto-reply address setting -@dataclass(slots=True, frozen=True) -class Message(Generic[C]): - chat_item: T.AChatItem - content: C - bot: "Bot" - - @property - def chat_info(self) -> T.ChatInfo: - return self.chat_item["chatInfo"] - - @property - def text(self) -> str | None: - c = self.content - if isinstance(c, dict): - return c.get("text") # type: ignore[return-value] - return None - - async def reply(self, text: str) -> "Message[T.MsgContent]": - items = await self.bot.api.api_send_text_reply(self.chat_item, text) - ci = items[0] - content = ci["chatItem"]["content"] - # content is CIContent — snd variant has msgContent; cast for type safety. - msg_content: T.MsgContent = content["msgContent"] # type: ignore[index] - return Message(chat_item=ci, content=msg_content, bot=self.bot) - - async def reply_content(self, content: T.MsgContent) -> "Message[T.MsgContent]": - items = await self.bot.api.api_send_messages( - self.chat_info, [{"msgContent": content, "mentions": {}}] - ) - ci = items[0] - ci_content = ci["chatItem"]["content"] - msg_content: T.MsgContent = ci_content["msgContent"] # type: ignore[index] - return Message(chat_item=ci, content=msg_content, bot=self.bot) - - -# Concrete narrowed aliases — one per MsgContent_ variant in _types.py. -TextMessage = Message[T.MsgContent_text] -LinkMessage = Message[T.MsgContent_link] -ImageMessage = Message[T.MsgContent_image] -VideoMessage = Message[T.MsgContent_video] -VoiceMessage = Message[T.MsgContent_voice] -FileMessage = Message[T.MsgContent_file] -ReportMessage = Message[T.MsgContent_report] -ChatMessage = Message[T.MsgContent_chat] -UnknownMessage = Message[T.MsgContent_unknown] - -MessageHandler = Callable[[Message[Any]], Awaitable[None]] -CommandHandler = Callable[[Message[Any], ParsedCommand], Awaitable[None]] -EventHandler = Callable[[CEvt.ChatEvent], Awaitable[None]] - - -class Middleware: - """Override `__call__` to wrap message handlers with cross-cutting logic. - - `handler` is the next stage in the chain — call it with `(message, data)` - to continue, or skip the call to short-circuit. `data` is a per-dispatch - dict that middleware can use to pass values down the chain. + If you want just identity + messaging without any of that, use `Client` + directly. """ - async def __call__( - self, - handler: Callable[[Message[Any], dict[str, object]], Awaitable[None]], - message: Message[Any], - data: dict[str, object], - ) -> None: - await handler(message, data) - - -class Bot: def __init__( self, *, - profile: BotProfile, + profile: Profile, db: Db, welcome: str | T.MsgContent | None = None, commands: list[BotCommand] | None = None, @@ -124,423 +66,42 @@ class Bot: auto_accept: bool = True, business_address: bool = False, allow_files: bool = False, - use_bot_profile: bool = True, log_contacts: bool = True, log_network: bool = False, ) -> None: - self._profile = profile - self._db = db + super().__init__( + profile=profile, + db=db, + confirm_migrations=confirm_migrations, + update_profile=update_profile, + log_contacts=log_contacts, + log_network=log_network, + ) self._welcome = welcome self._commands = commands or [] - self._confirm_migrations = confirm_migrations - self._opts = { - "create_address": create_address, - "update_address": update_address, - "update_profile": update_profile, - "auto_accept": auto_accept, - "business_address": business_address, - "allow_files": allow_files, - "use_bot_profile": use_bot_profile, - "log_contacts": log_contacts, - "log_network": log_network, - } - self._api: ChatApi | None = None - self._serving = False - self._stop_event = asyncio.Event() - self._message_handlers: list[tuple[Callable[[Message[Any]], bool], MessageHandler]] = [] - self._command_handlers: list[ - tuple[tuple[str, ...], Callable[[Message[Any]], bool], CommandHandler] - ] = [] - self._event_handlers: dict[str, list[EventHandler]] = {} - self._middleware: list[Middleware] = [] - # Track default-handler registration so __aenter__ on a re-used bot - # doesn't accumulate duplicate log/error handlers. - self._defaults_registered = False - - @property - def api(self) -> ChatApi: - if self._api is None: - raise RuntimeError("Bot not initialized — call bot.run() or use `async with bot:`") - return self._api + self._create_address = create_address + self._update_address = update_address + self._auto_accept = auto_accept + self._business_address = business_address + self._allow_files = allow_files # ------------------------------------------------------------------ # - # Decorators + # Profile + address sync (overrides hooks in Client) # ------------------------------------------------------------------ # - @overload - def on_message( - self, *, content_type: Literal["text"], **rest: Any - ) -> Callable[ - [Callable[[TextMessage], Awaitable[None]]], - Callable[[TextMessage], Awaitable[None]], - ]: ... + async def _post_start(self, user: T.User) -> None: + """Bots sync address first, then embed the link in the profile.""" + link = await self._sync_address(user) + await self._maybe_sync_profile(user, contact_link=link) - @overload - def on_message( - self, *, content_type: Literal["link"], **rest: Any - ) -> Callable[ - [Callable[[LinkMessage], Awaitable[None]]], - Callable[[LinkMessage], Awaitable[None]], - ]: ... - - @overload - def on_message( - self, *, content_type: Literal["image"], **rest: Any - ) -> Callable[ - [Callable[[ImageMessage], Awaitable[None]]], - Callable[[ImageMessage], Awaitable[None]], - ]: ... - - @overload - def on_message( - self, *, content_type: Literal["video"], **rest: Any - ) -> Callable[ - [Callable[[VideoMessage], Awaitable[None]]], - Callable[[VideoMessage], Awaitable[None]], - ]: ... - - @overload - def on_message( - self, *, content_type: Literal["voice"], **rest: Any - ) -> Callable[ - [Callable[[VoiceMessage], Awaitable[None]]], - Callable[[VoiceMessage], Awaitable[None]], - ]: ... - - @overload - def on_message( - self, *, content_type: Literal["file"], **rest: Any - ) -> Callable[ - [Callable[[FileMessage], Awaitable[None]]], - Callable[[FileMessage], Awaitable[None]], - ]: ... - - @overload - def on_message( - self, *, content_type: Literal["report"], **rest: Any - ) -> Callable[ - [Callable[[ReportMessage], Awaitable[None]]], - Callable[[ReportMessage], Awaitable[None]], - ]: ... - - @overload - def on_message( - self, *, content_type: Literal["chat"], **rest: Any - ) -> Callable[ - [Callable[[ChatMessage], Awaitable[None]]], - Callable[[ChatMessage], Awaitable[None]], - ]: ... - - @overload - def on_message( - self, *, content_type: Literal["unknown"], **rest: Any - ) -> Callable[ - [Callable[[UnknownMessage], Awaitable[None]]], - Callable[[UnknownMessage], Awaitable[None]], - ]: ... - - @overload - def on_message(self, **rest: Any) -> Callable[[MessageHandler], MessageHandler]: ... - - def on_message(self, **filter_kw: Any) -> Callable[[MessageHandler], MessageHandler]: - predicate = compile_message_filter(filter_kw) - - def deco(fn: MessageHandler) -> MessageHandler: - self._message_handlers.append((predicate, fn)) - return fn - - return deco - - def on_command( - self, name: str | tuple[str, ...], **filter_kw: Any - ) -> Callable[[CommandHandler], CommandHandler]: - names = (name,) if isinstance(name, str) else tuple(name) - predicate = compile_message_filter(filter_kw) - - def deco(fn: CommandHandler) -> CommandHandler: - self._command_handlers.append((names, predicate, fn)) - return fn - - return deco - - def on_event(self, event: CEvt.ChatEvent_Tag, /) -> Callable[[EventHandler], EventHandler]: - def deco(fn: EventHandler) -> EventHandler: - self._event_handlers.setdefault(event, []).append(fn) - return fn - - return deco - - def use(self, middleware: Middleware) -> None: - self._middleware.append(middleware) - - # ------------------------------------------------------------------ # - # Lifecycle - # ------------------------------------------------------------------ # - - async def __aenter__(self) -> "Bot": - # Order matters: libsimplex `/_start` requires an active user, so - # ensure (or create) the user first, THEN start the chat, THEN - # do address + profile sync. Mirrors Node bot.ts:48-64. - self._api = await ChatApi.init(self._db, self._confirm_migrations) - user = await self._ensure_active_user() - await self._api.start_chat() - await self._sync_address_and_profile(user) - self._register_log_handlers() - return self - - async def __aexit__(self, *exc_info: object) -> None: - self.stop() - if self._api is not None: - try: - await self._api.stop_chat() - finally: - await self._api.close() - self._api = None - - def run(self) -> None: - """Blocking entry: runs serve_forever() with SIGINT/SIGTERM handlers installed. - - Configures `logging.basicConfig(level=INFO)` if the root logger has no - handlers yet, so the bot's startup messages and the announced address - are visible without callers having to set up logging. Embedders that - manage logging themselves are unaffected (basicConfig is a no-op when - handlers already exist). - """ - if not logging.getLogger().handlers: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(name)s %(message)s", - ) - - async def _main() -> None: - async with self: - loop = asyncio.get_running_loop() - # First Ctrl+C → graceful stop (~500ms, bounded by the - # receive-loop poll interval). Second Ctrl+C → force-exit - # immediately (in case stop_chat / close hang on a wedged - # FFI call). Standard CLI UX (jupyter, ipython, …). - sigint_count = 0 - - def on_interrupt() -> None: - nonlocal sigint_count - sigint_count += 1 - if sigint_count == 1: - log.info("stopping bot... (press Ctrl+C again to force exit)") - self.stop() - else: - os._exit(130) # 128 + SIGINT - - if hasattr(_signal, "SIGINT"): - try: - loop.add_signal_handler(_signal.SIGINT, on_interrupt) - loop.add_signal_handler(_signal.SIGTERM, self.stop) - except NotImplementedError: # Windows - _signal.signal(_signal.SIGINT, lambda *_: on_interrupt()) - await self.serve_forever() - - asyncio.run(_main()) - - async def serve_forever(self) -> None: - if self._serving: - raise RuntimeError("already serving") - self._serving = True - self._stop_event.clear() - try: - await self._receive_loop() - finally: - self._serving = False - - def stop(self) -> None: - self._stop_event.set() - - async def _receive_loop(self) -> None: - # Catch broad Exception so a single malformed event or transient - # native error doesn't crash the whole bot. CancelledError must - # always re-raise so `bot.stop()` and asyncio cancellation work. - # `wait_us=500_000` (500ms) bounds the worst-case Ctrl+C latency: - # the C call blocks the worker thread until timeout, and the loop - # only checks `_stop_event` between polls. - while not self._stop_event.is_set(): - try: - event = await self.api.recv_chat_event(wait_us=500_000) - except asyncio.CancelledError: - raise - except ChatAPIError as e: - # Async chat errors emitted via the Haskell `eToView` path — - # routine soft errors (stale connections after a peer deletes - # a chat, file cleanup failures, etc.) intermixed with - # CRITICAL agent failures the operator must see. Mirror the - # desktop policy in SimpleXAPI.kt:3332-3340: escalate - # CRITICAL agent errors, keep everything else at debug. - chat_err: Any = e.chat_error or {} - agent_err: Any = ( - chat_err.get("agentError", {}) if chat_err.get("type") == "errorAgent" else {} - ) - if agent_err.get("type") == "CRITICAL": - log.error( - "chat agent CRITICAL: %s (offerRestart=%s)", - agent_err.get("criticalErr"), - agent_err.get("offerRestart"), - ) - else: - log.debug("chat event error: %s", chat_err.get("type")) - continue - except Exception: - log.exception("recv_chat_event failed") - # Bound the spin rate when the FFI is wedged on a persistent - # error (vs the timeout path, which already paces itself). - await asyncio.sleep(0.5) - continue - if event is None: - continue - try: - await self._dispatch_event(event) - except asyncio.CancelledError: - raise - except Exception: - log.exception("dispatch_event failed for tag=%s", event.get("type")) - - # ------------------------------------------------------------------ # - # Dispatch - # ------------------------------------------------------------------ # - - async def _dispatch_event(self, event: CEvt.ChatEvent) -> None: - tag = event["type"] - for h in self._event_handlers.get(tag, []): - try: - await h(event) - except Exception: - log.exception("on_event handler failed") - if tag == "newChatItems": - evt: CEvt.NewChatItems = event # type: ignore[assignment] - for ci in evt["chatItems"]: - content = ci["chatItem"]["content"] - if content["type"] != "rcvMsgContent": - continue - msg_content = content["msgContent"] # type: ignore[index] - msg: Message[T.MsgContent] = Message(chat_item=ci, content=msg_content, bot=self) - await self._dispatch_message(msg) - - async def _dispatch_message(self, msg: Message[Any]) -> None: - # First-match-wins. The squaring bot's `@on_message(text=NUMBER_RE)` - # and catch-all `@on_message(content_type="text")` both match a number - # like "1"; we want only the first to fire. Registration order is the - # priority order — register the most-specific filters first. - # - # Slash-commands are tried first against command handlers; if no - # command handler matches, fall through to message handlers (so - # `@on_message` can still catch unknown slash-commands). - cmd = self._parse_command(msg) - if cmd is not None: - for names, predicate, handler in self._command_handlers: - if cmd.keyword in names and predicate(msg): - await self._invoke_command_with_middleware(handler, msg, cmd) - return - for predicate, handler in self._message_handlers: - if predicate(msg): - await self._invoke_with_middleware(handler, msg) - return - - async def _invoke_with_middleware(self, handler: MessageHandler, message: Message[Any]) -> None: - # Fast path: most bots register no middleware. Skip the closure-chain - # construction and the empty-data dict on every dispatch. - if not self._middleware: - try: - await handler(message) - except Exception: - log.exception("message handler failed") - return - - async def call(m: Message[Any], _data: dict[str, object]) -> None: - await handler(m) - - chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call - for mw in reversed(self._middleware): - inner = chain - - async def _wrapped( - m: Message[Any], - d: dict[str, object], - mw: Middleware = mw, - inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner, - ) -> None: - await mw(inner, m, d) - - chain = _wrapped - - try: - await chain(message, {}) - except Exception: - log.exception("message handler failed") - - async def _invoke_command_with_middleware( - self, handler: CommandHandler, message: Message[Any], cmd: ParsedCommand - ) -> None: - if not self._middleware: - try: - await handler(message, cmd) - except Exception: - log.exception("command handler failed") - return - - async def call(m: Message[Any], _data: dict[str, object]) -> None: - await handler(m, cmd) - - chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call - for mw in reversed(self._middleware): - inner = chain - - async def _wrapped( - m: Message[Any], - d: dict[str, object], - mw: Middleware = mw, - inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner, - ) -> None: - await mw(inner, m, d) - - chain = _wrapped - - try: - await chain(message, {}) - except Exception: - log.exception("command handler failed") - - @staticmethod - def _parse_command(msg: Message[Any]) -> ParsedCommand | None: - parsed = util.ci_bot_command(msg.chat_item["chatItem"]) - if parsed is None: - return None - keyword, args = parsed - return ParsedCommand(keyword=keyword, args=args) - - # ------------------------------------------------------------------ # - # Profile + address sync - # ------------------------------------------------------------------ # - - async def _ensure_active_user(self) -> T.User: - """Get or create the active user. Must run before `start_chat`. - - Mirrors Node `createBotUser` (bot.ts:158-166). The chat controller - won't accept `/_start` without a user, so this phase has to land - before lifecycle proceeds. - """ - api = self.api - user = await api.api_get_active_user() - if user is None: - log.info("No active user in database, creating...") - user = await api.api_create_active_user(self._bot_profile_to_wire()) - log.info("Bot user: %s", user["profile"]["displayName"]) - return user - - async def _sync_address_and_profile(self, user: T.User) -> None: - """Address + profile sync. Runs after `start_chat` (mirrors bot.ts:57-63).""" + async def _sync_address(self, user: T.User) -> str | None: + """Address sync. Returns the public link if any, for embedding in the profile.""" api = self.api user_id = user["userId"] - # 2. Address (numbered to match bot.ts comments — phase 1 was user creation). address = await api.api_get_user_address(user_id) if address is None: - if self._opts["create_address"]: + if self._create_address: log.info("Bot has no address, creating...") await api.api_create_user_address(user_id) address = await api.api_get_user_address(user_id) @@ -549,17 +110,16 @@ class Bot: else: log.warning("Bot has no address") - # Always announce the address — matches Node bot.ts:60. link: str | None = None if address is not None: link = util.contact_address_str(address["connLinkContact"]) log.info("Bot address: %s", link) - # 3. Address settings (auto-accept + welcome message). Mirrors bot.ts:185-194. + # Address settings (auto-accept + welcome message). Mirrors bot.ts:185-194. # autoAccept present → accept; absent → no auto-accept (mirrors Node bot.ts). - if address is not None and self._opts["update_address"]: - desired: T.AddressSettings = {"businessAddress": self._opts["business_address"]} - if self._opts["auto_accept"]: + if address is not None and self._update_address: + desired: T.AddressSettings = {"businessAddress": self._business_address} + if self._auto_accept: desired["autoAccept"] = {"acceptIncognito": False} if self._welcome is not None: desired["autoReply"] = ( @@ -571,154 +131,45 @@ class Bot: log.info("Bot address settings changed, updating...") await api.api_set_address_settings(user_id, desired) - # 4. Profile update. Mirrors Node `updateBotUserProfile` (bot.ts:199-214). - # Field-by-field comparison: user["profile"] is LocalProfile (has extra - # fields profileId, localAlias, preferences, peerType) so a full-dict - # equality would always differ. - new_profile = self._bot_profile_to_wire() - if link is not None and self._opts["use_bot_profile"]: - # Mirrors bot.ts:62 — embed the connection link in the bot's profile - # so contacts that resolve the bot via stored profile data see the - # current address. - new_profile["contactLink"] = link - cur = user["profile"] - changed = ( - cur["displayName"] != new_profile["displayName"] - or cur.get("fullName", "") != new_profile.get("fullName", "") - or cur.get("shortDescr") != new_profile.get("shortDescr") - or cur.get("image") != new_profile.get("image") - or cur.get("preferences") != new_profile.get("preferences") - or cur.get("peerType") != new_profile.get("peerType") - or cur.get("contactLink") != new_profile.get("contactLink") - ) - if changed and self._opts["update_profile"]: - log.info("Bot profile changed, updating...") - await api.api_update_profile(user_id, new_profile) + return link - def _bot_profile_to_wire(self) -> T.Profile: - """Construct wire-format Profile, applying bot conventions when use_bot_profile=True. + def _profile_to_wire(self) -> T.Profile: + """Bot profile: base profile + peerType=bot, command list, calls/voice prefs disabled. - Mirrors Node mkBotProfile (bot.ts:88-102): bots get peerType="bot", - calls/voice prefs disabled, files gated on `allow_files`, and any - registered `commands` embedded in the profile preferences. + Mirrors Node `mkBotProfile` (bot.ts:88-102). """ - p: T.Profile = { - "displayName": self._profile.display_name, - "fullName": self._profile.full_name, + p = super()._profile_to_wire() + prefs: T.Preferences = { + "calls": {"allow": "no"}, + "voice": {"allow": "no"}, + "files": {"allow": "yes" if self._allow_files else "no"}, } - if self._profile.short_descr is not None: - p["shortDescr"] = self._profile.short_descr - if self._profile.image is not None: - p["image"] = self._profile.image - if self._opts["use_bot_profile"]: - prefs: T.Preferences = { - "calls": {"allow": "no"}, - "voice": {"allow": "no"}, - "files": {"allow": "yes" if self._opts["allow_files"] else "no"}, - } - if self._commands: - prefs["commands"] = [ - {"type": "command", "keyword": c.keyword, "label": c.label} - for c in self._commands - ] - p["preferences"] = prefs - p["peerType"] = "bot" - elif self._commands: - raise ValueError( - "use_bot_profile=False but commands were passed; commands are " - "only sent when use_bot_profile=True (they're embedded in the " - "user profile preferences)." - ) + if self._commands: + prefs["commands"] = [ + {"type": "command", "keyword": c.keyword, "label": c.label} + for c in self._commands + ] + p["preferences"] = prefs + p["peerType"] = "bot" return p - # ------------------------------------------------------------------ # - # Log subscriptions (mirror Node subscribeLogEvents bot.ts:142-156) - # ------------------------------------------------------------------ # - def _register_log_handlers(self) -> None: - # Idempotent: a Bot reused across multiple `__aenter__` cycles must - # not stack duplicate log handlers. Always-on error handlers run - # regardless of log_contacts/log_network so messageError/chatError/ - # chatErrors don't disappear into the void. - if self._defaults_registered: - return - self._defaults_registered = True - self._event_handlers.setdefault("messageError", []).append(self._log_message_error) - self._event_handlers.setdefault("chatError", []).append(self._log_chat_error) - self._event_handlers.setdefault("chatErrors", []).append(self._log_chat_errors) - if self._opts["log_contacts"]: - self._event_handlers.setdefault("contactConnected", []).append( - self._log_contact_connected - ) - self._event_handlers.setdefault("contactDeletedByContact", []).append( - self._log_contact_deleted - ) - if self._opts["log_network"]: - self._event_handlers.setdefault("hostConnected", []).append(self._log_host_connected) - self._event_handlers.setdefault("hostDisconnected", []).append( - self._log_host_disconnected - ) - self._event_handlers.setdefault("subscriptionStatus", []).append( - self._log_subscription_status - ) - - @staticmethod - async def _log_contact_connected(evt: CEvt.ChatEvent) -> None: - log.info("%s connected", evt["contact"]["profile"]["displayName"]) # type: ignore[index] - - @staticmethod - async def _log_contact_deleted(evt: CEvt.ChatEvent) -> None: - log.info( - "%s deleted connection with bot", - evt["contact"]["profile"]["displayName"], # type: ignore[index] - ) - - @staticmethod - async def _log_host_connected(evt: CEvt.ChatEvent) -> None: - log.info("connected server %s", evt["transportHost"]) # type: ignore[index] - - @staticmethod - async def _log_host_disconnected(evt: CEvt.ChatEvent) -> None: - log.info("disconnected server %s", evt["transportHost"]) # type: ignore[index] - - @staticmethod - async def _log_subscription_status(evt: CEvt.ChatEvent) -> None: - log.info( - "%d subscription(s) %s", - len(evt["connections"]), # type: ignore[index] - evt["subscriptionStatus"]["type"], # type: ignore[index] - ) - - @staticmethod - async def _log_message_error(evt: CEvt.ChatEvent) -> None: - log.warning("messageError: %s", evt.get("severity", "?")) # type: ignore[union-attr] - - @staticmethod - async def _log_chat_error(evt: CEvt.ChatEvent) -> None: - err = evt.get("chatError") # type: ignore[union-attr] - log.error("chatError: %s", err.get("type") if isinstance(err, dict) else err) - - @staticmethod - async def _log_chat_errors(evt: CEvt.ChatEvent) -> None: - errs = evt.get("chatErrors") or [] # type: ignore[union-attr] - log.error("chatErrors: %d errors", len(errs)) - - -# Suppress unused-import warnings for re-exported names used only at type-check time. __all__ = [ "Bot", "BotCommand", "BotProfile", "ChatMessage", + "Client", + "CommandHandler", + "EventHandler", "FileMessage", "ImageMessage", "LinkMessage", "Message", "MessageHandler", - "CommandHandler", - "EventHandler", "Middleware", "ParsedCommand", + "Profile", "ReportMessage", "TextMessage", "UnknownMessage", diff --git a/packages/simplex-chat-python/src/simplex_chat/client.py b/packages/simplex-chat-python/src/simplex_chat/client.py new file mode 100644 index 0000000000..b0d144b8b9 --- /dev/null +++ b/packages/simplex-chat-python/src/simplex_chat/client.py @@ -0,0 +1,955 @@ +"""Base `Client` API: lifecycle, dispatch, decorators, connect_to / send_and_wait / events. + +Bot extends Client to add server-side features (address, auto-accept, welcome, +commands). Client by itself is suitable for monitors, probes, automated +participants — anything that talks TO services rather than accepting incoming +connections. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import signal as _signal +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, overload + +from . import util +from .api import ChatApi, ChatCommandError, ContactAlreadyExistsError, Db +from .core import ChatAPIError, MigrationConfirmation +from .filters import compile_message_filter +from .types import CEvt, T + +log = logging.getLogger("simplex_chat") + +C = TypeVar("C", bound="T.MsgContent") + + +@dataclass(slots=True) +class Profile: + """SimpleX user profile fields: display name, optional full name, descr, avatar. + + Universal — used by both `Client` and `Bot`. The bot-specific extensions + (peerType=bot, command list, calls/voice preferences) are added at + wire-conversion time by `Bot`, not stored here. + """ + + display_name: str + full_name: str = "" + short_descr: str | None = None + image: str | None = None + + +# Backwards-compatibility alias — the dataclass was named `BotProfile` before +# the Client/Bot hierarchy was introduced. Keep the old name working so +# `from simplex_chat import BotProfile` doesn't break existing code. +BotProfile = Profile + + +@dataclass(slots=True, frozen=True) +class ParsedCommand: + keyword: str + args: str + + +@dataclass(slots=True, frozen=True) +class Message(Generic[C]): + chat_item: T.AChatItem + content: C + client: "Client" + + @property + def chat_info(self) -> T.ChatInfo: + return self.chat_item["chatInfo"] + + @property + def text(self) -> str | None: + c = self.content + if isinstance(c, dict): + return c.get("text") # type: ignore[return-value] + return None + + async def reply(self, text: str) -> "Message[T.MsgContent]": + items = await self.client.api.api_send_text_reply(self.chat_item, text) + ci = items[0] + content = ci["chatItem"]["content"] + # content is CIContent — snd variant has msgContent; cast for type safety. + msg_content: T.MsgContent = content["msgContent"] # type: ignore[index] + return Message(chat_item=ci, content=msg_content, client=self.client) + + async def reply_content(self, content: T.MsgContent) -> "Message[T.MsgContent]": + items = await self.client.api.api_send_messages( + self.chat_info, [{"msgContent": content, "mentions": {}}] + ) + ci = items[0] + ci_content = ci["chatItem"]["content"] + msg_content: T.MsgContent = ci_content["msgContent"] # type: ignore[index] + return Message(chat_item=ci, content=msg_content, client=self.client) + + +# Concrete narrowed aliases — one per MsgContent_ variant in _types.py. +TextMessage = Message[T.MsgContent_text] +LinkMessage = Message[T.MsgContent_link] +ImageMessage = Message[T.MsgContent_image] +VideoMessage = Message[T.MsgContent_video] +VoiceMessage = Message[T.MsgContent_voice] +FileMessage = Message[T.MsgContent_file] +ReportMessage = Message[T.MsgContent_report] +ChatMessage = Message[T.MsgContent_chat] +UnknownMessage = Message[T.MsgContent_unknown] + +MessageHandler = Callable[[Message[Any]], Awaitable[None]] +CommandHandler = Callable[[Message[Any], ParsedCommand], Awaitable[None]] +EventHandler = Callable[[CEvt.ChatEvent], Awaitable[None]] + + +class Middleware: + """Override `__call__` to wrap message handlers with cross-cutting logic. + + `handler` is the next stage in the chain — call it with `(message, data)` + to continue, or skip the call to short-circuit. `data` is a per-dispatch + dict that middleware can use to pass values down the chain. + """ + + async def __call__( + self, + handler: Callable[[Message[Any], dict[str, object]], Awaitable[None]], + message: Message[Any], + data: dict[str, object], + ) -> None: + await handler(message, data) + + +class Client: + """SimpleX participant — has an identity, sends and receives messages. + + No address, no auto-accept of incoming requests, no bot profile prefs. Use + this for monitors, probes, automated participants — anything that talks + TO services rather than accepting incoming connections. Use `Bot` for the + server-side flavour. + + Typical pattern: + + async with Client(profile=Profile(display_name="m"), db=...) as c: + serve = asyncio.create_task(c.serve_forever()) + contact = await c.connect_to(link) + reply = await c.send_and_wait(contact["contactId"], "/help") + c.stop() + await serve + + The decorator-style handlers (`@on_message`, `@on_command`, `@on_event`) + work too if you want callback-style dispatch instead of async-await. + """ + + def __init__( + self, + *, + profile: Profile, + db: Db, + confirm_migrations: MigrationConfirmation = MigrationConfirmation.YES_UP, + update_profile: bool = True, + log_contacts: bool = False, + log_network: bool = False, + ) -> None: + self._profile = profile + self._db = db + self._confirm_migrations = confirm_migrations + self._update_profile = update_profile + self._log_contacts = log_contacts + self._log_network = log_network + self._api: ChatApi | None = None + self._serving = False + self._stop_event = asyncio.Event() + self._message_handlers: list[tuple[Callable[[Message[Any]], bool], MessageHandler]] = [] + self._command_handlers: list[ + tuple[tuple[str, ...], Callable[[Message[Any]], bool], CommandHandler] + ] = [] + self._event_handlers: dict[str, list[EventHandler]] = {} + self._middleware: list[Middleware] = [] + # Track default-handler registration so __aenter__ on a re-used client + # doesn't accumulate duplicate log/error handlers. + self._defaults_registered = False + # Internal waiters used by `send_and_wait` (keyed by contact_id, FIFO + # within a contact) and `connect_to` (one-shot, resolved on the next + # contactConnected event). Populated by user-async-callers, drained + # in `_dispatch_event` before user handlers run. + self._reply_waiters: dict[int, list[asyncio.Future[Message[Any]]]] = {} + self._connect_waiters: list[asyncio.Future[T.Contact]] = [] + + @property + def api(self) -> ChatApi: + if self._api is None: + raise RuntimeError("Client not initialized — call run() or use `async with client:`") + return self._api + + # ------------------------------------------------------------------ # + # Decorators + # ------------------------------------------------------------------ # + + @overload + def on_message( + self, *, content_type: Literal["text"], **rest: Any + ) -> Callable[ + [Callable[[TextMessage], Awaitable[None]]], + Callable[[TextMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["link"], **rest: Any + ) -> Callable[ + [Callable[[LinkMessage], Awaitable[None]]], + Callable[[LinkMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["image"], **rest: Any + ) -> Callable[ + [Callable[[ImageMessage], Awaitable[None]]], + Callable[[ImageMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["video"], **rest: Any + ) -> Callable[ + [Callable[[VideoMessage], Awaitable[None]]], + Callable[[VideoMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["voice"], **rest: Any + ) -> Callable[ + [Callable[[VoiceMessage], Awaitable[None]]], + Callable[[VoiceMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["file"], **rest: Any + ) -> Callable[ + [Callable[[FileMessage], Awaitable[None]]], + Callable[[FileMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["report"], **rest: Any + ) -> Callable[ + [Callable[[ReportMessage], Awaitable[None]]], + Callable[[ReportMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["chat"], **rest: Any + ) -> Callable[ + [Callable[[ChatMessage], Awaitable[None]]], + Callable[[ChatMessage], Awaitable[None]], + ]: ... + + @overload + def on_message( + self, *, content_type: Literal["unknown"], **rest: Any + ) -> Callable[ + [Callable[[UnknownMessage], Awaitable[None]]], + Callable[[UnknownMessage], Awaitable[None]], + ]: ... + + @overload + def on_message(self, **rest: Any) -> Callable[[MessageHandler], MessageHandler]: ... + + def on_message(self, **filter_kw: Any) -> Callable[[MessageHandler], MessageHandler]: + predicate = compile_message_filter(filter_kw) + + def deco(fn: MessageHandler) -> MessageHandler: + self._message_handlers.append((predicate, fn)) + return fn + + return deco + + def on_command( + self, name: str | tuple[str, ...], **filter_kw: Any + ) -> Callable[[CommandHandler], CommandHandler]: + names = (name,) if isinstance(name, str) else tuple(name) + predicate = compile_message_filter(filter_kw) + + def deco(fn: CommandHandler) -> CommandHandler: + self._command_handlers.append((names, predicate, fn)) + return fn + + return deco + + # `on_event` is exposed as a property typed as the generated + # `OnEventDecorator` Protocol so per-tag narrowing applies — e.g. + # `@client.on_event("contactConnected")` types the handler's event + # parameter as `CEvt.ContactConnected`, not the unnarrowed + # `CEvt.ChatEvent` union. The Protocol's overload chain lives in + # generated code (one entry per event tag) so it stays in sync with + # the wire schema automatically. The runtime implementation is the + # plain handler-registration below. + @property + def on_event(self) -> CEvt.OnEventDecorator: + return self._register_event_handler # type: ignore[return-value] + + def _register_event_handler( + self, event: str, / + ) -> Callable[[EventHandler], EventHandler]: + def deco(fn: EventHandler) -> EventHandler: + self._event_handlers.setdefault(event, []).append(fn) + return fn + + return deco + + def use(self, middleware: Middleware) -> None: + self._middleware.append(middleware) + + # ------------------------------------------------------------------ # + # Lifecycle + # ------------------------------------------------------------------ # + + async def __aenter__(self) -> "Client": + # Order matters: libsimplex `/_start` requires an active user, so + # ensure (or create) the user first, THEN start the chat, THEN + # do post-start setup (profile sync; Bot adds address sync). + # Clear `_stop_event` here (not in `serve_forever`/`events`) so that + # a `stop()` call landing between `__aenter__` and the receive loop + # — e.g. a signal handler firing while signal handlers are being + # wired up — is preserved and causes the loop to exit immediately + # on entry. + self._stop_event.clear() + self._api = await ChatApi.init(self._db, self._confirm_migrations) + try: + user = await self._ensure_active_user() + await self._api.start_chat() + await self._post_start(user) + self._register_log_handlers() + return self + except BaseException: + # __aexit__ is only called when __aenter__ returns successfully — + # roll back the open chat controller here so a failure during + # init doesn't leak the FFI resource. + await self._shutdown_partial_init() + raise + + async def _shutdown_partial_init(self) -> None: + """Best-effort teardown for an `__aenter__` that didn't reach return.""" + api = self._api + if api is None: + return + if api.started: + try: + await api.stop_chat() + except Exception: + log.exception("stop_chat failed during init rollback") + try: + await api.close() + except Exception: + log.exception("close failed during init rollback") + self._api = None + + async def __aexit__(self, *exc_info: object) -> None: + self.stop() + api = self._api + if api is None: + return + # Null out the reference up-front so the Client appears closed even + # if stop_chat / close raise — otherwise `client.api` would still + # hand back a half-shutdown controller after `async with` exits. + self._api = None + try: + await api.stop_chat() + finally: + await api.close() + + async def _post_start(self, user: T.User) -> None: + """Hook for subclasses to add work between `start_chat` and serving. + + Default (Client): sync profile only. Bot overrides to also sync its + address and embed the connection link in the profile. + """ + await self._maybe_sync_profile(user, contact_link=None) + + def run(self) -> None: + """Blocking entry: runs serve_forever() with SIGINT/SIGTERM handlers installed. + + Configures `logging.basicConfig(level=INFO)` if the root logger has no + handlers yet, so startup messages and the announced address are + visible without callers having to set up logging. Embedders that + manage logging themselves are unaffected (basicConfig is a no-op when + handlers already exist). + """ + if not logging.getLogger().handlers: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + async def _main() -> None: + async with self: + loop = asyncio.get_running_loop() + # First Ctrl+C → graceful stop (~500ms, bounded by the + # receive-loop poll interval). Second Ctrl+C → force-exit + # immediately (in case stop_chat / close hang on a wedged + # FFI call). Standard CLI UX (jupyter, ipython, …). + sigint_count = 0 + + def on_interrupt() -> None: + nonlocal sigint_count + sigint_count += 1 + if sigint_count == 1: + log.info("stopping... (press Ctrl+C again to force exit)") + self.stop() + else: + os._exit(130) # 128 + SIGINT + + if hasattr(_signal, "SIGINT"): + try: + loop.add_signal_handler(_signal.SIGINT, on_interrupt) + loop.add_signal_handler(_signal.SIGTERM, self.stop) + except NotImplementedError: # Windows + _signal.signal(_signal.SIGINT, lambda *_: on_interrupt()) + await self.serve_forever() + + asyncio.run(_main()) + + async def serve_forever(self) -> None: + if self._serving: + raise RuntimeError("already serving") + self._serving = True + try: + await self._receive_loop() + finally: + self._serving = False + + def stop(self) -> None: + self._stop_event.set() + + async def events(self) -> AsyncIterator[CEvt.ChatEvent]: + """Yield chat events one at a time — alternative to `serve_forever`. + + Runs the full dispatch pipeline on each event (internal waiters, + user `@on_event`/`@on_message`/`@on_command` handlers), then yields + the raw event for inspection. Use this when you want direct control + over the receive loop, e.g. to surface errors that `serve_forever` + would otherwise swallow, or to compose with other async iterators. + + Mutually exclusive with `serve_forever`. Stops when `stop()` is + called or when the consumer exits the `async for` loop (which + triggers the generator's `aclose`). Async-generator GC alone is + not reliable for cleanup — exit the loop explicitly. + """ + if self._serving: + raise RuntimeError( + "already serving — events() and serve_forever() are mutually exclusive" + ) + self._serving = True + try: + while not self._stop_event.is_set(): + try: + event = await self.api.recv_chat_event(wait_us=500_000) + except asyncio.CancelledError: + raise + if event is None: + continue + try: + await self._dispatch_event(event) + except asyncio.CancelledError: + raise + except Exception: + log.exception("dispatch_event failed for tag=%s", event.get("type")) + yield event + finally: + self._serving = False + + async def connect_to(self, link: str, *, timeout: float = 120.0) -> T.Contact: + """Connect to a SimpleX contact link, returning the resulting Contact. + + Idempotent: if the link is already known (via `api_connect_plan`) + the existing Contact is returned without re-handshaking. Otherwise + initiates the handshake and waits for the `contactConnected` event. + + Requires the receive loop to be running (`serve_forever` or + `events()`), since the handshake completes asynchronously. + + Concurrency caveat: pending `connect_to` waiters are a single FIFO + with no link↔waiter correlation. If you call `connect_to` for two + different links concurrently, or if a third party connects to your + address (Bot subclass with `auto_accept=True`) while a `connect_to` + is in flight, the returned Contact may not be the one you asked + for. Sequence concurrent connects, or call them one at a time. + + Raises: + asyncio.TimeoutError: handshake didn't complete within `timeout` + ValueError: timeout is not positive + RuntimeError: no active user, or receive loop not running + """ + if timeout <= 0: + # Reject upfront — otherwise wait_for raises TimeoutError after + # the handshake side-effect (api_connect_active_user) has + # already gone over the wire, leaving the caller with no + # Contact reference and a half-initiated connection. + raise ValueError(f"timeout must be positive, got {timeout!r}") + if not self._serving: + raise RuntimeError( + "connect_to requires the receive loop to be running — " + "call serve_forever() (in a task) or iterate events() first" + ) + api = self.api + user = await api.api_get_active_user() + if user is None: + raise RuntimeError("no active user") + + existing = await self._lookup_known_contact(user["userId"], link) + if existing is not None: + return existing + + loop = asyncio.get_running_loop() + waiter: asyncio.Future[T.Contact] = loop.create_future() + self._connect_waiters.append(waiter) + try: + try: + await api.api_connect_active_user(link) + except ContactAlreadyExistsError: + # Handshake mid-flight, or a previous incomplete attempt + # left the connection in a known-but-not-connected state. + # Either way: wait for the contactConnected event. + pass + return await asyncio.wait_for(waiter, timeout=timeout) + finally: + if waiter in self._connect_waiters: + self._connect_waiters.remove(waiter) + + async def _lookup_known_contact(self, user_id: int, link: str) -> T.Contact | None: + """Resolve a link to an existing Contact via api_connect_plan, or None. + + Only ChatCommandError is swallowed (malformed link, etc.) — the + connect_to caller will fall back to the full handshake path. + Transport/FFI errors propagate so the caller sees the real cause. + """ + try: + plan, _ = await self.api.api_connect_plan(user_id, link) + except ChatCommandError: + return None + if plan["type"] == "contactAddress": + cap = plan["contactAddressPlan"] + if cap["type"] == "known": + return cap["contact"] + if plan["type"] == "invitationLink": + ilp = plan["invitationLinkPlan"] + if ilp["type"] == "known": + return ilp["contact"] + return None + + async def send_and_wait( + self, + contact_id: int, + text: str, + *, + timeout: float = 30.0, + ) -> "Message[T.MsgContent]": + """Send text to a direct contact and wait for the next reply from them. + + Waiters are FIFO per contact_id: two concurrent calls to the same + contact get two replies in send order. Concurrent calls to *different* + contacts run in parallel. Once a reply matches a waiter, user + message handlers do NOT fire for that message — the awaiter owns it. + + Correlation caveat: matching is by sender contact_id only — there + is no message-id correlation. ANY direct message from `contact_id` + arriving while a waiter is pending will resolve that waiter, even + if it was an unsolicited message (e.g. an auto-reply from a bot's + address settings, a delayed reply from a previous send, a push + notification). For strict request/response semantics, ensure the + peer is otherwise quiet, or use the `@on_message` callback model. + + Requires the receive loop to be running. Raises asyncio.TimeoutError + on timeout, ValueError if timeout is not positive. + """ + if timeout <= 0: + # Reject upfront — otherwise wait_for raises TimeoutError after + # api_send_text_message already went over the wire, surprising + # the caller with a sent message and no Future to await. + raise ValueError(f"timeout must be positive, got {timeout!r}") + if not self._serving: + raise RuntimeError( + "send_and_wait requires the receive loop to be running — " + "call serve_forever() (in a task) or iterate events() first" + ) + loop = asyncio.get_running_loop() + waiter: asyncio.Future[Message[Any]] = loop.create_future() + waiters = self._reply_waiters.setdefault(contact_id, []) + waiters.append(waiter) + try: + await self.api.api_send_text_message(["direct", contact_id], text) + return await asyncio.wait_for(waiter, timeout=timeout) + finally: + # Always clean up our slot, even on send error or timeout. Leaving + # an unresolved Future in the dict would make the next incoming + # message resolve a future no one is waiting on. + if waiter in waiters: + waiters.remove(waiter) + if not waiters: + self._reply_waiters.pop(contact_id, None) + + async def _receive_loop(self) -> None: + # Catch broad Exception so a single malformed event or transient + # native error doesn't crash the whole client. CancelledError must + # always re-raise so `stop()` and asyncio cancellation work. + # `wait_us=500_000` (500ms) bounds the worst-case Ctrl+C latency: + # the C call blocks the worker thread until timeout, and the loop + # only checks `_stop_event` between polls. + while not self._stop_event.is_set(): + try: + event = await self.api.recv_chat_event(wait_us=500_000) + except asyncio.CancelledError: + raise + except ChatAPIError as e: + # Async chat errors emitted via the Haskell `eToView` path — + # routine soft errors (stale connections after a peer deletes + # a chat, file cleanup failures, etc.) intermixed with + # CRITICAL agent failures the operator must see. Mirror the + # desktop policy in SimpleXAPI.kt:3332-3340: escalate + # CRITICAL agent errors, keep everything else at debug. + chat_err: Any = e.chat_error or {} + agent_err: Any = ( + chat_err.get("agentError", {}) if chat_err.get("type") == "errorAgent" else {} + ) + if agent_err.get("type") == "CRITICAL": + log.error( + "chat agent CRITICAL: %s (offerRestart=%s)", + agent_err.get("criticalErr"), + agent_err.get("offerRestart"), + ) + else: + log.debug("chat event error: %s", chat_err.get("type")) + continue + except Exception: + log.exception("recv_chat_event failed") + # Bound the spin rate when the FFI is wedged on a persistent + # error (vs the timeout path, which already paces itself). + await asyncio.sleep(0.5) + continue + if event is None: + continue + try: + await self._dispatch_event(event) + except asyncio.CancelledError: + raise + except Exception: + log.exception("dispatch_event failed for tag=%s", event.get("type")) + + # ------------------------------------------------------------------ # + # Dispatch + # ------------------------------------------------------------------ # + + async def _dispatch_event(self, event: CEvt.ChatEvent) -> None: + tag = event["type"] + # Resolve internal waiters BEFORE user handlers. A pending + # `connect_to` consumes the contactConnected; a pending + # `send_and_wait` consumes the matching incoming message — user + # handlers don't fire for that message. This matches the mental + # model: the awaiter explicitly asked for this event. + if tag == "contactConnected" and self._connect_waiters: + contact: T.Contact = event["contact"] # type: ignore[typeddict-item] + waiter = self._connect_waiters.pop(0) + if not waiter.done(): + waiter.set_result(contact) + for h in self._event_handlers.get(tag, []): + try: + await h(event) + except Exception: + log.exception("on_event handler failed") + if tag == "newChatItems": + evt: CEvt.NewChatItems = event # type: ignore[assignment] + for ci in evt["chatItems"]: + content = ci["chatItem"]["content"] + if content["type"] != "rcvMsgContent": + continue + msg_content = content["msgContent"] # type: ignore[index] + msg: Message[T.MsgContent] = Message(chat_item=ci, content=msg_content, client=self) + # If a send_and_wait is pending for this sender, fulfil it + # and skip the user dispatch chain — the awaiter "owns" this + # reply. FIFO within a contact_id. + if self._maybe_resolve_reply_waiter(msg): + continue + await self._dispatch_message(msg) + + def _maybe_resolve_reply_waiter(self, msg: Message[T.MsgContent]) -> bool: + chat_info = msg.chat_info + if chat_info.get("type") != "direct": + return False + contact_id = chat_info.get("contact", {}).get("contactId") # type: ignore[union-attr] + if contact_id is None: + return False + waiters = self._reply_waiters.get(contact_id) + if not waiters: + return False + # Skip waiters whose callers have already given up (cancelled by + # wait_for timing out at the same loop tick). Without this skip, + # a reply arriving in the narrow timeout-race window would be + # silently dropped because the FIFO would pop a done waiter and + # neither resolve it nor dispatch to user handlers. + while waiters: + waiter = waiters.pop(0) + if not waiter.done(): + if not waiters: + self._reply_waiters.pop(contact_id, None) + waiter.set_result(msg) + return True + self._reply_waiters.pop(contact_id, None) + return False + + async def _dispatch_message(self, msg: Message[Any]) -> None: + # First-match-wins. The squaring bot's `@on_message(text=NUMBER_RE)` + # and catch-all `@on_message(content_type="text")` both match a number + # like "1"; we want only the first to fire. Registration order is the + # priority order — register the most-specific filters first. + # + # Slash-commands are tried first against command handlers; if no + # command handler matches, fall through to message handlers (so + # `@on_message` can still catch unknown slash-commands). + cmd = self._parse_command(msg) + if cmd is not None: + for names, predicate, handler in self._command_handlers: + if cmd.keyword in names and predicate(msg): + await self._invoke_command_with_middleware(handler, msg, cmd) + return + for predicate, handler in self._message_handlers: + if predicate(msg): + await self._invoke_with_middleware(handler, msg) + return + + async def _invoke_with_middleware(self, handler: MessageHandler, message: Message[Any]) -> None: + # Fast path: most clients register no middleware. Skip the closure-chain + # construction and the empty-data dict on every dispatch. + if not self._middleware: + try: + await handler(message) + except Exception: + log.exception("message handler failed") + return + + async def call(m: Message[Any], _data: dict[str, object]) -> None: + await handler(m) + + chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call + for mw in reversed(self._middleware): + inner = chain + + async def _wrapped( + m: Message[Any], + d: dict[str, object], + mw: Middleware = mw, + inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner, + ) -> None: + await mw(inner, m, d) + + chain = _wrapped + + try: + await chain(message, {}) + except Exception: + log.exception("message handler failed") + + async def _invoke_command_with_middleware( + self, handler: CommandHandler, message: Message[Any], cmd: ParsedCommand + ) -> None: + if not self._middleware: + try: + await handler(message, cmd) + except Exception: + log.exception("command handler failed") + return + + async def call(m: Message[Any], _data: dict[str, object]) -> None: + await handler(m, cmd) + + chain: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = call + for mw in reversed(self._middleware): + inner = chain + + async def _wrapped( + m: Message[Any], + d: dict[str, object], + mw: Middleware = mw, + inner: Callable[[Message[Any], dict[str, object]], Awaitable[None]] = inner, + ) -> None: + await mw(inner, m, d) + + chain = _wrapped + + try: + await chain(message, {}) + except Exception: + log.exception("command handler failed") + + @staticmethod + def _parse_command(msg: Message[Any]) -> ParsedCommand | None: + parsed = util.ci_bot_command(msg.chat_item["chatItem"]) + if parsed is None: + return None + keyword, args = parsed + return ParsedCommand(keyword=keyword, args=args) + + # ------------------------------------------------------------------ # + # Profile sync + # ------------------------------------------------------------------ # + + async def _ensure_active_user(self) -> T.User: + """Get or create the active user. Must run before `start_chat`. + + Mirrors Node `createBotUser` (bot.ts:158-166). The chat controller + won't accept `/_start` without a user, so this phase has to land + before lifecycle proceeds. + """ + api = self.api + user = await api.api_get_active_user() + if user is None: + log.info("No active user in database, creating...") + user = await api.api_create_active_user(self._profile_to_wire()) + log.info("user: %s", user["profile"]["displayName"]) + return user + + async def _maybe_sync_profile(self, user: T.User, *, contact_link: str | None) -> None: + """Update the user profile on the wire if its fields changed. + + `contact_link` is only set by Bot (to embed its address). Mirrors + Node `updateBotUserProfile` (bot.ts:199-214). Field-by-field + comparison because user["profile"] is LocalProfile (has extra + fields profileId, localAlias, preferences, peerType) so a full + dict equality would always differ. + """ + if not self._update_profile: + return + new_profile = self._profile_to_wire() + if contact_link is not None: + new_profile["contactLink"] = contact_link + cur = user["profile"] + changed = ( + cur["displayName"] != new_profile["displayName"] + or cur.get("fullName", "") != new_profile.get("fullName", "") + or cur.get("shortDescr") != new_profile.get("shortDescr") + or cur.get("image") != new_profile.get("image") + or cur.get("preferences") != new_profile.get("preferences") + or cur.get("peerType") != new_profile.get("peerType") + or cur.get("contactLink") != new_profile.get("contactLink") + ) + if changed: + log.info("profile changed, updating...") + await self.api.api_update_profile(user["userId"], new_profile) + + def _profile_to_wire(self) -> T.Profile: + """Convert the user-facing Profile dataclass to wire format. + + Base version produces a plain user profile. Bot overrides this to + add the bot-specific extensions (peerType=bot, command list, + calls/voice/files prefs). + """ + p: T.Profile = { + "displayName": self._profile.display_name, + "fullName": self._profile.full_name, + } + if self._profile.short_descr is not None: + p["shortDescr"] = self._profile.short_descr + if self._profile.image is not None: + p["image"] = self._profile.image + return p + + # ------------------------------------------------------------------ # + # Log subscriptions (mirror Node subscribeLogEvents bot.ts:142-156) + # ------------------------------------------------------------------ # + + def _register_log_handlers(self) -> None: + # Idempotent: re-entering the async context must not stack duplicate + # log handlers. Always-on error handlers run regardless of + # log_contacts/log_network so messageError/chatError/chatErrors + # don't disappear into the void. + if self._defaults_registered: + return + self._defaults_registered = True + self._event_handlers.setdefault("messageError", []).append(self._log_message_error) + self._event_handlers.setdefault("chatError", []).append(self._log_chat_error) + self._event_handlers.setdefault("chatErrors", []).append(self._log_chat_errors) + if self._log_contacts: + self._event_handlers.setdefault("contactConnected", []).append( + self._log_contact_connected + ) + self._event_handlers.setdefault("contactDeletedByContact", []).append( + self._log_contact_deleted + ) + if self._log_network: + self._event_handlers.setdefault("hostConnected", []).append(self._log_host_connected) + self._event_handlers.setdefault("hostDisconnected", []).append( + self._log_host_disconnected + ) + self._event_handlers.setdefault("subscriptionStatus", []).append( + self._log_subscription_status + ) + + @staticmethod + async def _log_contact_connected(evt: CEvt.ChatEvent) -> None: + log.info("%s connected", evt["contact"]["profile"]["displayName"]) # type: ignore[index] + + @staticmethod + async def _log_contact_deleted(evt: CEvt.ChatEvent) -> None: + log.info( + "%s deleted connection", + evt["contact"]["profile"]["displayName"], # type: ignore[index] + ) + + @staticmethod + async def _log_host_connected(evt: CEvt.ChatEvent) -> None: + log.info("connected server %s", evt["transportHost"]) # type: ignore[index] + + @staticmethod + async def _log_host_disconnected(evt: CEvt.ChatEvent) -> None: + log.info("disconnected server %s", evt["transportHost"]) # type: ignore[index] + + @staticmethod + async def _log_subscription_status(evt: CEvt.ChatEvent) -> None: + log.info( + "%d subscription(s) %s", + len(evt["connections"]), # type: ignore[index] + evt["subscriptionStatus"]["type"], # type: ignore[index] + ) + + @staticmethod + async def _log_message_error(evt: CEvt.ChatEvent) -> None: + log.warning("messageError: %s", evt.get("severity", "?")) # type: ignore[union-attr] + + @staticmethod + async def _log_chat_error(evt: CEvt.ChatEvent) -> None: + err = evt.get("chatError") # type: ignore[union-attr] + log.error("chatError: %s", err.get("type") if isinstance(err, dict) else err) + + @staticmethod + async def _log_chat_errors(evt: CEvt.ChatEvent) -> None: + errs = evt.get("chatErrors") or [] # type: ignore[union-attr] + log.error("chatErrors: %d errors", len(errs)) + + +__all__ = [ + "BotProfile", # backwards-compat alias for Profile + "ChatMessage", + "Client", + "CommandHandler", + "EventHandler", + "FileMessage", + "ImageMessage", + "LinkMessage", + "Message", + "MessageHandler", + "Middleware", + "ParsedCommand", + "Profile", + "ReportMessage", + "TextMessage", + "UnknownMessage", + "VideoMessage", + "VoiceMessage", +] diff --git a/packages/simplex-chat-python/src/simplex_chat/filters.py b/packages/simplex-chat-python/src/simplex_chat/filters.py index cdce5b7bb6..8af15c1c66 100644 --- a/packages/simplex-chat-python/src/simplex_chat/filters.py +++ b/packages/simplex-chat-python/src/simplex_chat/filters.py @@ -37,6 +37,15 @@ def compile_message_filter(kw: dict[str, Any]) -> Callable[[Any], bool]: predicates.append(gid_match) + if (cid := kw.get("contact_id")) is not None: + cid_set: tuple[int, ...] = (cid,) if isinstance(cid, int) else tuple(cid) + + def cid_match(m: Any) -> bool: + ci = m.chat_item["chatInfo"] + return ci["type"] == "direct" and ci["contact"]["contactId"] in cid_set + + predicates.append(cid_match) + if (when := kw.get("when")) is not None: predicates.append(when) diff --git a/packages/simplex-chat-python/src/simplex_chat/types/_events.py b/packages/simplex-chat-python/src/simplex_chat/types/_events.py index 77484fbf3f..7b7c724c92 100644 --- a/packages/simplex-chat-python/src/simplex_chat/types/_events.py +++ b/packages/simplex-chat-python/src/simplex_chat/types/_events.py @@ -1,7 +1,8 @@ # API Events # This file is generated automatically. from __future__ import annotations -from typing import Literal, NotRequired, TypedDict +from collections.abc import Awaitable, Callable +from typing import Literal, NotRequired, Protocol, TypedDict, overload from . import _types as T class ContactConnected(TypedDict): @@ -377,3 +378,318 @@ ChatEvent = ( ) ChatEvent_Tag = Literal["contactConnected", "contactUpdated", "contactDeletedByContact", "receivedContactRequest", "newMemberContactReceivedInv", "contactSndReady", "newChatItems", "chatItemReaction", "chatItemsDeleted", "chatItemUpdated", "groupChatItemsDeleted", "chatItemsStatusesUpdated", "receivedGroupInvitation", "userJoinedGroup", "groupUpdated", "joinedGroupMember", "memberRole", "deletedMember", "leftMember", "deletedMemberUser", "groupDeleted", "connectedToGroupMember", "memberAcceptedByOther", "memberBlockedForAll", "groupMemberUpdated", "groupLinkDataUpdated", "groupRelayUpdated", "rcvFileDescrReady", "rcvFileComplete", "sndFileCompleteXFTP", "rcvFileStart", "rcvFileSndCancelled", "rcvFileAccepted", "rcvFileError", "rcvFileWarning", "sndFileError", "sndFileWarning", "acceptingContactRequest", "acceptingBusinessRequest", "contactConnecting", "businessLinkConnecting", "joinedGroupMemberConnecting", "sentGroupInvitation", "groupLinkConnecting", "hostConnected", "hostDisconnected", "subscriptionStatus", "messageError", "chatError", "chatErrors"] + + +class OnEventDecorator(Protocol): + """Per-tag narrowing protocol for ``Client.on_event``. + + ``@client.on_event("contactConnected")`` types the handler's + ``evt`` parameter as :class:`ContactConnected` rather than the + unnarrowed :data:`ChatEvent` union. + """ + + @overload + def __call__(self, event: Literal["contactConnected"], /) -> Callable[ + [Callable[["ContactConnected"], Awaitable[None]]], + Callable[["ContactConnected"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["contactUpdated"], /) -> Callable[ + [Callable[["ContactUpdated"], Awaitable[None]]], + Callable[["ContactUpdated"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["contactDeletedByContact"], /) -> Callable[ + [Callable[["ContactDeletedByContact"], Awaitable[None]]], + Callable[["ContactDeletedByContact"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["receivedContactRequest"], /) -> Callable[ + [Callable[["ReceivedContactRequest"], Awaitable[None]]], + Callable[["ReceivedContactRequest"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["newMemberContactReceivedInv"], /) -> Callable[ + [Callable[["NewMemberContactReceivedInv"], Awaitable[None]]], + Callable[["NewMemberContactReceivedInv"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["contactSndReady"], /) -> Callable[ + [Callable[["ContactSndReady"], Awaitable[None]]], + Callable[["ContactSndReady"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["newChatItems"], /) -> Callable[ + [Callable[["NewChatItems"], Awaitable[None]]], + Callable[["NewChatItems"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["chatItemReaction"], /) -> Callable[ + [Callable[["ChatItemReaction"], Awaitable[None]]], + Callable[["ChatItemReaction"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["chatItemsDeleted"], /) -> Callable[ + [Callable[["ChatItemsDeleted"], Awaitable[None]]], + Callable[["ChatItemsDeleted"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["chatItemUpdated"], /) -> Callable[ + [Callable[["ChatItemUpdated"], Awaitable[None]]], + Callable[["ChatItemUpdated"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["groupChatItemsDeleted"], /) -> Callable[ + [Callable[["GroupChatItemsDeleted"], Awaitable[None]]], + Callable[["GroupChatItemsDeleted"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["chatItemsStatusesUpdated"], /) -> Callable[ + [Callable[["ChatItemsStatusesUpdated"], Awaitable[None]]], + Callable[["ChatItemsStatusesUpdated"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["receivedGroupInvitation"], /) -> Callable[ + [Callable[["ReceivedGroupInvitation"], Awaitable[None]]], + Callable[["ReceivedGroupInvitation"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["userJoinedGroup"], /) -> Callable[ + [Callable[["UserJoinedGroup"], Awaitable[None]]], + Callable[["UserJoinedGroup"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["groupUpdated"], /) -> Callable[ + [Callable[["GroupUpdated"], Awaitable[None]]], + Callable[["GroupUpdated"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["joinedGroupMember"], /) -> Callable[ + [Callable[["JoinedGroupMember"], Awaitable[None]]], + Callable[["JoinedGroupMember"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["memberRole"], /) -> Callable[ + [Callable[["MemberRole"], Awaitable[None]]], + Callable[["MemberRole"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["deletedMember"], /) -> Callable[ + [Callable[["DeletedMember"], Awaitable[None]]], + Callable[["DeletedMember"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["leftMember"], /) -> Callable[ + [Callable[["LeftMember"], Awaitable[None]]], + Callable[["LeftMember"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["deletedMemberUser"], /) -> Callable[ + [Callable[["DeletedMemberUser"], Awaitable[None]]], + Callable[["DeletedMemberUser"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["groupDeleted"], /) -> Callable[ + [Callable[["GroupDeleted"], Awaitable[None]]], + Callable[["GroupDeleted"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["connectedToGroupMember"], /) -> Callable[ + [Callable[["ConnectedToGroupMember"], Awaitable[None]]], + Callable[["ConnectedToGroupMember"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["memberAcceptedByOther"], /) -> Callable[ + [Callable[["MemberAcceptedByOther"], Awaitable[None]]], + Callable[["MemberAcceptedByOther"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["memberBlockedForAll"], /) -> Callable[ + [Callable[["MemberBlockedForAll"], Awaitable[None]]], + Callable[["MemberBlockedForAll"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["groupMemberUpdated"], /) -> Callable[ + [Callable[["GroupMemberUpdated"], Awaitable[None]]], + Callable[["GroupMemberUpdated"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["groupLinkDataUpdated"], /) -> Callable[ + [Callable[["GroupLinkDataUpdated"], Awaitable[None]]], + Callable[["GroupLinkDataUpdated"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["groupRelayUpdated"], /) -> Callable[ + [Callable[["GroupRelayUpdated"], Awaitable[None]]], + Callable[["GroupRelayUpdated"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["rcvFileDescrReady"], /) -> Callable[ + [Callable[["RcvFileDescrReady"], Awaitable[None]]], + Callable[["RcvFileDescrReady"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["rcvFileComplete"], /) -> Callable[ + [Callable[["RcvFileComplete"], Awaitable[None]]], + Callable[["RcvFileComplete"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["sndFileCompleteXFTP"], /) -> Callable[ + [Callable[["SndFileCompleteXFTP"], Awaitable[None]]], + Callable[["SndFileCompleteXFTP"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["rcvFileStart"], /) -> Callable[ + [Callable[["RcvFileStart"], Awaitable[None]]], + Callable[["RcvFileStart"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["rcvFileSndCancelled"], /) -> Callable[ + [Callable[["RcvFileSndCancelled"], Awaitable[None]]], + Callable[["RcvFileSndCancelled"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["rcvFileAccepted"], /) -> Callable[ + [Callable[["RcvFileAccepted"], Awaitable[None]]], + Callable[["RcvFileAccepted"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["rcvFileError"], /) -> Callable[ + [Callable[["RcvFileError"], Awaitable[None]]], + Callable[["RcvFileError"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["rcvFileWarning"], /) -> Callable[ + [Callable[["RcvFileWarning"], Awaitable[None]]], + Callable[["RcvFileWarning"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["sndFileError"], /) -> Callable[ + [Callable[["SndFileError"], Awaitable[None]]], + Callable[["SndFileError"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["sndFileWarning"], /) -> Callable[ + [Callable[["SndFileWarning"], Awaitable[None]]], + Callable[["SndFileWarning"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["acceptingContactRequest"], /) -> Callable[ + [Callable[["AcceptingContactRequest"], Awaitable[None]]], + Callable[["AcceptingContactRequest"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["acceptingBusinessRequest"], /) -> Callable[ + [Callable[["AcceptingBusinessRequest"], Awaitable[None]]], + Callable[["AcceptingBusinessRequest"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["contactConnecting"], /) -> Callable[ + [Callable[["ContactConnecting"], Awaitable[None]]], + Callable[["ContactConnecting"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["businessLinkConnecting"], /) -> Callable[ + [Callable[["BusinessLinkConnecting"], Awaitable[None]]], + Callable[["BusinessLinkConnecting"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["joinedGroupMemberConnecting"], /) -> Callable[ + [Callable[["JoinedGroupMemberConnecting"], Awaitable[None]]], + Callable[["JoinedGroupMemberConnecting"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["sentGroupInvitation"], /) -> Callable[ + [Callable[["SentGroupInvitation"], Awaitable[None]]], + Callable[["SentGroupInvitation"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["groupLinkConnecting"], /) -> Callable[ + [Callable[["GroupLinkConnecting"], Awaitable[None]]], + Callable[["GroupLinkConnecting"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["hostConnected"], /) -> Callable[ + [Callable[["HostConnected"], Awaitable[None]]], + Callable[["HostConnected"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["hostDisconnected"], /) -> Callable[ + [Callable[["HostDisconnected"], Awaitable[None]]], + Callable[["HostDisconnected"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["subscriptionStatus"], /) -> Callable[ + [Callable[["SubscriptionStatus"], Awaitable[None]]], + Callable[["SubscriptionStatus"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["messageError"], /) -> Callable[ + [Callable[["MessageError"], Awaitable[None]]], + Callable[["MessageError"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["chatError"], /) -> Callable[ + [Callable[["ChatError"], Awaitable[None]]], + Callable[["ChatError"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: Literal["chatErrors"], /) -> Callable[ + [Callable[["ChatErrors"], Awaitable[None]]], + Callable[["ChatErrors"], Awaitable[None]], + ]: ... + + @overload + def __call__(self, event: str, /) -> Callable[ + [Callable[["ChatEvent"], Awaitable[None]]], + Callable[["ChatEvent"], Awaitable[None]], + ]: ... diff --git a/packages/simplex-chat-python/tests/test_bot_registration.py b/packages/simplex-chat-python/tests/test_bot_registration.py index 7401d2ef5d..f6f245c344 100644 --- a/packages/simplex-chat-python/tests/test_bot_registration.py +++ b/packages/simplex-chat-python/tests/test_bot_registration.py @@ -1,6 +1,6 @@ import pytest -from simplex_chat import Bot, BotCommand, BotProfile, Middleware, SqliteDb +from simplex_chat import Bot, BotCommand, BotProfile, Client, Middleware, Profile, SqliteDb from simplex_chat.api import ChatApi @@ -57,9 +57,9 @@ def test_command_keyword_tuple(): def test_bot_profile_to_wire_default(): - """use_bot_profile=True (default) sets peerType=bot and disables calls/voice.""" + """Bot's profile wire-form sets peerType=bot and disables calls/voice.""" bot = _bot() - p = bot._bot_profile_to_wire() + p = bot._profile_to_wire() assert p["displayName"] == "x" assert p.get("peerType") == "bot" prefs = p.get("preferences") or {} @@ -74,7 +74,7 @@ def test_bot_profile_to_wire_allow_files(): db=SqliteDb(file_prefix="/tmp/test"), allow_files=True, ) - prefs = bot._bot_profile_to_wire().get("preferences") or {} + prefs = bot._profile_to_wire().get("preferences") or {} assert prefs.get("files", {}).get("allow") == "yes" @@ -84,32 +84,26 @@ def test_bot_profile_to_wire_with_commands(): db=SqliteDb(file_prefix="/tmp/test"), commands=[BotCommand(keyword="ping", label="Ping bot"), BotCommand("help", "Show help")], ) - cmds = bot._bot_profile_to_wire().get("preferences", {}).get("commands") or [] + cmds = bot._profile_to_wire().get("preferences", {}).get("commands") or [] assert len(cmds) == 2 assert cmds[0] == {"type": "command", "keyword": "ping", "label": "Ping bot"} assert cmds[1] == {"type": "command", "keyword": "help", "label": "Show help"} -def test_bot_profile_to_wire_no_bot_profile(): - bot = Bot( - profile=BotProfile(display_name="x"), - db=SqliteDb(file_prefix="/tmp/test"), - use_bot_profile=False, - ) - p = bot._bot_profile_to_wire() +def test_client_profile_to_wire_has_no_bot_extras(): + """Client's wire profile has no peerType=bot, no command list, no calls/voice prefs. + That's the whole point of having Client as a separate class.""" + c = Client(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test")) + p = c._profile_to_wire() + assert p["displayName"] == "x" assert "peerType" not in p assert "preferences" not in p -def test_commands_without_bot_profile_raises(): - bot = Bot( - profile=BotProfile(display_name="x"), - db=SqliteDb(file_prefix="/tmp/test"), - use_bot_profile=False, - commands=[BotCommand("ping", "Ping bot")], - ) - with pytest.raises(ValueError, match="use_bot_profile=False"): - bot._bot_profile_to_wire() +def test_bot_profile_alias_is_profile(): + """`BotProfile` is kept as an alias for backwards compatibility.""" + assert BotProfile is Profile + assert BotProfile(display_name="x") == Profile(display_name="x") def test_dispatch_message_first_match_wins(): diff --git a/packages/simplex-chat-python/tests/test_client_and_waiters.py b/packages/simplex-chat-python/tests/test_client_and_waiters.py new file mode 100644 index 0000000000..7c01ae576a --- /dev/null +++ b/packages/simplex-chat-python/tests/test_client_and_waiters.py @@ -0,0 +1,616 @@ +"""Tests for Client class + connect_to / send_and_wait / events plumbing. + +Stubs out ChatApi so we exercise the dispatch and waiter logic without +spinning up the native libsimplex controller. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from simplex_chat import ( + Bot, + BotProfile, + Client, + ContactAlreadyExistsError, + Profile, + SqliteDb, +) + + +class FakeApi: + """Drop-in replacement for ChatApi for tests that don't need the FFI. + + Records api_send_text_message calls; supports scripting api_connect_plan + and api_connect_active_user behaviour. + """ + + def __init__(self) -> None: + self.sent: list[tuple[Any, str]] = [] + self.connect_plan_result: Any = ("error", None) # default: no known contact + self.connect_should_raise: Exception | None = None + self.active_user: dict[str, Any] = {"userId": 1, "profile": {"displayName": "x"}} + + async def api_send_text_message(self, chat, text, in_reply_to=None): + self.sent.append((chat, text)) + return [] + + async def api_connect_plan(self, _user_id, _link): + kind = self.connect_plan_result[0] + if kind == "known_contact_address": + return ( + { + "type": "contactAddress", + "contactAddressPlan": {"type": "known", "contact": self.connect_plan_result[1]}, + }, + {}, + ) + if kind == "known_invitation": + return ( + { + "type": "invitationLink", + "invitationLinkPlan": {"type": "known", "contact": self.connect_plan_result[1]}, + }, + {}, + ) + if kind == "ok": + return ( + { + "type": "contactAddress", + "contactAddressPlan": {"type": "ok"}, + }, + {}, + ) + # default "error" + return ({"type": "error", "chatError": {}}, {}) + + async def api_connect_active_user(self, _link): + if self.connect_should_raise is not None: + raise self.connect_should_raise + return "contact" + + async def api_get_active_user(self): + return self.active_user + + +def _bot_with_fake_api() -> tuple[Bot, FakeApi]: + bot = Bot(profile=BotProfile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test")) + api = FakeApi() + bot._api = api # type: ignore[assignment] + bot._serving = True # pretend receive loop is up + return bot, api + + +# --------------------------------------------------------------------------- +# Client class +# --------------------------------------------------------------------------- + + +def test_client_has_no_address_or_bot_profile_attributes(): + """Client should not carry bot-side state (address creation, auto-accept, + welcome, commands). That's the whole point of separating Client from Bot.""" + c = Client(profile=Profile(display_name="monitor"), db=SqliteDb(file_prefix="/tmp/test")) + for attr in ("_create_address", "_update_address", "_auto_accept", "_welcome", "_commands"): + assert not hasattr(c, attr), f"Client unexpectedly has Bot-only attribute {attr}" + # And the wire profile has no bot peerType + p = c._profile_to_wire() + assert "peerType" not in p + assert "preferences" not in p + + +def test_bot_is_a_client_subclass(): + """Bot should extend Client, so anywhere a Client is accepted, a Bot fits too.""" + assert issubclass(Bot, Client) + + +def test_client_exposes_messaging_methods(): + c = Client(profile=Profile(display_name="m"), db=SqliteDb(file_prefix="/tmp/test")) + assert hasattr(c, "connect_to") + assert hasattr(c, "send_and_wait") + assert hasattr(c, "events") + assert hasattr(c, "on_message") # decorators available on Client too + + +# --------------------------------------------------------------------------- +# send_and_wait +# --------------------------------------------------------------------------- + + +def test_send_and_wait_requires_serving(): + """Without the receive loop running, send_and_wait must raise — otherwise + callers would silently hang waiting for a reply that's never dispatched.""" + bot = Bot(profile=BotProfile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test")) + bot._api = FakeApi() # type: ignore[assignment] + # _serving is False by default + with pytest.raises(RuntimeError, match="receive loop"): + asyncio.run(bot.send_and_wait(1, "hi")) + + +def test_send_and_wait_resolves_on_matching_reply(): + """A reply from the awaited contact should resolve the Future and skip + regular message dispatch.""" + bot, api = _bot_with_fake_api() + fallback_calls: list[str] = [] + + @bot.on_message(content_type="text") + async def fallback(_msg): + fallback_calls.append("fallback") + + async def go() -> str: + send_task = asyncio.create_task(bot.send_and_wait(42, "ping", timeout=2.0)) + # Yield so the task gets to register its waiter. + await asyncio.sleep(0) + evt = {"type": "newChatItems", "chatItems": [ + { + "chatInfo": {"type": "direct", "contact": {"contactId": 42}}, + "chatItem": { + "content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "pong"}}, + }, + } + ]} + await bot._dispatch_event(evt) # type: ignore[arg-type] + reply = await send_task + return reply.text or "" + + result = asyncio.run(go()) + assert result == "pong" + assert api.sent == [(["direct", 42], "ping")] + assert fallback_calls == [], "fallback handler should NOT fire when a waiter consumed the reply" + + +def test_send_and_wait_ignores_other_contacts(): + """Replies from a different contact must not resolve the waiter — that + would mis-correlate responses and is the bug send_and_wait exists to + prevent users from writing themselves.""" + bot, _api = _bot_with_fake_api() + + async def go(): + send_task = asyncio.create_task(bot.send_and_wait(42, "ping", timeout=0.5)) + await asyncio.sleep(0) + evt = {"type": "newChatItems", "chatItems": [ + { + "chatInfo": {"type": "direct", "contact": {"contactId": 99}}, + "chatItem": { + "content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "not for you"}}, + }, + } + ]} + await bot._dispatch_event(evt) # type: ignore[arg-type] + with pytest.raises(asyncio.TimeoutError): + await send_task + + asyncio.run(go()) + + +def test_send_and_wait_fifo_within_contact(): + """Two concurrent waiters on the same contact should resolve in send order.""" + bot, _api = _bot_with_fake_api() + + async def go() -> tuple[str, str]: + first = asyncio.create_task(bot.send_and_wait(42, "first", timeout=2.0)) + await asyncio.sleep(0) + second = asyncio.create_task(bot.send_and_wait(42, "second", timeout=2.0)) + await asyncio.sleep(0) + for text in ("reply1", "reply2"): + evt = {"type": "newChatItems", "chatItems": [ + { + "chatInfo": {"type": "direct", "contact": {"contactId": 42}}, + "chatItem": { + "content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": text}}, + }, + } + ]} + await bot._dispatch_event(evt) # type: ignore[arg-type] + return (await first).text or "", (await second).text or "" + + a, b = asyncio.run(go()) + assert (a, b) == ("reply1", "reply2") + + +def test_send_and_wait_cleans_up_state_on_timeout(): + """Timed-out waiters must be removed so they don't accidentally consume + later replies.""" + bot, _api = _bot_with_fake_api() + + async def go(): + with pytest.raises(asyncio.TimeoutError): + await bot.send_and_wait(42, "ping", timeout=0.05) + assert 42 not in bot._reply_waiters, f"leaked waiters: {bot._reply_waiters}" + + asyncio.run(go()) + + +def test_dispatch_skips_cancelled_waiters_and_falls_through_to_handlers(): + """Race fix: if a waiter is cancelled (wait_for timed out) but still in + the FIFO when a reply arrives, the dispatcher must skip it and either + resolve a live waiter OR fall through to user message handlers — not + silently drop the message.""" + bot, _api = _bot_with_fake_api() + fallback_calls: list[str] = [] + + @bot.on_message(content_type="text") + async def fallback(msg): + fallback_calls.append(msg.text or "") + + async def go(): + # Manually inject a cancelled waiter (simulating wait_for timeout + # cleanup losing the race with the inbound message). + loop = asyncio.get_running_loop() + stale: asyncio.Future = loop.create_future() + stale.cancel() + bot._reply_waiters[42] = [stale] + + evt = {"type": "newChatItems", "chatItems": [ + { + "chatInfo": {"type": "direct", "contact": {"contactId": 42}}, + "chatItem": { + "content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "racing reply"}}, + }, + } + ]} + await bot._dispatch_event(evt) # type: ignore[arg-type] + + asyncio.run(go()) + assert fallback_calls == ["racing reply"], ( + "dispatcher dropped the message instead of falling through to user handlers; " + f"got {fallback_calls}" + ) + assert 42 not in bot._reply_waiters, "cancelled waiter wasn't cleaned up" + + +def test_send_and_wait_parallel_different_contacts(): + """Concurrent send_and_wait to different contacts must not block each other. + + The library docstring promises this; this test pins the behaviour so a + future refactor (e.g., adding a single lock) can't quietly break it.""" + bot, _api = _bot_with_fake_api() + + async def go() -> tuple[str, str]: + t_a = asyncio.create_task(bot.send_and_wait(10, "a", timeout=2.0)) + await asyncio.sleep(0) + t_b = asyncio.create_task(bot.send_and_wait(20, "b", timeout=2.0)) + await asyncio.sleep(0) + # Deliver reply for B first — order shouldn't matter. + await bot._dispatch_event({"type": "newChatItems", "chatItems": [ # type: ignore[arg-type] + { + "chatInfo": {"type": "direct", "contact": {"contactId": 20}}, + "chatItem": {"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "B"}}}, + } + ]}) + await bot._dispatch_event({"type": "newChatItems", "chatItems": [ # type: ignore[arg-type] + { + "chatInfo": {"type": "direct", "contact": {"contactId": 10}}, + "chatItem": {"content": {"type": "rcvMsgContent", "msgContent": {"type": "text", "text": "A"}}}, + } + ]}) + return (await t_a).text or "", (await t_b).text or "" + + a, b = asyncio.run(go()) + assert (a, b) == ("A", "B") + + +# --------------------------------------------------------------------------- +# connect_to +# --------------------------------------------------------------------------- + + +def test_connect_to_returns_known_contact_without_handshake(): + """If the link is already known, connect_to skips api_connect entirely.""" + bot, api = _bot_with_fake_api() + existing = {"contactId": 7, "profile": {"displayName": "SimpleX Directory"}} + api.connect_plan_result = ("known_contact_address", existing) + + contact = asyncio.run(bot.connect_to("link", timeout=2.0)) + assert contact["contactId"] == 7 + # No connect issued: send buffer untouched. + assert api.sent == [] + + +def test_connect_to_waits_for_contactConnected(): + """For unknown links, connect_to issues the handshake and waits for the + contactConnected event before returning.""" + bot, api = _bot_with_fake_api() + api.connect_plan_result = ("ok", None) + new_contact = {"contactId": 11, "profile": {"displayName": "Friend"}} + + async def go(): + connect_task = asyncio.create_task(bot.connect_to("link", timeout=2.0)) + await asyncio.sleep(0) + await bot._dispatch_event({"type": "contactConnected", "contact": new_contact}) # type: ignore[arg-type] + return await connect_task + + contact = asyncio.run(go()) + assert contact["contactId"] == 11 + + +def test_connect_to_tolerates_contact_already_exists(): + """ContactAlreadyExistsError must NOT abort connect_to — a previous + incomplete attempt may have left the connection mid-handshake; the + contactConnected event will still arrive.""" + bot, api = _bot_with_fake_api() + api.connect_plan_result = ("ok", None) + api.connect_should_raise = ContactAlreadyExistsError( + "exists", {"type": "contactAlreadyExists"} # type: ignore[arg-type] + ) + + async def go(): + connect_task = asyncio.create_task(bot.connect_to("link", timeout=2.0)) + await asyncio.sleep(0) + await bot._dispatch_event({"type": "contactConnected", "contact": {"contactId": 5, "profile": {"displayName": "Friend"}}}) # type: ignore[arg-type] + return await connect_task + + contact = asyncio.run(go()) + assert contact["contactId"] == 5 + + +def test_connect_to_requires_serving(): + bot = Bot(profile=BotProfile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test")) + bot._api = FakeApi() # type: ignore[assignment] + with pytest.raises(RuntimeError, match="receive loop"): + asyncio.run(bot.connect_to("link")) + + +def test_connect_to_timeout_cleans_up_waiter(): + bot, api = _bot_with_fake_api() + api.connect_plan_result = ("ok", None) + + async def go(): + with pytest.raises(asyncio.TimeoutError): + await bot.connect_to("link", timeout=0.05) + assert bot._connect_waiters == [], "leaked connect waiter" + + asyncio.run(go()) + + +def test_connect_to_rejects_non_positive_timeout(): + """timeout<=0 must fail upfront — otherwise wait_for raises after the + handshake side-effect has already gone over the wire.""" + bot, _api = _bot_with_fake_api() + + async def go(): + for bad in (0, -1, -0.001): + with pytest.raises(ValueError, match="timeout must be positive"): + await bot.connect_to("link", timeout=bad) + + asyncio.run(go()) + + +def test_send_and_wait_rejects_non_positive_timeout(): + """Same as connect_to: timeout<=0 would surprise the caller with a sent + message and no Future to await.""" + bot, api = _bot_with_fake_api() + + async def go(): + for bad in (0, -1, -0.5): + with pytest.raises(ValueError, match="timeout must be positive"): + await bot.send_and_wait(42, "ping", timeout=bad) + # And nothing was sent. + assert api.sent == [] + + asyncio.run(go()) + + +def test_stop_before_serve_forever_is_preserved(monkeypatch): + """If stop() is called between __aenter__ and serve_forever (e.g. a + signal handler fires during the window where run() wires SIGINT), the + pre-set _stop_event must NOT be cleared by serve_forever — otherwise + the signal is silently lost and the loop runs indefinitely.""" + import simplex_chat.client as client_mod + + class _FakeApi: + @classmethod + async def init(cls, *_a, **_kw): + return cls() + + @property + def started(self): + return False + + async def start_chat(self): + pass + + async def stop_chat(self): + pass + + async def close(self): + pass + + async def api_get_active_user(self): + return {"userId": 1, "profile": {"displayName": "x"}} + + async def recv_chat_event(self, wait_us=0): + # Should NOT be reached — the loop should exit on the pre-set + # stop event before it ever polls for an event. + raise AssertionError("receive loop should have exited immediately") + + # _ensure_active_user / _maybe_sync_profile pokes + async def send_chat_cmd(self, _cmd): + return {"type": "cmdOk"} + + monkeypatch.setattr(client_mod, "ChatApi", _FakeApi) + + c = Client(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test")) + + async def go(): + async with c: + c.stop() # signal fires before serve_forever + await c.serve_forever() # must not block + + asyncio.run(go()) + + +def test_aexit_nulls_api_even_if_close_raises(monkeypatch): + """If `close()` raises inside __aexit__, the Client must still appear + closed — `client.api` should refuse to hand back the half-shutdown + controller, and re-entering the context manager should re-init cleanly.""" + import simplex_chat.client as client_mod + + init_count = [0] + + class _BoomCloseApi: + @classmethod + async def init(cls, *_a, **_kw): + init_count[0] += 1 + return cls() + + @property + def started(self): + return False + + async def start_chat(self): + pass + + async def stop_chat(self): + pass + + async def close(self): + raise RuntimeError("close failed") + + async def api_get_active_user(self): + return {"userId": 1, "profile": {"displayName": "x"}} + + async def send_chat_cmd(self, _cmd): + return {"type": "cmdOk"} + + monkeypatch.setattr(client_mod, "ChatApi", _BoomCloseApi) + + c = Client(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test")) + + async def go(): + with pytest.raises(RuntimeError, match="close failed"): + async with c: + pass + # _api must be None despite close() raising + assert c._api is None, "Client._api leaked after __aexit__ close() raised" + with pytest.raises(RuntimeError, match="not initialized"): + _ = c.api + # Re-enter must work + try: + async with c: + pass + except RuntimeError: + pass # close raises again, fine + assert init_count[0] == 2, "re-entry didn't re-init the controller" + + asyncio.run(go()) + + +def test_aenter_rolls_back_partial_init_on_post_start_failure(monkeypatch): + """If anything in __aenter__ raises after ChatApi.init succeeded — including + _post_start — the controller must be closed. Otherwise the with-block isn't + entered, __aexit__ never runs, and the FFI handle leaks.""" + import simplex_chat.client as client_mod + + closed: list[str] = [] + started: list[bool] = [False] + + class FakeChatApi: + @classmethod + async def init(cls, *_args, **_kwargs): + return cls() + + @property + def started(self) -> bool: + return started[0] + + async def start_chat(self): + started[0] = True + + async def stop_chat(self): + started[0] = False + closed.append("stop") + + async def close(self): + closed.append("close") + + # Stub the bits _ensure_active_user / _maybe_sync_profile reach for. + async def api_get_active_user(self): + return {"userId": 1, "profile": {"displayName": "x"}} + + async def send_chat_cmd(self, _cmd): + return {"type": "cmdOk"} + + monkeypatch.setattr(client_mod, "ChatApi", FakeChatApi) + + class Boom(RuntimeError): + pass + + class BoomClient(Client): + async def _post_start(self, user): + raise Boom("kaboom") + + c = BoomClient(profile=Profile(display_name="x"), db=SqliteDb(file_prefix="/tmp/test")) + + async def go(): + with pytest.raises(Boom): + async with c: + pytest.fail("should not enter the with-block") + + asyncio.run(go()) + assert closed == ["stop", "close"], f"controller not cleaned up: {closed}" + assert c._api is None, "Client._api should be reset to None after rollback" + + +def test_lookup_known_contact_propagates_non_command_errors(): + """_lookup_known_contact must NOT mask transport / FFI errors as 'unknown + link' — only ChatCommandError (malformed link, etc.) should fall through + to the handshake path. Bare Exception catch would hide real bugs.""" + bot, api = _bot_with_fake_api() + + class BoomError(RuntimeError): + pass + + async def boom(_user_id, _link): + raise BoomError("FFI wedged") + + api.api_connect_plan = boom # type: ignore[assignment] + + async def go(): + with pytest.raises(BoomError): + await bot._lookup_known_contact(1, "link") + + asyncio.run(go()) + + +# --------------------------------------------------------------------------- +# Exception subclasses +# --------------------------------------------------------------------------- + + +def test_contact_already_exists_is_chat_command_error_subclass(): + """Callers should be able to catch the base class to handle all command + errors uniformly, and the specific subclass for targeted handling.""" + from simplex_chat import ChatCommandError, ContactAlreadyExistsError + + assert issubclass(ContactAlreadyExistsError, ChatCommandError) + + e = ContactAlreadyExistsError("x", {"type": "contactAlreadyExists"}) # type: ignore[arg-type] + assert isinstance(e, ChatCommandError) + assert e.response_type == "contactAlreadyExists" + + +def test_chat_command_error_response_type_property(): + from simplex_chat import ChatCommandError + + e = ChatCommandError("x", {"type": "someError"}) # type: ignore[arg-type] + assert e.response_type == "someError" + + +# --------------------------------------------------------------------------- +# events() mutual exclusion with serve_forever +# --------------------------------------------------------------------------- + + +def test_events_raises_if_already_serving(): + bot, _api = _bot_with_fake_api() + # _serving=True is set by _bot_with_fake_api + + async def go(): + with pytest.raises(RuntimeError, match="mutually exclusive"): + async for _ in bot.events(): + pass + + asyncio.run(go()) diff --git a/packages/simplex-chat-python/tests/test_filters.py b/packages/simplex-chat-python/tests/test_filters.py index 08fb66ed92..3c909df4df 100644 --- a/packages/simplex-chat-python/tests/test_filters.py +++ b/packages/simplex-chat-python/tests/test_filters.py @@ -3,7 +3,7 @@ import re from simplex_chat.filters import compile_message_filter -def _msg(content_type="text", text=None, chat_type="direct", group_id=None): +def _msg(content_type="text", text=None, chat_type="direct", group_id=None, contact_id=None): """Build a minimal mock Message-like object for filter testing.""" class M: @@ -11,12 +11,12 @@ def _msg(content_type="text", text=None, chat_type="direct", group_id=None): m = M() m.content = {"type": content_type, "text": text} if text is not None else {"type": content_type} - m.chat_item = { - "chatInfo": { - "type": chat_type, - **({"groupInfo": {"groupId": group_id}} if chat_type == "group" else {}), - } - } + chat_info: dict = {"type": chat_type} + if chat_type == "group": + chat_info["groupInfo"] = {"groupId": group_id} + elif chat_type == "direct" and contact_id is not None: + chat_info["contact"] = {"contactId": contact_id} + m.chat_item = {"chatInfo": chat_info} return m @@ -81,3 +81,23 @@ def test_group_id_tuple_or(): f = compile_message_filter({"group_id": (1, 2, 3)}) assert f(_msg(chat_type="group", group_id=2)) assert not f(_msg(chat_type="group", group_id=99)) + + +def test_contact_id_filter(): + f = compile_message_filter({"contact_id": 7}) + assert f(_msg(chat_type="direct", contact_id=7)) + assert not f(_msg(chat_type="direct", contact_id=99)) + assert not f(_msg(chat_type="group", group_id=7)) + + +def test_contact_id_tuple_or(): + f = compile_message_filter({"contact_id": (1, 2, 3)}) + assert f(_msg(chat_type="direct", contact_id=2)) + assert not f(_msg(chat_type="direct", contact_id=99)) + + +def test_contact_id_combined_with_content_type(): + f = compile_message_filter({"content_type": "text", "contact_id": 5}) + assert f(_msg(content_type="text", chat_type="direct", contact_id=5)) + assert not f(_msg(content_type="image", chat_type="direct", contact_id=5)) + assert not f(_msg(content_type="text", chat_type="direct", contact_id=99))