"""this module contains the cachable task implementation"""
import hashlib
import inspect
import pickle
from abc import ABC
from dataclasses import dataclass
from typing import Dict, Optional
from mosaic_orchestrator.cache import Cache
from mosaic_orchestrator.result_container import ResultContainer
from mosaic_orchestrator.task import (
Task,
Input,
Output,
Hashable,
is_calculated_or_mutable_or_const,
_get_cache_dir,
)
from mosaic_orchestrator.tree import Path
from mosaic_orchestrator.tool import Tool
from mosaic_orchestrator.pdk import PdkItem
from mosaic_orchestrator.work_products import WorkProduct
from mosaic_orchestrator.asset import Asset
[docs]@dataclass
class CacheResult:
"""builds the result tree of a task tree, used to evaluate the task cache"""
result: object
sub_results: Dict[Path, str]
[docs]class CachableTask(Task, ABC):
"""Results of `CachableTasks` are cached persistently and reused when all inputs match.
The return value of the run() method is considered to be the result of a task. The complete
state of a task, including all member variables, is considered when looking up previous results.
In case of a cache-hit the run() method is not executed and the cached result is returned
instead. To implement custom invalidation logic (e.g. when the result of a cache is representing
a file that is susceptible to change between executions) the invalidate() method can be
overridden. When fine-grained control of the state representation of a class is needed the
`Hashable` protocol can be implemented. It is used whenever it is present in a member variable
of a task. To ignore a input for caching set its `transient` flag to `True`.
"""
def __init__(self, inputs: Dict[str, object] = None):
"""Constructor method
Args:
inputs (:obj:`dict`, optional): a dictionary of inputs that are set from a
parent task.
"""
super().__init__(inputs)
self.is_invalid = False # at default the cache is valid
self.task_result = None
self.task_hash = None
self.cache_used = False
[docs] def check_is_cache_valid(self, result: ResultContainer) -> bool:
"""This method is called when evaluating cached results.
Overwrite this method to implement custom result invalidation.
Args:
result: the cached result object.
Returns:
True if the result is still valid, in this case it will be reused and the task will not
be executed again. When False is returned the cached result is deleted and the task is
executed again. The default implementation returns True without considering the result.
"""
if isinstance(result, WorkProduct):
try:
return result._hash == result.calc_hash(self)
except Exception as e:
self.log.warn(
f"Cache is invalid because an exception occured during cache validation: {e}")
return False
return True
def _calc_hash_of_input(self, name: str, input: any) -> any:
if isinstance(input, Tool):
return None # no need to cache the tools
if isinstance(input, PdkItem):
if "__getstate__" in input.__dict__:
return input.name, input.__getstate__()
# only use name and version as fallback if __getstate__ is not implemented
return input.name, input.version
if isinstance(input, Output):
return None # do use outputs for hashing
if isinstance(input, Input) and input.transient:
return None # no need to cache transient inputs
if name in [
"_all_inputs",
"_special_inputs",
"_normal_inputs",
"run",
"_old_run",
"_old_root_run",
"_path",
"cache_used",
"_mosaic_logger",
"_mosaic",
"_calculated_default_values",
"task_result",
"task_hash",
"is_invalid",
"run_id",
"_calculated_inputs",
"measure_execution_time",
"track_input",
"_runtime_evaluation_functions",
]:
return None
if isinstance(input, Task) or is_calculated_or_mutable_or_const(input):
return None
# Before running a task we need to update the hashes of all (input) workproducts
if hasattr(input, "value") and isinstance(input.value, WorkProduct):
wp_hash = input.value.calc_hash(self)
input.value._hash = wp_hash
return f"work_product: {name}", wp_hash
if isinstance(input, Hashable):
try:
input_hash = input.task_hash()
except Exception as error:
raise AttributeError(
f"Failed to calculated task_hash of '{name}' in task '{self.path}' - {error}"
) from error
_try_pickle(input_hash, self.path, name)
return f"hashable: {name}", input_hash
# default:
_try_pickle(input, self.path, name)
return f"field: {name}", input
[docs] def calc_hash(self, inputs: Optional, *args, **kwargs) -> str:
"""Calculates a hash for this task based on the runtime arguments.
Overriding this method is discouraged as it can have unexpected results when not done
properly.
"""
# Always include the pdk hash
obj = [self._calc_hash_of_input("PDK", self._mosaic.pdk)]
for name, input in vars(self).items():
if name == "_relevant_inputs":
for i, val in enumerate(input):
hash = self._calc_hash_of_input(f"{name}[{i}]", val)
if hash is not None:
obj.append(hash)
else:
hash = self._calc_hash_of_input(name, input)
if hash is not None:
obj.append(hash)
# Assets are not instance members! Make sure to include them in hash
# calculation
for name, input in vars(type(self)).items():
if isinstance(input, Asset):
hash = self._calc_hash_of_input(name, input)
if hash is not None:
obj.append(hash)
if inputs:
_try_pickle(
inputs,
message=f"Caching failed: could not pickle inputs of task {self.path}",
)
obj.append(inputs)
if kwargs:
_try_pickle(
kwargs,
message=f"Caching failed: could not pickle custom arguments {kwargs} of task {self.path}",
)
obj.append(kwargs)
if args:
_try_pickle(
args,
message=f"Caching failed: could not pickle custom arguments {args} of task {self.path}",
)
obj.append(args)
try:
source = ""
if type(self).__module__ != "__main__":
source = inspect.getsource(type(self))
object_string = f"{pickle.dumps(obj)}_{source}"
calculated_hash = hashlib.md5(object_string.encode())
result = f"{self.__class__.__name__}_{calculated_hash.hexdigest()}"
return result
except Exception as error:
raise AttributeError(
f"{error} - Task inputs and results must be serializable using pickle package"
) from error
def _is_task_invalid(self, result: CacheResult, cache: Cache) -> bool:
"""Check if the cache of this task is invalid.
It is invalid if any of the subtasks are invalid (e.g. an invalid result).
Args:
result (CacheResult): Result from the cache
cache (Cache): Cache
Returns:
bool: True if it is invalid
"""
if self.is_invalid:
return True
task_invalid = False
subtasks = self.get_fields_of_type(CachableTask)
for subtask in subtasks:
if subtask.is_invalid:
task_invalid = True
break
relative_path = subtask.path.remove(self.path)
if relative_path not in result.sub_results:
task_invalid = True
break
subtask_hash = result.sub_results[relative_path]
subtask_result = cache.get(subtask_hash)
if not subtask_result:
task_invalid = True
break
if subtask._is_task_invalid(
subtask_result, cache
): # pylint: disable=protected-access
task_invalid = True
if task_invalid:
self.is_invalid = True
return True
if not self.check_is_cache_valid(result.result):
self.is_invalid = True
return True
return False
def _cache(self, value):
with Cache(_get_cache_dir(self)) as cache:
try:
cache.put(self.task_hash, value)
except Exception as error:
raise AttributeError(
f"{error} - Task inputs and results must be serializable using pickle package"
) from error
def _try_pickle(value, path=None, field=None, message=None):
"""try pickle given value, raises a AttributeError if it not pickleable.
optional path, field and message can be used to specify the error further"""
try:
pickle.dumps(value)
except Exception as exc:
if message:
raise AttributeError(
f"Caching failed: could not pickle field '{field}' of task '{path}'"
) from exc
raise AttributeError(
f"Caching failed: could not pickle field '{field}' of task '{path}'"
) from exc