Source code for efro.dataclassio._pathcapture

# Released under the MIT License. See LICENSE for details.
#
"""Functionality related to capturing nested dataclass paths."""

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING

from efro.dataclassio._base import parse_annotated, _get_origin
from efro.dataclassio._prep import PrepSession

if TYPE_CHECKING:
    from typing import Any, Callable


class _PathCapture:
    """Utility for obtaining dataclass storage paths in a type safe way."""

    def __init__(self, obj: Any, pathparts: list[str] | None = None):
        self._is_dataclass = dataclasses.is_dataclass(obj)
        if pathparts is None:
            pathparts = []
        self._cls = obj if isinstance(obj, type) else type(obj)
        self._pathparts = pathparts

    def __getattr__(self, name: str) -> _PathCapture:
        # We only allow diving into sub-objects if we are a dataclass.
        if not self._is_dataclass:
            raise TypeError(
                f"Field path cannot include attribute '{name}' "
                f'under parent {self._cls}; parent types must be dataclasses.'
            )

        prep = PrepSession(explicit=False).prep_dataclass(
            self._cls, recursion_level=0
        )
        assert prep is not None
        try:
            anntype = prep.annotations[name]
        except KeyError as exc:
            raise AttributeError(f'{type(self)} has no {name} field.') from exc
        anntype, ioattrs = parse_annotated(anntype)
        storagename = (
            name
            if (ioattrs is None or ioattrs.storagename is None)
            else ioattrs.storagename
        )
        origin = _get_origin(anntype)
        return _PathCapture(origin, pathparts=self._pathparts + [storagename])

    @property
    def path(self) -> str:
        """The final output path."""
        return '.'.join(self._pathparts)


[docs] class DataclassFieldLookup[T]: """Get info about nested dataclass fields in type-safe way.""" def __init__(self, cls: type[T]) -> None: self.cls = cls
[docs] def path(self, callback: Callable[[T], Any]) -> str: """Look up a path on child dataclass fields. example: DataclassFieldLookup(MyType).path(lambda obj: obj.foo.bar) The above example will return the string 'foo.bar' or something like 'f.b' if the dataclasses have custom storage names set. It will also be static-type-checked, triggering an error if MyType.foo.bar is not a valid path. Note, however, that the callback technically allows any return value but only nested dataclasses and their fields will succeed. """ # We tell the type system that we are returning an instance # of our class, which allows it to perform type checking on # member lookups. In reality, however, we are providing a # special object which captures path lookups, so we can build # a string from them. if not TYPE_CHECKING: out = callback(_PathCapture(self.cls)) if not isinstance(out, _PathCapture): raise TypeError( f'Expected a valid path under' f' the provided object; got a {type(out)}.' ) return out.path return ''
[docs] def paths(self, callback: Callable[[T], list[Any]]) -> list[str]: """Look up multiple paths on child dataclass fields. Functionality is identical to path() but for multiple paths at once. example: DataclassFieldLookup(MyType).paths(lambda obj: [obj.foo, obj.bar]) """ outvals: list[str] = [] if not TYPE_CHECKING: outs = callback(_PathCapture(self.cls)) assert isinstance(outs, list) for out in outs: if not isinstance(out, _PathCapture): raise TypeError( f'Expected a valid path under' f' the provided object; got a {type(out)}.' ) outvals.append(out.path) return outvals
# Docs-generation hack; import some stuff that we likely only forward-declared # in our actual source code so that docs tools can find it. from typing import (Coroutine, Any, Literal, Callable, Generator, Awaitable, Sequence, Self) import asyncio from concurrent.futures import Future from pathlib import Path from enum import Enum