diff --git a/src/ahriman/web/views/v1/auditlog/event_bus.py b/src/ahriman/web/views/v1/auditlog/event_bus.py index 28d2fa63..c4c2b0b8 100644 --- a/src/ahriman/web/views/v1/auditlog/event_bus.py +++ b/src/ahriman/web/views/v1/auditlog/event_bus.py @@ -19,7 +19,7 @@ # import json -from aiohttp.web import HTTPBadRequest, StreamResponse +from aiohttp.web import HTTPBadRequest, Request, StreamResponse from aiohttp_sse import EventSourceResponse, sse_response from asyncio import Queue, QueueShutDown, wait_for from typing import ClassVar @@ -37,12 +37,48 @@ class EventBusView(BaseView): event bus SSE view Attributes: - GET_PERMISSION(UserAccess): (class attribute) get permissions of self + READ_EVENTS(set[EventType]): (class attribute) events which are allowed for read-only users """ - GET_PERMISSION: ClassVar[UserAccess] = UserAccess.Full + READ_EVENTS: ClassVar[set[EventType]] = { + EventType.PackageHeld, + EventType.PackageOutdated, + EventType.PackageRemoved, + EventType.PackageStatusChanged, + EventType.PackageUpdateFailed, + EventType.PackageUpdated, + EventType.ServiceStatusChanged, + } ROUTES = ["/api/v1/events/stream"] + @classmethod + async def get_permission(cls, request: Request) -> UserAccess: + """ + retrieve user permission from the request + + Args: + request(Request): request object + + Returns: + UserAccess: extracted permission + """ + if request.method.upper() not in ("GET", "HEAD"): + return await BaseView.get_permission(request) + + permission = UserAccess.Full + event_filter = request.query.getall("event", []) if request.query is not None else [] + + if event_filter: + try: + topics = {EventType(event) for event in event_filter} + except ValueError: + pass + else: + if topics.issubset(cls.READ_EVENTS): + permission = UserAccess.Read + + return permission + @staticmethod async def _run(response: EventSourceResponse, queue: Queue[SSEvent]) -> None: """ @@ -66,7 +102,7 @@ class EventBusView(BaseView): tags=["Audit log"], summary="Live updates", description="Stream live updates via SSE", - permission=GET_PERMISSION, + permission=UserAccess.Full, error_400_enabled=True, error_404_description="Repository is unknown", schema=SSESchema(many=True), diff --git a/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py b/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py index ce01dd27..73eed13a 100644 --- a/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py +++ b/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py @@ -3,6 +3,7 @@ import pytest from aiohttp.test_utils import TestClient from asyncio import Queue +from multidict import MultiDict from pytest_mock import MockerFixture from unittest.mock import AsyncMock @@ -11,6 +12,7 @@ from ahriman.models.event import EventType from ahriman.models.package import Package from ahriman.models.user_access import UserAccess from ahriman.web.keys import WatcherKey +from ahriman.web.views.base import BaseView from ahriman.web.views.v1.auditlog.event_bus import EventBusView @@ -38,6 +40,51 @@ async def test_get_permission() -> None: assert await EventBusView.get_permission(request) == UserAccess.Full +async def test_get_permission_build_log() -> None: + """ + must return full permission for build log stream + """ + request = pytest.helpers.request("", "", "GET", params=MultiDict(event=EventType.BuildLog)) + assert await EventBusView.get_permission(request) == UserAccess.Full + + +async def test_get_permission_build_log_with_read_events() -> None: + """ + must return full permission for mixed build log and read event stream + """ + request = pytest.helpers.request("", "", "GET", params=MultiDict([ + ("event", EventType.BuildLog), + ("event", EventType.PackageUpdated), + ])) + assert await EventBusView.get_permission(request) == UserAccess.Full + + +async def test_get_permission_invalid_event() -> None: + """ + must return full permission for invalid event type + """ + request = pytest.helpers.request("", "", "GET", params=MultiDict(event="invalid")) + assert await EventBusView.get_permission(request) == UserAccess.Full + + +async def test_get_permission_post() -> None: + """ + must use default permission for non-get requests + """ + request = pytest.helpers.request("", "", "POST", params=MultiDict(event=EventType.PackageUpdated)) + assert await EventBusView.get_permission(request) == await BaseView.get_permission(request) + + +async def test_get_permission_read_events() -> None: + """ + must return read permission for package and status streams + """ + request = pytest.helpers.request("", "", "GET", params=MultiDict( + ("event", event_type) for event_type in EventBusView.READ_EVENTS + )) + assert await EventBusView.get_permission(request) == UserAccess.Read + + def test_routes() -> None: """ must return correct routes