Source code for polyfactory.fields

from __future__ import annotations

from typing import Any, Callable, Generic, TypedDict, TypeVar, cast

from typing_extensions import ParamSpec

from polyfactory.exceptions import MissingParamException, ParameterException
from polyfactory.field_meta import Null
from polyfactory.utils import deprecation
from polyfactory.utils.predicates import is_safe_subclass

T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")


[docs]class WrappedCallable(TypedDict): """A ref storing a callable. This class is a utility meant to prevent binding of methods.""" value: Callable
[docs]class Require: """A factory field that marks an attribute as a required build-time kwarg."""
[docs]class Ignore: """A factory field that marks an attribute as ignored."""
[docs]class Use(Generic[P, T]): """Factory field used to wrap a callable. The callable will be invoked whenever building the given factory attribute. """ __slots__ = ("args", "fn", "kwargs")
[docs] def __init__(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: """Wrap a callable. :param fn: A callable to wrap. :param args: Any args to pass to the callable. :param kwargs: Any kwargs to pass to the callable. """ self.fn: WrappedCallable = {"value": fn} self.kwargs = kwargs self.args = args
[docs] def to_value(self) -> T: """Invoke the callable. :returns: The output of the callable. """ return cast("T", self.fn["value"](*self.args, **self.kwargs))
[docs]class PostGenerated: """Factory field that allows generating values after other fields are generated by the factory.""" __slots__ = ("args", "fn", "kwargs")
[docs] def __init__(self, fn: Callable, *args: Any, **kwargs: Any) -> None: """Designate field as post-generated. :param fn: A callable. :param args: Args for the callable. :param kwargs: Kwargs for the callable. """ self.fn: WrappedCallable = {"value": fn} self.kwargs = kwargs self.args = args
[docs] def to_value(self, name: str, values: dict[str, Any]) -> Any: """Invoke the post-generation callback passing to it the build results. :param name: Field name. :param values: Generated values. :returns: An arbitrary value. """ return self.fn["value"](name, values, *self.args, **self.kwargs)
[docs]class Fixture: """Factory field to create a pytest fixture from a factory.""" __slots__ = ("kwargs", "ref", "size")
[docs] @deprecation.deprecated(version="2.20.0", alternative="Use factory directly") def __init__(self, fixture: Callable, size: int | None = None, **kwargs: Any) -> None: """Create a fixture from a factory. :param fixture: A factory that was registered as a fixture. :param size: Optional batch size. :param kwargs: Any build kwargs. """ self.ref: WrappedCallable = {"value": fixture} self.size = size self.kwargs = kwargs
[docs] def to_value(self) -> Any: """Call the factory's build or batch method. :raises: ParameterException :returns: The build result. """ from polyfactory.factories.base import BaseFactory factory = self.ref["value"] if not is_safe_subclass(factory, BaseFactory): msg = "fixture has not been registered using the register_factory decorator" raise ParameterException(msg) if self.size is not None: return factory.batch(self.size, **self.kwargs) return factory.build(**self.kwargs)
[docs]class Param(Generic[T]): """A constant parameter that can be used by other fields but will not be passed to the final object. If a value for the parameter is not passed in the field's definition, it must be passed at build time. Otherwise, a MissingParamException will be raised. """ __slots__ = ("is_callable", "kwargs", "param")
[docs] def __init__( self, param: T | Callable[..., T] | type[Null] = Null, is_callable: bool = False, **kwargs: Any ) -> None: """Designate a parameter. :param param: A constant or an unpassed value that can be referenced later """ if param is not Null and is_callable and not callable(param): msg = "If an object is passed to param, a callable must be passed when is_callable is True" raise ParameterException(msg) if not is_callable and kwargs: msg = "kwargs can only be used with callable parameters" raise ParameterException(msg) self.param = param self.is_callable = is_callable self.kwargs = kwargs
[docs] def to_value(self, from_build: T | Callable[..., T] | type[Null] = Null, **kwargs: Any) -> T: """Determines the value to use at build time If a value was passed to the constructor, it will be used. Otherwise, the value passed at build time will be used. If no value was passed at build time, a MissingParamException will be raised. :param args: from_build: The value passed at build time (if any). :returns: The value :raises: MissingParamException """ # If no param is passed at initialization, a value must be passed now if self.param is Null: # from_build was passed, so determine the value based on whether or # not we're supposed to call a callable if from_build is not Null: return ( cast("T", from_build) if not self.is_callable else cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs}) ) # Otherwise, raise an exception msg = ( "Expected a parameter value to be passed at build time" if not self.is_callable else "Expected a callable to be passed at build time" ) raise MissingParamException(msg) # A param was passed at initialization if self.is_callable: # In this case, we are going to call the callable, but we can still # override if are passed a callable at build if from_build is not Null: if callable(from_build): return cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs}) # If we were passed a value at build that isn't a callable, raise # an exception msg = "The value passed at build time is not callable" raise TypeError(msg) # Otherwise, return the value passed at initialization return cast("Callable[..., T]", self.param)(**{**self.kwargs, **kwargs}) # Inthis case, we are not using a callable, so return either the value # passed at build time or initialization return cast("T", self.param) if from_build is Null else cast("T", from_build)