Using the serialization plugin#

Our next improvement is to leverage the SQLAlchemySerializationPlugin so that we can receive and return our SQLAlchemy models directly to and from our handlers.

Here’s the code:

 1from contextlib import asynccontextmanager
 2from typing import AsyncGenerator, List, Optional
 3
 4from sqlalchemy import select
 5from sqlalchemy.exc import IntegrityError, NoResultFound
 6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
 7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 8
 9from litestar import Litestar, get, post, put
10from litestar.datastructures import State
11from litestar.exceptions import ClientException, NotFoundException
12from litestar.plugins.sqlalchemy import SQLAlchemySerializationPlugin
13from litestar.status_codes import HTTP_409_CONFLICT
14
15
16class Base(DeclarativeBase): ...
17
18
19class TodoItem(Base):
20    __tablename__ = "todo_items"
21
22    title: Mapped[str] = mapped_column(primary_key=True)
23    done: Mapped[bool]
24
25
26@asynccontextmanager
27async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
28    engine = getattr(app.state, "engine", None)
29    if engine is None:
30        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite", echo=True)
31        app.state.engine = engine
32
33    async with engine.begin() as conn:
34        await conn.run_sync(Base.metadata.create_all)
35
36    try:
37        yield
38    finally:
39        await engine.dispose()
40
41
42sessionmaker = async_sessionmaker(expire_on_commit=False)
43
44
45async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
46    async with sessionmaker(bind=state.engine) as session:
47        try:
48            async with session.begin():
49                yield session
50        except IntegrityError as exc:
51            raise ClientException(
52                status_code=HTTP_409_CONFLICT,
53                detail=str(exc),
54            ) from exc
55
56
57async def get_todo_by_title(todo_name: str, session: AsyncSession) -> TodoItem:
58    query = select(TodoItem).where(TodoItem.title == todo_name)
59    result = await session.execute(query)
60    try:
61        return result.scalar_one()
62    except NoResultFound as e:
63        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
64
65
66async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
67    query = select(TodoItem)
68    if done is not None:
69        query = query.where(TodoItem.done.is_(done))
70
71    result = await session.execute(query)
72    return list(result.scalars().all())
73
74
75@get("/")
76async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> List[TodoItem]:
77    return await get_todo_list(done, transaction)
78
79
80@post("/")
81async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
82    transaction.add(data)
83    return data
84
85
86@put("/{item_title:str}")
87async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
88    todo_item = await get_todo_by_title(item_title, transaction)
89    todo_item.title = data.title
90    todo_item.done = data.done
91    return todo_item
92
93
94app = Litestar(
95    [get_list, add_item, update_item],
96    dependencies={"transaction": provide_transaction},
97    lifespan=[db_connection],
98    plugins=[SQLAlchemySerializationPlugin()],
99)
  1from contextlib import asynccontextmanager
  2from typing import Optional
  3from collections.abc import AsyncGenerator
  4
  5from sqlalchemy import select
  6from sqlalchemy.exc import IntegrityError, NoResultFound
  7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
  9
 10from litestar import Litestar, get, post, put
 11from litestar.datastructures import State
 12from litestar.exceptions import ClientException, NotFoundException
 13from litestar.plugins.sqlalchemy import SQLAlchemySerializationPlugin
 14from litestar.status_codes import HTTP_409_CONFLICT
 15
 16
 17class Base(DeclarativeBase): ...
 18
 19
 20class TodoItem(Base):
 21    __tablename__ = "todo_items"
 22
 23    title: Mapped[str] = mapped_column(primary_key=True)
 24    done: Mapped[bool]
 25
 26
 27@asynccontextmanager
 28async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
 29    engine = getattr(app.state, "engine", None)
 30    if engine is None:
 31        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite", echo=True)
 32        app.state.engine = engine
 33
 34    async with engine.begin() as conn:
 35        await conn.run_sync(Base.metadata.create_all)
 36
 37    try:
 38        yield
 39    finally:
 40        await engine.dispose()
 41
 42
 43sessionmaker = async_sessionmaker(expire_on_commit=False)
 44
 45
 46async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 47    async with sessionmaker(bind=state.engine) as session:
 48        try:
 49            async with session.begin():
 50                yield session
 51        except IntegrityError as exc:
 52            raise ClientException(
 53                status_code=HTTP_409_CONFLICT,
 54                detail=str(exc),
 55            ) from exc
 56
 57
 58async def get_todo_by_title(todo_name: str, session: AsyncSession) -> TodoItem:
 59    query = select(TodoItem).where(TodoItem.title == todo_name)
 60    result = await session.execute(query)
 61    try:
 62        return result.scalar_one()
 63    except NoResultFound as e:
 64        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
 65
 66
 67async def get_todo_list(done: Optional[bool], session: AsyncSession) -> list[TodoItem]:
 68    query = select(TodoItem)
 69    if done is not None:
 70        query = query.where(TodoItem.done.is_(done))
 71
 72    result = await session.execute(query)
 73    return list(result.scalars().all())
 74
 75
 76@get("/")
 77async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> list[TodoItem]:
 78    return await get_todo_list(done, transaction)
 79
 80
 81@post("/")
 82async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
 83    transaction.add(data)
 84    return data
 85
 86
 87@put("/{item_title:str}")
 88async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
 89    todo_item = await get_todo_by_title(item_title, transaction)
 90    todo_item.title = data.title
 91    todo_item.done = data.done
 92    return todo_item
 93
 94
 95app = Litestar(
 96    [get_list, add_item, update_item],
 97    dependencies={"transaction": provide_transaction},
 98    lifespan=[db_connection],
 99    plugins=[SQLAlchemySerializationPlugin()],
100)
 1from contextlib import asynccontextmanager
 2from collections.abc import AsyncGenerator
 3
 4from sqlalchemy import select
 5from sqlalchemy.exc import IntegrityError, NoResultFound
 6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
 7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 8
 9from litestar import Litestar, get, post, put
10from litestar.datastructures import State
11from litestar.exceptions import ClientException, NotFoundException
12from litestar.plugins.sqlalchemy import SQLAlchemySerializationPlugin
13from litestar.status_codes import HTTP_409_CONFLICT
14
15
16class Base(DeclarativeBase): ...
17
18
19class TodoItem(Base):
20    __tablename__ = "todo_items"
21
22    title: Mapped[str] = mapped_column(primary_key=True)
23    done: Mapped[bool]
24
25
26@asynccontextmanager
27async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
28    engine = getattr(app.state, "engine", None)
29    if engine is None:
30        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite", echo=True)
31        app.state.engine = engine
32
33    async with engine.begin() as conn:
34        await conn.run_sync(Base.metadata.create_all)
35
36    try:
37        yield
38    finally:
39        await engine.dispose()
40
41
42sessionmaker = async_sessionmaker(expire_on_commit=False)
43
44
45async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
46    async with sessionmaker(bind=state.engine) as session:
47        try:
48            async with session.begin():
49                yield session
50        except IntegrityError as exc:
51            raise ClientException(
52                status_code=HTTP_409_CONFLICT,
53                detail=str(exc),
54            ) from exc
55
56
57async def get_todo_by_title(todo_name: str, session: AsyncSession) -> TodoItem:
58    query = select(TodoItem).where(TodoItem.title == todo_name)
59    result = await session.execute(query)
60    try:
61        return result.scalar_one()
62    except NoResultFound as e:
63        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
64
65
66async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
67    query = select(TodoItem)
68    if done is not None:
69        query = query.where(TodoItem.done.is_(done))
70
71    result = await session.execute(query)
72    return list(result.scalars().all())
73
74
75@get("/")
76async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
77    return await get_todo_list(done, transaction)
78
79
80@post("/")
81async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
82    transaction.add(data)
83    return data
84
85
86@put("/{item_title:str}")
87async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
88    todo_item = await get_todo_by_title(item_title, transaction)
89    todo_item.title = data.title
90    todo_item.done = data.done
91    return todo_item
92
93
94app = Litestar(
95    [get_list, add_item, update_item],
96    dependencies={"transaction": provide_transaction},
97    lifespan=[db_connection],
98    plugins=[SQLAlchemySerializationPlugin()],
99)

We’ve simply imported the plugin and added it to our app’s plugins list, and now we can receive and return our SQLAlchemy data models directly to and from our handler.

We’ve also been able to remove the TodoType and TodoCollectionType aliases, and the serialize_todo() function, making the implementation even more concise.

Compare handlers before and after Serialization Plugin#

Once more, let’s compare the sets of application handlers before and after our refactoring:

 1from typing import List, Optional
 2
 3from sqlalchemy.ext.asyncio import AsyncSession
 4
 5from litestar import Litestar, get, post, put
 6from litestar.plugins.sqlalchemy import SQLAlchemySerializationPlugin
 7
 8
 9@get("/")
10async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> List[TodoItem]:
11    return await get_todo_list(done, transaction)
12
13
14@post("/")
15async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
16    transaction.add(data)
17    return data
18
19
20@put("/{item_title:str}")
21async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
22    todo_item = await get_todo_by_title(item_title, transaction)
23    todo_item.title = data.title
24    todo_item.done = data.done
25    return todo_item
26
27
28app = Litestar(
29    [get_list, add_item, update_item],
30    dependencies={"transaction": provide_transaction},
31    lifespan=[db_connection],
32    plugins=[SQLAlchemySerializationPlugin()],
33)
 1from typing import Optional
 2
 3from sqlalchemy.ext.asyncio import AsyncSession
 4
 5from litestar import Litestar, get, post, put
 6from litestar.plugins.sqlalchemy import SQLAlchemySerializationPlugin
 7
 8
 9@get("/")
10async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> list[TodoItem]:
11    return await get_todo_list(done, transaction)
12
13
14@post("/")
15async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
16    transaction.add(data)
17    return data
18
19
20@put("/{item_title:str}")
21async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
22    todo_item = await get_todo_by_title(item_title, transaction)
23    todo_item.title = data.title
24    todo_item.done = data.done
25    return todo_item
26
27
28app = Litestar(
29    [get_list, add_item, update_item],
30    dependencies={"transaction": provide_transaction},
31    lifespan=[db_connection],
32    plugins=[SQLAlchemySerializationPlugin()],
33)
 1from sqlalchemy.ext.asyncio import AsyncSession
 2
 3from litestar import Litestar, get, post, put
 4from litestar.plugins.sqlalchemy import SQLAlchemySerializationPlugin
 5
 6
 7@get("/")
 8async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
 9    return await get_todo_list(done, transaction)
10
11
12@post("/")
13async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
14    transaction.add(data)
15    return data
16
17
18@put("/{item_title:str}")
19async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
20    todo_item = await get_todo_by_title(item_title, transaction)
21    todo_item.title = data.title
22    todo_item.done = data.done
23    return todo_item
24
25
26app = Litestar(
27    [get_list, add_item, update_item],
28    dependencies={"transaction": provide_transaction},
29    lifespan=[db_connection],
30    plugins=[SQLAlchemySerializationPlugin()],
31)
 1from typing import Optional
 2
 3from sqlalchemy.exc import IntegrityError
 4
 5from litestar import Litestar, get, post, put
 6from litestar.datastructures import State
 7from litestar.exceptions import ClientException
 8from litestar.status_codes import HTTP_409_CONFLICT
 9
10
11@get("/")
12async def get_list(state: State, done: Optional[bool] = None) -> TodoCollectionType:
13    async with sessionmaker(bind=state.engine) as session:
14        return [serialize_todo(todo) for todo in await get_todo_list(done, session)]
15
16
17@post("/")
18async def add_item(data: TodoType, state: State) -> TodoType:
19    new_todo = TodoItem(title=data["title"], done=data["done"])
20    async with sessionmaker(bind=state.engine) as session:
21        try:
22            async with session.begin():
23                session.add(new_todo)
24        except IntegrityError as e:
25            raise ClientException(
26                status_code=HTTP_409_CONFLICT,
27                detail=f"TODO {new_todo.title!r} already exists",
28            ) from e
29
30    return serialize_todo(new_todo)
31
32
33@put("/{item_title:str}")
34async def update_item(item_title: str, data: TodoType, state: State) -> TodoType:
35    async with sessionmaker(bind=state.engine) as session, session.begin():
36        todo_item = await get_todo_by_title(item_title, session)
37        todo_item.title = data["title"]
38        todo_item.done = data["done"]
39    return serialize_todo(todo_item)
40
41
42app = Litestar([get_list, add_item, update_item], lifespan=[db_connection])
 1from sqlalchemy.exc import IntegrityError
 2
 3from litestar import Litestar, get, post, put
 4from litestar.datastructures import State
 5from litestar.exceptions import ClientException
 6from litestar.status_codes import HTTP_409_CONFLICT
 7
 8
 9@get("/")
10async def get_list(state: State, done: bool | None = None) -> TodoCollectionType:
11    async with sessionmaker(bind=state.engine) as session:
12        return [serialize_todo(todo) for todo in await get_todo_list(done, session)]
13
14
15@post("/")
16async def add_item(data: TodoType, state: State) -> TodoType:
17    new_todo = TodoItem(title=data["title"], done=data["done"])
18    async with sessionmaker(bind=state.engine) as session:
19        try:
20            async with session.begin():
21                session.add(new_todo)
22        except IntegrityError as e:
23            raise ClientException(
24                status_code=HTTP_409_CONFLICT,
25                detail=f"TODO {new_todo.title!r} already exists",
26            ) from e
27
28    return serialize_todo(new_todo)
29
30
31@put("/{item_title:str}")
32async def update_item(item_title: str, data: TodoType, state: State) -> TodoType:
33    async with sessionmaker(bind=state.engine) as session, session.begin():
34        todo_item = await get_todo_by_title(item_title, session)
35        todo_item.title = data["title"]
36        todo_item.done = data["done"]
37    return serialize_todo(todo_item)
38
39
40app = Litestar([get_list, add_item, update_item], lifespan=[db_connection])

Very nice! But, we can do better.

Next steps#

In our application, we’ve had to build a bit of scaffolding to integrate SQLAlchemy with our application. We’ve had to define the db_connection() lifespan context manager, and the provide_transaction() dependency provider.

Next we’ll look at how the SQLAlchemyInitPlugin can help us.