Source code for litestar.response.streaming

from __future__ import annotations

from collections.abc import AsyncGenerator, AsyncIterable, Iterable
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Union

from anyio import Event, create_task_group

from litestar.enums import MediaType
from litestar.response.base import ASGIResponse, Response
from litestar.types.helper_types import StreamType
from litestar.utils.helpers import get_enum_string_value
from litestar.utils.sync import AsyncIteratorWrapper

if TYPE_CHECKING:
    from litestar.background_tasks import BackgroundTask, BackgroundTasks
    from litestar.connection import Request
    from litestar.datastructures.cookie import Cookie
    from litestar.enums import OpenAPIMediaType
    from litestar.types import HTTPResponseBodyEvent, Receive, ResponseCookies, ResponseHeaders, Send, TypeEncodersMap

__all__ = (
    "ASGIStreamingResponse",
    "Stream",
)


[docs] class ASGIStreamingResponse(ASGIResponse): """A streaming response.""" __slots__ = ("disconnect_event", "iterator") _should_set_content_length = False
[docs] def __init__( self, *, iterator: StreamType, background: BackgroundTask | BackgroundTasks | None = None, content_length: int | None = None, cookies: Iterable[Cookie] | None = None, encoding: str = "utf-8", headers: dict[str, Any] | None = None, is_head_response: bool = False, media_type: MediaType | str | None = None, status_code: int | None = None, ) -> None: """A low-level ASGI streaming response. Args: background: A background task or a list of background tasks to be executed after the response is sent. content_length: The response content length. cookies: The response cookies. encoding: The response encoding. headers: The response headers. is_head_response: A boolean indicating if the response is a HEAD response. iterator: An async iterator or iterable. media_type: The response media type. status_code: The response status code. """ super().__init__( background=background, content_length=content_length, cookies=cookies, encoding=encoding, headers=headers, is_head_response=is_head_response, media_type=media_type, status_code=status_code, ) self.disconnect_event = Event() self.iterator: AsyncIterable[str | bytes] | AsyncGenerator[str | bytes, None] = ( iterator if isinstance(iterator, AsyncIterable) else AsyncIteratorWrapper(iterator) )
async def _listen_for_disconnect(self, receive: Receive) -> None: """Listen for a cancellation message, and if received - call cancel on the cancel scope. Args: receive: The ASGI receive function. Returns: None """ while message := await receive(): if message["type"].endswith(".disconnect"): break self.disconnect_event.set() async def _stream(self, send: Send) -> None: """Send the chunks from the iterator as a stream of ASGI 'http.response.body' events. Args: send: The ASGI Send function. Returns: None """ stream_event: HTTPResponseBodyEvent = { "type": "http.response.body", "body": b"", "more_body": True, } async for chunk in self.iterator: if self.disconnect_event.is_set(): return stream_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding) await send(stream_event) stream_event["body"] = b"" stream_event["more_body"] = False await send(stream_event)
[docs] async def send_body(self, send: Send, receive: Receive) -> None: """Emit a stream of events correlating with the response body. Args: send: The ASGI send function. receive: The ASGI receive function. Returns: None """ async with create_task_group() as task_group: task_group.start_soon(partial(self._listen_for_disconnect, receive)) await self._stream(send)
[docs] class Stream(Response[StreamType[Union[str, bytes]]]): """An HTTP response that streams the response data as a series of ASGI ``http.response.body`` events.""" __slots__ = ("iterator",)
[docs] def __init__( self, content: StreamType[str | bytes] | Callable[[], StreamType[str | bytes]], *, background: BackgroundTask | BackgroundTasks | None = None, cookies: ResponseCookies | None = None, encoding: str = "utf-8", headers: ResponseHeaders | None = None, media_type: MediaType | OpenAPIMediaType | str | None = None, status_code: int | None = None, ) -> None: """Initialize the response. Args: content: A sync or async iterator or iterable. background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. Defaults to None. cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response ``Set-Cookie`` header. encoding: The encoding to be used for the response headers. headers: A string keyed dictionary of response headers. Header keys are insensitive. media_type: A value for the response ``Content-Type`` header. status_code: An HTTP status code. """ super().__init__( background=background, content=b"", # type: ignore[arg-type] cookies=cookies, encoding=encoding, headers=headers, media_type=media_type, status_code=status_code, ) self.iterator = content
[docs] def to_asgi_response( self, request: Request, *, background: BackgroundTask | BackgroundTasks | None = None, cookies: Iterable[Cookie] | None = None, headers: dict[str, str] | None = None, is_head_response: bool = False, media_type: MediaType | str | None = None, status_code: int | None = None, type_encoders: TypeEncodersMap | None = None, ) -> ASGIResponse: """Create an ASGIStreamingResponse from a StremaingResponse instance. Args: background: Background task(s) to be executed after the response is sent. cookies: A list of cookies to be set on the response. headers: Additional headers to be merged with the response headers. Response headers take precedence. is_head_response: Whether the response is a HEAD response. media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored. request: The :class:`Request <.connection.Request>` instance. status_code: Status code for the response. If ``status_code`` is already set on the response, this is type_encoders: A dictionary of type encoders to use for encoding the response content. Returns: An ASGIStreamingResponse instance. """ iterator = self.iterator if not isinstance(iterator, (Iterable, AsyncIterable)) and callable(iterator): iterator = iterator() return ASGIStreamingResponse( background=self.background or background, content_length=0, cookies=chain(self.cookies, cookies or ()), encoding=self.encoding, headers={**headers, **self.headers} if headers is not None else self.headers, is_head_response=is_head_response, iterator=iterator, media_type=get_enum_string_value(media_type or self.media_type or MediaType.JSON), status_code=self.status_code or status_code, )