import abc
import collections
import copy
import enum
import re
from eth_utils import (
to_dict,
to_set,
to_tuple,
)
from rlp.exceptions import (
ListSerializationError,
ObjectSerializationError,
ListDeserializationError,
ObjectDeserializationError,
)
from .lists import (
List,
)
class MetaBase:
fields = None
field_names = None
field_attrs = None
sedes = None
def _get_duplicates(values):
counts = collections.Counter(values)
return tuple(
item
for item, num in counts.items()
if num > 1
)
def validate_args_and_kwargs(args, kwargs, arg_names, allow_missing=False):
duplicate_arg_names = _get_duplicates(arg_names)
if duplicate_arg_names:
raise TypeError("Duplicate argument names: {0}".format(sorted(duplicate_arg_names)))
needed_kwargs = arg_names[len(args):]
used_kwargs = set(arg_names[:len(args)])
duplicate_kwargs = used_kwargs.intersection(kwargs.keys())
if duplicate_kwargs:
raise TypeError("Duplicate kwargs: {0}".format(sorted(duplicate_kwargs)))
unknown_kwargs = set(kwargs.keys()).difference(arg_names)
if unknown_kwargs:
raise TypeError("Unknown kwargs: {0}".format(sorted(unknown_kwargs)))
missing_kwargs = set(needed_kwargs).difference(kwargs.keys())
if not allow_missing and missing_kwargs:
raise TypeError("Missing kwargs: {0}".format(sorted(missing_kwargs)))
@to_tuple
def merge_kwargs_to_args(args, kwargs, arg_names, allow_missing=False):
validate_args_and_kwargs(args, kwargs, arg_names, allow_missing=allow_missing)
needed_kwargs = arg_names[len(args):]
yield from args
for arg_name in needed_kwargs:
yield kwargs[arg_name]
@to_dict
def merge_args_to_kwargs(args, kwargs, arg_names, allow_missing=False):
validate_args_and_kwargs(args, kwargs, arg_names, allow_missing=allow_missing)
yield from kwargs.items()
for value, name in zip(args, arg_names):
yield name, value
def _eq(left, right):
"""
Equality comparison that allows for equality between tuple and list types
with equivalent elements.
"""
if isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
return len(left) == len(right) and all(_eq(*pair) for pair in zip(left, right))
else:
return left == right
class ChangesetState(enum.Enum):
INITIALIZED = 'INITIALIZED'
OPEN = 'OPEN'
CLOSED = 'CLOSED'
class ChangesetField:
field = None
def __init__(self, field):
self.field = field
def __get__(self, instance, type=None):
if instance is None:
return self
elif instance.__state__ is not ChangesetState.OPEN:
raise AttributeError("Changeset is not active. Attribute access not allowed")
else:
try:
return instance.__diff__[self.field]
except KeyError:
return getattr(instance.__original__, self.field)
def __set__(self, instance, value):
if instance.__state__ is not ChangesetState.OPEN:
raise AttributeError("Changeset is not active. Attribute access not allowed")
instance.__diff__[self.field] = value
class BaseChangeset:
# reference to the original Serializable instance.
__original__ = None
# the state of this fieldset. Initialized -> Open -> Closed
__state__ = None
# the field changes that have been made in this change
__diff__ = None
def __init__(self, obj, changes=None):
self.__original__ = obj
self.__state__ = ChangesetState.INITIALIZED
self.__diff__ = changes or {}
def commit(self):
obj = self.build_rlp()
self.close()
return obj
def build_rlp(self):
if self.__state__ == ChangesetState.OPEN:
field_kwargs = {
name: self.__diff__.get(name, self.__original__[name])
for name
in self.__original__._meta.field_names
}
return type(self.__original__)(**field_kwargs)
else:
raise ValueError("Cannot open Changeset which is not in the OPEN state")
def open(self):
if self.__state__ == ChangesetState.INITIALIZED:
self.__state__ = ChangesetState.OPEN
else:
raise ValueError("Cannot open Changeset which is not in the INITIALIZED state")
def close(self):
if self.__state__ == ChangesetState.OPEN:
self.__state__ = ChangesetState.CLOSED
else:
raise ValueError("Cannot close Changeset which is not in the OPEN state")
def __enter__(self):
if self.__state__ == ChangesetState.INITIALIZED:
self.open()
return self
else:
raise ValueError("Cannot open Changeset which is not in the INITIALIZED state")
def __exit__(self, exc_type, exc_value, traceback):
if self.__state__ == ChangesetState.OPEN:
self.close()
def Changeset(obj, changes):
namespace = {
name: ChangesetField(name)
for name
in obj._meta.field_names
}
cls = type(
"{0}Changeset".format(obj.__class__.__name__),
(BaseChangeset,),
namespace,
)
return cls(obj, changes)
class BaseSerializable(collections.abc.Sequence):
def __init__(self, *args, **kwargs):
if kwargs:
field_values = merge_kwargs_to_args(args, kwargs, self._meta.field_names)
else:
field_values = args
if len(field_values) != len(self._meta.field_names):
raise TypeError(
'Argument count mismatch. expected {0} - got {1} - missing {2}'.format(
len(self._meta.field_names),
len(field_values),
','.join(self._meta.field_names[len(field_values):]),
)
)
for value, attr in zip(field_values, self._meta.field_attrs):
setattr(self, attr, make_immutable(value))
_cached_rlp = None
def as_dict(self):
return dict(
(field, value)
for field, value
in zip(self._meta.field_names, self)
)
def __iter__(self):
for attr in self._meta.field_attrs:
yield getattr(self, attr)
def __getitem__(self, idx):
if isinstance(idx, int):
attr = self._meta.field_attrs[idx]
return getattr(self, attr)
elif isinstance(idx, slice):
field_slice = self._meta.field_attrs[idx]
return tuple(getattr(self, field) for field in field_slice)
elif isinstance(idx, str):
return getattr(self, idx)
else:
raise IndexError("Unsupported type for __getitem__: {0}".format(type(idx)))
def __len__(self):
return len(self._meta.fields)
def __eq__(self, other):
return isinstance(other, Serializable) and hash(self) == hash(other)
def __getstate__(self):
state = self.__dict__.copy()
# The hash() builtin is not stable across processes
# (https://docs.python.org/3/reference/datamodel.html#object.__hash__), so we do this here
# to ensure pickled instances don't carry the cached hash() as that may cause issues like
# https://github.com/ethereum/py-evm/issues/1318
state['_hash_cache'] = None
return state
_hash_cache = None
def __hash__(self):
if self._hash_cache is None:
self._hash_cache = hash(tuple(self))
return self._hash_cache
def __repr__(self):
keyword_args = tuple("{}={!r}".format(k, v) for k, v in self.as_dict().items())
return "{}({})".format(
type(self).__name__,
", ".join(keyword_args),
)
@classmethod
def serialize(cls, obj):
try:
return cls._meta.sedes.serialize(obj)
except ListSerializationError as e:
raise ObjectSerializationError(obj=obj, sedes=cls, list_exception=e)
@classmethod
def deserialize(cls, serial, **extra_kwargs):
try:
values = cls._meta.sedes.deserialize(serial)
except ListDeserializationError as e:
raise ObjectDeserializationError(serial=serial, sedes=cls, list_exception=e)
args_as_kwargs = merge_args_to_kwargs(values, {}, cls._meta.field_names)
return cls(**args_as_kwargs, **extra_kwargs)
def copy(self, *args, **kwargs):
missing_overrides = set(
self._meta.field_names
).difference(
kwargs.keys()
).difference(
self._meta.field_names[:len(args)]
)
unchanged_kwargs = {
key: copy.deepcopy(value)
for key, value
in self.as_dict().items()
if key in missing_overrides
}
combined_kwargs = dict(**unchanged_kwargs, **kwargs)
all_kwargs = merge_args_to_kwargs(args, combined_kwargs, self._meta.field_names)
return type(self)(**all_kwargs)
def __copy__(self):
return self.copy()
def __deepcopy__(self, *args):
return self.copy()
_in_mutable_context = False
def build_changeset(self, *args, **kwargs):
args_as_kwargs = merge_args_to_kwargs(
args,
kwargs,
self._meta.field_names,
allow_missing=True,
)
return Changeset(self, changes=args_as_kwargs)
def make_immutable(value):
if isinstance(value, list):
return tuple(make_immutable(item) for item in value)
else:
return value
@to_tuple
def _mk_field_attrs(field_names, extra_namespace):
namespace = set(field_names).union(extra_namespace)
for field in field_names:
while True:
field = '_' + field
if field not in namespace:
namespace.add(field)
yield field
break
def _mk_field_property(field, attr):
def field_fn_getter(self):
return getattr(self, attr)
def field_fn_setter(self, value):
if not self._in_mutable_context:
raise AttributeError("can't set attribute")
setattr(self, attr, value)
return property(field_fn_getter, field_fn_setter)
IDENTIFIER_REGEX = re.compile(r"^[^\d\W]\w*\Z", re.UNICODE)
def _is_valid_identifier(value):
# Source: https://stackoverflow.com/questions/5474008/regular-expression-to-confirm-whether-a-string-is-a-valid-identifier-in-python # noqa: E501
if not isinstance(value, str):
return False
return bool(IDENTIFIER_REGEX.match(value))
@to_set
def _get_class_namespace(cls):
if hasattr(cls, '__dict__'):
yield from cls.__dict__.keys()
if hasattr(cls, '__slots__'):
yield from cls.__slots__
class SerializableBase(abc.ABCMeta):
def __new__(cls, name, bases, attrs):
super_new = super(SerializableBase, cls).__new__
serializable_bases = tuple(b for b in bases if isinstance(b, SerializableBase))
has_multiple_serializable_parents = len(serializable_bases) > 1
is_serializable_subclass = any(serializable_bases)
declares_fields = 'fields' in attrs
if not is_serializable_subclass:
# If this is the original creation of the `Serializable` class,
# just create the class.
return super_new(cls, name, bases, attrs)
elif not declares_fields:
if has_multiple_serializable_parents:
raise TypeError(
"Cannot create subclass from multiple parent `Serializable` "
"classes without explicit `fields` declaration."
)
else:
# This is just a vanilla subclass of a `Serializable` parent class.
parent_serializable = serializable_bases[0]
if hasattr(parent_serializable, '_meta'):
fields = parent_serializable._meta.fields
else:
# This is a subclass of `Serializable` which has no
# `fields`, likely intended for further subclassing.
fields = ()
else:
# ensure that the `fields` property is a tuple of tuples to ensure
# immutability.
fields = tuple(tuple(field) for field in attrs.pop('fields'))
# split the fields into names and sedes
if fields:
field_names, sedes = zip(*fields)
else:
field_names, sedes = (), ()
# check that field names are unique
duplicate_field_names = _get_duplicates(field_names)
if duplicate_field_names:
raise TypeError(
"The following fields are duplicated in the `fields` "
"declaration: "
"{0}".format(",".join(sorted(duplicate_field_names)))
)
# check that field names are valid identifiers
invalid_field_names = {
field_name
for field_name
in field_names
if not _is_valid_identifier(field_name)
}
if invalid_field_names:
raise TypeError(
"The following field names are not valid python identifiers: {0}".format(
",".join("`{0}`".format(item) for item in sorted(invalid_field_names))
)
)
# extract all of the fields from parent `Serializable` classes.
parent_field_names = {
field_name
for base in serializable_bases if hasattr(base, '_meta')
for field_name in base._meta.field_names
}
# check that all fields from parent serializable classes are
# represented on this class.
missing_fields = parent_field_names.difference(field_names)
if missing_fields:
raise TypeError(
"Subclasses of `Serializable` **must** contain a full superset "
"of the fields defined in their parent classes. The following "
"fields are missing: "
"{0}".format(",".join(sorted(missing_fields)))
)
# the actual field values are stored in separate *private* attributes.
# This computes attribute names that don't conflict with other
# attributes already present on the class.
reserved_namespace = set(attrs.keys()).union(
attr
for base in bases
for parent_cls in base.__mro__
for attr in _get_class_namespace(parent_cls)
)
field_attrs = _mk_field_attrs(field_names, reserved_namespace)
# construct the Meta object to store field information for the class
meta_namespace = {
'fields': fields,
'field_attrs': field_attrs,
'field_names': field_names,
'sedes': List(sedes),
}
meta_base = attrs.pop('_meta', MetaBase)
meta = type(
'Meta',
(meta_base,),
meta_namespace,
)
attrs['_meta'] = meta
# construct `property` attributes for read only access to the fields.
field_props = tuple(
(field, _mk_field_property(field, attr))
for field, attr
in zip(meta.field_names, meta.field_attrs)
)
return super_new(
cls,
name,
bases,
dict(
field_props +
tuple(attrs.items())
),
)
[docs]class Serializable(BaseSerializable, metaclass=SerializableBase):
"""
The base class for serializable objects.
"""
pass