"""This module contains helper functions for the mosaic framework"""
from contextlib import contextmanager
from datetime import datetime
import functools
import os
import pathlib
import types
import typing
from logging import Logger
import uuid
from mosaic_orchestrator import tree
from mosaic_orchestrator.cachable_task import CachableTask, CacheResult
from mosaic_orchestrator.cache import Cache
from mosaic_orchestrator.result import (
ValidationError,
Result,
ErrorType,
ErrorDefinition,
)
from mosaic_orchestrator.mosaic_exception import MosaicException
from mosaic_orchestrator.task import (
TaskNode,
Task,
Input,
InputException,
Calculated,
TaskState,
is_calculated,
is_calculated_or_mutable_or_const,
is_mutable,
CalculatedAccessException,
is_const,
Const,
_is_task_input,
get_task_inputs,
_get_cache_dir,
)
from mosaic_orchestrator.tool import Tool
from mosaic_orchestrator.tree import Path, concat
from mosaic_orchestrator.utils import (
get_type_annotations,
get_all_inputs,
get_type_of_class_field,
is_function,
ListHandler,
)
[docs]class TaskTreeException(Exception):
"""raised if the task tree can not be constructed"""
[docs]def find_node(
task_tree: TaskNode, path: Path, check_input: bool = True
) -> Result[TaskNode]:
"""Find a node identified by its path in a trask tree
Args:
check_input: when True the path is expected to point to a Input which is then
checked to be valid.
Returns:
A result containing the TaskNode if it was found or an error if it was not found.
"""
current_node = task_tree
if path.length == 1 and check_input:
result = _check_input(path.target, task_tree.value)
if result.failure:
return result
for step in path.steps[:-1]:
children = current_node.children
step_list = list(
filter(lambda child: child.value.path.target == step, children)
) # pylint: disable=cell-var-from-loop
if not step_list:
return Result.Fail(f"'{step}' is no child of {type(current_node.value)}")
current_node = step_list[0]
return Result.Ok(current_node)
[docs]def find_task(task: Task, path: Path):
"""Get a subtask of the given task identified by the relative path
Args:
task: the task that is used as a root
path: the path to the target subtask. it must be relative to the given task and include it
as a root.
Returns:
the instance of the subtask.
"""
current = task
for step in path.steps[:-1]:
if not hasattr(current, step):
raise TaskTreeException(f"'{step}' is no child of {type(current)}")
current = getattr(current, step)
return current
def _validate_input(task: Task, input: Input):
"""calls Inputs validation lamda or validate function"""
try:
class_validate = input.validate()
if isinstance(class_validate, ValidationError):
return class_validate
if input.validate_func is not None:
constructor_validate = input.validate_func(task, input.value)
if isinstance(constructor_validate, ValidationError):
return constructor_validate
except Exception as e:
return ValidationError(f"Validation fatal error of {task} {input}: {e}")
return class_validate
def _check_if_injectable(task: Task, path: Path) -> Result:
"""add check error for injection of Calculated input. Also remember which Default inputs
are being injected."""
input = getattr(task, path.target)
if is_calculated(input):
return Result.Fail(
f"input {path} cannot be altered since it is set to Calculated"
)
if is_const(input):
return Result.Fail(f"input {path} cannot be altered since it is set to Const")
if is_mutable(input):
task._injected_mutables.append(path.target)
return Result.Ok()
def _set_input_or_wrap(instance: object, name: str, input) -> Result:
"""set a input to an instance, if the given input is not a Input it gets wrapped into it
cases:
Input from above Definition in class What to do
1) something with value -> type annotation => create new input of type with default
2) something with value -> no type annotation => copy value into existing Input
3) Calculated -> type annotation => create new input of type with default=None and mark _special_input
4) Calculated -> no type annotation => set value to None and mark _special_input in existing Input
5) Input with same type -> either case => replace completly
6) Input different type -> either case => Error
"""
# get type from type annotation or object
input_type = get_type_of_class_field(instance=instance, name=name)
origin = typing.get_origin(input_type) # get base type
if origin is not None:
input_type = origin
if isinstance(input, Input):
if isinstance(input, input_type):
task_input = input
else:
return Result.Fail(
f"input {input_type} not instantiable with {input} = {input.value} in {instance}"
)
elif not hasattr(instance, name):
# We encountered just a type annotation
try:
task_input = input_type(default=input)
except TypeError as error:
return Result.Fail(
f"input {input_type} not instantiable with {input} of type {type(instance)} - {error}"
)
if is_calculated_or_mutable_or_const(input):
task_input._special_input = input
else:
task_input = getattr(instance, name)
if is_calculated_or_mutable_or_const(input):
task_input._special_input = input
else: # callable or literal
task_input.value = input
setattr(instance, name, task_input)
return Result.Ok(value=task_input)
def _evaluate_callable_inputs(node: TaskNode):
"""evaluates callable inputs"""
_evaluate_constructor_lambdas(node)
_evaluate_default_lambdas(node)
_evaluate_task_states(node)
def _evaluate_task_states(node):
"""evaluates callable (lambda) value of a TaskState"""
for name, var in vars(node.task).items():
if isinstance(var, TaskState) and is_function(var.value):
try:
var.value = var.value(node.task)
except CalculatedAccessException:
node.task._runtime_evaluation_functions[name] = functools.partial(
var.value, node.task
) # pylint: disable=protected-access
var._special_input = Calculated
var.value = None
def _evaluate_default_lambdas(node):
"""evaluates callable (lambda) value of Inputs"""
for name, input in get_task_inputs(node.task).items():
if not is_calculated(input):
try:
if input and is_function(input.value):
if is_const(input):
input.value = input.value(node.parent)
else:
input.value = input.value(node.task)
except CalculatedAccessException:
# Evaluation not possible because a lambda is accessing a calculated input
# This means, we cannot calculate and validate it!
# This dependend input is therefore a calculated input as well
node.task._runtime_evaluation_functions[name] = functools.partial(
input.value, node.task
) # pylint: disable=protected-access
# turn it into a calculated input
input._special_input = Calculated
input.value = None
def _evaluate_constructor_lambdas(node):
"""evaluates callable (lambda) value of a constructor"""
for path, task in node.injected_inputs:
input = getattr(node.task, path.target)
if is_function(input.value):
try:
input.value = input.value(task)
except CalculatedAccessException:
node.task._runtime_evaluation_functions[
path.target
] = functools.partial(
input.value, task
) # pylint: disable=protected-access
input._special_input = Calculated
input.value = None
def _check_input(name, task: Task) -> Result:
"""Check if a given name is a valid input of a Task"""
annotations = get_type_annotations(type(task))
if name not in annotations.keys() and not hasattr(task, name):
return Result.Fail(f"'{name}' is no input of {type(task)}")
if hasattr(task, name) and isinstance(getattr(task, name), TaskState):
return Result.Fail(f"'{name}' of {type(task)} is a TaskState and not settable")
if (hasattr(task, name) and not isinstance(getattr(task, name), Input)) or (
name in annotations.keys() and not _is_task_input(annotations[name])
):
return Result.Fail(f"'{name}' of {type(task)} is not a Input or subclassing it")
return Result.Ok()
def _wrap_root(root: Task):
"""wrap the run method of a task"""
root._old_root_run = root.run # pylint: disable=protected-access
root.run = types.MethodType(_root_run_wrapper, root)
def _root_run_wrapper(self, *args, **kwargs):
"""wrapped run method, creates current working directories, calls
the run method and cleanup after run.
"""
self.mosaic._startup() # pylint: disable=protected-access
cwd = pathlib.Path.cwd()
try:
os.makedirs(
self.mosaic.run_directory, exist_ok=True
) # pylint: disable=protected-access
os.chdir(self.mosaic.run_directory) # pylint: disable=protected-access
new_args = dict(**kwargs)
result = self._old_root_run(
*args, **new_args
) # pylint: disable=protected-access
os.chdir(cwd)
self.mosaic._cleanup() # pylint: disable=protected-access
return result
except Exception as error:
os.chdir(cwd)
self.mosaic._cleanup(error) # pylint: disable=protected-access
raise error from error
def _evaluate_runtime_lambdas(task):
"""evaluate the runtime lambdas of a task"""
for name, function in task._runtime_evaluation_functions.items():
# pylint: disable=protected-access
_set_input_or_wrap(task, name, function())
_try_validate_input(name, task)
def _try_validate_input(name, task: Task):
"""try to validate a input of a Task, rises an exception if validation fails"""
input = getattr(task, name)
task_type = type(task)
if hasattr(task_type, name):
class_input = getattr(type(task), name)
if hasattr(class_input, "validate"):
if class_input.validate_func is not None:
validation = class_input.validate_func(task, input.value)
else:
validation = input.validate()
if isinstance(validation, ValidationError):
path = concat(task.path, name)
raise MosaicException(
ErrorDefinition(
error_type=ErrorType.INPUT_ERROR,
message=f"Validation failed for '{path}' with value "
f"'{input.value}' reason: {validation.message}",
source=str(path),
)
)
def _collect_results(task: CachableTask) -> CacheResult:
"""collects the results of a cachable tasks cache"""
result = CacheResult(result=task.task_result, sub_results={})
stack = []
stack.extend(task.get_fields_of_type(CachableTask))
while stack:
current_task = stack.pop()
stack.extend(current_task.get_fields_of_type(CachableTask))
if not current_task.task_hash:
continue
result.sub_results[current_task.path.remove(
task.path)] = current_task.task_hash
return result
def _inject_log_to_task_tool(task):
"""inject the task logger to a tool, during execution of a task the used tool should use the logger of the task"""
for _, var in vars(task).items():
if isinstance(var, Tool):
_inject_log_to_tools(
tool=var, log=task._mosaic_logger
) # pylint: disable=protected-access
def _inject_log_to_tools(tool: Tool, log: Logger):
"""inject log to a tool"""
for name, annotation_type in get_type_annotations(type(tool)).items():
if annotation_type is Logger:
setattr(tool, name, log)
if issubclass(annotation_type, Tool):
_inject_log_to_tools(getattr(tool, name), log)
def _get_cached_result_and_logs(task: CachableTask):
"""get cached results and logs for a task, returns None if no cache is available"""
with Cache(_get_cache_dir(task)) as cache:
result: CacheResult = cache.get(task.task_hash)
if result is None:
return None, None
logs = cache.get(task.task_hash + "_logs")
if task._is_task_invalid(result, cache): # pylint: disable=protected-access
cache.clear(task.task_hash)
cache.clear(task.task_hash + "_logs")
task.task_result = None
return None, None
return result.result, logs
def _cache_log(handler: ListHandler, task: Task):
"""put the logs of a task into the cache"""
if not handler.log_list:
return
with Cache(_get_cache_dir(task)) as cache:
key = task.task_hash + "_logs"
try:
cache.put(key, handler.log_list)
except Exception as error:
raise AttributeError(
f"{error} - Task inputs and results must be serializable using pickle package"
) from error
def _inject_special_inputs(task, inputs):
"""Inject inputs passed to run method via inputs. Inputs that were already injected
from top level are ignored"""
inputs = inputs or {}
# Check if all given inputs actually exists
for name, _ in inputs.items():
path = Path(name)
current_task = find_task(task, path)
if not hasattr(current_task, path.target):
raise InputException(
f"Unexpected run inputs ({name} in {inputs}) for task {type(task)}"
)
# Check if all calculated or mutable has been set in run()
# Const can be set but it is not a must
for name in task._special_inputs.keys():
path = Path(name)
current_task = find_task(task, path)
# ignore inputs that were injected from top level
if path.target in current_task._injected_mutables:
continue
if name in inputs:
result = _set_input_or_wrap(
current_task,
path.target,
inputs[name],
)
if result.failure:
raise InputException(result.error)
_try_validate_input(path.target, current_task)
continue
if hasattr(current_task, path.target):
input = getattr(current_task, path.target)
# only const or calculated can be missing
if is_const(input):
if input.value is not None:
result = _set_input_or_wrap(
current_task, path.target, input.value)
if result.failure:
raise InputException(result.error)
_try_validate_input(path.target, current_task)
continue
else:
result = _set_input_or_wrap(
current_task,
path.target,
input.default,
)
if result.failure:
raise InputException(result.error)
_try_validate_input(path.target, current_task)
continue
raise InputException(
f"Missing input {name} in inputs for '{current_task.path}.{path.target}'. "
"Maybe it has been set to Calculated but has not been set when calling the run() method."
)
def _add_prefix_to_all_inputs(root: TaskNode, prefix):
"""adds a prefix to all inputs according the task tree"""
tree.traverse(root, lambda node: _add_prefix_to_inputs(node, prefix))
def _add_prefix_to_inputs(node: TaskNode, prefix):
"""adds a given prefix to all inputs of a task"""
for input_name in list(node.inputs.keys()):
new_name = f"{prefix}.{input_name}"
node.inputs[new_name] = node.inputs[input_name]
del node.inputs[input_name]