Source code for litestar_vite.inertia.helpers

from __future__ import annotations

import inspect
from collections import defaultdict
from collections.abc import Mapping
from contextlib import contextmanager
from functools import lru_cache
from textwrap import dedent
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Coroutine,
    Dict,
    Generator,
    Generic,
    Iterable,
    List,
    TypeVar,
    cast,
    overload,
)

from anyio.from_thread import BlockingPortal, start_blocking_portal
from litestar.exceptions import ImproperlyConfiguredException
from litestar.utils.empty import value_or_default
from litestar.utils.scope.state import ScopeState
from markupsafe import Markup
from typing_extensions import ParamSpec, TypeGuard

if TYPE_CHECKING:
    from litestar.connection import ASGIConnection

    from litestar_vite.inertia.plugin import InertiaPlugin
    from litestar_vite.inertia.routes import Routes

T = TypeVar("T")
T_ParamSpec = ParamSpec("T_ParamSpec")
PropKeyT = TypeVar("PropKeyT", bound=str)
StaticT = TypeVar("StaticT", bound=object)


@overload
def lazy(key: str, value_or_callable: None) -> StaticProp[str, None]: ...


@overload
def lazy(key: str, value_or_callable: T) -> StaticProp[str, T]: ...


@overload
def lazy(key: str, value_or_callable: Callable[..., None] = ...) -> DeferredProp[str, None]: ...


@overload
def lazy(key: str, value_or_callable: Callable[..., Coroutine[Any, Any, None]] = ...) -> DeferredProp[str, None]: ...


@overload
def lazy(
    key: str,
    value_or_callable: Callable[..., T | Coroutine[Any, Any, T]] = ...,  # pyright: ignore[reportInvalidTypeVarUse]
) -> DeferredProp[str, T]: ...


[docs] def lazy( key: str, value_or_callable: None | Callable[T_ParamSpec, None | Coroutine[Any, Any, None]] | T | Callable[T_ParamSpec, T | Coroutine[Any, Any, T]] = None, ) -> StaticProp[str, None] | StaticProp[str, T] | DeferredProp[str, T] | DeferredProp[str, None]: """Wrap an async function to return a DeferredProp.""" if value_or_callable is None: return StaticProp[str, None](key=key, value=None) if not callable(value_or_callable): return StaticProp[str, T](key=key, value=value_or_callable) return DeferredProp[str, T](key=key, value=cast("Callable[..., T | Coroutine[Any, Any, T]]", value_or_callable))
[docs] class StaticProp(Generic[PropKeyT, StaticT]): """A wrapper for static property evaluation."""
[docs] def __init__(self, key: PropKeyT, value: StaticT) -> None: self._key = key self._result = value
@property def key(self) -> PropKeyT: return self._key def render(self, portal: BlockingPortal | None = None) -> StaticT: return self._result
[docs] class DeferredProp(Generic[PropKeyT, T]): """A wrapper for deferred property evaluation."""
[docs] def __init__( self, key: PropKeyT, value: Callable[..., None | T | Coroutine[Any, Any, T | None]] | None = None ) -> None: self._key = key self._value = value self._evaluated = False self._result: T | None = None
@property def key(self) -> PropKeyT: return self._key @contextmanager def with_portal(self, portal: BlockingPortal | None = None) -> Generator[BlockingPortal, None, None]: if portal is None: with start_blocking_portal() as p: yield p else: yield portal @staticmethod def _is_awaitable( v: Callable[..., T | Coroutine[Any, Any, T]], ) -> TypeGuard[Coroutine[Any, Any, T]]: return inspect.iscoroutinefunction(v) def render(self, portal: BlockingPortal | None = None) -> T | None: if self._evaluated: return self._result if self._value is None or not callable(self._value): self._result = self._value self._evaluated = True return self._result if not self._is_awaitable(cast("Callable[..., T]", self._value)): self._result = cast("T", self._value()) self._evaluated = True return self._result with self.with_portal(portal) as p: self._result = p.call(cast("Callable[..., T]", self._value)) self._evaluated = True return self._result
[docs] def is_lazy_prop(value: Any) -> TypeGuard[DeferredProp[Any, Any]]: """Check if value is a deferred property. Args: value: Any value to check Returns: bool: True if value is a deferred property """ return isinstance(value, (DeferredProp, StaticProp))
[docs] def should_render(value: Any, partial_data: set[str] | None = None) -> bool: """Check if value should be rendered. Args: value: Any value to check partial_data: Optional set of keys for partial rendering Returns: bool: True if value should be rendered """ partial_data = partial_data or set() if is_lazy_prop(value): return value.key in partial_data return True
[docs] def is_or_contains_lazy_prop(value: Any) -> bool: """Check if value is or contains a deferred property. Args: value: Any value to check Returns: bool: True if value is or contains a deferred property """ if is_lazy_prop(value): return True if isinstance(value, str): return False if isinstance(value, Mapping): return any(is_or_contains_lazy_prop(v) for v in cast("Mapping[str, Any]", value).values()) if isinstance(value, Iterable): return any(is_or_contains_lazy_prop(v) for v in cast("Iterable[Any]", value)) return False
[docs] def lazy_render(value: T, partial_data: set[str] | None = None, portal: BlockingPortal | None = None) -> T: """Filter deferred properties from the value based on partial data. Args: value: The value to filter partial_data: Keys for partial rendering portal: Optional portal to use for async rendering Returns: The filtered value """ partial_data = partial_data or set() if isinstance(value, str): return cast("T", value) if isinstance(value, Mapping): return cast( "T", { k: lazy_render(v, partial_data, portal) for k, v in cast("Mapping[str, Any]", value).items() if should_render(v, partial_data) }, ) if isinstance(value, (list, tuple)): filtered = [ lazy_render(v, partial_data, portal) for v in cast("Iterable[Any]", value) if should_render(v, partial_data) ] return cast("T", type(value)(filtered)) # pyright: ignore[reportUnknownArgumentType] if is_lazy_prop(value) and should_render(value, partial_data): return cast("T", value.render(portal)) return cast("T", value)
[docs] def get_shared_props( request: ASGIConnection[Any, Any, Any, Any], partial_data: set[str] | None = None, ) -> dict[str, Any]: """Return shared session props for a request. Args: request: The ASGI connection. partial_data: Optional set of keys for partial rendering. portal: Optional portal to use for async rendering Returns: Dict[str, Any]: The shared props. Note: Be sure to call this before `self.create_template_context` if you would like to include the `flash` message details. """ props: dict[str, Any] = {} flash: dict[str, list[str]] = defaultdict(list) errors: dict[str, Any] = {} error_bag = request.headers.get("X-Inertia-Error-Bag", None) try: errors = request.session.pop("_errors", {}) shared_props = cast("Dict[str,Any]", request.session.pop("_shared", {})) inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin")) # Handle deferred props for key, value in shared_props.items(): if is_lazy_prop(value) and should_render(value, partial_data): props[key] = value.render(inertia_plugin.portal) continue if should_render(value, partial_data): props[key] = value for message in cast("List[Dict[str,Any]]", request.session.pop("_messages", [])): flash[message["category"]].append(message["message"]) props.update(inertia_plugin.config.extra_static_page_props) for session_prop in inertia_plugin.config.extra_session_page_props: if session_prop not in props and session_prop in request.session: props[session_prop] = request.session.get(session_prop) except (AttributeError, ImproperlyConfiguredException): msg = "Unable to generate all shared props. A valid session was not found for this request." request.logger.warning(msg) props["flash"] = flash props["errors"] = {error_bag: errors} if error_bag is not None else errors props["csrf_token"] = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "") return props
[docs] def share( connection: ASGIConnection[Any, Any, Any, Any], key: str, value: Any, ) -> None: """Share a value in the session. Args: connection: The ASGI connection. key: The key to store the value under. value: The value to store. """ try: connection.session.setdefault("_shared", {}).update({key: value}) except (AttributeError, ImproperlyConfiguredException): msg = "Unable to set `share` session state. A valid session was not found for this request." connection.logger.warning(msg)
[docs] def error( connection: ASGIConnection[Any, Any, Any, Any], key: str, message: str, ) -> None: """Set an error message in the session. Args: connection: The ASGI connection. key: The key to store the error under. message: The error message. """ try: connection.session.setdefault("_errors", {}).update({key: message}) except (AttributeError, ImproperlyConfiguredException): msg = "Unable to set `error` session state. A valid session was not found for this request." connection.logger.warning(msg)
def js_routes_script(js_routes: Routes) -> Markup: @lru_cache def _markup_safe_json_dumps(js_routes: str) -> Markup: js = js_routes.replace("<", "\\u003c").replace(">", "\\u003e").replace("&", "\\u0026").replace("'", "\\u0027") return Markup(js) return Markup( dedent(f""" <script type="module"> globalThis.routes = JSON.parse('{_markup_safe_json_dumps(js_routes.formatted_routes)}') </script> """), )