Implementing Custom Authentication#

Litestar exports AbstractAuthenticationMiddleware, which is an abstract base class (ABC) that implements the MiddlewareProtocol. To add authentication to your app using this class as a basis, subclass it and implement the abstract method authenticate_request():

Adding authentication to your app by subclassing AbstractAuthenticationMiddleware#
from litestar.middleware import (
   AbstractAuthenticationMiddleware,
   AuthenticationResult,
)
from litestar.connection import ASGIConnection


class MyAuthenticationMiddleware(AbstractAuthenticationMiddleware):
   async def authenticate_request(
       self, connection: ASGIConnection
   ) -> AuthenticationResult:
       # do something here.
       ...

As you can see, authenticate_request is an async function that receives a connection instance and is supposed to return an AuthenticationResult instance, which is a dataclass that has two attributes:

  1. user: a non-optional value representing a user. It is typed as Any so it receives any value, including None.

  2. auth: an optional value representing the authentication scheme. Defaults to None.

These values are then set as part of the scope dictionary, and they are made available as Request.user and Request.auth respectively, for HTTP route handlers, and WebSocket.user and WebSocket.auth for websocket route handlers.

Creating a Custom Authentication Middleware#

Since the above is quite hard to grasp in the abstract, let us see an example.

We start off by creating a user model. It can be implemented using msgspec, Pydantic, an ODM, ORM, etc. For the sake of this example here let us say it is a dataclass:

user and token models#
@dataclass
class MyUser:
    name: str


@dataclass
class MyToken:
    api_key: str

We can now create our authentication middleware:

authentication_middleware.py#
class CustomAuthenticationMiddleware(AbstractAuthenticationMiddleware):
    async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
        """Given a request, parse the request api key stored in the header and retrieve the user correlating to the token from the DB"""

        # retrieve the auth header
        auth_header = connection.headers.get(API_KEY_HEADER)
        if not auth_header:
            raise NotAuthorizedException()

        # this would be a database call
        token = MyToken(api_key=auth_header)
        user = MyUser(name=TOKEN_USER_DATABASE.get(token.api_key))
        if not user.name:
            raise NotAuthorizedException()
        return AuthenticationResult(user=user, auth=token)

Finally, we need to pass our middleware to the Litestar constructor:

main.py#
# you can optionally exclude certain paths from authentication.
# the following excludes all routes mounted at or under `/schema*`
auth_mw = DefineMiddleware(CustomAuthenticationMiddleware, exclude="schema")

app = Litestar(
    route_handlers=[site_index, my_http_handler, my_ws_handler],
    middleware=[auth_mw],
    dependencies={"some_dependency": Provide(my_dependency)},
)

That is it. CustomAuthenticationMiddleware will now run for every request, and we would be able to access these in a http route handler in the following way:

Accessing the user and auth in a http route handler with CustomAuthenticationMiddleware#
@get("/")
def my_http_handler(request: Request[MyUser, MyToken, State]) -> None:
    user = request.user  # correctly typed as MyUser
    auth = request.auth  # correctly typed as MyToken
    assert isinstance(user, MyUser)
    assert isinstance(auth, MyToken)

Or for a websocket route:

Accessing the user and auth in a websocket route handler with CustomAuthenticationMiddleware#
@websocket("/")
async def my_ws_handler(socket: WebSocket[MyUser, MyToken, State]) -> None:
    user = socket.user  # correctly typed as MyUser
    auth = socket.auth  # correctly typed as MyToken
    assert isinstance(user, MyUser)
    assert isinstance(auth, MyToken)

And if you would like to exclude individual routes outside those configured:

Excluding individual routes from CustomAuthenticationMiddleware#
@get(path="/", exclude_from_auth=True)
async def site_index() -> Response:
    """Site index"""
    exists = await anyio.Path("index.html").exists()
    if exists:
        async with await anyio.open_file(anyio.Path("index.html")) as file:
            content = await file.read()
            return Response(content=content, status_code=200, media_type=MediaType.HTML)
    raise NotFoundException("Site index was not found")

And of course use the same kind of mechanism for dependencies:

Using CustomAuthenticationMiddleware in a dependency#
async def my_dependency(request: Request[MyUser, MyToken, State]) -> Any:
    user = request.user  # correctly typed as MyUser
    auth = request.auth  # correctly typed as MyToken
    assert isinstance(user, MyUser)
    assert isinstance(auth, MyToken)