Source code for litestar.middleware.base

from __future__ import annotations

import abc
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable

from litestar.enums import ScopeType
from litestar.middleware._utils import (
    build_exclude_path_pattern,
    should_bypass_middleware,
)
from litestar.utils.deprecation import warn_deprecation

__all__ = (
    "ASGIMiddleware",
    "AbstractMiddleware",
    "DefineMiddleware",
    "MiddlewareProtocol",
)

if TYPE_CHECKING:
    from litestar.middleware.constraints import MiddlewareConstraints
    from litestar.types import RouteHandlerType, Scopes
    from litestar.types.asgi_types import ASGIApp, Receive, Scope, Send


[docs] @runtime_checkable class MiddlewareProtocol(Protocol): """Abstract middleware protocol.""" __slots__ = ("app",) app: ASGIApp
[docs] async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Execute the ASGI middleware. Called by the previous middleware in the stack if a response is not awaited prior. Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its closure. Args: scope: The ASGI connection scope. receive: The ASGI receive function. send: The ASGI send function. Returns: None """
[docs] class DefineMiddleware: """Container enabling passing ``*args`` and ``**kwargs`` to Middleware class constructors and factory functions.""" __slots__ = ("args", "kwargs", "middleware")
[docs] def __init__(self, middleware: Callable[..., ASGIApp], *args: Any, **kwargs: Any) -> None: """Initialize ``DefineMiddleware``. Args: middleware: A callable that returns an ASGIApp. *args: Positional arguments to pass to the callable. **kwargs: Key word arguments to pass to the callable. Notes: The callable will be passed a kwarg ``app``, which is the next ASGI app to call in the middleware stack. It therefore must define such a kwarg. """ self.middleware = middleware self.args = args self.kwargs = kwargs
[docs] def __call__(self, app: ASGIApp) -> ASGIApp: """Call the middleware constructor or factory. Args: app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack. Returns: Calls :class:`DefineMiddleware.middleware <.DefineMiddleware>` and returns the ASGIApp created. """ return self.middleware(*self.args, app=app, **self.kwargs)
[docs] class AbstractMiddleware: """Abstract middleware providing base functionality common to all middlewares, for dynamically engaging/bypassing the middleware based on paths, ``opt``-keys and scope types. When implementing new middleware, this class should be used as a base. """ scopes: Scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET} exclude: str | list[str] | None = None exclude_opt_key: str | None = None
[docs] def __init__( self, app: ASGIApp, exclude: str | list[str] | None = None, exclude_opt_key: str | None = None, scopes: Scopes | None = None, ) -> None: """Initialize the middleware. Args: app: The ``next`` ASGI app to call. exclude: A pattern or list of patterns to match against a request's path. If a match is found, the middleware will be skipped. exclude_opt_key: An identifier that is set in the route handler ``opt`` key which allows skipping the middleware. scopes: ASGI scope types, should be a set including either or both 'ScopeType.HTTP' and 'ScopeType.WEBSOCKET'. """ self.app = app self.scopes = scopes or self.scopes self.exclude_opt_key = exclude_opt_key or self.exclude_opt_key self.exclude_pattern = build_exclude_path_pattern(exclude=(exclude or self.exclude), middleware_cls=type(self))
@classmethod def __init_subclass__(cls, **kwargs: Any) -> None: if not any(c.__module__.startswith("litestar") and c is not AbstractMiddleware for c in cls.mro()): # we don't want to warn about usage of 'AbstractMiddleware' if users aren't # directly subclassing it, i.e. they're subclassing another Litestar # middleware which itself subclasses 'AbstractMiddleware' warn_deprecation( version="2.15", deprecated_name="AbstractMiddleware", kind="class", alternative="litestar.middleware.ASGIMiddleware", ) super().__init_subclass__(**kwargs) original__call__ = cls.__call__ async def wrapped_call(self: AbstractMiddleware, scope: Scope, receive: Receive, send: Send) -> None: if should_bypass_middleware( scope=scope, scopes=self.scopes, exclude_path_pattern=self.exclude_pattern, exclude_opt_key=self.exclude_opt_key, ): await self.app(scope, receive, send) else: await original__call__(self, scope, receive, send) # pyright: ignore # https://github.com/python/mypy/issues/2427#issuecomment-384229898 setattr(cls, "__call__", wrapped_call)
[docs] @abstractmethod async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Execute the ASGI middleware. Called by the previous middleware in the stack if a response is not awaited prior. Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its closure. Args: scope: The ASGI connection scope. receive: The ASGI receive function. send: The ASGI send function. Returns: None """ raise NotImplementedError("abstract method must be implemented")
[docs] class ASGIMiddleware(abc.ABC): """An abstract base class to easily construct ASGI middlewares, providing functionality to dynamically skip the middleware based on ASGI ``scope["type"]``, handler ``opt`` keys or path patterns and a simple way to pass configuration to middlewares. This base class does not implement an ``__init__`` method, so subclasses are free to use it to customize the middleware's configuration. .. important:: An instance of the individual middlewares will be created *once* and used to build up the internal middleware stack. As such, middlewares should *not* be stateful, as this state will be shared across all requests. Any connection-specific state should be scoped to the `handle` implementation. Not doing so would typically lead to conflicting variable reads / writes across requests, and - most likely - bugs. .. code-block:: python class MyMiddleware(ASGIMiddleware): scopes = (ScopeType.HTTP,) exclude_path_pattern = ("/not/this/path",) exclude_opt_key = "exclude_my_middleware" def __init__(self, my_logger: Logger) -> None: self.logger = my_logger async def handle( self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp ) -> None: self.logger.debug("Received request for path %s", scope["path"]) await next_app(scope, receive, send) self.logger.debug("Processed request for path %s", scope["path"]) app = Litestar(..., middleware=[MyMiddleware(logger=my_logger)]) .. versionadded:: 2.15 """ scopes: tuple[ScopeType, ...] = ( ScopeType.HTTP, ScopeType.WEBSOCKET, ScopeType.ASGI, ) """Scope types this middleware should be applied to""" exclude_path_pattern: str | tuple[str, ...] | None = None r""" A regex pattern (or tuple of patterns) to exclude this middleware from route handlers whose path matches any of the provided patterns. .. important:: Pattern matching is performed against the **handler's path** (e.g., ``/user/{user_id:int}/``), NOT against the actual **request path** (e.g., ``/user/1234/``). This is a critical distinction for dynamic routes. If you need to exclude based on paths dynamically, use :attr:`~litestar.middleware.ASGIMiddleware.should_bypass_for_scope` instead, matching on ``scope["path"]``. **Example 1: Static path** Handler path:: /api/health To exclude this handler, use a pattern like:: exclude_path_pattern = r"^/api/health$" **Example 2: Dynamic path (path parameters)** Handler path:: /user/{user_id:int}/profile └─────┬──────┘ └─ This is what the pattern matches against Actual request paths that match this handler:: /user/1234/profile /user/5678/profile /user/9999/profile To exclude this handler, the pattern must match the **handler**, not the actual request path:: exclude_path_pattern = "/user/{user_id:int}/profile" exclude_path_pattern = "/user/\{.+?\}/" """ exclude_opt_key: str | None = None """ Exclude this middleware for handlers with an opt-key of this name that is truthy """ should_bypass_for_scope: Callable[[Scope], bool] | None = None r""" A callable that takes in the :class:`~litestar.types.Scope` of the current connection and returns a boolean, indicating if the middleware should be skipped for the current request. This can for example be used to exclude a middleware based on a dynamic path:: should_bypass_for_scope = lambda scope: scope["path"].endswith(".jpg") Applied to a route with a dynamic path like ``/static/{file_name:str}``, it would be skipped *only* if ``file_name`` has a ``.jpg`` extension. .. note:: If it is not required to dynamically match the path of a request, :attr:`~litestar.middleware.ASGIMiddleware.exclude_path_pattern` should be used instead. Since its exclusion is done statically at startup time, it has no performance cost at runtime. .. versionadded:: 3.0 """ constraints: MiddlewareConstraints | None = None
[docs] def should_bypass_for_handler(self, handler: RouteHandlerType) -> bool: """Return ``True`` if this middleware should be bypassed for ``handler``, according to :attr:`~litestar.middleware.ASGIMiddleware.scopes`, :attr:`~litestar.middleware.ASGIMiddleware.exclude_path_pattern` or :attr:`~litestar.middleware.ASGIMiddleware.exclude_opt_key`, otherwise ``False``. """ from litestar.handlers import ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler if isinstance(handler, HTTPRouteHandler) and ScopeType.HTTP not in self.scopes: return True if isinstance(handler, WebsocketRouteHandler) and ScopeType.WEBSOCKET not in self.scopes: return True if isinstance(handler, ASGIRouteHandler) and ScopeType.ASGI not in self.scopes: return True if self.exclude_opt_key and handler.opt.get(self.exclude_opt_key): return True pattern = build_exclude_path_pattern(exclude=self.exclude_path_pattern, middleware_cls=type(self)) if pattern and any(pattern.search(path) for path in handler.paths): return True return False
[docs] def __call__(self, app: ASGIApp) -> ASGIApp: """Create the actual middleware callable""" handle = self.handle should_bypass_for_scope = self.should_bypass_for_scope if should_bypass_for_scope is None: async def middleware(scope: Scope, receive: Receive, send: Send) -> None: await handle(scope=scope, receive=receive, send=send, next_app=app) else: async def middleware(scope: Scope, receive: Receive, send: Send) -> None: if should_bypass_for_scope(scope): await app(scope, receive, send) else: await handle(scope=scope, receive=receive, send=send, next_app=app) return middleware
[docs] @abc.abstractmethod async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None: """Handle ASGI call. Args: scope: The ASGI connection scope. receive: The ASGI receive function. send: The ASGI send function next_app: The next ASGI application in the middleware stack to call """ raise NotImplementedError