mirror of
https://github.com/arcan1s/ahriman.git
synced 2026-05-26 16:46:15 +00:00
Compare commits
3 Commits
b94cba4d25
..
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 84649ea399 | |||
| fa9fa73078 | |||
| 3e1e24cb50 |
+1
-1
@@ -3,7 +3,7 @@ version: 2
|
|||||||
build:
|
build:
|
||||||
os: ubuntu-lts-latest
|
os: ubuntu-lts-latest
|
||||||
tools:
|
tools:
|
||||||
python: "3.12"
|
python: "3.13"
|
||||||
apt_packages:
|
apt_packages:
|
||||||
- graphviz
|
- graphviz
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,8 @@
|
|||||||
#
|
#
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from asyncio import Lock, Queue, QueueFull
|
from asyncio import Lock, Queue, QueueFull, QueueShutDown
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ahriman.core.log import LazyLogging
|
from ahriman.core.log import LazyLogging
|
||||||
@@ -29,6 +30,22 @@ from ahriman.models.event import EventType
|
|||||||
SSEvent = tuple[str, dict[str, Any]]
|
SSEvent = tuple[str, dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _Subscription:
|
||||||
|
"""
|
||||||
|
internal event bus subscription record
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
topics(list[EventType] | None): event type filter, ``None`` means all
|
||||||
|
object_id(str | None): object identifier filter, ``None`` means all
|
||||||
|
queue(Queue[SSEvent]): per-subscriber event queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
topics: list[EventType] | None
|
||||||
|
object_id: str | None
|
||||||
|
queue: Queue[SSEvent]
|
||||||
|
|
||||||
|
|
||||||
class EventBus(LazyLogging):
|
class EventBus(LazyLogging):
|
||||||
"""
|
"""
|
||||||
event bus implementation
|
event bus implementation
|
||||||
@@ -45,7 +62,7 @@ class EventBus(LazyLogging):
|
|||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
|
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
self._subscribers: dict[str, tuple[list[EventType] | None, str | None, Queue[SSEvent | None]]] = {}
|
self._subscribers: dict[str, _Subscription] = {}
|
||||||
|
|
||||||
async def broadcast(self, event_type: EventType, object_id: str | None, **kwargs: Any) -> None:
|
async def broadcast(self, event_type: EventType, object_id: str | None, **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -60,30 +77,31 @@ class EventBus(LazyLogging):
|
|||||||
event.update(kwargs)
|
event.update(kwargs)
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
for subscriber_id, (topics, filter_object_id, queue) in self._subscribers.items():
|
snapshot = list(self._subscribers.items())
|
||||||
if topics is not None and event_type not in topics:
|
|
||||||
|
for subscriber_id, subscription in snapshot:
|
||||||
|
if subscription.topics is not None and event_type not in subscription.topics:
|
||||||
continue
|
continue
|
||||||
if filter_object_id is not None and object_id != filter_object_id:
|
if subscription.object_id is not None and object_id != subscription.object_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
queue.put_nowait((event_type, event))
|
subscription.queue.put_nowait((event_type, event))
|
||||||
except QueueFull:
|
except QueueFull:
|
||||||
self.logger.warning("discard message to slow subscriber %s", subscriber_id)
|
self.logger.warning("discard message to slow subscriber %s", subscriber_id)
|
||||||
|
except QueueShutDown:
|
||||||
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""
|
"""
|
||||||
gracefully shutdown all subscribers
|
gracefully shutdown all subscribers
|
||||||
"""
|
"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
for _, _, queue in self._subscribers.values():
|
for subscription in self._subscribers.values():
|
||||||
try:
|
subscription.queue.shutdown()
|
||||||
queue.put_nowait(None)
|
|
||||||
except QueueFull:
|
|
||||||
pass
|
|
||||||
queue.shutdown()
|
|
||||||
|
|
||||||
async def subscribe(self, topics: list[EventType] | None = None,
|
async def subscribe(self, topics: list[EventType] | None = None,
|
||||||
object_id: str | None = None) -> tuple[str, Queue[SSEvent | None]]:
|
object_id: str | None = None) -> tuple[str, Queue[SSEvent]]:
|
||||||
"""
|
"""
|
||||||
register new subscriber
|
register new subscriber
|
||||||
|
|
||||||
@@ -94,13 +112,13 @@ class EventBus(LazyLogging):
|
|||||||
events for all objects will be delivered (Default value = None)
|
events for all objects will be delivered (Default value = None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[str, Queue[SSEvent | None]]: subscriber identifier and associated queue
|
tuple[str, Queue[SSEvent]]: subscriber identifier and associated queue
|
||||||
"""
|
"""
|
||||||
subscriber_id = str(uuid.uuid4())
|
subscriber_id = str(uuid.uuid4())
|
||||||
queue: Queue[SSEvent | None] = Queue(self.max_size)
|
queue: Queue[SSEvent] = Queue(self.max_size)
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._subscribers[subscriber_id] = (topics, object_id, queue)
|
self._subscribers[subscriber_id] = _Subscription(topics=topics, object_id=object_id, queue=queue)
|
||||||
|
|
||||||
return subscriber_id, queue
|
return subscriber_id, queue
|
||||||
|
|
||||||
@@ -112,7 +130,6 @@ class EventBus(LazyLogging):
|
|||||||
subscriber_id(str): subscriber unique identifier
|
subscriber_id(str): subscriber unique identifier
|
||||||
"""
|
"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
result = self._subscribers.pop(subscriber_id, None)
|
subscription = self._subscribers.pop(subscriber_id, None)
|
||||||
if result is not None:
|
if subscription is not None:
|
||||||
_, _, queue = result
|
subscription.queue.shutdown()
|
||||||
queue.shutdown()
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from aiohttp.web import HTTPBadRequest, StreamResponse
|
from aiohttp.web import HTTPBadRequest, Request, StreamResponse
|
||||||
from aiohttp_sse import EventSourceResponse, sse_response
|
from aiohttp_sse import EventSourceResponse, sse_response
|
||||||
from asyncio import Queue, QueueShutDown, wait_for
|
from asyncio import Queue, QueueShutDown, wait_for
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
@@ -35,32 +35,63 @@ from ahriman.web.views.base import BaseView
|
|||||||
class EventBusView(BaseView):
|
class EventBusView(BaseView):
|
||||||
"""
|
"""
|
||||||
event bus SSE view
|
event bus SSE view
|
||||||
|
|
||||||
Attributes:
|
|
||||||
GET_PERMISSION(UserAccess): (class attribute) get permissions of self
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
GET_PERMISSION: ClassVar[UserAccess] = UserAccess.Full
|
READ_EVENTS: ClassVar[set[EventType]] = {
|
||||||
|
EventType.PackageHeld,
|
||||||
|
EventType.PackageOutdated,
|
||||||
|
EventType.PackageRemoved,
|
||||||
|
EventType.PackageStatusChanged,
|
||||||
|
EventType.PackageUpdateFailed,
|
||||||
|
EventType.PackageUpdated,
|
||||||
|
EventType.ServiceStatusChanged,
|
||||||
|
}
|
||||||
ROUTES = ["/api/v1/events/stream"]
|
ROUTES = ["/api/v1/events/stream"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_permission(cls, request: Request) -> UserAccess:
|
||||||
|
"""
|
||||||
|
retrieve user permission from the request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request(Request): request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserAccess: extracted permission
|
||||||
|
"""
|
||||||
|
if request.method.upper() not in ("GET", "HEAD"):
|
||||||
|
return await BaseView.get_permission(request)
|
||||||
|
|
||||||
|
permission = UserAccess.Full
|
||||||
|
event_filter = request.query.getall("event", []) if request.query is not None else []
|
||||||
|
|
||||||
|
if event_filter:
|
||||||
|
try:
|
||||||
|
topics = {EventType(event) for event in event_filter}
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if topics.issubset(cls.READ_EVENTS):
|
||||||
|
permission = UserAccess.Read
|
||||||
|
|
||||||
|
return permission
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _run(response: EventSourceResponse, queue: Queue[SSEvent | None]) -> None:
|
async def _run(response: EventSourceResponse, queue: Queue[SSEvent]) -> None:
|
||||||
"""
|
"""
|
||||||
read events from queue and send them to the client
|
read events from queue and send them to the client
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response(EventSourceResponse): SSE response instance
|
response(EventSourceResponse): SSE response instance
|
||||||
queue(Queue[SSEvent | None]): subscriber queue
|
queue(Queue[SSEvent]): subscriber queue
|
||||||
"""
|
"""
|
||||||
while response.is_connected():
|
while response.is_connected():
|
||||||
try:
|
try:
|
||||||
message = await wait_for(queue.get(), timeout=response.ping_interval)
|
event_type, data = await wait_for(queue.get(), timeout=response.ping_interval)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
except QueueShutDown:
|
||||||
if message is None:
|
break
|
||||||
break # terminate queue on sentinel event
|
|
||||||
event_type, data = message
|
|
||||||
|
|
||||||
await response.send(json.dumps(data), event=event_type)
|
await response.send(json.dumps(data), event=event_type)
|
||||||
|
|
||||||
@@ -68,7 +99,7 @@ class EventBusView(BaseView):
|
|||||||
tags=["Audit log"],
|
tags=["Audit log"],
|
||||||
summary="Live updates",
|
summary="Live updates",
|
||||||
description="Stream live updates via SSE",
|
description="Stream live updates via SSE",
|
||||||
permission=GET_PERMISSION,
|
permission=UserAccess.Full,
|
||||||
error_400_enabled=True,
|
error_400_enabled=True,
|
||||||
error_404_description="Repository is unknown",
|
error_404_description="Repository is unknown",
|
||||||
schema=SSESchema(many=True),
|
schema=SSESchema(many=True),
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from asyncio import QueueShutDown
|
||||||
|
|
||||||
from ahriman.core.status.event_bus import EventBus
|
from ahriman.core.status.event_bus import EventBus
|
||||||
from ahriman.models.event import EventType
|
from ahriman.models.event import EventType
|
||||||
from ahriman.models.package import Package
|
from ahriman.models.package import Package
|
||||||
@@ -49,15 +51,25 @@ async def test_broadcast_queue_full(event_bus: EventBus, package_ahriman: Packag
|
|||||||
assert queue.qsize() == 1
|
assert queue.qsize() == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_broadcast_queue_shutdown(event_bus: EventBus, package_ahriman: Package) -> None:
|
||||||
|
"""
|
||||||
|
must skip subscriber whose queue was shutdown concurrently
|
||||||
|
"""
|
||||||
|
_, queue = await event_bus.subscribe()
|
||||||
|
queue.shutdown()
|
||||||
|
|
||||||
|
await event_bus.broadcast(EventType.PackageUpdated, package_ahriman.base)
|
||||||
|
|
||||||
|
|
||||||
async def test_shutdown(event_bus: EventBus) -> None:
|
async def test_shutdown(event_bus: EventBus) -> None:
|
||||||
"""
|
"""
|
||||||
must send sentinel to all subscribers on shutdown
|
must shutdown all subscriber queues on shutdown
|
||||||
"""
|
"""
|
||||||
_, queue = await event_bus.subscribe()
|
_, queue = await event_bus.subscribe()
|
||||||
|
|
||||||
await event_bus.shutdown()
|
await event_bus.shutdown()
|
||||||
message = queue.get_nowait()
|
with pytest.raises(QueueShutDown):
|
||||||
assert message is None
|
queue.get_nowait()
|
||||||
|
|
||||||
|
|
||||||
async def test_shutdown_queue_full(event_bus: EventBus, package_ahriman: Package) -> None:
|
async def test_shutdown_queue_full(event_bus: EventBus, package_ahriman: Package) -> None:
|
||||||
@@ -105,8 +117,7 @@ async def test_subscribe_with_topics(event_bus: EventBus) -> None:
|
|||||||
must register subscriber with topic filter
|
must register subscriber with topic filter
|
||||||
"""
|
"""
|
||||||
subscriber_id, _ = await event_bus.subscribe([EventType.BuildLog])
|
subscriber_id, _ = await event_bus.subscribe([EventType.BuildLog])
|
||||||
topics, _, _ = event_bus._subscribers[subscriber_id]
|
assert event_bus._subscribers[subscriber_id].topics == [EventType.BuildLog]
|
||||||
assert topics == [EventType.BuildLog]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_subscribe_with_object_id(event_bus: EventBus, package_ahriman: Package) -> None:
|
async def test_subscribe_with_object_id(event_bus: EventBus, package_ahriman: Package) -> None:
|
||||||
@@ -114,8 +125,7 @@ async def test_subscribe_with_object_id(event_bus: EventBus, package_ahriman: Pa
|
|||||||
must register subscriber with object_id filter
|
must register subscriber with object_id filter
|
||||||
"""
|
"""
|
||||||
subscriber_id, _ = await event_bus.subscribe(object_id=package_ahriman.base)
|
subscriber_id, _ = await event_bus.subscribe(object_id=package_ahriman.base)
|
||||||
_, object_id, _ = event_bus._subscribers[subscriber_id]
|
assert event_bus._subscribers[subscriber_id].object_id == package_ahriman.base
|
||||||
assert object_id == package_ahriman.base
|
|
||||||
|
|
||||||
|
|
||||||
async def test_unsubscribe(event_bus: EventBus) -> None:
|
async def test_unsubscribe(event_bus: EventBus) -> None:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import pytest
|
|||||||
|
|
||||||
from aiohttp.test_utils import TestClient
|
from aiohttp.test_utils import TestClient
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
|
from multidict import MultiDict
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ from ahriman.models.event import EventType
|
|||||||
from ahriman.models.package import Package
|
from ahriman.models.package import Package
|
||||||
from ahriman.models.user_access import UserAccess
|
from ahriman.models.user_access import UserAccess
|
||||||
from ahriman.web.keys import WatcherKey
|
from ahriman.web.keys import WatcherKey
|
||||||
|
from ahriman.web.views.base import BaseView
|
||||||
from ahriman.web.views.v1.auditlog.event_bus import EventBusView
|
from ahriman.web.views.v1.auditlog.event_bus import EventBusView
|
||||||
|
|
||||||
|
|
||||||
@@ -38,6 +40,51 @@ async def test_get_permission() -> None:
|
|||||||
assert await EventBusView.get_permission(request) == UserAccess.Full
|
assert await EventBusView.get_permission(request) == UserAccess.Full
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_permission_build_log() -> None:
|
||||||
|
"""
|
||||||
|
must return full permission for build log stream
|
||||||
|
"""
|
||||||
|
request = pytest.helpers.request("", "", "GET", params=MultiDict(event=EventType.BuildLog))
|
||||||
|
assert await EventBusView.get_permission(request) == UserAccess.Full
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_permission_build_log_with_read_events() -> None:
|
||||||
|
"""
|
||||||
|
must return full permission for mixed build log and read event stream
|
||||||
|
"""
|
||||||
|
request = pytest.helpers.request("", "", "GET", params=MultiDict([
|
||||||
|
("event", EventType.BuildLog),
|
||||||
|
("event", EventType.PackageUpdated),
|
||||||
|
]))
|
||||||
|
assert await EventBusView.get_permission(request) == UserAccess.Full
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_permission_invalid_event() -> None:
|
||||||
|
"""
|
||||||
|
must return full permission for invalid event type
|
||||||
|
"""
|
||||||
|
request = pytest.helpers.request("", "", "GET", params=MultiDict(event="invalid"))
|
||||||
|
assert await EventBusView.get_permission(request) == UserAccess.Full
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_permission_post() -> None:
|
||||||
|
"""
|
||||||
|
must use default permission for non-get requests
|
||||||
|
"""
|
||||||
|
request = pytest.helpers.request("", "", "POST", params=MultiDict(event=EventType.PackageUpdated))
|
||||||
|
assert await EventBusView.get_permission(request) == await BaseView.get_permission(request)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_permission_read_events() -> None:
|
||||||
|
"""
|
||||||
|
must return read permission for package and status streams
|
||||||
|
"""
|
||||||
|
request = pytest.helpers.request("", "", "GET", params=MultiDict(
|
||||||
|
("event", event_type) for event_type in EventBusView.READ_EVENTS
|
||||||
|
))
|
||||||
|
assert await EventBusView.get_permission(request) == UserAccess.Read
|
||||||
|
|
||||||
|
|
||||||
def test_routes() -> None:
|
def test_routes() -> None:
|
||||||
"""
|
"""
|
||||||
must return correct routes
|
must return correct routes
|
||||||
@@ -53,7 +100,7 @@ async def test_run_timeout() -> None:
|
|||||||
|
|
||||||
async def _shutdown() -> None:
|
async def _shutdown() -> None:
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
await queue.put(None)
|
queue.shutdown()
|
||||||
|
|
||||||
response = AsyncMock()
|
response = AsyncMock()
|
||||||
response.is_connected = lambda: True
|
response.is_connected = lambda: True
|
||||||
|
|||||||
Reference in New Issue
Block a user