from __future__ import annotations
import abc
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable
from litestar.enums import ScopeType
from litestar.middleware._utils import (
build_exclude_path_pattern,
should_bypass_middleware,
)
from litestar.utils.deprecation import warn_deprecation
__all__ = (
"ASGIMiddleware",
"AbstractMiddleware",
"DefineMiddleware",
"MiddlewareProtocol",
)
if TYPE_CHECKING:
from litestar.middleware.constraints import MiddlewareConstraints
from litestar.types import RouteHandlerType, Scopes
from litestar.types.asgi_types import ASGIApp, Receive, Scope, Send
[docs]
@runtime_checkable
class MiddlewareProtocol(Protocol):
"""Abstract middleware protocol."""
__slots__ = ("app",)
app: ASGIApp
[docs]
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Execute the ASGI middleware.
Called by the previous middleware in the stack if a response is not awaited prior.
Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its
closure.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None
"""
[docs]
class DefineMiddleware:
"""Container enabling passing ``*args`` and ``**kwargs`` to Middleware class constructors and factory functions."""
__slots__ = ("args", "kwargs", "middleware")
[docs]
def __init__(self, middleware: Callable[..., ASGIApp], *args: Any, **kwargs: Any) -> None:
"""Initialize ``DefineMiddleware``.
Args:
middleware: A callable that returns an ASGIApp.
*args: Positional arguments to pass to the callable.
**kwargs: Key word arguments to pass to the callable.
Notes:
The callable will be passed a kwarg ``app``, which is the next ASGI app to call in the middleware stack.
It therefore must define such a kwarg.
"""
self.middleware = middleware
self.args = args
self.kwargs = kwargs
[docs]
def __call__(self, app: ASGIApp) -> ASGIApp:
"""Call the middleware constructor or factory.
Args:
app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack.
Returns:
Calls :class:`DefineMiddleware.middleware <.DefineMiddleware>` and returns the ASGIApp created.
"""
return self.middleware(*self.args, app=app, **self.kwargs)
[docs]
class AbstractMiddleware:
"""Abstract middleware providing base functionality common to all middlewares, for dynamically engaging/bypassing
the middleware based on paths, ``opt``-keys and scope types.
When implementing new middleware, this class should be used as a base.
"""
scopes: Scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET}
exclude: str | list[str] | None = None
exclude_opt_key: str | None = None
[docs]
def __init__(
self,
app: ASGIApp,
exclude: str | list[str] | None = None,
exclude_opt_key: str | None = None,
scopes: Scopes | None = None,
) -> None:
"""Initialize the middleware.
Args:
app: The ``next`` ASGI app to call.
exclude: A pattern or list of patterns to match against a request's path.
If a match is found, the middleware will be skipped.
exclude_opt_key: An identifier that is set in the route handler
``opt`` key which allows skipping the middleware.
scopes: ASGI scope types, should be a set including
either or both 'ScopeType.HTTP' and 'ScopeType.WEBSOCKET'.
"""
self.app = app
self.scopes = scopes or self.scopes
self.exclude_opt_key = exclude_opt_key or self.exclude_opt_key
self.exclude_pattern = build_exclude_path_pattern(exclude=(exclude or self.exclude), middleware_cls=type(self))
@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
if not any(c.__module__.startswith("litestar") and c is not AbstractMiddleware for c in cls.mro()):
# we don't want to warn about usage of 'AbstractMiddleware' if users aren't
# directly subclassing it, i.e. they're subclassing another Litestar
# middleware which itself subclasses 'AbstractMiddleware'
warn_deprecation(
version="2.15",
deprecated_name="AbstractMiddleware",
kind="class",
alternative="litestar.middleware.ASGIMiddleware",
)
super().__init_subclass__(**kwargs)
original__call__ = cls.__call__
async def wrapped_call(self: AbstractMiddleware, scope: Scope, receive: Receive, send: Send) -> None:
if should_bypass_middleware(
scope=scope,
scopes=self.scopes,
exclude_path_pattern=self.exclude_pattern,
exclude_opt_key=self.exclude_opt_key,
):
await self.app(scope, receive, send)
else:
await original__call__(self, scope, receive, send) # pyright: ignore
# https://github.com/python/mypy/issues/2427#issuecomment-384229898
setattr(cls, "__call__", wrapped_call)
[docs]
@abstractmethod
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Execute the ASGI middleware.
Called by the previous middleware in the stack if a response is not awaited prior.
Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its
closure.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None
"""
raise NotImplementedError("abstract method must be implemented")
[docs]
class ASGIMiddleware(abc.ABC):
"""An abstract base class to easily construct ASGI middlewares, providing
functionality to dynamically skip the middleware based on ASGI ``scope["type"]``,
handler ``opt`` keys or path patterns and a simple way to pass configuration to
middlewares.
This base class does not implement an ``__init__`` method, so subclasses are free
to use it to customize the middleware's configuration.
.. important::
An instance of the individual middlewares will be created *once* and used to
build up the internal middleware stack. As such, middlewares should *not* be
stateful, as this state will be shared across all requests.
Any connection-specific state should be scoped to the `handle` implementation.
Not doing so would typically lead to conflicting variable reads / writes across
requests, and - most likely - bugs.
.. code-block:: python
class MyMiddleware(ASGIMiddleware):
scopes = (ScopeType.HTTP,)
exclude_path_pattern = ("/not/this/path",)
exclude_opt_key = "exclude_my_middleware"
def __init__(self, my_logger: Logger) -> None:
self.logger = my_logger
async def handle(
self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp
) -> None:
self.logger.debug("Received request for path %s", scope["path"])
await next_app(scope, receive, send)
self.logger.debug("Processed request for path %s", scope["path"])
app = Litestar(..., middleware=[MyMiddleware(logger=my_logger)])
.. versionadded:: 2.15
"""
scopes: tuple[ScopeType, ...] = (
ScopeType.HTTP,
ScopeType.WEBSOCKET,
ScopeType.ASGI,
)
"""Scope types this middleware should be applied to"""
exclude_path_pattern: str | tuple[str, ...] | None = None
r"""
A regex pattern (or tuple of patterns) to exclude this middleware from route
handlers whose path matches any of the provided patterns.
.. important::
Pattern matching is performed against the **handler's path** (e.g.,
``/user/{user_id:int}/``), NOT against the actual **request path** (e.g.,
``/user/1234/``). This is a critical distinction for dynamic routes.
If you need to exclude based on paths dynamically, use
:attr:`~litestar.middleware.ASGIMiddleware.should_bypass_for_scope`
instead, matching on ``scope["path"]``.
**Example 1: Static path**
Handler path::
/api/health
To exclude this handler, use a pattern like::
exclude_path_pattern = r"^/api/health$"
**Example 2: Dynamic path (path parameters)**
Handler path::
/user/{user_id:int}/profile
└─────┬──────┘
└─ This is what the pattern matches against
Actual request paths that match this handler::
/user/1234/profile
/user/5678/profile
/user/9999/profile
To exclude this handler, the pattern must match the **handler**, not the actual request path::
exclude_path_pattern = "/user/{user_id:int}/profile"
exclude_path_pattern = "/user/\{.+?\}/"
"""
exclude_opt_key: str | None = None
"""
Exclude this middleware for handlers with an opt-key of this name that is truthy
"""
should_bypass_for_scope: Callable[[Scope], bool] | None = None
r"""
A callable that takes in the :class:`~litestar.types.Scope` of the current
connection and returns a boolean, indicating if the middleware should be skipped for
the current request.
This can for example be used to exclude a middleware based on a dynamic path::
should_bypass_for_scope = lambda scope: scope["path"].endswith(".jpg")
Applied to a route with a dynamic path like ``/static/{file_name:str}``, it would
be skipped *only* if ``file_name`` has a ``.jpg`` extension.
.. note::
If it is not required to dynamically match the path of a request,
:attr:`~litestar.middleware.ASGIMiddleware.exclude_path_pattern` should be
used instead. Since its exclusion is done statically at startup time, it has no
performance cost at runtime.
.. versionadded:: 3.0
"""
constraints: MiddlewareConstraints | None = None
[docs]
def should_bypass_for_handler(self, handler: RouteHandlerType) -> bool:
"""Return ``True`` if this middleware should be bypassed for ``handler``, according
to :attr:`~litestar.middleware.ASGIMiddleware.scopes`,
:attr:`~litestar.middleware.ASGIMiddleware.exclude_path_pattern` or
:attr:`~litestar.middleware.ASGIMiddleware.exclude_opt_key`, otherwise ``False``.
"""
from litestar.handlers import ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler
if isinstance(handler, HTTPRouteHandler) and ScopeType.HTTP not in self.scopes:
return True
if isinstance(handler, WebsocketRouteHandler) and ScopeType.WEBSOCKET not in self.scopes:
return True
if isinstance(handler, ASGIRouteHandler) and ScopeType.ASGI not in self.scopes:
return True
if self.exclude_opt_key and handler.opt.get(self.exclude_opt_key):
return True
pattern = build_exclude_path_pattern(exclude=self.exclude_path_pattern, middleware_cls=type(self))
if pattern and any(pattern.search(path) for path in handler.paths):
return True
return False
[docs]
def __call__(self, app: ASGIApp) -> ASGIApp:
"""Create the actual middleware callable"""
handle = self.handle
should_bypass_for_scope = self.should_bypass_for_scope
if should_bypass_for_scope is None:
async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
await handle(scope=scope, receive=receive, send=send, next_app=app)
else:
async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
if should_bypass_for_scope(scope):
await app(scope, receive, send)
else:
await handle(scope=scope, receive=receive, send=send, next_app=app)
return middleware
[docs]
@abc.abstractmethod
async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None:
"""Handle ASGI call.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function
next_app: The next ASGI application in the middleware stack to call
"""
raise NotImplementedError