Adding Additional Features to the Repository#

While most of the functionality you need is built into the repository, there are still cases where you need to add in additional functionality. Let’s explore ways that we can add functionality on top of the repository pattern.

Tip

The full code for this tutorial can be found below in the Full Code section.

Slug Fields#

app.py#
 1from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
 2from sqlalchemy.types import String
 3
 4from litestar.plugins.sqlalchemy import (
 5    base,
 6)
 7
 8
 9@declarative_mixin
10class SlugKey:
11    """Slug unique Field Model Mixin."""
12
13    __abstract__ = True
14    slug: Mapped[str] = mapped_column(String(length=100), nullable=False, unique=True, sort_order=-9)
15
16
17class BlogPost(base.UUIDAuditBase, SlugKey):
18    title: Mapped[str]
19    content: Mapped[str]

In this example, we are using a BlogPost model to hold blog post titles and contents. The primary key for this model is a UUID type. UUID and int are good options for primary keys, but there are a number of reasons you may not want to use them in your routes. For instance, it can be a security problem to expose integer-based primary keys in the URL. While UUIDs don’t have this same problem, they are not user-friendly or easy-to-remember, and create complex URLs. One way to solve this is to add a user friendly unique identifier to the table that can be used for urls. This is often called a “slug”.

First, we’ll create a SlugKey field mixin that adds a text-based, URL-friendly, unique column slug to the table. We want to ensure we create a slug value based on the data passed to the title field. To demonstrate what we are trying to accomplish, we want a record that has a blog title of “Follow the Yellow Brick Road!” to have the slugified value of “follow-the-yellow-brick-road”.

app.py#
 1from __future__ import annotations
 2
 3import random
 4import re
 5import string
 6import unicodedata
 7from typing import Any
 8
 9from litestar.plugins.sqlalchemy import (
10    repository,
11)
12
13
14class SQLAlchemyAsyncSlugRepository(repository.SQLAlchemyAsyncRepository[repository.ModelT]):
15    """Extends the repository to include slug model features.."""
16
17    async def get_available_slug(
18        self,
19        value_to_slugify: str,
20        **kwargs: Any,
21    ) -> str:
22        """Get a unique slug for the supplied value.
23
24        If the value is found to exist, a random 4 digit character is appended to the end.
25        There may be a better way to do this, but I wanted to limit the number of
26        additional database calls.
27
28        Args:
29            value_to_slugify (str): A string that should be converted to a unique slug.
30            **kwargs: stuff
31
32        Returns:
33            str: a unique slug for the supplied value. This is safe for URLs and other
34            unique identifiers.
35        """
36        slug = self._slugify(value_to_slugify)
37        if await self._is_slug_unique(slug):
38            return slug
39        # generate a random 4 digit alphanumeric string to make the slug unique and
40        # avoid another DB lookup.
41        random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
42        return f"{slug}-{random_string}"
43
44    @staticmethod
45    def _slugify(value: str) -> str:
46        """slugify.
47
48        Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
49        dashes to single dashes. Remove characters that aren't alphanumerics,
50        underscores, or hyphens. Convert to lowercase. Also strip leading and
51        trailing whitespace, dashes, and underscores.
52
53        Args:
54            value (str): the string to slugify
55
56        Returns:
57            str: a slugified string of the value parameter
58        """
59        value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
60        value = re.sub(r"[^\w\s-]", "", value.lower())
61        return re.sub(r"[-\s]+", "-", value).strip("-_")
62
63    async def _is_slug_unique(
64        self,
65        slug: str,
66        **kwargs: Any,
67    ) -> bool:
68        return await self.get_one_or_none(slug=slug) is None

Since the BlogPost.title field is not marked as unique, this means that we’ll have to test the slug value for uniqueness before the insert. If the initial slug is found, a random set of digits are appended to the end of the slug to make it unique.

app.py#
 1from __future__ import annotations
 2
 3from uuid import UUID
 4
 5from pydantic import BaseModel as _BaseModel
 6from sqlalchemy.orm import Mapped
 7
 8from litestar import post
 9from litestar.plugins.sqlalchemy import (
10    base,
11)
12
13
14class BaseModel(_BaseModel):
15    """Extend Pydantic's BaseModel to enable ORM mode"""
16
17    model_config = {"from_attributes": True}
18
19
20class BlogPost(base.UUIDAuditBase, SlugKey):
21    title: Mapped[str]
22    content: Mapped[str]
23
24
25class BlogPostRepository(SQLAlchemyAsyncSlugRepository[BlogPost]):
26    """Blog Post repository."""
27
28    model_type = BlogPost
29
30
31class BlogPostDTO(BaseModel):
32    id: UUID | None
33    slug: str
34    title: str
35    content: str
36
37
38class BlogPostCreate(BaseModel):
39    title: str
40    content: str
41@post(path="/")
42async def create_blog(
43    blog_post_repo: BlogPostRepository,
44    data: BlogPostCreate,
45) -> BlogPostDTO:
46    """Create a new blog post."""
47    _data = data.model_dump(exclude_unset=True, by_alias=False, exclude_none=True)
48    _data["slug"] = await blog_post_repo.get_available_slug(_data["title"])
49    obj = await blog_post_repo.add(BlogPost(**_data))
50    await blog_post_repo.session.commit()
51    return BlogPostDTO.model_validate(obj)

We are all set to use this in our routes now. First, we’ll convert our incoming Pydantic model to a dictionary. Next, we’ll fetch a unique slug for our text. Finally, we insert the model with the added slug.

Note

Using this method does introduce an additional query on each insert. This should be considered when determining which fields actually need this type of functionality.

Full Code#

Full Code (click to toggle)
app.py#
  1from __future__ import annotations
  2
  3import random
  4import re
  5import string
  6import unicodedata
  7from typing import TYPE_CHECKING, Any
  8from uuid import UUID
  9
 10from pydantic import BaseModel as _BaseModel
 11from pydantic import TypeAdapter
 12from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
 13from sqlalchemy.types import String
 14
 15from litestar import Litestar, get, post
 16from litestar.di import Provide
 17from litestar.plugins.sqlalchemy import (
 18    AsyncSessionConfig,
 19    SQLAlchemyAsyncConfig,
 20    SQLAlchemyInitPlugin,
 21    base,
 22    repository,
 23)
 24
 25if TYPE_CHECKING:
 26    from sqlalchemy.ext.asyncio import AsyncSession
 27
 28
 29class BaseModel(_BaseModel):
 30    """Extend Pydantic's BaseModel to enable ORM mode"""
 31
 32    model_config = {"from_attributes": True}
 33
 34
 35# we are going to add a simple "slug" to our model that is a URL safe surrogate key to
 36# our database record.
 37@declarative_mixin
 38class SlugKey:
 39    """Slug unique Field Model Mixin."""
 40
 41    __abstract__ = True
 42    slug: Mapped[str] = mapped_column(String(length=100), nullable=False, unique=True, sort_order=-9)
 43
 44
 45# this class can be re-used with any model that has the `SlugKey` Mixin
 46class SQLAlchemyAsyncSlugRepository(repository.SQLAlchemyAsyncRepository[repository.ModelT]):
 47    """Extends the repository to include slug model features.."""
 48
 49    async def get_available_slug(
 50        self,
 51        value_to_slugify: str,
 52        **kwargs: Any,
 53    ) -> str:
 54        """Get a unique slug for the supplied value.
 55
 56        If the value is found to exist, a random 4 digit character is appended to the end.
 57        There may be a better way to do this, but I wanted to limit the number of
 58        additional database calls.
 59
 60        Args:
 61            value_to_slugify (str): A string that should be converted to a unique slug.
 62            **kwargs: stuff
 63
 64        Returns:
 65            str: a unique slug for the supplied value. This is safe for URLs and other
 66            unique identifiers.
 67        """
 68        slug = self._slugify(value_to_slugify)
 69        if await self._is_slug_unique(slug):
 70            return slug
 71        # generate a random 4 digit alphanumeric string to make the slug unique and
 72        # avoid another DB lookup.
 73        random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
 74        return f"{slug}-{random_string}"
 75
 76    @staticmethod
 77    def _slugify(value: str) -> str:
 78        """slugify.
 79
 80        Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
 81        dashes to single dashes. Remove characters that aren't alphanumerics,
 82        underscores, or hyphens. Convert to lowercase. Also strip leading and
 83        trailing whitespace, dashes, and underscores.
 84
 85        Args:
 86            value (str): the string to slugify
 87
 88        Returns:
 89            str: a slugified string of the value parameter
 90        """
 91        value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
 92        value = re.sub(r"[^\w\s-]", "", value.lower())
 93        return re.sub(r"[-\s]+", "-", value).strip("-_")
 94
 95    async def _is_slug_unique(
 96        self,
 97        slug: str,
 98        **kwargs: Any,
 99    ) -> bool:
100        return await self.get_one_or_none(slug=slug) is None
101
102
103# The `UUIDAuditBase` class includes the same UUID` based primary key (`id`) and 2
104# additional columns: `created_at` and `updated_at`. `created_at` is a timestamp of when the
105# record created, and `updated_at` is the last time the record was modified.
106class BlogPost(base.UUIDAuditBase, SlugKey):
107    title: Mapped[str]
108    content: Mapped[str]
109
110
111class BlogPostRepository(SQLAlchemyAsyncSlugRepository[BlogPost]):
112    """Blog Post repository."""
113
114    model_type = BlogPost
115
116
117class BlogPostDTO(BaseModel):
118    id: UUID | None
119    slug: str
120    title: str
121    content: str
122
123
124class BlogPostCreate(BaseModel):
125    title: str
126    content: str
127
128
129# we can optionally override the default `select` used for the repository to pass in
130# specific SQL options such as join details
131async def provide_blog_post_repo(db_session: AsyncSession) -> BlogPostRepository:
132    """This provides a simple example demonstrating how to override the join options
133    for the repository."""
134    return BlogPostRepository(session=db_session)
135
136
137session_config = AsyncSessionConfig(expire_on_commit=False)
138sqlalchemy_config = SQLAlchemyAsyncConfig(
139    connection_string="sqlite+aiosqlite:///test.sqlite", session_config=session_config
140)  # Create 'async_session' dependency.
141sqlalchemy_plugin = SQLAlchemyInitPlugin(config=sqlalchemy_config)
142
143
144async def on_startup() -> None:
145    """Initializes the database."""
146    async with sqlalchemy_config.get_engine().begin() as conn:
147        await conn.run_sync(base.UUIDAuditBase.metadata.create_all)
148
149
150@get(path="/")
151async def get_blogs(
152    blog_post_repo: BlogPostRepository,
153) -> list[BlogPostDTO]:
154    """Interact with SQLAlchemy engine and session."""
155    objs = await blog_post_repo.list()
156    type_adapter = TypeAdapter(list[BlogPostDTO])
157    return type_adapter.validate_python(objs)
158
159
160@get(path="/{post_slug:str}")
161async def get_blog_details(
162    post_slug: str,
163    blog_post_repo: BlogPostRepository,
164) -> BlogPostDTO:
165    """Interact with SQLAlchemy engine and session."""
166    obj = await blog_post_repo.get_one(slug=post_slug)
167    return BlogPostDTO.model_validate(obj)
168
169
170@post(path="/")
171async def create_blog(
172    blog_post_repo: BlogPostRepository,
173    data: BlogPostCreate,
174) -> BlogPostDTO:
175    """Create a new blog post."""
176    _data = data.model_dump(exclude_unset=True, by_alias=False, exclude_none=True)
177    _data["slug"] = await blog_post_repo.get_available_slug(_data["title"])
178    obj = await blog_post_repo.add(BlogPost(**_data))
179    await blog_post_repo.session.commit()
180    return BlogPostDTO.model_validate(obj)
181
182
183app = Litestar(
184    route_handlers=[create_blog, get_blogs, get_blog_details],
185    dependencies={"blog_post_repo": Provide(provide_blog_post_repo, sync_to_thread=False)},
186    on_startup=[on_startup],
187    plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)],
188)