Source code for litestar.connection.request

from __future__ import annotations

import math
import warnings
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, cast

from litestar._multipart import parse_content_header, parse_multipart_form
from litestar._parsers import parse_url_encoded_form_data
from litestar.connection.base import (
    ASGIConnection,
    AuthT,
    StateT,
    UserT,
    empty_receive,
    empty_send,
)
from litestar.datastructures.headers import Accept
from litestar.datastructures.multi_dicts import FormMultiDict
from litestar.enums import ASGIExtension, RequestEncodingType
from litestar.exceptions import (
    ClientException,
    InternalServerException,
    LitestarException,
    LitestarWarning,
)
from litestar.exceptions.http_exceptions import RequestEntityTooLarge
from litestar.serialization import decode_json, decode_msgpack
from litestar.types import Empty, HTTPReceiveMessage

__all__ = ("Request",)


if TYPE_CHECKING:
    from litestar.handlers.http_handlers import HTTPRouteHandler  # noqa: F401
    from litestar.types.asgi_types import HTTPScope, Method, Receive, Scope, Send
    from litestar.types.empty import EmptyType


SERVER_PUSH_HEADERS = {
    "accept",
    "accept-encoding",
    "accept-language",
    "cache-control",
    "user-agent",
}


class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]):
    """The Litestar Request class."""

    __slots__ = (
        "_json",
        "_form",
        "_body",
        "_msgpack",
        "_content_type",
        "_accept",
        "_content_length",
        "is_connected",
        "supports_push_promise",
    )

    scope: HTTPScope  # pyright: ignore
    """The ASGI scope attached to the connection."""
    receive: Receive
    """The ASGI receive function."""
    send: Send
    """The ASGI send function."""

    def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
        """Initialize ``Request``.

        Args:
            scope: The ASGI connection scope.
            receive: The ASGI receive function.
            send: The ASGI send function.
        """
        super().__init__(scope, receive, send)
        self.is_connected: bool = True
        self._body: bytes | EmptyType = Empty
        self._form: FormMultiDict | EmptyType = Empty
        self._json: Any = Empty
        self._msgpack: Any = Empty
        self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty
        self._accept: Accept | EmptyType = Empty
        self._content_length: int | None | EmptyType = Empty
        self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions

    @property
    def method(self) -> Method:
        """Return the request method.

        Returns:
            The request :class:`Method <litestar.types.Method>`
        """
        return self.scope["method"]

    @property
    def content_type(self) -> tuple[str, dict[str, str]]:
        """Parse the request's 'Content-Type' header, returning the header value and any options as a dictionary.

        Returns:
            A tuple with the parsed value and a dictionary containing any options send in it.
        """
        if self._content_type is Empty:
            if (content_type := self._connection_state.content_type) is not Empty:
                self._content_type = content_type
            else:
                self._content_type = self._connection_state.content_type = parse_content_header(
                    self.headers.get("Content-Type", "")
                )
        return self._content_type

    @property
    def accept(self) -> Accept:
        """Parse the request's 'Accept' header, returning an :class:`Accept <litestar.datastructures.headers.Accept>` instance.

        Returns:
            An :class:`Accept <litestar.datastructures.headers.Accept>` instance, representing the list of acceptable media types.
        """
        if self._accept is Empty:
            if (accept := self._connection_state.accept) is not Empty:
                self._accept = accept
            else:
                self._accept = self._connection_state.accept = Accept(self.headers.get("Accept", "*/*"))
        return self._accept

    async def json(self) -> Any:
        """Retrieve the json request body from the request.

        Returns:
            An arbitrary value
        """
        if self._json is Empty:
            if (json_ := self._connection_state.json) is not Empty:
                self._json = json_
            else:
                body = await self.body()
                self._json = self._connection_state.json = decode_json(
                    body or b"null", type_decoders=self.route_handler.resolve_type_decoders()
                )
        return self._json

    async def msgpack(self) -> Any:
        """Retrieve the MessagePack request body from the request.

        Returns:
            An arbitrary value
        """
        if self._msgpack is Empty:
            if (msgpack := self._connection_state.msgpack) is not Empty:
                self._msgpack = msgpack
            else:
                body = await self.body()
                self._msgpack = self._connection_state.msgpack = decode_msgpack(
                    body or b"\xc0", type_decoders=self.route_handler.resolve_type_decoders()
                )
        return self._msgpack

    @property
    def content_length(self) -> int | None:
        cached_content_length = self._content_length
        if cached_content_length is not Empty:
            return cached_content_length

        content_length_header = self.headers.get("content-length")
        try:
            content_length = self._content_length = (
                int(content_length_header) if content_length_header is not None else None
            )
        except ValueError:
            raise ClientException(f"Invalid content-length: {content_length_header!r}") from None
        return content_length

    async def stream(self) -> AsyncGenerator[bytes, None]:
        """Return an async generator that streams chunks of bytes.

        Returns:
            An async generator.

        Raises:
            RuntimeError: if the stream is already consumed
        """
        if self._body is Empty:
            if not self.is_connected:
                raise InternalServerException("stream consumed")

            announced_content_length = self.content_length
            # setting this to 'math.inf' as a micro-optimisation; Comparing against a
            # float is slightly faster than checking if a value is 'None' and then
            # comparing it to an int. since we expect a limit to be set most of the
            # time, this is a bit more efficient
            max_content_length = self.route_handler.resolve_request_max_body_size() or math.inf

            # if the 'content-length' header is set, and exceeds the limit, we can bail
            # out early before reading anything
            if announced_content_length is not None and announced_content_length > max_content_length:
                raise RequestEntityTooLarge

            total_bytes_streamed: int = 0
            while event := cast("HTTPReceiveMessage", await self.receive()):
                if event["type"] == "http.request":
                    body = event["body"]
                    if body:
                        total_bytes_streamed += len(body)

                        # if a 'content-length' header was set, check if we have
                        # received more bytes than specified. in most cases this should
                        # be caught before it hits the application layer and an ASGI
                        # server (e.g. uvicorn) will not allow this, but since it's not
                        # forbidden according to the HTTP or ASGI spec, we err on the
                        # side of caution and still perform this check.
                        #
                        # uvicorn documented behaviour for this case:
                        # https://github.com/encode/uvicorn/blob/fe3910083e3990695bc19c2ef671dd447262ae18/docs/server-behavior.md?plain=1#L11
                        if announced_content_length:
                            if total_bytes_streamed > announced_content_length:
                                raise ClientException("Malformed request")

                        # we don't have a 'content-length' header, likely a chunked
                        # transfer. we don't really care and simply check if we have
                        # received more bytes than allowed
                        elif total_bytes_streamed > max_content_length:
                            raise RequestEntityTooLarge

                        yield body

                    if not event.get("more_body", False):
                        break

                if event["type"] == "http.disconnect":
                    raise InternalServerException("client disconnected prematurely")

            self.is_connected = False
            yield b""

        else:
            yield self._body
            yield b""
            return

    async def body(self) -> bytes:
        """Return the body of the request.

        Returns:
            A byte-string representing the body of the request.
        """
        if self._body is Empty:
            if (body := self._connection_state.body) is not Empty:
                self._body = body
            else:
                self._body = self._connection_state.body = b"".join([c async for c in self.stream()])
        return self._body

    async def form(self) -> FormMultiDict:
        """Retrieve form data from the request. If the request is either a 'multipart/form-data' or an
        'application/x-www-form- urlencoded', return a FormMultiDict instance populated with the values sent in the
        request, otherwise, an empty instance.

        Returns:
            A FormMultiDict instance
        """
        if self._form is Empty:
            if (form_data := self._connection_state.form) is Empty:
                content_type, options = self.content_type
                if content_type == RequestEncodingType.MULTI_PART:
                    form_data = parse_multipart_form(
                        body=await self.body(),
                        boundary=options.get("boundary", "").encode(),
                        multipart_form_part_limit=self.app.multipart_form_part_limit,
                    )
                elif content_type == RequestEncodingType.URL_ENCODED:
                    form_data = parse_url_encoded_form_data(
                        await self.body(),
                    )
                else:
                    form_data = {}

                self._connection_state.form = form_data

            # form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a
            # list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so
            # multi-keys can be accessed properly
            items = []
            for k, v in form_data.items():
                if isinstance(v, list):
                    for sv in v:
                        items.append((k, sv))
                else:
                    items.append((k, v))
            self._form = FormMultiDict(items)

        return self._form

    async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None:
        """Send a push promise.

        This method requires the `http.response.push` extension to be sent from the ASGI server.

        Args:
            path: Path to send the promise to.
            raise_if_unavailable: Raise an exception if server push is not supported by
                the server

        Returns:
            None
        """
        if not self.supports_push_promise:
            if raise_if_unavailable:
                raise LitestarException("Attempted to send a push promise but the server does not support it")

            warnings.warn(
                "Attempted to send a push promise but the server does not support it. In a future version, this will "
                "raise an exception. To enable this behaviour in the current version, set raise_if_unavailable=True. "
                "To prevent this behaviour, make sure that the server you are using supports the 'http.response.push' "
                "ASGI extension, or check this dynamically via "
                ":attr:`~litestar.connection.Request.supports_push_promise`",
                stacklevel=2,
                category=LitestarWarning,
            )

            return

        raw_headers = [
            (header_name.encode("latin-1"), value.encode("latin-1"))
            for header_name in (self.headers.keys() & SERVER_PUSH_HEADERS)
            for value in self.headers.getall(header_name, [])
        ]
        await self.send({"type": "http.response.push", "path": path, "headers": raw_headers})