#!/usr/bin/env python3
#
# Copyright 2021 Graviti. Licensed under MIT License.
#
"""AttrsMixin and Field class.
:class:`AttrsMixin` provides a list of special methods based on field configs.
:class:`Field` is a class describing the attr related fields.
"""
from sys import version_info
from typing import (
Any,
Callable,
Dict,
List,
NoReturn,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
no_type_check,
)
from typing_extensions import Protocol
from ..exception import AttrError
_T = TypeVar("_T")
_Callable = Callable[[Any], Any]
_KeyConverter = Callable[[str], str]
_BUILTINS = {"builtins", None, "typing"}
_DEFAULT_ERROR_MESSAGE = "'{class_name}' object has no attribute '{attr_name}'"
_ATTRS_BASE = "_attrs_base"
class _A(Protocol): # pylint: disable=too-few-public-methods
def dumps(self) -> Any:
"""Dumps all the information of attrs into a dict."""
[docs]class Field: # pylint: disable=too-few-public-methods, too-many-instance-attributes
"""A class to identify attr fields.
Arguments:
is_dynamic: Whether attr is a dynamic attr.
key: Display value of the attr in contents.
default: Default value of the attr.
error_message: The custom error message of the attr.
"""
# remove __slots__ and add __getattr__ due to: https://github.com/PyCQA/pylint/issues/4341
# __slots__ = ("is_dynamic", "key", "default", "error_message", "loader", "dumper")
def __init__(
self,
is_dynamic: bool,
key: Union[str, None, _KeyConverter],
default: Any,
error_message: Optional[str],
) -> None:
self.loader: _Callable
self.dumper: _Callable
self.is_dynamic = is_dynamic
self.default = default
if callable(key):
self.key_converter = key
else:
self.key = key
if error_message:
self.error_message = error_message
def __getattr__(self, name: str) -> NoReturn:
raise AttributeError(
_DEFAULT_ERROR_MESSAGE.format(class_name=self.__class__.__name__, attr_name=name)
)
[docs]class BaseField: # pylint: disable=too-few-public-methods
"""A class to identify fields of base class.
Arguments:
key: Display value of the attr.
"""
def __init__(self, key: Optional[str]) -> None:
self.loader: _Callable
self.dumper: _Callable
self.key = key
[docs]class AttrsMixin:
"""AttrsMixin provides a list of special methods based on attr fields.
Examples:
box2d: Box2DSubcatalog = attr(is_dynamic=True, key="BOX2D")
"""
_attrs_fields: Dict[str, Field]
_attrs_base: Any
def __init_subclass__(cls) -> None:
type_ = cls.__annotations__.pop(_ATTRS_BASE, None)
if type_:
cls._attrs_base.loader = type_._loads # pylint: disable=protected-access
cls._attrs_base.dumper = type_.dumps
attrs_fields = {}
for base_class in cls.__bases__:
base_fields = getattr(base_class, "_attrs_fields", None)
if base_fields:
attrs_fields.update(base_fields)
for name, type_ in getattr(cls, "__annotations__", {}).items():
field = getattr(cls, name, None)
if isinstance(field, Field):
field.loader, field.dumper = _get_operators(type_)
if hasattr(field, "key_converter"):
field.key = field.key_converter(name)
attrs_fields[name] = field
delattr(cls, name)
cls._attrs_fields = attrs_fields
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return self.__dict__ == other.__dict__
@no_type_check
def __getattr__(self, name: str) -> NoReturn:
"""Raise an AttributeError exception when an attr does not exist.
Arguments:
name: Name of the attr.
Raises:
AttributeError: When an attr does not exist.
"""
error_message = getattr(
self._attrs_fields.get(name), "error_message", _DEFAULT_ERROR_MESSAGE
)
raise AttributeError(
error_message.format(class_name=self.__class__.__name__, attr_name=name)
)
def _loads(self, contents: Any) -> None:
"""Load an instance from a dict containing the information of the contents.
Arguments:
contents: A dict containing all data.
"""
base = getattr(self, _ATTRS_BASE, None)
if base:
value = contents if base.key is None else contents[base.key]
base.loader(self, value)
for name, field in self._attrs_fields.items():
if field.is_dynamic and field.key not in contents:
continue
if field.key is None:
value = contents
elif field.default is not ...:
value = contents.get(field.key, field.default)
else:
value = contents[field.key]
setattr(self, name, field.loader(value))
def _dumps(self) -> Dict[str, Any]:
"""Dumps all the information of attrs into a dict.
Returns:
contents: A dict containing all data of attrs.
"""
contents: Dict[str, Any] = {}
base = getattr(self, _ATTRS_BASE, None)
if base:
_key_dumper(base.key, contents, base.dumper(self))
for name, field in self._attrs_fields.items():
if field.is_dynamic and not hasattr(self, name):
continue
value = getattr(self, name)
if value == field.default:
continue
_key_dumper(field.key, contents, field.dumper(value))
return contents
[docs]def attr(
*,
is_dynamic: bool = False,
key: Union[str, None, _KeyConverter] = lambda x: x,
default: Any = ...,
error_message: Optional[str] = None,
) -> Any:
"""Return an instance to identify attr fields.
Arguments:
is_dynamic: Determine if this is a dynamic attr.
key: Display value of the attr in contents.
default: Default value of the attr.
error_message: The custom error message of the attr.
Raises:
AttrError: Dynamic attr cannot have default value.
Returns:
A :class:`Field` instance containing all attr fields.
"""
if is_dynamic and default is not ...:
raise AttrError()
return Field(is_dynamic, key, default, error_message)
[docs]def attr_base(key: Optional[str] = None) -> Any:
"""Return an instance to identify base class fields.
Arguments:
key: Display value of the attr.
Returns:
A :class:`BaseField` instance containing all base class fields.
"""
return BaseField(key)
[docs]def upper(name: str) -> str:
"""Convert the name value to uppercase.
Arguments:
name: name of the attr.
Returns:
The uppercase value.
"""
return name.upper()
[docs]def camel(name: str) -> str:
"""Convert the name value to camelcase.
Arguments:
name: name of the attr.
Returns:
The camelcase value.
"""
mixed = name.title().replace("_", "")
return f"{mixed[0].lower()}{mixed[1:]}"
def _key_dumper(key: Optional[str], contents: Dict[str, Any], value: Any) -> None:
if key is None:
contents.update(value)
else:
contents[key] = value
def _get_origin_in_3_7(annotation: Any) -> Any:
return getattr(annotation, "__origin__", None)
def _get_origin_in_3_6(annotation: Any) -> Any:
module = getattr(annotation, "__module__", None)
if module in _BUILTINS:
return getattr(annotation, "__extra__", None)
return getattr(annotation, "__origin__", None)
_get_origin = _get_origin_in_3_6 if version_info < (3, 7) else _get_origin_in_3_7
def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
"""Get attr operating methods by annotations.
AttrsMixin has three operating types which are classified by attr annotation.
1. builtin types, like str, int, None
2. tensorbay custom class, like tensorbay.label.Classification
3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D]
Arguments:
annotation: Type of the attr.
Returns:
Operating methods of the annotation.
"""
origin = _get_origin(annotation)
sequence: Any
if isinstance(origin, type) and issubclass(origin, Sequence):
sequence = origin
type_ = annotation.__args__[0]
else:
sequence = None
type_ = annotation
if {getattr(sequence, "__module__", None), getattr(type_, "__module__", None)} < _BUILTINS:
return _builtin_operator, _builtin_operator
if sequence is None:
return type_.loads, _attr_dumper
return (
lambda contents: sequence(type_.loads(content) for content in contents),
_attr_list_dumper,
)
def _builtin_operator(contents: _T) -> _T:
return contents
def _attr_dumper(attr_: _A) -> Any:
return attr_.dumps()
def _attr_list_dumper(attrs: List[_A]) -> List[Any]:
return [attr_.dumps() for attr_ in attrs]