Source code for litestar_vite.inertia.helpers

import inspect
from collections import defaultdict
from collections.abc import Coroutine, Generator, Iterable, Mapping
from contextlib import contextmanager
from functools import lru_cache
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, cast, overload

from anyio.from_thread import BlockingPortal, start_blocking_portal
from litestar.exceptions import ImproperlyConfiguredException
from litestar.utils.empty import value_or_default
from litestar.utils.scope.state import ScopeState
from markupsafe import Markup
from typing_extensions import ParamSpec, TypeGuard

if TYPE_CHECKING:
    from litestar.connection import ASGIConnection

    from litestar_vite.inertia.plugin import InertiaPlugin
    from litestar_vite.inertia.routes import Routes

T = TypeVar("T")
T_ParamSpec = ParamSpec("T_ParamSpec")
PropKeyT = TypeVar("PropKeyT", bound=str)
StaticT = TypeVar("StaticT", bound=object)


@overload
def lazy(key: str, value_or_callable: "None") -> "StaticProp[str, None]": ...


@overload
def lazy(key: str, value_or_callable: "T") -> "StaticProp[str, T]": ...


@overload
def lazy(key: str, value_or_callable: "Callable[..., None]" = ...) -> "DeferredProp[str, None]": ...


@overload
def lazy(
    key: str, value_or_callable: "Callable[..., Coroutine[Any, Any, None]]" = ...
) -> "DeferredProp[str, None]": ...


@overload
def lazy(
    key: str,
    value_or_callable: "Callable[..., Union[T, Coroutine[Any, Any, T]]]" = ...,  # pyright: ignore[reportInvalidTypeVarUse]
) -> "DeferredProp[str, T]": ...


[docs] def lazy( key: str, value_or_callable: "Optional[Union[T, Callable[..., Coroutine[Any, Any, None]], Callable[..., T], Callable[..., Union[T, Coroutine[Any, Any, T]]]]]" = None, ) -> "Union[StaticProp[str, None], StaticProp[str, T], DeferredProp[str, T], DeferredProp[str, None]]": """Wrap an async function to return a DeferredProp. Args: key: The key to store the value under. value_or_callable: The value or callable to store. Returns: The wrapped value or callable. """ if value_or_callable is None: return StaticProp[str, None](key=key, value=None) if not callable(value_or_callable): return StaticProp[str, T](key=key, value=value_or_callable) return DeferredProp[str, T](key=key, value=cast("Callable[..., T | Coroutine[Any, Any, T]]", value_or_callable))
[docs] class StaticProp(Generic[PropKeyT, StaticT]): """A wrapper for static property evaluation."""
[docs] def __init__(self, key: "PropKeyT", value: "StaticT") -> None: self._key = key self._result = value
@property def key(self) -> "PropKeyT": return self._key def render(self, portal: "Optional[BlockingPortal]" = None) -> "StaticT": return self._result
[docs] class DeferredProp(Generic[PropKeyT, T]): """A wrapper for deferred property evaluation."""
[docs] def __init__( self, key: "PropKeyT", value: "Optional[Callable[..., Optional[Union[T, Coroutine[Any, Any, T]]]]]" = None ) -> None: self._key = key self._value = value self._evaluated = False self._result: "Optional[T]" = None
@property def key(self) -> "PropKeyT": return self._key @staticmethod @contextmanager def with_portal(portal: "Optional[BlockingPortal]" = None) -> "Generator[BlockingPortal, None, None]": if portal is None: with start_blocking_portal() as p: yield p else: yield portal @staticmethod def _is_awaitable( v: "Callable[..., Union[T, Coroutine[Any, Any, T]]]", ) -> "TypeGuard[Coroutine[Any, Any, T]]": return inspect.iscoroutinefunction(v) def render(self, portal: "Optional[BlockingPortal]" = None) -> "Union[T, None]": if self._evaluated: return self._result if self._value is None or not callable(self._value): self._result = self._value self._evaluated = True return self._result if not self._is_awaitable(cast("Callable[..., T]", self._value)): self._result = cast("T", self._value()) self._evaluated = True return self._result with self.with_portal(portal) as p: self._result = p.call(cast("Callable[..., T]", self._value)) self._evaluated = True return self._result
[docs] def is_lazy_prop(value: "Any") -> "TypeGuard[Union[DeferredProp[Any, Any], StaticProp[Any, Any]]]": """Check if value is a deferred property. Args: value: Any value to check Returns: bool: True if value is a deferred property """ return isinstance(value, (DeferredProp, StaticProp))
[docs] def should_render(value: "Any", partial_data: "Optional[set[str]]" = None) -> "bool": """Check if value should be rendered. Args: value: Any value to check partial_data: Optional set of keys for partial rendering Returns: bool: True if value should be rendered """ partial_data = partial_data or set() if is_lazy_prop(value): return value.key in partial_data return True
[docs] def is_or_contains_lazy_prop(value: "Any") -> "bool": """Check if value is or contains a deferred property. Args: value: Any value to check Returns: bool: True if value is or contains a deferred property """ if is_lazy_prop(value): return True if isinstance(value, str): return False if isinstance(value, Mapping): return any(is_or_contains_lazy_prop(v) for v in cast("Mapping[str, Any]", value).values()) if isinstance(value, Iterable): return any(is_or_contains_lazy_prop(v) for v in cast("Iterable[Any]", value)) return False
[docs] def lazy_render( value: "T", partial_data: "Optional[set[str]]" = None, portal: "Optional[BlockingPortal]" = None ) -> "T": """Filter deferred properties from the value based on partial data. Args: value: The value to filter partial_data: Keys for partial rendering portal: Optional portal to use for async rendering Returns: The filtered value """ partial_data = partial_data or set() if isinstance(value, str): return cast("T", value) if isinstance(value, Mapping): return cast( "T", { k: lazy_render(v, partial_data, portal) for k, v in cast("Mapping[str, Any]", value).items() if should_render(v, partial_data) }, ) if isinstance(value, (list, tuple)): filtered = [ lazy_render(v, partial_data, portal) for v in cast("Iterable[Any]", value) if should_render(v, partial_data) ] return cast("T", type(value)(filtered)) # pyright: ignore[reportUnknownArgumentType] if is_lazy_prop(value) and should_render(value, partial_data): return cast("T", value.render(portal)) return cast("T", value)
[docs] def get_shared_props( request: "ASGIConnection[Any, Any, Any, Any]", partial_data: "Optional[set[str]]" = None, ) -> "dict[str, Any]": """Return shared session props for a request. Args: request: The ASGI connection. partial_data: Optional set of keys for partial rendering. portal: Optional portal to use for async rendering Returns: Dict[str, Any]: The shared props. Note: Be sure to call this before `self.create_template_context` if you would like to include the `flash` message details. """ props: "dict[str, Any]" = {} flash: "dict[str, list[str]]" = defaultdict(list) errors: "dict[str, Any]" = {} error_bag = request.headers.get("X-Inertia-Error-Bag", None) try: errors = request.session.pop("_errors", {}) shared_props = cast("dict[str,Any]", request.session.pop("_shared", {})) inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin")) # Handle deferred props for key, value in shared_props.items(): if is_lazy_prop(value) and should_render(value, partial_data): props[key] = value.render(inertia_plugin.portal) continue if should_render(value, partial_data): props[key] = value for message in cast("list[dict[str,Any]]", request.session.pop("_messages", [])): flash[message["category"]].append(message["message"]) props.update(inertia_plugin.config.extra_static_page_props) for session_prop in inertia_plugin.config.extra_session_page_props: if session_prop not in props and session_prop in request.session: props[session_prop] = request.session.get(session_prop) except (AttributeError, ImproperlyConfiguredException): msg = "Unable to generate all shared props. A valid session was not found for this request." request.logger.warning(msg) props["flash"] = flash props["errors"] = {error_bag: errors} if error_bag is not None else errors props["csrf_token"] = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "") return props
[docs] def share( connection: "ASGIConnection[Any, Any, Any, Any]", key: "str", value: "Any", ) -> "None": """Share a value in the session. Args: connection: The ASGI connection. key: The key to store the value under. value: The value to store. """ try: connection.session.setdefault("_shared", {}).update({key: value}) except (AttributeError, ImproperlyConfiguredException): msg = "Unable to set `share` session state. A valid session was not found for this request." connection.logger.warning(msg)
[docs] def error( connection: "ASGIConnection[Any, Any, Any, Any]", key: "str", message: "str", ) -> "None": """Set an error message in the session. Args: connection: The ASGI connection. key: The key to store the error under. message: The error message. """ try: connection.session.setdefault("_errors", {}).update({key: message}) except (AttributeError, ImproperlyConfiguredException): msg = "Unable to set `error` session state. A valid session was not found for this request." connection.logger.warning(msg)
def js_routes_script(js_routes: "Routes") -> "Markup": @lru_cache def _markup_safe_json_dumps(js_routes: "str") -> "Markup": js = js_routes.replace("<", "\\u003c").replace(">", "\\u003e").replace("&", "\\u0026").replace("'", "\\u0027") return Markup(js) return Markup( dedent(f""" <script type="module"> globalThis.routes = JSON.parse('{_markup_safe_json_dumps(js_routes.formatted_routes)}') </script> """), )