from __future__ import annotations
from typing import TYPE_CHECKING, Any
from litestar.exceptions import MissingDependencyException
from litestar.openapi.spec import OpenAPIFormat, OpenAPIType, Reference, Schema
from litestar.plugins import OpenAPISchemaPlugin
from litestar.plugins.pydantic.utils import (
get_model_info,
is_pydantic_constrained_field,
is_pydantic_model_class,
is_pydantic_root_model,
is_pydantic_undefined,
is_pydantic_v2,
)
from litestar.typing import FieldDefinition
from litestar.utils import is_class_and_subclass
try:
import pydantic as _ # noqa: F401
except ImportError as e:
raise MissingDependencyException("pydantic") from e
try:
import pydantic as pydantic_v2
if not is_pydantic_v2(pydantic_v2):
raise ImportError
from pydantic import v1 as pydantic_v1
except ImportError:
import pydantic as pydantic_v1 # type: ignore[no-redef]
pydantic_v2 = None # type: ignore[assignment]
if TYPE_CHECKING:
from litestar._openapi.schema_generation.schema import SchemaCreator
PYDANTIC_TYPE_MAP: dict[type[Any] | None | Any, Schema] = {
pydantic_v1.ByteSize: Schema(type=OpenAPIType.INTEGER),
pydantic_v1.EmailStr: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL),
pydantic_v1.IPvAnyAddress: Schema(
one_of=[
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV4,
description="IPv4 address",
),
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV6,
description="IPv6 address",
),
]
),
pydantic_v1.IPvAnyInterface: Schema(
one_of=[
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV4,
description="IPv4 interface",
),
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV6,
description="IPv6 interface",
),
]
),
pydantic_v1.IPvAnyNetwork: Schema(
one_of=[
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV4,
description="IPv4 network",
),
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV6,
description="IPv6 network",
),
]
),
pydantic_v1.Json: Schema(type=OpenAPIType.OBJECT, format=OpenAPIFormat.JSON_POINTER),
pydantic_v1.NameEmail: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL, description="Name and email"),
# removed in v2
pydantic_v1.PyObject: Schema(
type=OpenAPIType.STRING,
description="dot separated path identifying a python object, e.g. 'decimal.Decimal'",
),
# annotated in v2
pydantic_v1.UUID1: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.UUID,
description="UUID1 string",
),
pydantic_v1.UUID3: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.UUID,
description="UUID3 string",
),
pydantic_v1.UUID4: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.UUID,
description="UUID4 string",
),
pydantic_v1.UUID5: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.UUID,
description="UUID5 string",
),
pydantic_v1.DirectoryPath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE),
pydantic_v1.AnyUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL),
pydantic_v1.AnyHttpUrl: Schema(
type=OpenAPIType.STRING, format=OpenAPIFormat.URL, description="must be a valid HTTP based URL"
),
pydantic_v1.FilePath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE),
pydantic_v1.HttpUrl: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.URL,
description="must be a valid HTTP based URL",
max_length=2083,
),
pydantic_v1.RedisDsn: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="redis DSN"),
pydantic_v1.PostgresDsn: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="postgres DSN"),
pydantic_v1.SecretBytes: Schema(type=OpenAPIType.STRING),
pydantic_v1.SecretStr: Schema(type=OpenAPIType.STRING),
pydantic_v1.StrictBool: Schema(type=OpenAPIType.BOOLEAN),
pydantic_v1.StrictBytes: Schema(type=OpenAPIType.STRING),
pydantic_v1.StrictFloat: Schema(type=OpenAPIType.NUMBER),
pydantic_v1.StrictInt: Schema(type=OpenAPIType.INTEGER),
pydantic_v1.StrictStr: Schema(type=OpenAPIType.STRING),
pydantic_v1.NegativeFloat: Schema(type=OpenAPIType.NUMBER, exclusive_maximum=0.0),
pydantic_v1.NegativeInt: Schema(type=OpenAPIType.INTEGER, exclusive_maximum=0),
pydantic_v1.NonNegativeInt: Schema(type=OpenAPIType.INTEGER, minimum=0),
pydantic_v1.NonPositiveFloat: Schema(type=OpenAPIType.NUMBER, maximum=0.0),
pydantic_v1.PaymentCardNumber: Schema(type=OpenAPIType.STRING, min_length=12, max_length=19),
pydantic_v1.PositiveFloat: Schema(type=OpenAPIType.NUMBER, exclusive_minimum=0.0),
pydantic_v1.PositiveInt: Schema(type=OpenAPIType.INTEGER, exclusive_minimum=0),
}
if pydantic_v2 is not None: # pragma: no cover
from pydantic import networks
PYDANTIC_TYPE_MAP.update(
{
pydantic_v2.SecretStr: Schema(type=OpenAPIType.STRING),
pydantic_v2.SecretBytes: Schema(type=OpenAPIType.STRING),
pydantic_v2.ByteSize: Schema(type=OpenAPIType.INTEGER),
pydantic_v2.EmailStr: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL),
pydantic_v2.IPvAnyAddress: Schema(
one_of=[
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV4,
description="IPv4 address",
),
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV6,
description="IPv6 address",
),
]
),
pydantic_v2.IPvAnyInterface: Schema(
one_of=[
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV4,
description="IPv4 interface",
),
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV6,
description="IPv6 interface",
),
]
),
pydantic_v2.IPvAnyNetwork: Schema(
one_of=[
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV4,
description="IPv4 network",
),
Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.IPV6,
description="IPv6 network",
),
]
),
pydantic_v2.Json: Schema(type=OpenAPIType.OBJECT, format=OpenAPIFormat.JSON_POINTER),
pydantic_v2.NameEmail: Schema(
type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL, description="Name and email"
),
pydantic_v2.AnyUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL),
pydantic_v2.PastDate: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.DATE,
description="date with the constraint that the value must be in the past",
),
pydantic_v2.FutureDate: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.DATE,
description="date with the constraint that the value must be in the future",
),
pydantic_v2.PastDatetime: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.DATE_TIME,
description="datetime with the constraint that the value must be in the past",
),
pydantic_v2.FutureDatetime: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.DATE_TIME,
description="datetime with the constraint that the value must be in the future",
),
pydantic_v2.AwareDatetime: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.DATE_TIME,
description="datetime with the constraint that the value must have timezone info",
),
pydantic_v2.NaiveDatetime: Schema(
type=OpenAPIType.STRING,
format=OpenAPIFormat.DATE_TIME,
description="datetime with the constraint that the value must lack timezone info",
),
}
)
if int(pydantic_v2.version.version_short().split(".")[1]) >= 10:
# These were 'Annotated' type aliases before Pydantic 2.10, where they were
# changed to proper classes. Using subscripted generics type in an 'isinstance'
# check would raise a 'TypeError' on Python <3.12
PYDANTIC_TYPE_MAP.update(
{
networks.HttpUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL),
networks.AnyHttpUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL),
}
)
_supported_types = (pydantic_v1.BaseModel, *PYDANTIC_TYPE_MAP.keys())
if pydantic_v2 is not None: # pragma: no cover
_supported_types = (pydantic_v2.BaseModel, *_supported_types)
[docs]
class PydanticSchemaPlugin(OpenAPISchemaPlugin):
[docs]
def __init__(self, prefer_alias: bool = False) -> None:
self.prefer_alias = prefer_alias
[docs]
@staticmethod
def is_plugin_supported_type(value: Any) -> bool:
return isinstance(value, _supported_types) or is_class_and_subclass(value, _supported_types) # type: ignore[arg-type]
[docs]
@staticmethod
def is_undefined_sentinel(value: Any) -> bool:
return is_pydantic_undefined(value)
[docs]
@staticmethod
def is_constrained_field(field_definition: FieldDefinition) -> bool:
return is_pydantic_constrained_field(field_definition.annotation)
[docs]
def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema | Reference:
"""Given a type annotation, transform it into an OpenAPI schema class.
Args:
field_definition: FieldDefinition instance.
schema_creator: An instance of the schema creator class
Returns:
An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance.
"""
if schema_creator.prefer_alias != self.prefer_alias:
schema_creator.prefer_alias = True
if is_pydantic_model_class(field_definition.annotation):
return self.for_pydantic_model(field_definition=field_definition, schema_creator=schema_creator)
return PYDANTIC_TYPE_MAP[field_definition.annotation] # pragma: no cover
[docs]
@classmethod
def for_pydantic_model(cls, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema | Reference: # pyright: ignore
"""Create a schema object for a given pydantic model class.
Args:
field_definition: FieldDefinition instance.
schema_creator: An instance of the schema creator class
Returns:
A schema instance.
"""
model_info = get_model_info(field_definition.annotation, prefer_alias=schema_creator.prefer_alias)
# Handle RootModel: generate schema for the root field content instead of treating it as a regular field
if is_pydantic_root_model(field_definition.annotation) and (
root_field := model_info.field_definitions.get("root")
):
root_field_def = FieldDefinition.from_annotation(
annotation=root_field.annotation,
name=field_definition.name,
default=field_definition.default,
extra=field_definition.extra,
)
return schema_creator.for_field_definition(root_field_def)
return schema_creator.create_component_schema(
field_definition,
required=sorted(f.name for f in model_info.field_definitions.values() if f.is_required),
property_fields=model_info.field_definitions,
title=model_info.title,
examples=None if model_info.example is None else [model_info.example],
)