Source code for mosaic_orchestrator.cachable_task

"""this module contains the cachable task implementation"""
import hashlib
import inspect
import pickle
from abc import ABC
from dataclasses import dataclass
from typing import Dict, Optional

from mosaic_orchestrator.cache import Cache
from mosaic_orchestrator.result_container import ResultContainer
from mosaic_orchestrator.task import (
    Task,
    Input,
    Output,
    Hashable,
    is_calculated_or_mutable_or_const,
    _get_cache_dir,
)
from mosaic_orchestrator.tree import Path
from mosaic_orchestrator.tool import Tool
from mosaic_orchestrator.pdk import PdkItem
from mosaic_orchestrator.work_products import WorkProduct
from mosaic_orchestrator.asset import Asset


[docs]@dataclass class CacheResult: """builds the result tree of a task tree, used to evaluate the task cache""" result: object sub_results: Dict[Path, str]
[docs]class CachableTask(Task, ABC): """Results of `CachableTasks` are cached persistently and reused when all inputs match. The return value of the run() method is considered to be the result of a task. The complete state of a task, including all member variables, is considered when looking up previous results. In case of a cache-hit the run() method is not executed and the cached result is returned instead. To implement custom invalidation logic (e.g. when the result of a cache is representing a file that is susceptible to change between executions) the invalidate() method can be overridden. When fine-grained control of the state representation of a class is needed the `Hashable` protocol can be implemented. It is used whenever it is present in a member variable of a task. To ignore a input for caching set its `transient` flag to `True`. """ def __init__(self, inputs: Dict[str, object] = None): """Constructor method Args: inputs (:obj:`dict`, optional): a dictionary of inputs that are set from a parent task. """ super().__init__(inputs) self.is_invalid = False # at default the cache is valid self.task_result = None self.task_hash = None self.cache_used = False
[docs] def check_is_cache_valid(self, result: ResultContainer) -> bool: """This method is called when evaluating cached results. Overwrite this method to implement custom result invalidation. Args: result: the cached result object. Returns: True if the result is still valid, in this case it will be reused and the task will not be executed again. When False is returned the cached result is deleted and the task is executed again. The default implementation returns True without considering the result. """ if isinstance(result, WorkProduct): try: return result._hash == result.calc_hash(self) except Exception as e: self.log.warn( f"Cache is invalid because an exception occured during cache validation: {e}") return False return True
def _calc_hash_of_input(self, name: str, input: any) -> any: if isinstance(input, Tool): return None # no need to cache the tools if isinstance(input, PdkItem): if "__getstate__" in input.__dict__: return input.name, input.__getstate__() # only use name and version as fallback if __getstate__ is not implemented return input.name, input.version if isinstance(input, Output): return None # do use outputs for hashing if isinstance(input, Input) and input.transient: return None # no need to cache transient inputs if name in [ "_all_inputs", "_special_inputs", "_normal_inputs", "run", "_old_run", "_old_root_run", "_path", "cache_used", "_mosaic_logger", "_mosaic", "_calculated_default_values", "task_result", "task_hash", "is_invalid", "run_id", "_calculated_inputs", "measure_execution_time", "track_input", "_runtime_evaluation_functions", ]: return None if isinstance(input, Task) or is_calculated_or_mutable_or_const(input): return None # Before running a task we need to update the hashes of all (input) workproducts if hasattr(input, "value") and isinstance(input.value, WorkProduct): wp_hash = input.value.calc_hash(self) input.value._hash = wp_hash return f"work_product: {name}", wp_hash if isinstance(input, Hashable): try: input_hash = input.task_hash() except Exception as error: raise AttributeError( f"Failed to calculated task_hash of '{name}' in task '{self.path}' - {error}" ) from error _try_pickle(input_hash, self.path, name) return f"hashable: {name}", input_hash # default: _try_pickle(input, self.path, name) return f"field: {name}", input
[docs] def calc_hash(self, inputs: Optional, *args, **kwargs) -> str: """Calculates a hash for this task based on the runtime arguments. Overriding this method is discouraged as it can have unexpected results when not done properly. """ # Always include the pdk hash obj = [self._calc_hash_of_input("PDK", self._mosaic.pdk)] for name, input in vars(self).items(): if name == "_relevant_inputs": for i, val in enumerate(input): hash = self._calc_hash_of_input(f"{name}[{i}]", val) if hash is not None: obj.append(hash) else: hash = self._calc_hash_of_input(name, input) if hash is not None: obj.append(hash) # Assets are not instance members! Make sure to include them in hash # calculation for name, input in vars(type(self)).items(): if isinstance(input, Asset): hash = self._calc_hash_of_input(name, input) if hash is not None: obj.append(hash) if inputs: _try_pickle( inputs, message=f"Caching failed: could not pickle inputs of task {self.path}", ) obj.append(inputs) if kwargs: _try_pickle( kwargs, message=f"Caching failed: could not pickle custom arguments {kwargs} of task {self.path}", ) obj.append(kwargs) if args: _try_pickle( args, message=f"Caching failed: could not pickle custom arguments {args} of task {self.path}", ) obj.append(args) try: source = "" if type(self).__module__ != "__main__": source = inspect.getsource(type(self)) object_string = f"{pickle.dumps(obj)}_{source}" calculated_hash = hashlib.md5(object_string.encode()) result = f"{self.__class__.__name__}_{calculated_hash.hexdigest()}" return result except Exception as error: raise AttributeError( f"{error} - Task inputs and results must be serializable using pickle package" ) from error
def _is_task_invalid(self, result: CacheResult, cache: Cache) -> bool: """Check if the cache of this task is invalid. It is invalid if any of the subtasks are invalid (e.g. an invalid result). Args: result (CacheResult): Result from the cache cache (Cache): Cache Returns: bool: True if it is invalid """ if self.is_invalid: return True task_invalid = False subtasks = self.get_fields_of_type(CachableTask) for subtask in subtasks: if subtask.is_invalid: task_invalid = True break relative_path = subtask.path.remove(self.path) if relative_path not in result.sub_results: task_invalid = True break subtask_hash = result.sub_results[relative_path] subtask_result = cache.get(subtask_hash) if not subtask_result: task_invalid = True break if subtask._is_task_invalid( subtask_result, cache ): # pylint: disable=protected-access task_invalid = True if task_invalid: self.is_invalid = True return True if not self.check_is_cache_valid(result.result): self.is_invalid = True return True return False def _cache(self, value): with Cache(_get_cache_dir(self)) as cache: try: cache.put(self.task_hash, value) except Exception as error: raise AttributeError( f"{error} - Task inputs and results must be serializable using pickle package" ) from error
def _try_pickle(value, path=None, field=None, message=None): """try pickle given value, raises a AttributeError if it not pickleable. optional path, field and message can be used to specify the error further""" try: pickle.dumps(value) except Exception as exc: if message: raise AttributeError( f"Caching failed: could not pickle field '{field}' of task '{path}'" ) from exc raise AttributeError( f"Caching failed: could not pickle field '{field}' of task '{path}'" ) from exc