diff --git a/src/ahriman/web/views/v1/auditlog/event_bus.py b/src/ahriman/web/views/v1/auditlog/event_bus.py index c3a019a4..61d75452 100644 --- a/src/ahriman/web/views/v1/auditlog/event_bus.py +++ b/src/ahriman/web/views/v1/auditlog/event_bus.py @@ -19,7 +19,7 @@ # import json -from aiohttp.web import StreamResponse +from aiohttp.web import HTTPBadRequest, StreamResponse from aiohttp_sse import EventSourceResponse, sse_response from asyncio import Queue, QueueShutDown, wait_for from typing import ClassVar @@ -69,6 +69,7 @@ class EventBusView(BaseView): summary="Live updates", description="Stream live updates via SSE", permission=GET_PERMISSION, + error_400_enabled=True, error_404_description="Repository is unknown", schema=SSESchema(many=True), query_schema=EventBusFilterSchema, @@ -79,8 +80,14 @@ class EventBusView(BaseView): Returns: StreamResponse: 200 with streaming updates + + Raises: + HTTPBadRequest: if invalid event type is supplied """ - topics = [EventType(event) for event in self.request.query.getall("event", [])] or None + try: + topics = [EventType(event) for event in self.request.query.getall("event", [])] or None + except ValueError as ex: + raise HTTPBadRequest(reason=str(ex)) event_bus = self.service().event_bus async with sse_response(self.request) as response: diff --git a/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py b/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py index 5e7bad0c..1621b314 100644 --- a/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py +++ b/tests/ahriman/web/views/v1/auditlog/test_view_v1_auditlog_event_bus.py @@ -33,8 +33,9 @@ async def test_get_permission() -> None: """ must return correct permission for the request """ - request = pytest.helpers.request("", "", "GET") - assert await EventBusView.get_permission(request) == UserAccess.Full + for method in ("GET",): + request = pytest.helpers.request("", "", method) + assert await EventBusView.get_permission(request) == UserAccess.Full def test_routes() -> None: @@ -98,6 +99,17 @@ async def test_get_with_topic_filter(client: TestClient, package_ahriman: Packag assert EventType.PackageRemoved not in body +async def test_get_bad_request(client: TestClient) -> None: + """ + must return bad request for invalid event type + """ + response_schema = pytest.helpers.schema_response(EventBusView.get, code=400) + + response = await client.get("/api/v1/events/stream", params={"event": "invalid"}) + assert response.status == 400 + assert not response_schema.validate(await response.json()) + + async def test_get_not_found(client: TestClient) -> None: """ must return not found for unknown repository @@ -114,6 +126,5 @@ async def test_get_connection_reset(client: TestClient, mocker: MockerFixture) - must handle connection reset """ mocker.patch.object(EventBusView, "_run", side_effect=ConnectionResetError) - response = await client.get("/api/v1/events/stream") assert response.status == 200