Source code for endf_parserpy.utils.debugging_utils

############################################################
#
# Author(s):       Georg Schnabel
# Email:           g.schnabel@iaea.org
# Creation date:   2022/09/09
# Last modified:   2025/06/02
# License:         MIT
# Copyright (c) 2022-2024 International Atomic Energy Agency (IAEA)
#
############################################################

from collections.abc import MutableMapping, MutableSequence
from endf_parserpy.utils.math_utils import math_allclose
from endf_parserpy.utils.accessories import EndfDict
from .math_utils import EndfFloat


def smart_is_equal(x, y, atol=1e-8, rtol=1e-6):
    if type(x) != type(y):
        return False
    elif isinstance(x, float):
        return math_allclose(x, y, atol=atol, rtol=rtol)
    elif isinstance(x, int):
        return x == y
    else:
        return x == y


[docs] def compare_objects( obj1, obj2, atol=1e-8, rtol=1e-6, strlen_only=False, do_rstrip=False, rstrcut=None, fail_on_diff=True, diff_log=None, ): """Compare recursively two objects. This function enables the recursive comparison of two objects possible being or containing objects of type :class:`dict` or iterable array-like objects. For example, this class can be used to confirm or reject the equality of two nested dictionaries resulting from the parsing of ENDF-6 files via the :func:`~endf_parserpy.EndfParserPy.parsefile` method of the :class:`~endf_parserpy.EndfParserPy` class. The function can print out meaningful information where the discrepancies are present in the objects with a nested structure. Parameters ---------- obj1 : object Any kind of object but usually it will be a nested :class:`dict`-like structure. obj2 : object Any kind of object but usually it will be a nested :class:`dict`-like structure. atol : float The absolute tolerance for the comparison of two :class:`float` variables. rtol : float The relative tolerance for the comparison of two :class:`float` variables. strlen_only : bool If ``True``, only compare the lengths of strings, otherwise also the content of the strings is considered in the comparison. do_rstrip : bool If ``True``, strip whitespace characters at the end of the strings before comparison. rstrcut : Union[None, int] If an integer is provided, only retain the first ``rstrcut`` characters of the strings in the comparison. If ``None``, strings are compared as they are. fail_on_diff : bool If ``True``, this function will raise an exception at the first encounter of a difference. Otherwise, the function will fully compare the objects and return ``True`` if the two objects are equal and ``False`` if they exhibit differences. The second option is mostly useful in combination with ``diff_log=True``. diff_log : Union[None, List] A :class:`list` object can be passed which will be filled with strings that indicate the differences found. This option is only useful in combination with ``fail_on_diff=false``. """ if isinstance(obj1, EndfDict): obj1 = obj1.unwrap() if isinstance(obj2, EndfDict): obj2 = obj2.unwrap() return _compare_objects( obj1, obj2, curpath="", atol=atol, rtol=rtol, strlen_only=strlen_only, do_rstrip=do_rstrip, rstrcut=rstrcut, fail_on_diff=fail_on_diff, diff_log=diff_log, )
def _compare_objects( obj1, obj2, curpath="", atol=1e-8, rtol=1e-6, strlen_only=False, do_rstrip=False, rstrcut=None, fail_on_diff=True, diff_log=None, ): if diff_log is None: diff_log = [] found_diff = False def treat_diff(msg, exc): nonlocal found_diff found_diff = True if fail_on_diff: raise exc(msg) else: diff_log.append(msg) print(msg) if isinstance(obj1, EndfFloat): obj1 = float(obj1) if isinstance(obj2, EndfFloat): obj2 = float(obj2) if type(obj1) != type(obj2): treat_diff( f"at path {curpath}: " + f"type mismatch found, obj1: {obj1}, obj2: {obj2}", TypeError, ) elif isinstance(obj1, dict): only_in_obj1 = set(obj1).difference(obj2) if len(only_in_obj1) > 0: treat_diff( f"at path {curpath}: only obj1 contains {only_in_obj1}", IndexError ) only_in_obj2 = set(obj2).difference(obj1) if len(only_in_obj2) > 0: treat_diff( f"at path {curpath}: only obj2 contains {only_in_obj2}", IndexError ) common_keys = set(obj1).intersection(set(obj2)) common_int_keys = [k for k in common_keys if isinstance(k, int)] common_nonint_keys = [k for k in common_keys if not isinstance(k, int)] common_int_keys.sort() common_nonint_keys.sort() common_keys = common_nonint_keys + common_int_keys for key in common_keys: ret = _compare_objects( obj1[key], obj2[key], "/".join((curpath, str(key))), atol=atol, rtol=rtol, strlen_only=strlen_only, do_rstrip=do_rstrip, rstrcut=rstrcut, fail_on_diff=fail_on_diff, diff_log=diff_log, ) found_diff = found_diff or not ret else: if isinstance(obj1, str): if do_rstrip: obj1 = obj1.rstrip() obj2 = obj2.rstrip() if rstrcut is not None: obj1 = obj1[:rstrcut] obj2 = obj2[:rstrcut] if strlen_only: if len(obj1) != len(obj2): treat_diff( f"at path {curpath}: string lengths differ " f"({obj1} != {obj2})", ValueError, ) elif obj1 != obj2: treat_diff( f"at path {curpath}: strings differ " f"({obj1} != {obj2})", ValueError, ) elif hasattr(obj1, "__iter__"): len_obj1 = len(tuple(obj1)) len_obj2 = len(tuple(obj2)) if len_obj1 != len_obj2: treat_diff( f"Length mismatch at {curpath} " f"({len_obj1} vs {len_obj2})", ValueError, ) for i, (subel1, subel2) in enumerate(zip(obj1, obj2)): ret = _compare_objects( subel1, subel2, f"{curpath}[{str(i)}]", atol=atol, rtol=rtol, strlen_only=strlen_only, do_rstrip=do_rstrip, rstrcut=rstrcut, fail_on_diff=fail_on_diff, diff_log=diff_log, ) found_diff = found_diff or not ret else: if not smart_is_equal(obj1, obj2, atol=atol, rtol=rtol): treat_diff( f"Value mismatch at {curpath} " f"({obj1} vs {obj2})", ValueError ) # return True if equivalent return not found_diff
[docs] class TrackingDict(MutableMapping): """Class for tracking read access of elements in :class:`dict`-like objects. This class implements an interface to :class:`dict`-like objects for the purpose of tracking keys whose associated elements were retrieved. This tracking is applied recursively, hence also elements of :class:`dict`-like and :class:`list`-like objects stored within the root :class:`dict`-like object are potentially tracked. Not all keys are tracked, though. Read access to a key is only tracked if the following two criteria are met: - The key is an integer, i.e. of type :class:`int` - Elements within :class:`dict`-like`` objects are never tracked if the :class:`dict`-like object itself is stored under a key that starts with two underscores (``__``). - If an object is :class:`list`-like, it's elements are tracked. The first criteria are owed to the mode of operation of the :class:`~endf_parserpy.EndfParserPy` class. The methods :func:`~endf_parserpy.EndfParserPy.parsefile` and :func:`~endf_parserpy.EndfParserPy.writefile` of the :class:`~endf_parserpy.EndfParserPy` class will temporarily create auxiliary variables stored under keys starting with two underscores. It is not pertinent to track read access to those ephemeral objects. The purpose of the :class:`~endf_parserpy.utils.debugging_utils.TrackingDict` class--- when it comes to writing ENDF-6 formatted data---is to ensure that all elements in arrays (emulated with :class:`dict`-like objects containing only integer keys) are accessed. Otherwise, it means that some elements have not been written to the ENDF-6 file and this situation indicates an inconsistency between counter variables and the index ranges of arrays. """ def __init__(self, dict_like): """Initialize a ``TrackingDict`` object. Parameters __________ dict_like : dict The :class:`dict`-like object for which read access should be tracked. """ self._basedict = dict_like self._trackingobjs = {} self._accessed = set() def _should_track(self, key, obj): return not str(key).startswith("__") and isinstance( obj, (MutableMapping, MutableSequence) ) def _is_tracked(self, key): return key in self._trackingobjs def _create_trackobj(self, obj): if isinstance(obj, MutableMapping): return TrackingDict(obj) elif isinstance(obj, MutableSequence): return TrackingList(obj) else: raise TypeError(f"This object of type {type(obj)} cannot be tracked.") def __getitem__(self, key): retval = self._basedict.__getitem__(key) if isinstance(key, int): self._accessed.add(key) if self._should_track(key, retval) and not self._is_tracked(key): self._trackingobjs[key] = self._create_trackobj(retval) retval = self._trackingobjs.get(key, retval) return retval def __setitem__(self, key, value): if key in self._accessed: self._accessed.remove(key) if key in self._trackingobjs: self._trackingobjs.__delitem__(key) return self._basedict.__setitem__(key, value) def __delitem__(self, key): if key in self._accessed: self._accessed.__delitem__(key) if key in self._trackingobjs: self._trackingobjs.__delitem__(key) return self._basedict.__delitem__(key) def __iter__(self): return self._basedict.__iter__() def __len__(self): return self._basedict.__len__() def _verify_complete_retrieval(self, path=""): if len(self._accessed) > 0: for k in self._basedict: if isinstance(k, int) and k not in self._accessed: indexpath = path + "/" + str(k) raise IndexError(f"The content of {indexpath} was not accessed") for k in self._basedict: if self._is_tracked(k): indexpath = path + "/" + str(k) curval = self._trackingobjs[k] curval._verify_complete_retrieval(indexpath)
[docs] def verify_complete_retrieval(self): """Verify that all array elements have been accessed. This function will raise an ``IndexError`` exception if there are :class:`dict`-like objects where at least one key of type :class:`int` has been accessed but more keys exist that have not been accessed. """ self._verify_complete_retrieval()
def unwrap(self): return self._basedict
[docs] class TrackingList(MutableSequence): def __init__(self, list_like): """Initialize a ``TrackingList`` object. Parameters __________ list_like : list The :class:`list`-like object for which read access should be tracked. """ self._baselist = list_like self._trackingobjs = [None] * len(list_like) self._accessed = [False] * len(list_like) def __getitem__(self, key): retval = self._baselist.__getitem__(key) self._accessed[key] = True if isinstance(retval, MutableSequence): if key not in self._trackingobjs: self._trackingobjs[key] = TrackingList(retval) retval = self._trackingobjs[key] elif isinstance(retval, MutableMapping): if key not in self._trackingobjs: self._trackingobjs[key] = TrackingDict(retval) retval = self._trackingobjs[key] return retval def __setitem__(self, key, value): retval = self._baselist.__setitem__(key, value) self._accessed[key] = False self._trackinglists[key] = None return retval def __delitem__(self, key): retval = self._baselists.__delitem__(key) self._accessed.__delitem__(key) self._trackinglists.__delitem__(key) return retval def __iter__(self): return self._baselist.__iter__() def __len__(self): return self._baselist.__len__()
[docs] def insert(self, key, value): retval = self._baselist.insert(key, value) self._accessed.insert(key, False) self._trackinglists.insert(key, None) return retval
def _verify_complete_retrieval(self, path=""): if any(self._accessed): for k in range(len(self._baselist)): if not self._accessed[k]: indexpath = path + "/" + str(k) raise IndexError(f"The content of {indexpath} was not accessed") for k, curval in enumerate(self._baselist): if isinstance(curval, (TrackingList, TrackingDict)): indexpath = path + "/" + str(k) curval._verify_complete_retrieval(indexpath)
[docs] def verify_complete_retrieval(self): """Verify that all array elements have been accessed. This function will raise an ``IndexError`` exception if not all elements of :class:`list`-like objects have been accessed. """ self._verify_complete_retrieval()
def unwrap(self): return self._baselist