from __future__ import annotations
from typing import Any, Callable, Generic, TypedDict, TypeVar, cast
from typing_extensions import ParamSpec
from polyfactory.exceptions import MissingParamException, ParameterException
from polyfactory.field_meta import Null
from polyfactory.utils import deprecation
from polyfactory.utils.predicates import is_safe_subclass
T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")
[docs]class WrappedCallable(TypedDict):
"""A ref storing a callable. This class is a utility meant to prevent binding of methods."""
value: Callable
[docs]class Require:
"""A factory field that marks an attribute as a required build-time kwarg."""
[docs]class Ignore:
"""A factory field that marks an attribute as ignored."""
[docs]class Use(Generic[P, T]):
"""Factory field used to wrap a callable.
The callable will be invoked whenever building the given factory attribute.
"""
__slots__ = ("args", "fn", "kwargs")
[docs] def __init__(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
"""Wrap a callable.
:param fn: A callable to wrap.
:param args: Any args to pass to the callable.
:param kwargs: Any kwargs to pass to the callable.
"""
self.fn: WrappedCallable = {"value": fn}
self.kwargs = kwargs
self.args = args
[docs] def to_value(self) -> T:
"""Invoke the callable.
:returns: The output of the callable.
"""
return cast("T", self.fn["value"](*self.args, **self.kwargs))
[docs]class PostGenerated:
"""Factory field that allows generating values after other fields are generated by the factory."""
__slots__ = ("args", "fn", "kwargs")
[docs] def __init__(self, fn: Callable, *args: Any, **kwargs: Any) -> None:
"""Designate field as post-generated.
:param fn: A callable.
:param args: Args for the callable.
:param kwargs: Kwargs for the callable.
"""
self.fn: WrappedCallable = {"value": fn}
self.kwargs = kwargs
self.args = args
[docs] def to_value(self, name: str, values: dict[str, Any]) -> Any:
"""Invoke the post-generation callback passing to it the build results.
:param name: Field name.
:param values: Generated values.
:returns: An arbitrary value.
"""
return self.fn["value"](name, values, *self.args, **self.kwargs)
[docs]class Fixture:
"""Factory field to create a pytest fixture from a factory."""
__slots__ = ("kwargs", "ref", "size")
[docs] @deprecation.deprecated(version="2.20.0", alternative="Use factory directly")
def __init__(self, fixture: Callable, size: int | None = None, **kwargs: Any) -> None:
"""Create a fixture from a factory.
:param fixture: A factory that was registered as a fixture.
:param size: Optional batch size.
:param kwargs: Any build kwargs.
"""
self.ref: WrappedCallable = {"value": fixture}
self.size = size
self.kwargs = kwargs
[docs] def to_value(self) -> Any:
"""Call the factory's build or batch method.
:raises: ParameterException
:returns: The build result.
"""
from polyfactory.factories.base import BaseFactory
factory = self.ref["value"]
if not is_safe_subclass(factory, BaseFactory):
msg = "fixture has not been registered using the register_factory decorator"
raise ParameterException(msg)
if self.size is not None:
return factory.batch(self.size, **self.kwargs)
return factory.build(**self.kwargs)
[docs]class Param(Generic[T]):
"""A constant parameter that can be used by other fields but will not be
passed to the final object.
If a value for the parameter is not passed in the field's definition, it must
be passed at build time. Otherwise, a MissingParamException will be raised.
"""
__slots__ = ("is_callable", "kwargs", "param")
[docs] def __init__(
self, param: T | Callable[..., T] | type[Null] = Null, is_callable: bool = False, **kwargs: Any
) -> None:
"""Designate a parameter.
:param param: A constant or an unpassed value that can be referenced later
"""
if param is not Null and is_callable and not callable(param):
msg = "If an object is passed to param, a callable must be passed when is_callable is True"
raise ParameterException(msg)
if not is_callable and kwargs:
msg = "kwargs can only be used with callable parameters"
raise ParameterException(msg)
self.param = param
self.is_callable = is_callable
self.kwargs = kwargs
[docs] def to_value(self, from_build: T | Callable[..., T] | type[Null] = Null, **kwargs: Any) -> T:
"""Determines the value to use at build time
If a value was passed to the constructor, it will be used. Otherwise, the value
passed at build time will be used. If no value was passed at build time, a
MissingParamException will be raised.
:param args: from_build: The value passed at build time (if any).
:returns: The value
:raises: MissingParamException
"""
# If no param is passed at initialization, a value must be passed now
if self.param is Null:
# from_build was passed, so determine the value based on whether or
# not we're supposed to call a callable
if from_build is not Null:
return (
cast("T", from_build)
if not self.is_callable
else cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs})
)
# Otherwise, raise an exception
msg = (
"Expected a parameter value to be passed at build time"
if not self.is_callable
else "Expected a callable to be passed at build time"
)
raise MissingParamException(msg)
# A param was passed at initialization
if self.is_callable:
# In this case, we are going to call the callable, but we can still
# override if are passed a callable at build
if from_build is not Null:
if callable(from_build):
return cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs})
# If we were passed a value at build that isn't a callable, raise
# an exception
msg = "The value passed at build time is not callable"
raise TypeError(msg)
# Otherwise, return the value passed at initialization
return cast("Callable[..., T]", self.param)(**{**self.kwargs, **kwargs})
# Inthis case, we are not using a callable, so return either the value
# passed at build time or initialization
return cast("T", self.param) if from_build is Null else cast("T", from_build)