diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py new file mode 100644 index 00000000..1e0bdcea --- /dev/null +++ b/taskiq/brokers/nng/__init__.py @@ -0,0 +1,53 @@ +"""NNG broker package for taskiq.""" +from .hub import HubConfig, NNGHub +from .protocol import ( + ControlMessage, + ControlResponse, + MessageKind, + TaskEnvelope, + WorkerState, + WorkerStatus, +) +from .storage import ( + AffinityPolicy, + InMemoryStore, + LeastLoaded, + PowerOfTwoChoices, + PriorityScheduler, + QueueFullError, + RoutingPolicy, + RoundRobin, + Scheduler, + StoreConfig, + TaskContext, + WorkerView, + make_routing_policy, +) + +__all__ = [ + "HubConfig", + "NNGHub", + # protocol + "ControlMessage", + "ControlResponse", + "MessageKind", + "TaskEnvelope", + "WorkerState", + "WorkerStatus", + # store + "QueueFullError", + "InMemoryStore", + "StoreConfig", + # routing + "TaskContext", + "WorkerView", + "RoutingPolicy", + "AffinityPolicy", + "LeastLoaded", + "PowerOfTwoChoices", + "RoundRobin", + "make_routing_policy", + # scheduler + "Scheduler", + "PriorityScheduler", +] diff --git a/taskiq/brokers/nng/broker.py b/taskiq/brokers/nng/broker.py new file mode 100644 index 00000000..a6273e41 --- /dev/null +++ b/taskiq/brokers/nng/broker.py @@ -0,0 +1,328 @@ +"""NNG broker for taskiq — backed by a standalone :class:`NNGHub`.""" +from __future__ import annotations + +import asyncio +import base64 +import logging +import os +import tempfile +import time +import uuid +from collections.abc import AsyncGenerator, Callable +from contextlib import suppress +from typing import Any, TypeVar + +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.acks import AckableMessage +from taskiq.message import BrokerMessage + +from .protocol import ( + ControlMessage, + ControlResponse, + TaskEnvelope, + WorkerState, + WorkerStatus, +) + +try: + import pynng # type: ignore +except ImportError: + pynng = None # type: ignore[assignment] + +_T = TypeVar("_T") + +logger = logging.getLogger(__name__) + + +def _ipc_addr(prefix: str = "taskiq-nng") -> str: + name = f"{prefix}-{os.getpid()}-{uuid.uuid4().hex[:8]}.ipc" + return f"ipc://{os.path.join(tempfile.gettempdir(), name)}" + + +class NNGBroker(AsyncBroker): + """ + Taskiq broker backed by a standalone :class:`~taskiq.brokers.nng_hub.NNGHub`. + + The hub must be running before workers or clients start. Launch it with:: + + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc + + **Client mode** (``is_worker_process = False``) + Only the control socket is opened. :meth:`kick` submits tasks to the + hub via a Req0 → Rep0 round-trip. + + **Worker mode** (``is_worker_process = True``) + In addition to the control socket the broker opens a unique Pull0 + socket, registers with the hub, and runs a heartbeat loop. + :meth:`listen` yields :class:`~taskiq.acks.AckableMessage` instances + whose ``ack`` callback sends the correct ``lease_id`` back to the hub. + + Thread / coroutine safety + ───────────────────────── + ``Req0`` is strictly serial (one request in-flight per socket). + ``_ctrl_lock`` serialises all :meth:`_send_control` calls so that + concurrent coroutines (heartbeat + ack + kick) never interleave frames. + + Ack correctness + ─────────────── + The hub embeds the dispatch-generated ``lease_id`` inside every + :class:`~taskiq.brokers.nng_protocol.TaskEnvelope`. The ack closure + captures it directly, so validation on the hub side always succeeds for + genuine acks and correctly rejects late/duplicate ones. + """ + + def __init__( + self, + control_addr: str, + *, + result_backend: "AsyncResultBackend[_T] | None" = None, + task_id_generator: Callable[[], str] | None = None, + worker_task_addr: str | None = None, + worker_id: str | None = None, + heartbeat_interval: float = 5.0, + lease_timeout: float = 20.0, + capacity: int = 1, + max_retries: int = 0, + retry_backoff: float = 1.0, + retry_jitter: float = 0.0, + recv_timeout_ms: int = 5_000, + send_timeout_ms: int = 5_000, + ) -> None: + """ + Initialise the NNG broker. + + :param control_addr: NNG address of the hub's Rep0 control socket. + :param result_backend: optional result backend. + :param task_id_generator: optional task ID generator. + :param worker_task_addr: NNG address this worker's Pull0 listens on. + Defaults to a unique per-process IPC path. + :param worker_id: stable identifier for this worker process. + Defaults to ``-``. + :param heartbeat_interval: seconds between heartbeat messages to hub. + :param lease_timeout: seconds a dispatched task lease remains valid. + :param capacity: max concurrent tasks this worker will accept. + :param max_retries: default max retries for submitted tasks. + :param retry_backoff: base seconds for exponential backoff. + :param retry_jitter: jitter multiplier added to backoff (0 = no jitter). + :param recv_timeout_ms: Req0 recv timeout in milliseconds. + :param send_timeout_ms: Req0 send timeout in milliseconds. + """ + if pynng is None: + raise RuntimeError( + "pynng is required to use NNGBroker. " + "Install it with: pip install taskiq[nng]", + ) + super().__init__( + result_backend=result_backend, + task_id_generator=task_id_generator, + ) + self.control_addr = control_addr + self.worker_task_addr = worker_task_addr or _ipc_addr() + self.worker_id = worker_id or f"{os.getpid()}-{uuid.uuid4().hex[:12]}" + self.heartbeat_interval = heartbeat_interval + self.lease_timeout = lease_timeout + self.capacity = capacity + self.max_retries = max_retries + self.retry_backoff = retry_backoff + self.retry_jitter = retry_jitter + self.recv_timeout_ms = recv_timeout_ms + self.send_timeout_ms = send_timeout_ms + + self._ctrl_sock: Any = None # pynng.Req0 + self._task_sock: Any = None # pynng.Pull0 (worker mode only) + self._heartbeat_task: asyncio.Task[None] | None = None + # Req0 allows exactly one request in-flight; this lock enforces that. + self._ctrl_lock = asyncio.Lock() + + # ── lifecycle ───────────────────────────────────────────────────────────── + + async def startup(self) -> None: + """Open sockets, register with hub (worker mode), and start heartbeat.""" + self._ctrl_sock = pynng.Req0( + dial=self.control_addr, + recv_timeout=self.recv_timeout_ms, + send_timeout=self.send_timeout_ms, + ) + if self.is_worker_process: + # recv_buffer_size lets the hub pre-queue up to `capacity` task + # messages in NNG's recv buffer before listen() calls arecv(). + self._task_sock = pynng.Pull0( + listen=self.worker_task_addr, + recv_buffer_size=self.capacity, + ) + resp = await self._send_control( + "register", + { + "worker_id": self.worker_id, + "task_addr": self.worker_task_addr, + "capacity": self.capacity, + "inflight": 0, + "last_seen": time.time(), + "heartbeat_interval": self.heartbeat_interval, + "lease_timeout": self.lease_timeout, + "draining": False, + "status": str(WorkerStatus.STARTING), + "version": "taskiq-nng", + }, + ) + if not resp.ok: + raise RuntimeError(f"Worker registration failed: {resp.error}") + logger.info( + "Worker %s registered with hub at %s", + self.worker_id, + self.control_addr, + ) + self._heartbeat_task = asyncio.create_task( + self._heartbeat_loop(), + name=f"nng-hb-{self.worker_id[:8]}", + ) + await super().startup() + + async def shutdown(self) -> None: + """Drain, unregister, cancel heartbeat, and close all sockets.""" + if self.is_worker_process: + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + with suppress(asyncio.CancelledError): + await self._heartbeat_task + if self._ctrl_sock is not None: + with suppress(Exception): + await self._send_control( + "drain", {"worker_id": self.worker_id}, + ) + await self._send_control( + "unregister", {"worker_id": self.worker_id}, + ) + if self._task_sock is not None: + with suppress(Exception): + self._task_sock.close() + if self._ctrl_sock is not None: + with suppress(Exception): + self._ctrl_sock.close() + await super().shutdown() + + # ── internal helpers ────────────────────────────────────────────────────── + + async def _send_control( + self, kind: str, payload: dict[str, Any], + ) -> ControlResponse: + if self._ctrl_sock is None: + raise RuntimeError("Control socket is not open (call startup() first)") + async with self._ctrl_lock: + await self._ctrl_sock.asend( + ControlMessage(kind=kind, payload=payload).to_bytes(), + ) + raw = await self._ctrl_sock.arecv() + return ControlResponse.from_bytes(raw) + + async def _heartbeat_loop(self) -> None: + while True: + try: + await asyncio.sleep(self.heartbeat_interval) + resp = await self._send_control( + "heartbeat", {"worker_id": self.worker_id}, + ) + if not resp.ok: + logger.warning("Heartbeat rejected by hub: %s", resp.error) + except asyncio.CancelledError: + raise + except Exception as exc: + # Hub may be temporarily unreachable; log and keep trying. + logger.warning("Heartbeat failed: %s", exc) + + # ── AsyncBroker API ─────────────────────────────────────────────────────── + + async def kick(self, message: BrokerMessage) -> None: + """ + Submit a task to the hub for dispatch. + + :param message: broker message to submit. + :raises RuntimeError: if the broker has not been started or the hub + rejects the submission (e.g. queue full). + """ + if self._ctrl_sock is None: + raise RuntimeError("Broker is not started") + payload: dict[str, Any] = { + "task_id": message.task_id, + "task_name": message.task_name, + "payload_b64": base64.b64encode(message.message).decode("ascii"), + "labels": message.labels, + "lease_id": "", # hub assigns the real lease_id at dispatch time + "attempts": int(message.labels.get("attempts", 0)), + "max_retries": int( + message.labels.get("max_retries", self.max_retries), + ), + "retry_backoff": float( + message.labels.get("retry_backoff", self.retry_backoff), + ), + "retry_jitter": float( + message.labels.get("retry_jitter", self.retry_jitter), + ), + "priority": int(message.labels.get("priority", 0)), + "created_at": time.time(), + } + resp = await self._send_control("submit", payload) + if not resp.ok: + raise RuntimeError(resp.error or "task submission failed") + + async def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: + """ + Yield incoming tasks as :class:`~taskiq.acks.AckableMessage` instances. + + Each message's ``ack`` callback sends the hub-issued ``lease_id`` back + so the hub can validate the ack and reject any late/duplicate ones. + + :raises RuntimeError: if called outside worker mode or before startup. + :yields: ackable task messages. + """ + if not self.is_worker_process: + raise RuntimeError("listen() is only valid in worker mode") + if self._task_sock is None: + raise RuntimeError("Task socket is not open (call startup() first)") + + while True: + try: + raw = await self._task_sock.arecv() + except pynng.Closed: + logger.info("Task socket closed; stopping listen()") + return + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Task arecv error: %s", exc) + continue + + try: + envelope = TaskEnvelope.from_bytes(raw) + except Exception as exc: + logger.error("Malformed task envelope discarded: %s", exc) + continue + + task_id = envelope.task_id + worker_id = self.worker_id + lease_id = envelope.lease_id # hub-assigned; correct by construction + + async def _ack( + _task_id: str = task_id, + _worker_id: str = worker_id, + _lease_id: str = lease_id, + ) -> None: + try: + resp = await self._send_control( + "ack", + { + "task_id": _task_id, + "worker_id": _worker_id, + "lease_id": _lease_id, + }, + ) + if not resp.ok: + logger.debug( + "Ack rejected for %s (late/duplicate): %s", + _task_id, resp.error, + ) + except Exception as exc: + logger.warning("Ack send failed for %s: %s", _task_id, exc) + + yield AckableMessage(data=envelope.payload, ack=_ack) diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py new file mode 100644 index 00000000..f3920c6f --- /dev/null +++ b/taskiq/brokers/nng/hub.py @@ -0,0 +1,463 @@ +""" +NNG hub: central control plane, task dispatcher, and lease manager. + +Run as a standalone process:: + + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc + +Or embed it in an application for testing:: + + hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc")) + await hub.start() + ... + await hub.stop() +""" +from __future__ import annotations + +import argparse +import asyncio +import base64 +import logging +import os +import signal +import time +import uuid +from contextlib import suppress +from dataclasses import dataclass, field +from typing import Any + +try: + import pynng # type: ignore +except ImportError: + pynng = None # type: ignore[assignment] + +from .protocol import ( + ControlMessage, + ControlResponse, + TaskEnvelope, + WorkerState, +) +from .storage import ( + InMemoryStore, + PriorityScheduler, + QueueFullError, + RoutingPolicy, + Scheduler, + StoreConfig, + TaskContext, + make_routing_policy, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class HubConfig: + """Configuration for :class:`NNGHub`.""" + + control_addr: str + task_db: str = "" # kept for API compat; ignored by in-memory store + max_pending: int = 10_000 + heartbeat_timeout: float = 15.0 + lease_timeout: float = 20.0 + dispatch_interval: float = 0.05 + reaper_interval: float = 0.5 + routing_policy: RoutingPolicy | str = "least_loaded" + scheduler: Scheduler | None = None + backoff_cap: float = 60.0 + # Number of concurrent Rep0 contexts. Each context handles one req/rep + # pair independently; N contexts ≈ N simultaneous control-plane clients. + control_concurrency: int = 16 + dispatch_batch: int = 50 + # Per-context recv timeout in ms. Allows the stop event to be checked + # even when there are no incoming messages. + recv_timeout_ms: int = 1_000 + + +class NNGHub: + """ + Stateful central hub: control plane, task dispatcher, and lease manager. + + Architecture + ──────────── + **Control plane** — ``Rep0`` socket with ``control_concurrency`` + independent ``nng_ctx`` contexts running concurrently. Each context + handles one request-reply at a time, so N workers can + register/heartbeat/ack simultaneously without queuing behind each other. + + **Data plane** — One ``Push0`` socket per registered worker, dialed to + the worker's own ``Pull0`` listen address. The hub explicitly targets + the least-loaded worker instead of relying on NNG round-robin. + + **State** — :class:`~taskiq.brokers.nng.storage.InMemoryStore`. All + store operations are synchronous and execute directly on the asyncio event + loop without blocking (no I/O, no syscalls). + + **Recovery** — On startup, any tasks that were leased before the hub last + stopped (within the same process lifetime) are automatically requeued by + :meth:`~InMemoryStore.recover_dead_workers`. + """ + + def __init__(self, config: HubConfig) -> None: + """ + Initialise the hub with the given configuration. + + :param config: hub configuration. + """ + if pynng is None: + raise RuntimeError( + "pynng is required to use NNGHub. " + "Install it with: pip install taskiq[nng]" + ) + self.config = config + self.store = InMemoryStore( + StoreConfig( + max_pending=config.max_pending, + lease_timeout=config.lease_timeout, + backoff_cap=config.backoff_cap, + ), + ) + # Resolve once at construction so RoundRobin and similar stateful + # policies maintain their counter across dispatch calls. + self._routing: RoutingPolicy = make_routing_policy(config.routing_policy) + self._scheduler: Scheduler = config.scheduler or PriorityScheduler() + self._stop = asyncio.Event() + self._ctrl_sock: Any = None # pynng.Rep0 + self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 + self._tasks: list[asyncio.Task[None]] = [] + + # ── lifecycle ───────────────────────────────────────────────────────────── + + async def start(self) -> None: + """Start the hub: recover orphaned tasks, open sockets, spawn loops.""" + self.store.recover_dead_workers(self.config.heartbeat_timeout) + + self._ctrl_sock = pynng.Rep0(listen=self.config.control_addr) + self._ctrl_sock.recv_timeout = self.config.recv_timeout_ms + + self._tasks = [ + asyncio.create_task(self._dispatch_loop(), name="hub-dispatch"), + asyncio.create_task(self._reaper_loop(), name="hub-reaper"), + ] + for i in range(self.config.control_concurrency): + ctx = self._ctrl_sock.new_context() + self._tasks.append( + asyncio.create_task( + self._control_handler(ctx), name=f"hub-ctrl-{i}" + ), + ) + logger.info("NNG hub started on %s", self.config.control_addr) + + async def stop(self) -> None: + """Gracefully stop all hub loops and close sockets.""" + logger.info("NNG hub stopping…") + self._stop.set() + for t in self._tasks: + t.cancel() + with suppress(asyncio.CancelledError): + await t + for sock in self._worker_push.values(): + with suppress(Exception): + sock.close() + self._worker_push.clear() + if self._ctrl_sock is not None: + with suppress(Exception): + self._ctrl_sock.close() + logger.info("NNG hub stopped") + + # ── control plane ───────────────────────────────────────────────────────── + + async def _control_handler(self, ctx: Any) -> None: + """Run one Rep0 context: receive → dispatch → reply, in a loop.""" + while not self._stop.is_set(): + try: + raw = await ctx.arecv() + except pynng.Timeout: + continue + except (pynng.Closed, asyncio.CancelledError): + break + except Exception as exc: + logger.warning("Control recv error: %s", exc) + continue + + try: + response = await self._handle(raw) + except Exception as exc: + logger.exception("Unhandled error in control handler") + response = ControlResponse(ok=False, error=str(exc)) + + try: + await ctx.asend(response.to_bytes()) + except (pynng.Closed, asyncio.CancelledError): + break + except Exception as exc: + logger.warning("Control send error: %s", exc) + + async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 + """Dispatch a raw control message to the appropriate handler.""" + msg = ControlMessage.from_bytes(raw) + + if msg.kind == "ping": + return ControlResponse(ok=True, payload={"pong": True}) + + if msg.kind == "submit": + return await self._handle_submit(msg.payload) + + if msg.kind == "register": + return await self._handle_register(msg.payload) + + if msg.kind == "heartbeat": + self.store.heartbeat(msg.payload["worker_id"]) + return ControlResponse(ok=True, payload={"ok": True}) + + if msg.kind == "unregister": + return await self._handle_unregister(msg.payload["worker_id"]) + + if msg.kind == "drain": + self.store.mark_draining(msg.payload["worker_id"]) + return ControlResponse(ok=True, payload={"draining": True}) + + if msg.kind == "ack": + ok = self.store.ack( + msg.payload["task_id"], + msg.payload["worker_id"], + msg.payload["lease_id"], + ) + return ControlResponse(ok=ok, payload={"acked": ok}) + + if msg.kind == "nack": + ok = self.store.nack( + msg.payload["task_id"], + msg.payload["worker_id"], + msg.payload["lease_id"], + msg.payload.get("error", "unknown error"), + ) + return ControlResponse(ok=ok, payload={"nacked": ok}) + + if msg.kind == "status": + task = self.store.get_task(msg.payload["task_id"]) + return ControlResponse(ok=bool(task), payload=task or {}) + + if msg.kind == "stats": + return ControlResponse(ok=True, payload=self.store.stats()) + + return ControlResponse(ok=False, error=f"unknown kind: {msg.kind!r}") + + async def _handle_submit(self, payload: dict[str, Any]) -> ControlResponse: + envelope = TaskEnvelope(**payload) + try: + self.store.submit(envelope) + return ControlResponse(ok=True, payload={"task_id": envelope.task_id}) + except QueueFullError: + return ControlResponse(ok=False, error="queue full") + + async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse: + worker = WorkerState(**payload) + self.store.register_worker(worker) + if worker.worker_id not in self._worker_push: + try: + sock = pynng.Push0(dial=worker.task_addr) + self._worker_push[worker.worker_id] = sock + except Exception as exc: + logger.error( + "Failed to dial worker %s at %s: %s", + worker.worker_id, worker.task_addr, exc, + ) + return ControlResponse(ok=False, error=f"dial failed: {exc}") + return ControlResponse(ok=True, payload={"registered": True}) + + async def _handle_unregister(self, worker_id: str) -> ControlResponse: + self.store.unregister_worker(worker_id) + sock = self._worker_push.pop(worker_id, None) + if sock is not None: + with suppress(Exception): + sock.close() + return ControlResponse(ok=True, payload={"unregistered": True}) + + # ── dispatch loop ───────────────────────────────────────────────────────── + + async def _dispatch_loop(self) -> None: + while not self._stop.is_set(): + try: + sent = await self._dispatch_once() + if not sent: + await asyncio.sleep(self.config.dispatch_interval) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Dispatch loop error") + await asyncio.sleep(self.config.dispatch_interval) + + async def _dispatch_once(self) -> bool: + """Dispatch up to ``dispatch_batch`` due tasks to available workers.""" + due = self._scheduler.select(self.store, self.config.dispatch_batch) + if not due: + return False + sent_any = False + for row in due: + task_ctx = TaskContext( + task_id=row["task_id"], + task_name=row["task_name"], + labels=row["labels"], + priority=int(row["priority"]), + attempts=int(row["attempts"]), + ) + worker = self.store.choose_worker( + self._routing, + heartbeat_timeout=self.config.heartbeat_timeout, + task=task_ctx, + ) + if worker is None: + return sent_any # no capacity; leave remaining tasks in queue + + worker_id = worker["worker_id"] + lease_id = uuid.uuid4().hex + lease_until = time.time() + self.config.lease_timeout + + if not self.store.mark_leased( + row["task_id"], worker_id, lease_id, lease_until, + ): + continue # concurrent dispatch race; task already taken + + sock = self._worker_push.get(worker_id) + if sock is None: + logger.warning( + "No push socket for worker %s, requeueing %s", + worker_id, row["task_id"], + ) + self.store.nack(row["task_id"], worker_id, lease_id, "no socket") + continue + + envelope = TaskEnvelope( + task_id=row["task_id"], + task_name=row["task_name"], + payload_b64=base64.b64encode(row["payload"]).decode("ascii"), + labels=row["labels"], + lease_id=lease_id, + attempts=int(row["attempts"]) + 1, + max_retries=int(row["max_retries"]), + retry_backoff=float(row["retry_backoff"]), + retry_jitter=float(row["retry_jitter"]), + priority=int(row["priority"]), + created_at=float(row["created_at"]), + ) + try: + await sock.asend(envelope.to_bytes()) + sent_any = True + except Exception as exc: + logger.warning( + "Failed to deliver %s to worker %s: %s", + row["task_id"], worker_id, exc, + ) + self.store.nack( + row["task_id"], worker_id, lease_id, + f"dispatch send failed: {exc}", + ) + return sent_any + + # ── reaper loop ─────────────────────────────────────────────────────────── + + async def _reaper_loop(self) -> None: + while not self._stop.is_set(): + try: + await asyncio.sleep(self.config.reaper_interval) + reaped = self.store.reap_expired_leases() + if reaped: + logger.debug("Reaped %d expired leases", reaped) + recovered = self.store.recover_dead_workers( + self.config.heartbeat_timeout, + ) + if recovered: + logger.info("Requeued %d tasks from dead workers", recovered) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Reaper loop error") + + +# ── standalone CLI entry point ──────────────────────────────────────────────── + +def _build_config() -> HubConfig: + p = argparse.ArgumentParser( + description="taskiq-nng-hub — NNG task router, dispatcher, and lease manager", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument( + "--control-addr", + default=os.getenv("NNG_CONTROL_ADDR", "ipc:///tmp/taskiq-nng.ipc"), + help="NNG address the hub listens on. Env: NNG_CONTROL_ADDR", + ) + p.add_argument( + "--max-pending", + type=int, + default=int(os.getenv("NNG_MAX_PENDING", "10000")), + ) + p.add_argument( + "--heartbeat-timeout", + type=float, + default=float(os.getenv("NNG_HEARTBEAT_TIMEOUT", "15.0")), + help="Seconds of silence before a worker is declared dead.", + ) + p.add_argument( + "--lease-timeout", + type=float, + default=float(os.getenv("NNG_LEASE_TIMEOUT", "20.0")), + help="Seconds before an unacked task lease is reaped.", + ) + p.add_argument( + "--routing-policy", + choices=["least_loaded", "p2c", "round_robin"], + default=os.getenv("NNG_ROUTING_POLICY", "least_loaded"), + ) + p.add_argument( + "--control-concurrency", + type=int, + default=int(os.getenv("NNG_CONTROL_CONCURRENCY", "16")), + help="Number of concurrent Rep0 contexts.", + ) + p.add_argument( + "--log-level", + default=os.getenv("NNG_LOG_LEVEL", "INFO"), + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + args = p.parse_args() + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s %(name)-24s %(levelname)-8s %(message)s", + ) + return HubConfig( + control_addr=args.control_addr, + max_pending=args.max_pending, + heartbeat_timeout=args.heartbeat_timeout, + lease_timeout=args.lease_timeout, + routing_policy=args.routing_policy, + control_concurrency=args.control_concurrency, + ) + + +async def _run(config: HubConfig) -> None: + hub = NNGHub(config) + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def _on_signal() -> None: + logger.info("Shutdown signal received") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, _on_signal) + + await hub.start() + try: + await stop_event.wait() + finally: + await hub.stop() + + +def main() -> None: + """Entry point for the ``taskiq-nng-hub`` CLI command.""" + config = _build_config() + try: + asyncio.run(_run(config)) + except KeyboardInterrupt: + pass diff --git a/taskiq/brokers/nng/protocol.py b/taskiq/brokers/nng/protocol.py new file mode 100644 index 00000000..9b0b4d8e --- /dev/null +++ b/taskiq/brokers/nng/protocol.py @@ -0,0 +1,159 @@ +"""Wire protocol types for the NNG broker.""" +from __future__ import annotations + +import base64 +import enum +import json +from dataclasses import asdict, dataclass, field +from typing import Any + + +class _StrValue(str, enum.Enum): + """Base for string enums whose str() returns the plain value (Python 3.10+).""" + + def __str__(self) -> str: + return self.value + + +class MessageKind(_StrValue): + """Kinds of control-plane messages sent between broker/client and hub.""" + + SUBMIT = "submit" + REGISTER = "register" + HEARTBEAT = "heartbeat" + UNREGISTER = "unregister" + DRAIN = "drain" + ACK = "ack" + NACK = "nack" + STATUS = "status" + STATS = "stats" + PING = "ping" + + +class TaskState(_StrValue): + """Lifecycle state of a task in the hub store.""" + + READY = "ready" + LEASED = "leased" + DONE = "done" + FAILED = "failed" + + +class WorkerStatus(_StrValue): + """Lifecycle status of a registered worker.""" + + STARTING = "starting" + LISTENING = "listening" + DRAINING = "draining" + OFFLINE = "offline" + DEAD = "dead" + + +@dataclass +class TaskEnvelope: + """ + Task payload sent from hub to worker over the data plane. + + ``lease_id`` is the UUID assigned by the hub at dispatch time. + Workers must echo it back in the ACK so the hub can validate + that the ack is not stale (e.g. after lease expiry and requeue). + """ + + task_id: str + task_name: str + payload_b64: str + labels: dict[str, Any] = field(default_factory=dict) + lease_id: str = "" + attempts: int = 0 + max_retries: int = 0 + retry_backoff: float = 1.0 + retry_jitter: float = 0.0 + priority: int = 0 + created_at: float = 0.0 + + @property + def payload(self) -> bytes: + """Decode the base-64 task payload.""" + return base64.b64decode(self.payload_b64.encode("ascii")) + + @classmethod + def from_bytes(cls, raw: bytes) -> TaskEnvelope: + """Deserialise from JSON bytes.""" + return cls(**json.loads(raw.decode("utf-8"))) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + asdict(self), separators=(",", ":"), ensure_ascii=False + ).encode("utf-8") + + +@dataclass +class ControlMessage: + """Request sent over the control plane (Req0 → Rep0).""" + + kind: str + payload: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_bytes(cls, raw: bytes) -> ControlMessage: + """Deserialise from JSON bytes.""" + data = json.loads(raw.decode("utf-8")) + return cls(kind=data["kind"], payload=data.get("payload", {})) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + {"kind": self.kind, "payload": self.payload}, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + + +@dataclass +class ControlResponse: + """Response sent back over the control plane (Rep0 → Req0).""" + + ok: bool + payload: dict[str, Any] = field(default_factory=dict) + error: str | None = None + + @classmethod + def from_bytes(cls, raw: bytes) -> ControlResponse: + """Deserialise from JSON bytes.""" + data = json.loads(raw.decode("utf-8")) + return cls( + ok=data["ok"], + payload=data.get("payload", {}), + error=data.get("error"), + ) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + {"ok": self.ok, "payload": self.payload, "error": self.error}, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + + +@dataclass +class WorkerState: + """Snapshot of a worker's identity and capacity at registration time.""" + + worker_id: str + task_addr: str + capacity: int + inflight: int = 0 + last_seen: float = 0.0 + heartbeat_interval: float = 5.0 + lease_timeout: float = 15.0 + draining: bool = False + status: WorkerStatus = WorkerStatus.STARTING + version: str = "unknown" + + def to_dict(self) -> dict[str, Any]: + """Convert to a plain dict, serialising the status enum to its string value.""" + d = asdict(self) + d["status"] = str(self.status) + return d diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py new file mode 100644 index 00000000..b804400a --- /dev/null +++ b/taskiq/brokers/nng/storage.py @@ -0,0 +1,666 @@ +"""Pure in-memory task store for the NNG hub — no external dependencies.""" +from __future__ import annotations + +import functools +import inspect +import random +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from .protocol import TaskEnvelope, WorkerState + + +@dataclass +class StoreConfig: + """Configuration for :class:`InMemoryStore`.""" + + path: str = "" # kept for API compat; not used + max_pending: int = 10_000 + lease_timeout: float = 30.0 + backoff_base: float = 1.0 + backoff_cap: float = 60.0 + + +class QueueFullError(RuntimeError): + """Raised when a submission is attempted on a full queue.""" + + +@dataclass +class _Task: + task_id: str + task_name: str + payload: bytes + labels: dict[str, Any] + state: str # ready / leased / done / failed + attempts: int = 0 + max_retries: int = 0 + retry_backoff: float = 1.0 + retry_jitter: float = 0.0 + priority: int = 0 + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + next_run_at: float = field(default_factory=time.time) + lease_id: str | None = None + leased_worker_id: str | None = None + lease_until: float | None = None + last_error: str | None = None + + def as_dict(self) -> dict[str, Any]: + """Return a dict view of this task record.""" + return { + "task_id": self.task_id, + "task_name": self.task_name, + "payload": self.payload, + "labels": self.labels, + "state": self.state, + "attempts": self.attempts, + "max_retries": self.max_retries, + "retry_backoff": self.retry_backoff, + "retry_jitter": self.retry_jitter, + "priority": self.priority, + "created_at": self.created_at, + "updated_at": self.updated_at, + "next_run_at": self.next_run_at, + "lease_id": self.lease_id, + "leased_worker_id": self.leased_worker_id, + "lease_until": self.lease_until, + "last_error": self.last_error, + } + + def as_status_dict(self) -> dict[str, Any]: + """Return a JSON-safe dict (no raw bytes) for control-plane status responses.""" + d = self.as_dict() + d.pop("payload", None) + return d + + +@dataclass +class _Worker: + worker_id: str + task_addr: str + capacity: int + inflight: int = 0 + last_seen: float = 0.0 + heartbeat_interval: float = 5.0 + lease_timeout: float = 15.0 + draining: bool = False + status: str = "starting" + version: str = "unknown" + + def as_dict(self) -> dict[str, Any]: + """Return a dict view of this worker record.""" + return { + "worker_id": self.worker_id, + "task_addr": self.task_addr, + "capacity": self.capacity, + "inflight": self.inflight, + "last_seen": self.last_seen, + "heartbeat_interval": self.heartbeat_interval, + "lease_timeout": self.lease_timeout, + "draining": self.draining, + "status": self.status, + "version": self.version, + } + + +# ── task context ───────────────────────────────────────────────────────────── + + +@dataclass +class TaskContext: + """Task metadata passed to context-aware routing policies (e.g. affinity).""" + + task_id: str + task_name: str + labels: dict[str, Any] + priority: int = 0 + attempts: int = 0 + + +# ── routing policy abstraction ──────────────────────────────────────────────── + + +@dataclass(frozen=True) +class WorkerView: + """Immutable worker snapshot passed to :class:`RoutingPolicy` implementations.""" + + worker_id: str + inflight: int + capacity: int + + @property + def load(self) -> float: + """Fractional load: 0.0 idle → 1.0 at capacity.""" + return self.inflight / max(self.capacity, 1) + + +@runtime_checkable +class RoutingPolicy(Protocol): + """Strategy interface for selecting a dispatch target from available workers.""" + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the chosen worker, or None to hold off dispatch.""" + ... + + +class LeastLoaded: + """Pick the worker with the lowest inflight / capacity ratio.""" + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the least-loaded worker.""" + if not workers: + return None + return min(workers, key=lambda w: w.load) + + +class PowerOfTwoChoices: + """ + Power-of-two-choices routing. + + Samples two workers uniformly at random and returns the less loaded one. + Reduces hot-spot probability under high concurrency compared to pure random. + """ + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the less loaded of two randomly sampled workers.""" + if not workers: + return None + if len(workers) == 1: + return workers[0] + a, b = random.sample(workers, k=2) # noqa: S311 + return a if a.load <= b.load else b + + +class RoundRobin: + """ + Round-robin routing — cycle through workers in alphabetical ID order. + + Ignores load; useful when tasks are homogeneous and worker capacity is equal. + The counter is per-instance, so each :class:`NNGHub` maintains its own cycle. + """ + + def __init__(self) -> None: + """Initialise the cycle counter.""" + self._idx: int = 0 + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the next worker in the cycle.""" + if not workers: + return None + w = workers[self._idx % len(workers)] + self._idx += 1 + return w + + +class AffinityPolicy: + """ + Sticky routing: tasks with the same ``affinity_key`` label always go to the + same worker. Falls back to least-loaded when the preferred worker is gone. + + The affinity table is per-instance and lives only in memory. + """ + + def __init__(self) -> None: + """Initialise an empty affinity table.""" + self._table: dict[str, str] = {} # affinity_key → worker_id + + def choose( + self, + workers: list[WorkerView], + task: "TaskContext | None" = None, + ) -> WorkerView | None: + """Return the sticky worker for the task's affinity key, or least-loaded.""" + if not workers: + return None + if task is not None: + key = str(task.labels.get("affinity_key", "")) + if key and key in self._table: + match = next( + (w for w in workers if w.worker_id == self._table[key]), None + ) + if match is not None: + return match + chosen = min(workers, key=lambda w: w.load) + if task is not None: + key = str(task.labels.get("affinity_key", "")) + if key: + self._table[key] = chosen.worker_id + return chosen + + +@functools.lru_cache(maxsize=None) +def _policy_accepts_task(policy_cls: type) -> bool: + """Return True if policy.choose accepts a ``task`` keyword argument.""" + try: + return "task" in inspect.signature(policy_cls.choose).parameters + except (ValueError, TypeError): + return False + + +def _choose_with_context( + policy: RoutingPolicy, + views: list[WorkerView], + task: "TaskContext | None", +) -> "WorkerView | None": + """Call policy.choose, passing ``task`` only when the policy supports it.""" + if task is not None and _policy_accepts_task(type(policy)): + return policy.choose(views, task=task) # type: ignore[call-arg] + return policy.choose(views) + + +# Singletons for stateless built-ins; RoundRobin/AffinityPolicy singletons are +# fine for single-hub processes. Users needing isolated state should pass their +# own instance. +_BUILTIN_POLICIES: dict[str, RoutingPolicy] = { + "least_loaded": LeastLoaded(), + "p2c": PowerOfTwoChoices(), + "round_robin": RoundRobin(), + "affinity": AffinityPolicy(), # type: ignore[dict-item] +} + + +def make_routing_policy(policy: "RoutingPolicy | str") -> RoutingPolicy: + """ + Resolve a routing policy name or pass through an instance. + + :param policy: ``'least_loaded'``, ``'p2c'``, ``'round_robin'``, or a + :class:`RoutingPolicy` instance. + :return: concrete routing policy. + :raises ValueError: for unknown string names. + """ + if isinstance(policy, str): + resolved = _BUILTIN_POLICIES.get(policy) + if resolved is None: + raise ValueError( + f"Unknown routing policy {policy!r}; " + f"available: {sorted(_BUILTIN_POLICIES)}" + ) + return resolved + return policy + + +# ── scheduler abstraction ───────────────────────────────────────────────────── + + +@runtime_checkable +class Scheduler(Protocol): + """Strategy interface for selecting which tasks to dispatch next.""" + + def select(self, store: "InMemoryStore", limit: int) -> list[dict[str, Any]]: + """Return up to ``limit`` tasks ready for dispatch.""" + ... + + +class PriorityScheduler: + """Default scheduler: highest-priority due tasks first.""" + + def select(self, store: "InMemoryStore", limit: int) -> list[dict[str, Any]]: + """Delegate to :meth:`InMemoryStore.due_tasks`.""" + return store.due_tasks(limit) + + +# ── store ───────────────────────────────────────────────────────────────────── + + +class InMemoryStore: + """ + Pure in-memory task store for the NNG hub. + + All methods are synchronous and safe to call from a single asyncio event + loop — asyncio's cooperative scheduling makes them effectively atomic (no + ``await`` between reads and writes). + + State is lost when the process exits. For persistent task queues use a + dedicated result backend; the NNG broker is designed for low-latency + in-process delivery, not durable storage. + """ + + def __init__(self, config: StoreConfig) -> None: + """Initialise an empty store with the given configuration.""" + self.config = config + self._tasks: dict[str, _Task] = {} + self._workers: dict[str, _Worker] = {} + + # ── helpers ─────────────────────────────────────────────────────────────── + + def _backoff(self, attempts: int, backoff_base: float) -> float: + return min(self.config.backoff_cap, backoff_base * (2 ** max(0, attempts - 1))) + + def _requeue_or_fail(self, task: _Task, worker_id: str, error: str) -> bool: + now = time.time() + if task.attempts > task.max_retries: + task.state = "failed" + else: + task.state = "ready" + task.next_run_at = now + self._backoff(task.attempts, task.retry_backoff) + task.last_error = error + task.lease_id = None + task.leased_worker_id = None + task.lease_until = None + task.updated_at = now + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight = max(0, worker.inflight - 1) + return True + + # ── task lifecycle ──────────────────────────────────────────────────────── + + def pending_count(self) -> int: + """Return the count of ready and leased tasks.""" + return sum(1 for t in self._tasks.values() if t.state in ("ready", "leased")) + + def submit(self, envelope: TaskEnvelope) -> None: + """ + Accept a new task into the store. + + :param envelope: task envelope to store. + :raises QueueFullError: when ``max_pending`` is reached. + """ + if self.pending_count() >= self.config.max_pending: + raise QueueFullError("Task queue is full.") + now = time.time() + self._tasks[envelope.task_id] = _Task( + task_id=envelope.task_id, + task_name=envelope.task_name, + payload=envelope.payload, + labels=envelope.labels, + state="ready", + max_retries=envelope.max_retries, + retry_backoff=envelope.retry_backoff, + retry_jitter=envelope.retry_jitter, + priority=envelope.priority, + created_at=envelope.created_at or now, + updated_at=now, + next_run_at=now, + ) + + def due_tasks(self, limit: int = 50) -> list[dict[str, Any]]: + """ + Return ready tasks whose ``next_run_at`` is in the past. + + Results are ordered by priority (descending) then creation time. + + :param limit: maximum number of rows to return. + :return: list of task dicts. + """ + now = time.time() + ready = [ + t for t in self._tasks.values() + if t.state == "ready" and t.next_run_at <= now + ] + ready.sort(key=lambda t: (-t.priority, t.created_at)) + return [t.as_dict() for t in ready[:limit]] + + def mark_leased( + self, + task_id: str, + worker_id: str, + lease_id: str, + lease_until: float, + ) -> bool: + """ + Atomically transition a task from 'ready' to 'leased'. + + :param task_id: task to lease. + :param worker_id: worker receiving the task. + :param lease_id: unique token for this dispatch attempt. + :param lease_until: absolute epoch deadline for the lease. + :return: True on success; False if the task is not in 'ready' state. + """ + task = self._tasks.get(task_id) + if task is None or task.state != "ready": + return False + now = time.time() + task.state = "leased" + task.leased_worker_id = worker_id + task.lease_id = lease_id + task.lease_until = lease_until + task.attempts += 1 + task.updated_at = now + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight += 1 + return True + + def ack(self, task_id: str, worker_id: str, lease_id: str) -> bool: + """ + Mark a task as successfully completed. + + Late or duplicate acks (mismatched ``lease_id`` or state ≠ 'leased') + are silently rejected. + + :param task_id: task being acknowledged. + :param worker_id: worker sending the ack. + :param lease_id: lease token issued at dispatch. + :return: True if the ack was accepted. + """ + task = self._tasks.get(task_id) + if task is None or task.state != "leased": + return False + if task.lease_id != lease_id or task.leased_worker_id != worker_id: + return False + now = time.time() + task.state = "done" + task.updated_at = now + task.lease_id = None + task.leased_worker_id = None + task.lease_until = None + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight = max(0, worker.inflight - 1) + return True + + def nack( + self, task_id: str, worker_id: str, lease_id: str, error: str + ) -> bool: + """ + Explicitly fail a task, triggering retry or permanent failure. + + :param task_id: task being nacked. + :param worker_id: worker sending the nack. + :param lease_id: lease token issued at dispatch. + :param error: human-readable failure reason. + :return: True if the nack was accepted. + """ + task = self._tasks.get(task_id) + if ( + task is None + or task.state != "leased" + or task.lease_id != lease_id + or task.leased_worker_id != worker_id + ): + return False + return self._requeue_or_fail(task, worker_id, error) + + # ── reaper / recovery ───────────────────────────────────────────────────── + + def reap_expired_leases(self) -> int: + """ + Requeue or permanently fail tasks whose lease deadline has passed. + + :return: number of tasks reaped. + """ + now = time.time() + expired = [ + t for t in self._tasks.values() + if t.state == "leased" + and t.lease_until is not None + and t.lease_until < now + ] + for task in expired: + self._requeue_or_fail(task, task.leased_worker_id or "", "lease expired") + return len(expired) + + def recover_dead_workers(self, heartbeat_timeout: float) -> int: + """ + Mark workers that missed their heartbeat deadline as dead and requeue their tasks. + + :param heartbeat_timeout: seconds of silence before a worker is dead. + :return: number of tasks requeued. + """ + cutoff = time.time() - heartbeat_timeout + dead = [ + w for w in self._workers.values() + if w.last_seen < cutoff and w.status != "dead" + ] + requeued = 0 + for worker in dead: + worker.status = "dead" + worker.draining = True + leased = [ + t for t in self._tasks.values() + if t.state == "leased" and t.leased_worker_id == worker.worker_id + ] + for task in leased: + self._requeue_or_fail(task, worker.worker_id, "worker died") + requeued += 1 + return requeued + + # ── worker lifecycle ────────────────────────────────────────────────────── + + def register_worker(self, worker: WorkerState) -> None: + """ + Upsert a worker record, resetting drain state on re-registration. + + :param worker: worker state snapshot from the registration message. + """ + now = time.time() + existing = self._workers.get(worker.worker_id) + if existing is not None: + existing.task_addr = worker.task_addr + existing.capacity = worker.capacity + existing.last_seen = now + existing.heartbeat_interval = worker.heartbeat_interval + existing.lease_timeout = worker.lease_timeout + existing.draining = False + existing.status = "listening" + existing.version = worker.version + else: + self._workers[worker.worker_id] = _Worker( + worker_id=worker.worker_id, + task_addr=worker.task_addr, + capacity=worker.capacity, + inflight=0, + last_seen=now, + heartbeat_interval=worker.heartbeat_interval, + lease_timeout=worker.lease_timeout, + draining=False, + status="listening", + version=worker.version, + ) + + def heartbeat(self, worker_id: str) -> None: + """ + Record a heartbeat, resetting the worker's last_seen timestamp. + + :param worker_id: ID of the worker sending the heartbeat. + """ + worker = self._workers.get(worker_id) + if worker is not None: + worker.last_seen = time.time() + worker.status = "listening" + + def unregister_worker(self, worker_id: str) -> None: + """ + Remove a worker from the registry (graceful shutdown path). + + :param worker_id: ID of the worker unregistering. + """ + self._workers.pop(worker_id, None) + + def mark_draining(self, worker_id: str) -> None: + """ + Mark a worker as draining so the hub stops dispatching new tasks to it. + + :param worker_id: ID of the worker entering drain mode. + """ + worker = self._workers.get(worker_id) + if worker is not None: + worker.draining = True + worker.status = "draining" + + # ── routing ─────────────────────────────────────────────────────────────── + + def choose_worker( + self, + policy: "RoutingPolicy | str" = "least_loaded", + *, + heartbeat_timeout: float = 15.0, + task: "TaskContext | None" = None, + ) -> dict[str, Any] | None: + """ + Select the best available worker using a routing policy. + + Accepts a :class:`RoutingPolicy` instance or a string name + (``'least_loaded'``, ``'p2c'``, ``'round_robin'``, ``'affinity'``). + + Context-aware policies (e.g. :class:`AffinityPolicy`) receive the + optional ``task`` argument when they declare it in their ``choose`` + signature. + + :param policy: routing policy or name. + :param heartbeat_timeout: seconds before a worker is considered stale. + :param task: optional task context for context-aware policies. + :return: chosen worker dict, or None if no worker has capacity. + """ + cutoff = time.time() - heartbeat_timeout + available = [ + w for w in self._workers.values() + if w.status in ("starting", "listening") + and not w.draining + and w.last_seen >= cutoff + and w.inflight < w.capacity + ] + if not available: + return None + # Stable sort so RoundRobin cycles in a predictable, deterministic order. + views = sorted( + [WorkerView(w.worker_id, w.inflight, w.capacity) for w in available], + key=lambda v: v.worker_id, + ) + routing = make_routing_policy(policy) + chosen = _choose_with_context(routing, views, task) + if chosen is None: + return None + worker = self._workers.get(chosen.worker_id) + return worker.as_dict() if worker is not None else None + + # ── observability ───────────────────────────────────────────────────────── + + def get_task(self, task_id: str) -> dict[str, Any] | None: + """ + Fetch task status by ID (no raw bytes in result). + + :param task_id: ID of the task to look up. + :return: status dict or None if not found. + """ + task = self._tasks.get(task_id) + return task.as_status_dict() if task is not None else None + + def list_workers(self) -> list[dict[str, Any]]: + """Return all registered workers ordered by most-recently-seen.""" + return [ + w.as_dict() + for w in sorted( + self._workers.values(), key=lambda w: w.last_seen, reverse=True + ) + ] + + def stats(self) -> dict[str, int]: + """Return a summary dict with task state counts and active worker count.""" + counts: dict[str, int] = {} + for t in self._tasks.values(): + counts[t.state] = counts.get(t.state, 0) + 1 + active = sum( + 1 for w in self._workers.values() + if w.status in ("starting", "listening") and not w.draining + ) + return { + "ready": counts.get("ready", 0), + "leased": counts.get("leased", 0), + "done": counts.get("done", 0), + "failed": counts.get("failed", 0), + "active_workers": active, + } diff --git a/taskiq/brokers/nng_broker.py b/taskiq/brokers/nng_broker.py deleted file mode 100644 index 15ab3aaa..00000000 --- a/taskiq/brokers/nng_broker.py +++ /dev/null @@ -1,48 +0,0 @@ -from collections.abc import AsyncGenerator - -import pynng - -from taskiq.abc.broker import AsyncBroker -from taskiq.message import BrokerMessage - - -class NNGBroker(AsyncBroker): - """ - NanoMSG next generation broker. - - This broker is very much alike to the ZMQ broker, - It has a similar Idea, but slightly different - implementation. - """ - - def __init__(self, addr: str) -> None: - """ - Initialize the broker. - - :param addr: address which is used by both worker and client. - """ - super().__init__() - self.socket = pynng.Pair1(polyamorous=True) - self.addr = addr - - async def startup(self) -> None: - """Start the socket.""" - await super().startup() - if self.is_worker_process: - self.socket.listen(self.addr) - else: - self.socket.dial(self.addr, block=True) - - async def shutdown(self) -> None: - """Close the socket.""" - await super().shutdown() - self.socket.close() - - async def kick(self, message: BrokerMessage) -> None: - """Send a message.""" - await self.socket.ascend(message.message) - - async def listen(self) -> AsyncGenerator[bytes, None]: - """Infinite loop that receives messages.""" - while True: - yield await self.socket.arecv() diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py new file mode 100644 index 00000000..7bb7b2db --- /dev/null +++ b/tests/brokers/test_nng_broker.py @@ -0,0 +1,883 @@ +""" +Tests for the NNG broker, hub, storage, and protocol. + +The test suite is split into three layers: + +1. **Protocol** — pure serialisation roundtrips; no NNG sockets needed. +2. **Storage** — InMemoryStore unit tests; no NNG sockets needed. +3. **Integration** — real NNG sockets, single asyncio event loop. + Uses ``FakeWorker`` / ``FakeClient`` helpers that speak the wire protocol + directly so we can inject faults precisely (crash before ack, late ack, etc.). + +All NNG tests are skipped when ``pynng`` is not installed. +""" +from __future__ import annotations + +import asyncio +import os +import sys +import tempfile +import textwrap +import time +import uuid + +import pytest + +pynng = pytest.importorskip("pynng") + +from taskiq.brokers.nng import ( + AffinityPolicy, + HubConfig, + NNGHub, + ControlMessage, + ControlResponse, + InMemoryStore, + LeastLoaded, + MessageKind, + PowerOfTwoChoices, + PriorityScheduler, + QueueFullError, + RoutingPolicy, + RoundRobin, + Scheduler, + StoreConfig, + TaskContext, + TaskEnvelope, + WorkerState, + WorkerStatus, + WorkerView, + make_routing_policy, +) + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def _ipc(tag: str = "") -> str: + name = f"nng-test-{tag}-{uuid.uuid4().hex[:8]}.ipc" + return f"ipc://{os.path.join(tempfile.gettempdir(), name)}" + + +def _envelope(**kwargs: object) -> TaskEnvelope: + defaults: dict[str, object] = { + "task_id": uuid.uuid4().hex, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": 0, + "retry_backoff": 1.0, + "retry_jitter": 0.0, + "priority": 0, + "created_at": time.time(), + } + defaults.update(kwargs) + return TaskEnvelope(**defaults) # type: ignore[arg-type] + + +def _worker_state( + worker_id: str | None = None, + task_addr: str | None = None, + capacity: int = 2, +) -> WorkerState: + wid = worker_id or uuid.uuid4().hex + return WorkerState( + worker_id=wid, + task_addr=task_addr or f"ipc:///tmp/{wid}.ipc", + capacity=capacity, + heartbeat_interval=5.0, + lease_timeout=10.0, + ) + + +def _hub(control_addr: str, db_path: str, **kwargs: object) -> NNGHub: + defaults: dict[str, object] = { + "max_pending": 100, + "heartbeat_timeout": 2.0, + "lease_timeout": 2.0, + "dispatch_interval": 0.02, + "reaper_interval": 0.1, + "control_concurrency": 4, + } + defaults.update(kwargs) + cfg = HubConfig( + control_addr=control_addr, + task_db=db_path, + **defaults, # type: ignore[arg-type] + ) + return NNGHub(cfg) + + +@pytest.fixture +def db_path(tmp_path: object) -> str: + import pathlib + return str(pathlib.Path(str(tmp_path)) / "hub.db") # type: ignore[arg-type] + + +@pytest.fixture +def ctrl_addr() -> str: + return _ipc("ctrl") + + +class FakeWorker: + """Minimal NNG worker that speaks the control + task protocol.""" + + def __init__( + self, + control_addr: str, + task_addr: str | None = None, + capacity: int = 1, + ) -> None: + self.worker_id = uuid.uuid4().hex[:8] + self.task_addr = task_addr or _ipc("worker") + self._ctrl = pynng.Req0( + dial=control_addr, recv_timeout=3000, send_timeout=3000 + ) + self._pull = pynng.Pull0(listen=self.task_addr, recv_timeout=3000) + self._lock = asyncio.Lock() + self.capacity = capacity + + async def ctrl(self, kind: str, payload: dict[str, object]) -> ControlResponse: + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind=kind, payload=payload).to_bytes() + ) + raw = await self._ctrl.arecv() + return ControlResponse.from_bytes(raw) + + async def register(self) -> None: + resp = await self.ctrl( + "register", + { + "worker_id": self.worker_id, + "task_addr": self.task_addr, + "capacity": self.capacity, + "inflight": 0, + "last_seen": time.time(), + "heartbeat_interval": 1.0, + "lease_timeout": 2.0, + "draining": False, + "status": str(WorkerStatus.STARTING), + "version": "test", + }, + ) + assert resp.ok, f"register failed: {resp.error}" + + async def recv_task(self, timeout: float = 3.0) -> TaskEnvelope: + raw = await asyncio.wait_for(self._pull.arecv(), timeout=timeout) + return TaskEnvelope.from_bytes(raw) + + async def ack(self, task_id: str, lease_id: str) -> bool: + resp = await self.ctrl( + "ack", + { + "task_id": task_id, + "worker_id": self.worker_id, + "lease_id": lease_id, + }, + ) + return resp.ok + + async def heartbeat(self) -> None: + await self.ctrl("heartbeat", {"worker_id": self.worker_id}) + + async def drain_and_unregister(self) -> None: + await self.ctrl("drain", {"worker_id": self.worker_id}) + await self.ctrl("unregister", {"worker_id": self.worker_id}) + + def close(self) -> None: + self._ctrl.close() + self._pull.close() + + +class FakeClient: + """Minimal NNG client that can submit tasks and query hub status.""" + + def __init__(self, control_addr: str) -> None: + self._ctrl = pynng.Req0( + dial=control_addr, recv_timeout=3000, send_timeout=3000 + ) + self._lock = asyncio.Lock() + + async def submit(self, **labels: object) -> str: + tid = uuid.uuid4().hex + payload: dict[str, object] = { + "task_id": tid, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": labels.pop("max_retries", 0), + "retry_backoff": labels.pop("retry_backoff", 1.0), + "retry_jitter": 0.0, + "priority": labels.pop("priority", 0), + "created_at": time.time(), + } + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind="submit", payload=payload).to_bytes() + ) + raw = await self._ctrl.arecv() + resp = ControlResponse.from_bytes(raw) + assert resp.ok, f"submit failed: {resp.error}" + return tid + + async def ping(self) -> bool: + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind="ping", payload={}).to_bytes() + ) + raw = await self._ctrl.arecv() + return ControlResponse.from_bytes(raw).ok + + def close(self) -> None: + self._ctrl.close() + + +# ── 1. Protocol tests ───────────────────────────────────────────────────────── + + +def test_control_message_roundtrip() -> None: + msg = ControlMessage(kind=MessageKind.HEARTBEAT, payload={"worker_id": "w1"}) + assert ControlMessage.from_bytes(msg.to_bytes()) == msg + + +def test_control_response_roundtrip() -> None: + resp = ControlResponse(ok=True, payload={"task_id": "abc"}, error=None) + assert ControlResponse.from_bytes(resp.to_bytes()) == resp + + +def test_task_envelope_lease_id_preserved() -> None: + """Regression: v2 omitted lease_id from the envelope, breaking ack validation.""" + env = TaskEnvelope( + task_id="x", task_name="m:f", payload_b64="YQ==", lease_id="abc123" + ) + rt = TaskEnvelope.from_bytes(env.to_bytes()) + assert rt.lease_id == "abc123" + + +def test_task_envelope_payload_decode() -> None: + env = _envelope(payload_b64="dGVzdA==") + assert env.payload == b"test" + + +# ── 2. Storage tests ────────────────────────────────────────────────────────── + + +@pytest.fixture +def store(db_path: str) -> InMemoryStore: + return InMemoryStore(StoreConfig(path=db_path, max_pending=50, lease_timeout=5.0)) + + +def test_submit_and_pending(store: InMemoryStore) -> None: + store.submit(_envelope()) + assert store.pending_count() == 1 + + +def test_submit_queue_full(db_path: str) -> None: + s = InMemoryStore(StoreConfig(path=db_path, max_pending=2)) + s.submit(_envelope()) + s.submit(_envelope()) + with pytest.raises(QueueFullError): + s.submit(_envelope()) + + +def test_due_tasks_ordered_by_priority(store: InMemoryStore) -> None: + store.submit(_envelope(task_id="lo", priority=0)) + store.submit(_envelope(task_id="hi", priority=10)) + due = store.due_tasks(limit=10) + assert due[0]["task_id"] == "hi" + assert due[1]["task_id"] == "lo" + + +def test_ack_happy_path(store: InMemoryStore) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + assert store.mark_leased(env.task_id, w.worker_id, "L1", time.time() + 60) + assert store.ack(env.task_id, w.worker_id, "L1") + assert store.get_task(env.task_id)["state"] == "done" + + +def test_ack_wrong_lease_rejected(store: InMemoryStore) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "real", time.time() + 60) + assert not store.ack(env.task_id, w.worker_id, "wrong") + + +def test_late_ack_after_requeue_ignored(store: InMemoryStore) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L2", time.time() - 1) + assert store.reap_expired_leases() == 1 + assert not store.ack(env.task_id, w.worker_id, "L2") + + +def test_nack_requeues_with_backoff(store: InMemoryStore) -> None: + env = _envelope(max_retries=2, retry_backoff=1.0) + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L3", time.time() + 60) + assert store.nack(env.task_id, w.worker_id, "L3", "boom") + task = store.get_task(env.task_id) + assert task["state"] == "ready" + assert float(task["next_run_at"]) > time.time() + + +def test_nack_exceeds_retries_fails(store: InMemoryStore) -> None: + env = _envelope(max_retries=0) + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L4", time.time() + 60) + store.nack(env.task_id, w.worker_id, "L4", "error") + assert store.get_task(env.task_id)["state"] == "failed" + + +def test_dead_worker_tasks_requeued(store: InMemoryStore) -> None: + w = _worker_state() + store.register_worker(w) + env = _envelope(max_retries=3) + store.submit(env) + store.mark_leased(env.task_id, w.worker_id, "L5", time.time() + 60) + store._workers[w.worker_id].last_seen = 0 # simulate missed heartbeats + assert store.recover_dead_workers(heartbeat_timeout=1.0) == 1 + assert store.get_task(env.task_id)["state"] == "ready" + + +def test_choose_worker_least_loaded(store: InMemoryStore) -> None: + w1 = _worker_state(worker_id="w1", capacity=4) + w2 = _worker_state(worker_id="w2", capacity=4) + store.register_worker(w1) + store.register_worker(w2) + store._workers["w1"].inflight = 3 # w1 heavily loaded + chosen = store.choose_worker("least_loaded", heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] == "w2" + + +def test_stats(store: InMemoryStore) -> None: + w = _worker_state() + store.register_worker(w) + store.submit(_envelope()) + s = store.stats() + assert s["ready"] == 1 + assert s["active_workers"] == 1 + + +# ── 3. Integration tests ────────────────────────────────────────────────────── + + +async def test_ping(ctrl_addr: str, db_path: str) -> None: + hub = _hub(ctrl_addr, db_path) + await hub.start() + client = FakeClient(ctrl_addr) + try: + assert await client.ping() + finally: + client.close() + await hub.stop() + + +async def test_submit_dispatch_ack(ctrl_addr: str, db_path: str) -> None: + """Golden path: one task, one worker, full round-trip.""" + hub = _hub(ctrl_addr, db_path) + await hub.start() + worker = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await worker.register() + tid = await client.submit() + env = await worker.recv_task(timeout=3.0) + assert env.task_id == tid + assert env.lease_id != "", "Hub must populate lease_id in envelope" + assert await worker.ack(env.task_id, env.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + worker.close() + client.close() + await hub.stop() + + +async def test_multiple_workers_load_balanced(ctrl_addr: str, db_path: str) -> None: + """Both workers must receive at least one task — no single hot-spot.""" + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=4) + w2 = FakeWorker(ctrl_addr, capacity=4) + client = FakeClient(ctrl_addr) + try: + await w1.register() + await w2.register() + task_ids = [await client.submit() for _ in range(6)] + received: dict[str, list[str]] = {w1.worker_id: [], w2.worker_id: []} + pending = set(task_ids) + + async def drain(w: FakeWorker) -> None: + while pending: + try: + env = await w.recv_task(timeout=0.5) + received[w.worker_id].append(env.task_id) + pending.discard(env.task_id) + await w.ack(env.task_id, env.lease_id) + except asyncio.TimeoutError: + break + + await asyncio.gather(drain(w1), drain(w2)) + assert not pending, f"Tasks not delivered: {pending}" + assert len(received[w1.worker_id]) > 0 + assert len(received[w2.worker_id]) > 0 + finally: + w1.close() + w2.close() + client.close() + await hub.stop() + + +async def test_worker_crash_before_ack_task_requeued( + ctrl_addr: str, db_path: str +) -> None: + """ + Worker receives a task but dies before acking. + After lease expiry the hub must requeue it for a second worker. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await w1.register() + tid = await client.submit(max_retries=3) + env1 = await w1.recv_task(timeout=3.0) + assert env1.task_id == tid + w1.close() # simulate crash without acking + + await asyncio.sleep(3.5) # lease_timeout=2s + reaper_interval=0.1s + + assert hub.store.get_task(tid)["state"] == "ready" + + w2 = FakeWorker(ctrl_addr, capacity=1) + try: + await w2.register() + env2 = await w2.recv_task(timeout=3.0) + assert env2.task_id == tid + assert env2.lease_id != env1.lease_id + assert await w2.ack(env2.task_id, env2.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + w2.close() + finally: + client.close() + await hub.stop() + + +async def test_late_ack_after_requeue_rejected( + ctrl_addr: str, db_path: str +) -> None: + """ + Sequence: dispatch to w1 → lease expires → requeue → dispatch to w2. + w1's late ack must be rejected; w2's ack must succeed. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await w1.register() + tid = await client.submit(max_retries=3) + env1 = await w1.recv_task(timeout=3.0) + await asyncio.sleep(3.5) # let lease expire + + w2 = FakeWorker(ctrl_addr, capacity=1) + try: + await w2.register() + env2 = await w2.recv_task(timeout=3.0) + + # w1's stale ack must be rejected + assert not await w1.ack(env1.task_id, env1.lease_id) + # w2's valid ack succeeds + assert await w2.ack(env2.task_id, env2.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + w2.close() + finally: + w1.close() + client.close() + await hub.stop() + + +@pytest.mark.skip( + reason="In-memory store has no persistence; restart recovery requires a durable backend." +) +async def test_hub_restart_recovers_orphaned_tasks( + ctrl_addr: str, db_path: str +) -> None: + """Persistence across restarts is not supported by the in-memory store.""" + + +async def test_concurrent_heartbeats(ctrl_addr: str, db_path: str) -> None: + """ + N workers heartbeat simultaneously. With concurrent Rep0 contexts all + must succeed without serialisation stalls. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + workers = [FakeWorker(ctrl_addr, capacity=2) for _ in range(8)] + try: + await asyncio.gather(*[w.register() for w in workers]) + results = await asyncio.gather( + *[w.heartbeat() for w in workers], + return_exceptions=True, + ) + errors = [r for r in results if isinstance(r, Exception)] + assert not errors, f"Concurrent heartbeats failed: {errors}" + finally: + for w in workers: + w.close() + await hub.stop() + + +async def test_graceful_drain_and_unregister(ctrl_addr: str, db_path: str) -> None: + hub = _hub(ctrl_addr, db_path) + await hub.start() + worker = FakeWorker(ctrl_addr, capacity=2) + try: + await worker.register() + assert len(hub.store.list_workers()) == 1 + await worker.drain_and_unregister() + await asyncio.sleep(0.1) + assert len(hub.store.list_workers()) == 0 + finally: + worker.close() + await hub.stop() + + +# ── 2b. Routing policy unit tests ───────────────────────────────────────────── + + +def test_least_loaded_picks_idle_worker() -> None: + policy = LeastLoaded() + workers = [WorkerView("w1", inflight=3, capacity=4), WorkerView("w2", inflight=0, capacity=4)] + assert policy.choose(workers).worker_id == "w2" # type: ignore[union-attr] + + +def test_least_loaded_empty_returns_none() -> None: + assert LeastLoaded().choose([]) is None + + +def test_p2c_returns_a_worker() -> None: + policy = PowerOfTwoChoices() + workers = [WorkerView("w1", 1, 4), WorkerView("w2", 2, 4), WorkerView("w3", 0, 4)] + chosen = policy.choose(workers) + assert chosen is not None + assert chosen.worker_id in {"w1", "w2", "w3"} + + +def test_p2c_single_worker() -> None: + policy = PowerOfTwoChoices() + workers = [WorkerView("only", 0, 4)] + assert policy.choose(workers).worker_id == "only" # type: ignore[union-attr] + + +def test_round_robin_cycles() -> None: + policy = RoundRobin() + workers = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4), WorkerView("w3", 0, 4)] + ids = [policy.choose(workers).worker_id for _ in range(6)] # type: ignore[union-attr] + assert ids == ["w1", "w2", "w3", "w1", "w2", "w3"] + + +def test_make_routing_policy_string() -> None: + assert isinstance(make_routing_policy("least_loaded"), LeastLoaded) + assert isinstance(make_routing_policy("p2c"), PowerOfTwoChoices) + assert isinstance(make_routing_policy("round_robin"), RoundRobin) + + +def test_make_routing_policy_instance_passthrough() -> None: + policy = LeastLoaded() + assert make_routing_policy(policy) is policy + + +def test_make_routing_policy_unknown_raises() -> None: + with pytest.raises(ValueError, match="Unknown routing policy"): + make_routing_policy("no_such_policy") + + +def test_custom_routing_policy_accepted(store: InMemoryStore) -> None: + """Users can pass a RoutingPolicy instance directly to choose_worker.""" + + class AlwaysFirstPolicy: + """Trivial policy: always pick the worker with the lexicographically smallest ID.""" + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + return min(workers, key=lambda w: w.worker_id) if workers else None + + policy = AlwaysFirstPolicy() + # Verify it satisfies the Protocol at runtime. + assert isinstance(policy, RoutingPolicy) + + w1 = _worker_state(worker_id="aaa", capacity=4) + w2 = _worker_state(worker_id="zzz", capacity=4) + store.register_worker(w1) + store.register_worker(w2) + chosen = store.choose_worker(policy, heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] == "aaa" + + +def test_choose_worker_p2c(store: InMemoryStore) -> None: + """P2C routing returns one of the registered workers.""" + for i in range(4): + store.register_worker(_worker_state(worker_id=f"w{i}", capacity=4)) + chosen = store.choose_worker("p2c", heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] in {f"w{i}" for i in range(4)} + + +def test_hub_accepts_policy_instance(ctrl_addr: str, db_path: str) -> None: + """HubConfig.routing_policy accepts a RoutingPolicy instance.""" + hub = NNGHub(HubConfig( + control_addr=ctrl_addr, + routing_policy=RoundRobin(), + max_pending=100, + )) + assert isinstance(hub._routing, RoundRobin) + + +# ── 3b. Backpressure integration test ──────────────────────────────────────── + + +async def test_backpressure_hub_rejects_when_full( + ctrl_addr: str, db_path: str +) -> None: + """Hub returns error=queue full when max_pending is reached.""" + hub = _hub(ctrl_addr, db_path, max_pending=1) + await hub.start() + client = FakeClient(ctrl_addr) + try: + await client.submit() # fills the one slot (no worker → stays queued) + # Second submission must be rejected + payload: dict[str, object] = { + "task_id": uuid.uuid4().hex, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": 0, + "retry_backoff": 1.0, + "retry_jitter": 0.0, + "priority": 0, + "created_at": time.time(), + } + async with client._lock: + await client._ctrl.asend( + ControlMessage(kind="submit", payload=payload).to_bytes() + ) + raw = await client._ctrl.arecv() + resp = ControlResponse.from_bytes(raw) + assert not resp.ok + assert resp.error == "queue full" + finally: + client.close() + await hub.stop() + + +# ── 2c. AffinityPolicy unit tests ──────────────────────────────────────────── + + +def test_affinity_policy_sticks_to_worker() -> None: + """Same affinity_key must route to the same worker across calls.""" + policy = AffinityPolicy() + workers = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {"affinity_key": "user-42"}) + first = policy.choose(workers, task=task) + assert first is not None + for _ in range(10): + chosen = policy.choose(workers, task=task) + assert chosen is not None + assert chosen.worker_id == first.worker_id + + +def test_affinity_policy_falls_back_when_worker_gone() -> None: + """When the sticky worker is no longer available, fall back to least-loaded.""" + policy = AffinityPolicy() + workers_full = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {"affinity_key": "key-x"}) + first = policy.choose(workers_full, task=task) + assert first is not None + # Remove the sticky worker — only the other one remains. + remaining = [w for w in workers_full if w.worker_id != first.worker_id] + fallback = policy.choose(remaining, task=task) + assert fallback is not None + assert fallback.worker_id != first.worker_id + + +def test_affinity_policy_no_key_uses_least_loaded() -> None: + """Tasks without affinity_key get least-loaded routing.""" + policy = AffinityPolicy() + workers = [WorkerView("w1", 3, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {}) + chosen = policy.choose(workers, task=task) + assert chosen is not None + assert chosen.worker_id == "w2" + + +def test_affinity_policy_is_routing_policy() -> None: + assert isinstance(AffinityPolicy(), RoutingPolicy) + + +def test_choose_worker_affinity_string(store: InMemoryStore) -> None: + """String 'affinity' resolves to the singleton AffinityPolicy via choose_worker.""" + for wid in ("a1", "a2"): + store.register_worker(_worker_state(worker_id=wid, capacity=4)) + task = TaskContext("t1", "fn", {"affinity_key": "session-1"}) + first = store.choose_worker("affinity", heartbeat_timeout=30.0, task=task) + assert first is not None + for _ in range(5): + chosen = store.choose_worker("affinity", heartbeat_timeout=30.0, task=task) + assert chosen is not None + assert chosen["worker_id"] == first["worker_id"] + + +# ── 2d. Scheduler unit tests ───────────────────────────────────────────────── + + +def test_priority_scheduler_delegates_to_due_tasks(store: InMemoryStore) -> None: + store.submit(_envelope(task_id="lo", priority=0)) + store.submit(_envelope(task_id="hi", priority=5)) + sched = PriorityScheduler() + rows = sched.select(store, limit=10) + assert rows[0]["task_id"] == "hi" + + +def test_priority_scheduler_is_scheduler() -> None: + assert isinstance(PriorityScheduler(), Scheduler) + + +def test_custom_scheduler_used_by_hub(ctrl_addr: str, db_path: str) -> None: + """HubConfig.scheduler accepts a custom Scheduler instance.""" + + class NoopScheduler: + """Never returns tasks — useful for verifying it is actually called.""" + called = False + + def select( + self, store: InMemoryStore, limit: int + ) -> list[dict[str, object]]: + NoopScheduler.called = True + return [] + + scheduler = NoopScheduler() + assert isinstance(scheduler, Scheduler) + hub = NNGHub(HubConfig( + control_addr=ctrl_addr, + scheduler=scheduler, + max_pending=10, + )) + assert hub._scheduler is scheduler + + +# ── 4. Multiprocess integration test ───────────────────────────────────────── + +_WORKER_SCRIPT = textwrap.dedent("""\ + import asyncio, sys, os + sys.path.insert(0, {root!r}) + try: + import pynng # noqa: F401 + from taskiq.brokers.nng.broker import NNGBroker + except Exception as exc: + sys.stdout.write(f"SKIP:{{exc}}\\n") + sys.stdout.flush() + sys.exit(0) + + async def main() -> None: + broker = NNGBroker( + {ctrl_addr!r}, + worker_task_addr={task_addr!r}, + worker_id={worker_id!r}, + capacity=1, + heartbeat_interval=1.0, + recv_timeout_ms=3000, + send_timeout_ms=3000, + ) + broker.is_worker_process = True + await broker.startup() + sys.stdout.write("READY\\n") + sys.stdout.flush() + async for msg in broker.listen(): + sys.stdout.write(f"TASK:{{msg.data.decode()}}\\n") + sys.stdout.flush() + await msg.ack() + break + await broker.shutdown() + + asyncio.run(main()) +""") + + +async def test_multiprocess_worker_receives_task( + ctrl_addr: str, db_path: str +) -> None: + """A real subprocess worker (separate OS process) receives and acks a task.""" + repo_root = str( + __import__("pathlib").Path(__file__).parent.parent.parent.resolve() + ) + task_addr = _ipc("mp-worker") + worker_id = f"mp-{uuid.uuid4().hex[:8]}" + + script = _WORKER_SCRIPT.format( + root=repo_root, + ctrl_addr=ctrl_addr, + task_addr=task_addr, + worker_id=worker_id, + ) + + hub = _hub(ctrl_addr, db_path) + await hub.start() + client = FakeClient(ctrl_addr) + + proc = await asyncio.create_subprocess_exec( + sys.executable, "-c", script, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + async def _read_line(timeout: float = 10.0) -> str: + assert proc.stdout is not None + line = await asyncio.wait_for(proc.stdout.readline(), timeout=timeout) + return line.decode().strip() + + try: + first_line = await _read_line(timeout=10.0) + if first_line.startswith("SKIP:"): + pytest.skip(f"Worker subprocess skipped: {first_line[5:]}") + + assert first_line == "READY", f"Expected READY, got: {first_line!r}" + + # Submit a task now that the worker is registered and listening. + tid = await client.submit() + + task_line = await _read_line(timeout=10.0) + assert task_line.startswith("TASK:"), f"Expected TASK:..., got: {task_line!r}" + + await proc.wait() + + # Give hub's reaper a tick to process the ack. + await asyncio.sleep(0.2) + state = hub.store.get_task(tid) + assert state is not None + assert state["state"] == "done", f"Expected done, got {state['state']!r}" + finally: + if proc.returncode is None: + proc.terminate() + await proc.wait() + client.close() + await hub.stop()