from __future__ import annotations
from typing import TYPE_CHECKING, Any
from litestar.plugins import InitPlugin
from litestar.plugins.pydantic.dto import PydanticDTO
from litestar.plugins.pydantic.plugins.di import PydanticDIPlugin
from litestar.plugins.pydantic.plugins.init import PydanticInitPlugin
from litestar.plugins.pydantic.plugins.schema import PydanticSchemaPlugin
if TYPE_CHECKING:
    from pydantic import BaseModel
    from pydantic.v1 import BaseModel as BaseModelV1
    from litestar.config.app import AppConfig
    from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType
__all__ = (
    "PydanticDIPlugin",
    "PydanticDTO",
    "PydanticInitPlugin",
    "PydanticPlugin",
    "PydanticSchemaPlugin",
)
def _model_dump(
    model: BaseModel | BaseModelV1,
    *,
    by_alias: bool = False,
    round_trip: bool = False,
) -> dict[str, Any]:
    return (
        model.model_dump(mode="json", by_alias=by_alias, round_trip=round_trip)  # pyright: ignore
        if hasattr(model, "model_dump")
        else {k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=by_alias).items()}
    )
def _model_dump_json(
    model: BaseModel | BaseModelV1,
    by_alias: bool = False,
    round_trip: bool = False,
) -> str:
    return (
        model.model_dump_json(by_alias=by_alias, round_trip=round_trip)  # pyright: ignore
        if hasattr(model, "model_dump_json")
        else model.json(by_alias=by_alias)  # pyright: ignore
    )
[docs]
class PydanticPlugin(InitPlugin):
    """A plugin that provides Pydantic integration."""
    __slots__ = (
        "exclude",
        "exclude_defaults",
        "exclude_none",
        "exclude_unset",
        "include",
        "prefer_alias",
        "round_trip",
        "validate_strict",
    )
[docs]
    def __init__(
        self,
        exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
        exclude_defaults: bool = False,
        exclude_none: bool = False,
        exclude_unset: bool = False,
        include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
        prefer_alias: bool = False,
        validate_strict: bool = False,
        round_trip: bool = False,
    ) -> None:
        """Pydantic Plugin to support serialization / validation of Pydantic types / models
        :param exclude: Fields to exclude during serialization
        :param exclude_defaults: Fields to exclude during serialization when they are set to their default value
        :param exclude_none: Fields to exclude during serialization when they are set to ``None``
        :param exclude_unset: Fields to exclude during serialization when they arenot set
        :param include: Fields to exclude during serialization
        :param prefer_alias: Use the ``by_alias=True`` flag when dumping models
        :param validate_strict: Use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models
        :param round_trip: use ``round_trip=True`` when calling ``.model_dump``
          and ``.model_dump_json`` on Pydantic 2.x models
        """
        self.exclude = exclude
        self.exclude_defaults = exclude_defaults
        self.exclude_none = exclude_none
        self.exclude_unset = exclude_unset
        self.include = include
        self.prefer_alias = prefer_alias
        self.validate_strict = validate_strict
        self.round_trip = round_trip 
[docs]
    def on_app_init(self, app_config: AppConfig) -> AppConfig:
        """Configure application for use with Pydantic.
        Args:
            app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
        """
        app_config.plugins.extend(
            [
                PydanticInitPlugin(
                    exclude=self.exclude,
                    exclude_defaults=self.exclude_defaults,
                    exclude_none=self.exclude_none,
                    exclude_unset=self.exclude_unset,
                    include=self.include,
                    prefer_alias=self.prefer_alias,
                    validate_strict=self.validate_strict,
                    round_trip=self.round_trip,
                ),
                PydanticSchemaPlugin(prefer_alias=self.prefer_alias),
                PydanticDIPlugin(),
            ]
        )
        return app_config