from __future__ import annotations
import itertools
from collections.abc import Mapping
from mimetypes import guess_type
from pathlib import PurePath
from typing import (
TYPE_CHECKING,
Any,
Iterable,
TypeVar,
cast,
)
from urllib.parse import quote, urlparse, urlunparse
from litestar import Litestar, MediaType, Request, Response
from litestar.datastructures.cookie import Cookie
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response import Redirect
from litestar.response.base import ASGIResponse
from litestar.serialization import get_serializer
from litestar.status_codes import HTTP_200_OK, HTTP_303_SEE_OTHER, HTTP_307_TEMPORARY_REDIRECT, HTTP_409_CONFLICT
from litestar.utils.deprecation import warn_deprecation
from litestar.utils.empty import value_or_default
from litestar.utils.helpers import get_enum_string_value
from litestar.utils.scope.state import ScopeState
from litestar_vite.inertia._utils import get_headers
from litestar_vite.inertia.helpers import (
get_shared_props,
is_or_contains_lazy_prop,
js_routes_script,
lazy_render,
should_render,
)
from litestar_vite.inertia.plugin import InertiaPlugin
from litestar_vite.inertia.types import InertiaHeaderType, PageProps
from litestar_vite.plugin import VitePlugin
if TYPE_CHECKING:
from litestar.app import Litestar
from litestar.background_tasks import BackgroundTask, BackgroundTasks
from litestar.connection.base import AuthT, StateT, UserT
from litestar.types import ResponseCookies, ResponseHeaders, TypeEncodersMap
T = TypeVar("T")
[docs]
class InertiaResponse(Response[T]):
"""Inertia Response"""
[docs]
def __init__(
self,
content: T,
*,
template_name: str | None = None,
template_str: str | None = None,
background: BackgroundTask | BackgroundTasks | None = None,
context: dict[str, Any] | None = None,
cookies: ResponseCookies | None = None,
encoding: str = "utf-8",
headers: ResponseHeaders | None = None,
media_type: MediaType | str | None = None,
status_code: int = HTTP_200_OK,
type_encoders: TypeEncodersMap | None = None,
) -> None:
"""Handle the rendering of a given template into a bytes string.
Args:
content: A value for the response body that will be rendered into bytes string.
template_name: Path-like name for the template to be rendered, e.g. ``index.html``.
template_str: A string representing the template, e.g. ``tmpl = "Hello <strong>World</strong>"``.
background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
Defaults to ``None``.
context: A dictionary of key/value pairs to be passed to the temple engine's render method.
cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response
``Set-Cookie`` header.
encoding: Content encoding
headers: A string keyed dictionary of response headers. Header keys are insensitive.
media_type: A string or member of the :class:`MediaType <.enums.MediaType>` enum. If not set, try to infer
the media type based on the template name. If this fails, fall back to ``text/plain``.
status_code: A value for the response HTTP status code.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
"""
if template_name and template_str:
msg = "Either template_name or template_str must be provided, not both."
raise ValueError(msg)
self.content = content
self.background = background
self.cookies: list[Cookie] = (
[Cookie(key=key, value=value) for key, value in cookies.items()]
if isinstance(cookies, Mapping)
else list(cookies or [])
)
self.encoding = encoding
self.headers: dict[str, Any] = (
dict(headers) if isinstance(headers, Mapping) else {h.name: h.value for h in headers or {}}
)
self.media_type = media_type
self.status_code = status_code
self.response_type_encoders = {**(self.type_encoders or {}), **(type_encoders or {})}
self.context = context or {}
self.template_name = template_name
self.template_str = template_str
[docs]
def create_template_context(
self,
request: Request[UserT, AuthT, StateT],
page_props: PageProps[T],
type_encoders: TypeEncodersMap | None = None,
) -> dict[str, Any]:
"""Create a context object for the template.
Args:
request: A :class:`Request <.connection.Request>` instance.
page_props: A formatted object to return the inertia configuration.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
Returns:
A dictionary holding the template context
"""
csrf_token = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "")
inertia_props = self.render(page_props, MediaType.JSON, get_serializer(type_encoders)).decode()
return {
**self.context,
"inertia": inertia_props,
"js_routes": js_routes_script(request.app.state.js_routes),
"request": request,
"csrf_input": f'<input type="hidden" name="_csrf_token" value="{csrf_token}" />',
}
[docs]
def to_asgi_response( # noqa: C901, PLR0912
self,
app: Litestar | None,
request: Request[UserT, AuthT, StateT],
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: Iterable[Cookie] | None = None,
encoded_headers: Iterable[tuple[bytes, bytes]] | None = None,
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
status_code: int | None = None,
type_encoders: TypeEncodersMap | None = None,
) -> ASGIResponse:
if app is not None:
warn_deprecation(
version="2.1",
deprecated_name="app",
kind="parameter",
removal_in="3.0.0",
alternative="request.app",
)
inertia_enabled = cast(
"bool",
getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False),
)
is_inertia = cast("bool", getattr(request, "is_inertia", False))
headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)
type_encoders = (
{**type_encoders, **(self.response_type_encoders or {})} if type_encoders else self.response_type_encoders
)
if not inertia_enabled:
media_type = get_enum_string_value(self.media_type or media_type or MediaType.JSON)
return ASGIResponse(
background=self.background or background,
body=self.render(self.content, media_type, get_serializer(type_encoders)),
cookies=cookies,
encoded_headers=encoded_headers,
encoding=self.encoding,
headers=headers,
is_head_response=is_head_response,
media_type=media_type,
status_code=self.status_code or status_code,
)
is_partial_render = cast("bool", getattr(request, "is_partial_render", False))
partial_keys = cast("set[str]", getattr(request, "partial_keys", {}))
vite_plugin = request.app.plugins.get(VitePlugin)
inertia_plugin = request.app.plugins.get(InertiaPlugin)
template_engine = request.app.template_engine # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
headers.update(
{"Vary": "Accept", **get_headers(InertiaHeaderType(enabled=True))},
)
shared_props = get_shared_props(
request,
partial_data=partial_keys if is_partial_render else None,
)
if is_or_contains_lazy_prop(self.content):
filtered_content = lazy_render(
self.content,
partial_keys if is_partial_render else None,
inertia_plugin.portal,
)
if filtered_content is not None:
shared_props["content"] = filtered_content
elif should_render(self.content, partial_keys):
shared_props["content"] = self.content
page_props = PageProps[T](
component=request.inertia.route_component, # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType,reportAttributeAccessIssue]
props=shared_props, # pyright: ignore[reportArgumentType]
version=vite_plugin.asset_loader.version_id,
url=request.url.path,
)
if is_inertia:
media_type = get_enum_string_value(self.media_type or media_type or MediaType.JSON)
body = self.render(page_props, media_type, get_serializer(type_encoders))
return ASGIResponse( # pyright: ignore[reportUnknownMemberType]
background=self.background or background,
body=body,
cookies=cookies,
encoded_headers=encoded_headers,
encoding=self.encoding,
headers=headers,
is_head_response=is_head_response,
media_type=media_type,
status_code=self.status_code or status_code,
)
if not template_engine:
msg = "Template engine is not configured"
raise ImproperlyConfiguredException(msg)
# it should default to HTML at this point unless the user specified something
media_type = media_type or MediaType.HTML
if not media_type:
if self.template_name:
suffixes = PurePath(self.template_name).suffixes
for suffix in suffixes:
if _type := guess_type(f"name{suffix}")[0]:
media_type = _type
break
else:
media_type = MediaType.TEXT
else:
media_type = MediaType.HTML
context = self.create_template_context(request, page_props, type_encoders) # pyright: ignore[reportUnknownMemberType]
if self.template_str is not None:
body = template_engine.render_string(self.template_str, context).encode(self.encoding) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
else:
template_name = self.template_name or inertia_plugin.config.root_template
template = template_engine.get_template(template_name) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
body = template.render(**context).encode(self.encoding) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
return ASGIResponse( # pyright: ignore[reportUnknownMemberType]
background=self.background or background,
body=body, # pyright: ignore[reportUnknownArgumentType]
cookies=cookies,
encoded_headers=encoded_headers,
encoding=self.encoding,
headers=headers,
is_head_response=is_head_response,
media_type=media_type,
status_code=self.status_code or status_code,
)
[docs]
class InertiaExternalRedirect(Response[Any]):
"""Client side redirect."""
[docs]
def __init__(
self,
request: Request[Any, Any, Any],
redirect_to: str,
**kwargs: Any,
) -> None:
"""Initialize external redirect, Set status code to 409 (required by Inertia),
and pass redirect url.
"""
super().__init__(
content=b"",
status_code=HTTP_409_CONFLICT,
headers={"X-Inertia-Location": quote(redirect_to, safe="/#%[]=:;$&()+,!?*@'~")},
cookies=request.cookies,
**kwargs,
)
[docs]
class InertiaRedirect(Redirect):
"""Client side redirect."""
[docs]
def __init__(
self,
request: Request[Any, Any, Any],
redirect_to: str,
**kwargs: Any,
) -> None:
"""Initialize external redirect, Set status code to 409 (required by Inertia),
and pass redirect url.
"""
referer = urlparse(request.headers.get("Referer", str(request.base_url)))
redirect_to = urlunparse(urlparse(redirect_to)._replace(scheme=referer.scheme))
super().__init__(
path=redirect_to,
status_code=HTTP_307_TEMPORARY_REDIRECT if request.method == "GET" else HTTP_303_SEE_OTHER,
cookies=request.cookies,
**kwargs,
)
[docs]
class InertiaBack(Redirect):
"""Client side redirect."""
[docs]
def __init__(
self,
request: Request[Any, Any, Any],
**kwargs: Any,
) -> None:
"""Initialize external redirect, Set status code to 409 (required by Inertia),
and pass redirect url.
"""
super().__init__(
path=request.headers.get("Referer", str(request.base_url)),
status_code=HTTP_307_TEMPORARY_REDIRECT if request.method == "GET" else HTTP_303_SEE_OTHER,
cookies=request.cookies,
**kwargs,
)