# Released under the MIT License. See LICENSE for details.
#
"""Functionality for prepping types for use with dataclassio."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
#
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
import logging
from enum import Enum
import dataclasses
import typing
import types
import datetime
from typing import TYPE_CHECKING, get_type_hints
# noinspection PyProtectedMember
from efro.dataclassio._base import (
_parse_annotated,
_get_origin,
SIMPLE_TYPES,
IOMultiType,
)
if TYPE_CHECKING:
from typing import Any
from efro.dataclassio._base import IOAttrs
# How deep we go when prepping nested types (basically for detecting
# recursive types)
MAX_RECURSION = 10
# Attr name for data we store on dataclass types that have been prepped.
PREP_ATTR = '_DCIOPREP'
# We also store the prep-session while the prep is in progress.
# (necessary to support recursive types).
PREP_SESSION_ATTR = '_DCIOPREPSESSION'
[docs]
def ioprep(cls: type, globalns: dict | None = None) -> None:
"""Prep a dataclass type for use with this module's functionality.
Prepping ensures that all types contained in a data class as well as
the usage of said types are supported by this module and pre-builds
necessary constructs needed for encoding/decoding/etc.
Prepping will happen on-the-fly as needed, but a warning will be
emitted in such cases, as it is better to explicitly prep all used types
early in a process to ensure any invalid types or configuration are caught
immediately.
Prepping a dataclass involves evaluating its type annotations, which,
as of PEP 563, are stored simply as strings. This evaluation is done
with localns set to the class dict (so that types defined in the class
can be used) and globalns set to the containing module's class.
It is possible to override globalns for special cases such as when
prepping happens as part of an execed string instead of within a
module.
"""
PrepSession(explicit=True, globalns=globalns).prep_dataclass(
cls, recursion_level=0
)
[docs]
def ioprepped[T](cls: type[T]) -> type[T]:
"""Class decorator for easily prepping a dataclass at definition time.
Note that in some cases it may not be possible to prep a dataclass
immediately (such as when its type annotations refer to forward-declared
types). In these cases, dataclass_prep() should be explicitly called for
the class as soon as possible; ideally at module import time to expose any
errors as early as possible in execution.
"""
ioprep(cls)
return cls
[docs]
def will_ioprep[T](cls: type[T]) -> type[T]:
"""Class decorator hinting that we will prep a class later.
In some cases (such as recursive types) we cannot use the @ioprepped
decorator and must instead call ioprep() explicitly later. However,
some of our custom pylint checking behaves differently when the
@ioprepped decorator is present, in that case requiring type annotations
to be present and not simply forward declared under an "if TYPE_CHECKING"
block. (since they are used at runtime).
The @will_ioprep decorator triggers the same pylint behavior
differences as @ioprepped (which are necessary for the later ioprep() call
to work correctly) but without actually running any prep itself.
"""
return cls
[docs]
def is_ioprepped_dataclass(obj: Any) -> bool:
"""Return whether the obj is an ioprepped dataclass type or instance."""
cls = obj if isinstance(obj, type) else type(obj)
return dataclasses.is_dataclass(cls) and hasattr(cls, PREP_ATTR)
@dataclasses.dataclass
class PrepData:
"""Data we prepare and cache for a class during prep.
This data is used as part of the encoding/decoding/validating process.
"""
# Resolved annotation data with 'live' classes.
annotations: dict[str, Any]
# Map of storage names to attr names.
storage_names_to_attr_names: dict[str, str]
class PrepSession:
"""Context for a prep."""
def __init__(self, explicit: bool, globalns: dict | None = None):
self.explicit = explicit
self.globalns = globalns
def prep_dataclass(
self, cls: type, recursion_level: int
) -> PrepData | None:
"""Run prep on a dataclass if necessary and return its prep data.
The only case where this will return None is for recursive types
if the type is already being prepped higher in the call order.
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# We should only need to do this once per dataclass.
existing_data = getattr(cls, PREP_ATTR, None)
if existing_data is not None:
assert isinstance(existing_data, PrepData)
return existing_data
# Sanity check.
#
# Note that we now support recursive types via the
# PREP_SESSION_ATTR, so we theoretically shouldn't run into this
# this.
if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.')
# We should only be passed classes which are dataclasses.
cls_any: Any = cls
if not isinstance(cls_any, type) or not dataclasses.is_dataclass(cls):
raise TypeError(f'Passed arg {cls} is not a dataclass type.')
# Add a pointer to the prep-session while doing the prep. This
# way we can ignore types that we're already in the process of
# prepping and can support recursive types.
existing_prep = getattr(cls, PREP_SESSION_ATTR, None)
if existing_prep is not None:
if existing_prep is self:
return None
# We shouldn't need to support failed preps or preps from
# multiple threads at once.
raise RuntimeError('Found existing in-progress prep.')
setattr(cls, PREP_SESSION_ATTR, self)
# Generate a warning on non-explicit preps; we prefer prep to
# happen explicitly at runtime so errors can be detected early
# on.
if not self.explicit:
logging.warning(
'efro.dataclassio: implicitly prepping dataclass: %s.'
' It is highly recommended to explicitly prep dataclasses'
' as soon as possible after definition (via'
' efro.dataclassio.ioprep() or the'
' @efro.dataclassio.ioprepped decorator).',
cls,
)
try:
# NOTE: Now passing the class' __dict__ (vars()) as locals
# which allows us to pick up nested classes, etc.
resolved_annotations = get_type_hints(
cls,
localns=vars(cls),
globalns=self.globalns,
include_extras=True,
)
# pylint: enable=unexpected-keyword-arg
except Exception as exc:
raise TypeError(
f'dataclassio prep for {cls} failed with error: {exc}.'
f' Make sure all types used in annotations are defined'
f' at the module or class level or add them as part of an'
f' explicit prep call.'
) from exc
# noinspection PyDataclass
fields = dataclasses.fields(cls)
fields_by_name = {f.name: f for f in fields}
all_storage_names: set[str] = set()
storage_names_to_attr_names: dict[str, str] = {}
# Ok; we've resolved actual types for this dataclass. now
# recurse through them, verifying that we support all contained
# types and prepping any contained dataclass types.
for attrname, anntype in resolved_annotations.items():
anntype, ioattrs = _parse_annotated(anntype)
# If we found attached IOAttrs data, make sure it contains
# valid values for the field it is attached to.
if ioattrs is not None:
ioattrs.validate_for_field(cls, fields_by_name[attrname])
if ioattrs.storagename is not None:
storagename = ioattrs.storagename
storage_names_to_attr_names[ioattrs.storagename] = attrname
else:
storagename = attrname
else:
storagename = attrname
# Make sure we don't have any clashes in our storage names.
if storagename in all_storage_names:
raise TypeError(
f'Multiple attrs on {cls} are using'
f' storage-name \'{storagename}\''
)
all_storage_names.add(storagename)
self.prep_type(
cls,
attrname,
anntype,
ioattrs=ioattrs,
recursion_level=recursion_level + 1,
)
# Success! Store our resolved stuff with the class and we're
# done.
prepdata = PrepData(
annotations=resolved_annotations,
storage_names_to_attr_names=storage_names_to_attr_names,
)
setattr(cls, PREP_ATTR, prepdata)
# Clear our prep-session tag.
assert getattr(cls, PREP_SESSION_ATTR, None) is self
delattr(cls, PREP_SESSION_ATTR)
return prepdata
def prep_type(
self,
cls: type,
attrname: str,
anntype: Any,
ioattrs: IOAttrs | None,
recursion_level: int,
) -> None:
"""Run prep on a dataclass."""
# pylint: disable=too-many-positional-arguments
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
if recursion_level > MAX_RECURSION:
raise RuntimeError('Max recursion exceeded.')
origin = _get_origin(anntype)
# If we inherit from IOMultiType, we use its type map to
# determine which type we're going to instead of the annotation.
# And we can't really check those types because they are
# lazy-loaded. So I guess we're done here.
if issubclass(origin, IOMultiType):
return
# noinspection PyPep8
if origin is typing.Union or origin is types.UnionType:
self.prep_union(
cls, attrname, anntype, recursion_level=recursion_level + 1
)
return
if anntype is typing.Any:
return
# Everything below this point assumes the annotation type
# resolves to a concrete type.
if not isinstance(origin, type):
raise TypeError(
f'Unsupported type found for \'{attrname}\' on {cls}:'
f' {anntype}'
)
# If a soft_default value/factory was passed, we do some basic
# type checking on the top-level value here. We also run full
# recursive validation on values later during inputting, but
# this should catch at least some errors early on, which can be
# useful since soft_defaults are not static type checked.
if ioattrs is not None:
have_soft_default = False
soft_default: Any = None
if ioattrs.soft_default is not ioattrs.MISSING:
have_soft_default = True
soft_default = ioattrs.soft_default
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
assert callable(ioattrs.soft_default_factory)
have_soft_default = True
soft_default = ioattrs.soft_default_factory()
# Do a simple type check for the top level to catch basic
# soft_default mismatches early; full check will happen at
# input time.
if have_soft_default:
if not isinstance(soft_default, origin):
raise TypeError(
f'{cls} attr {attrname} has type {origin}'
f' but soft_default value is type {type(soft_default)}'
)
if origin in SIMPLE_TYPES:
return
# For sets and lists, check out their single contained type (if
# any).
if origin in (list, set):
childtypes = typing.get_args(anntype)
if len(childtypes) == 0:
# This is equivalent to Any; nothing else needs
# checking.
return
if len(childtypes) > 1:
raise TypeError(
f'Unrecognized typing arg count {len(childtypes)}'
f" for {anntype} attr '{attrname}' on {cls}"
)
self.prep_type(
cls,
attrname,
childtypes[0],
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
if origin is dict:
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# For key types we support Any, str, int,
# and Enums with uniform str/int values.
if not childtypes or childtypes[0] is typing.Any:
# 'Any' needs no further checks (just checked
# per-instance).
pass
elif childtypes[0] in (str, int):
# str and int are all good as keys.
pass
elif issubclass(childtypes[0], Enum):
# Allow our usual str or int enum types as keys.
self.prep_enum(childtypes[0], ioattrs=None)
else:
raise TypeError(
f'Dict key type {childtypes[0]} for \'{attrname}\''
f' on {cls.__name__} is not supported by dataclassio.'
)
# For value types we support any of our normal types.
if not childtypes or _get_origin(childtypes[1]) is typing.Any:
# 'Any' needs no further checks (just checked
# per-instance).
pass
else:
self.prep_type(
cls,
attrname,
childtypes[1],
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
# For Tuples, simply check individual member types. (and, for
# now, explicitly disallow zero member types or usage of
# ellipsis)
if origin is tuple:
childtypes = typing.get_args(anntype)
if not childtypes:
raise TypeError(
f'Tuple at \'{attrname}\''
f' has no type args; dataclassio requires type args.'
)
if childtypes[-1] is ...:
raise TypeError(
f'Found ellipsis as part of type for'
f' \'{attrname}\' on {cls.__name__};'
f' these are not'
f' supported by dataclassio.'
)
for childtype in childtypes:
self.prep_type(
cls,
attrname,
childtype,
ioattrs=None,
recursion_level=recursion_level + 1,
)
return
if issubclass(origin, Enum):
self.prep_enum(origin, ioattrs=ioattrs)
return
# We allow datetime objects (and google's extended subclass of
# them used in firestore, which is why we don't look for exact
# type here).
if issubclass(origin, datetime.datetime):
return
# We support datetime.timedelta.
if issubclass(origin, datetime.timedelta):
return
if dataclasses.is_dataclass(origin):
self.prep_dataclass(origin, recursion_level=recursion_level + 1)
return
if origin is bytes:
return
raise TypeError(
f"Attr '{attrname}' on {cls.__name__} contains"
f" type '{anntype}'"
f' which is not supported by dataclassio.'
)
def prep_union(
self, cls: type, attrname: str, anntype: Any, recursion_level: int
) -> None:
"""Run prep on a Union type."""
typeargs = typing.get_args(anntype)
if (
len(typeargs) != 2
or len([c for c in typeargs if c is type(None)]) != 1
): # noqa
raise TypeError(
f'Union {anntype} for attr \'{attrname}\' on'
f' {cls.__name__} is not supported by dataclassio;'
f' only 2 member Unions with one type being None'
f' are supported.'
)
for childtype in typeargs:
self.prep_type(
cls,
attrname,
childtype,
None,
recursion_level=recursion_level + 1,
)
def prep_enum(
self,
enumtype: type[Enum],
ioattrs: IOAttrs | None,
) -> None:
"""Run prep on an enum type."""
valtype: Any = None
# We currently support enums with str or int values; fail if we
# find any others.
for enumval in enumtype:
if not isinstance(enumval.value, (str, int)):
raise TypeError(
f'Enum value {enumval} has value type'
f' {type(enumval.value)}; only str and int is'
f' supported by dataclassio.'
)
if valtype is None:
valtype = type(enumval.value)
else:
if type(enumval.value) is not valtype:
raise TypeError(
f'Enum type {enumtype} has multiple'
f' value types; dataclassio requires'
f' them to be uniform.'
)
if ioattrs is not None:
# If they provided a fallback enum value, make sure it
# is the correct type.
if ioattrs.enum_fallback is not None:
if type(ioattrs.enum_fallback) is not enumtype:
raise TypeError(
f'enum_fallback {ioattrs.enum_fallback} does not'
f' match the field type ({enumtype}.'
)
# Docs-generation hack; import some stuff that we likely only forward-declared
# in our actual source code so that docs tools can find it.
from typing import (Coroutine, Any, Literal, Callable,
Generator, Awaitable, Sequence, Self)
import asyncio
from concurrent.futures import Future