Source code for mosaic_orchestrator.mosaic_utils

"""This module contains helper functions for the mosaic framework"""

from contextlib import contextmanager
from datetime import datetime
import functools
import os
import pathlib
import types
import typing
from logging import Logger
import uuid

from mosaic_orchestrator import tree
from mosaic_orchestrator.cachable_task import CachableTask, CacheResult
from mosaic_orchestrator.cache import Cache
from mosaic_orchestrator.result import (
    ValidationError,
    Result,
    ErrorType,
    ErrorDefinition,
)
from mosaic_orchestrator.mosaic_exception import MosaicException
from mosaic_orchestrator.task import (
    TaskNode,
    Task,
    Input,
    InputException,
    Calculated,
    TaskState,
    is_calculated,
    is_calculated_or_mutable_or_const,
    is_mutable,
    CalculatedAccessException,
    is_const,
    Const,
    _is_task_input,
    get_task_inputs,
    _get_cache_dir,
)
from mosaic_orchestrator.tool import Tool
from mosaic_orchestrator.tree import Path, concat
from mosaic_orchestrator.utils import (
    get_type_annotations,
    get_all_inputs,
    get_type_of_class_field,
    is_function,
    ListHandler,
)


[docs]class TaskTreeException(Exception): """raised if the task tree can not be constructed"""
[docs]def find_node( task_tree: TaskNode, path: Path, check_input: bool = True ) -> Result[TaskNode]: """Find a node identified by its path in a trask tree Args: check_input: when True the path is expected to point to a Input which is then checked to be valid. Returns: A result containing the TaskNode if it was found or an error if it was not found. """ current_node = task_tree if path.length == 1 and check_input: result = _check_input(path.target, task_tree.value) if result.failure: return result for step in path.steps[:-1]: children = current_node.children step_list = list( filter(lambda child: child.value.path.target == step, children) ) # pylint: disable=cell-var-from-loop if not step_list: return Result.Fail(f"'{step}' is no child of {type(current_node.value)}") current_node = step_list[0] return Result.Ok(current_node)
[docs]def find_task(task: Task, path: Path): """Get a subtask of the given task identified by the relative path Args: task: the task that is used as a root path: the path to the target subtask. it must be relative to the given task and include it as a root. Returns: the instance of the subtask. """ current = task for step in path.steps[:-1]: if not hasattr(current, step): raise TaskTreeException(f"'{step}' is no child of {type(current)}") current = getattr(current, step) return current
def _validate_input(task: Task, input: Input): """calls Inputs validation lamda or validate function""" try: class_validate = input.validate() if isinstance(class_validate, ValidationError): return class_validate if input.validate_func is not None: constructor_validate = input.validate_func(task, input.value) if isinstance(constructor_validate, ValidationError): return constructor_validate except Exception as e: return ValidationError(f"Validation fatal error of {task} {input}: {e}") return class_validate def _check_if_injectable(task: Task, path: Path) -> Result: """add check error for injection of Calculated input. Also remember which Default inputs are being injected.""" input = getattr(task, path.target) if is_calculated(input): return Result.Fail( f"input {path} cannot be altered since it is set to Calculated" ) if is_const(input): return Result.Fail(f"input {path} cannot be altered since it is set to Const") if is_mutable(input): task._injected_mutables.append(path.target) return Result.Ok() def _set_input_or_wrap(instance: object, name: str, input) -> Result: """set a input to an instance, if the given input is not a Input it gets wrapped into it cases: Input from above Definition in class What to do 1) something with value -> type annotation => create new input of type with default 2) something with value -> no type annotation => copy value into existing Input 3) Calculated -> type annotation => create new input of type with default=None and mark _special_input 4) Calculated -> no type annotation => set value to None and mark _special_input in existing Input 5) Input with same type -> either case => replace completly 6) Input different type -> either case => Error """ # get type from type annotation or object input_type = get_type_of_class_field(instance=instance, name=name) origin = typing.get_origin(input_type) # get base type if origin is not None: input_type = origin if isinstance(input, Input): if isinstance(input, input_type): task_input = input else: return Result.Fail( f"input {input_type} not instantiable with {input} = {input.value} in {instance}" ) elif not hasattr(instance, name): # We encountered just a type annotation try: task_input = input_type(default=input) except TypeError as error: return Result.Fail( f"input {input_type} not instantiable with {input} of type {type(instance)} - {error}" ) if is_calculated_or_mutable_or_const(input): task_input._special_input = input else: task_input = getattr(instance, name) if is_calculated_or_mutable_or_const(input): task_input._special_input = input else: # callable or literal task_input.value = input setattr(instance, name, task_input) return Result.Ok(value=task_input) def _evaluate_callable_inputs(node: TaskNode): """evaluates callable inputs""" _evaluate_constructor_lambdas(node) _evaluate_default_lambdas(node) _evaluate_task_states(node) def _evaluate_task_states(node): """evaluates callable (lambda) value of a TaskState""" for name, var in vars(node.task).items(): if isinstance(var, TaskState) and is_function(var.value): try: var.value = var.value(node.task) except CalculatedAccessException: node.task._runtime_evaluation_functions[name] = functools.partial( var.value, node.task ) # pylint: disable=protected-access var._special_input = Calculated var.value = None def _evaluate_default_lambdas(node): """evaluates callable (lambda) value of Inputs""" for name, input in get_task_inputs(node.task).items(): if not is_calculated(input): try: if input and is_function(input.value): if is_const(input): input.value = input.value(node.parent) else: input.value = input.value(node.task) except CalculatedAccessException: # Evaluation not possible because a lambda is accessing a calculated input # This means, we cannot calculate and validate it! # This dependend input is therefore a calculated input as well node.task._runtime_evaluation_functions[name] = functools.partial( input.value, node.task ) # pylint: disable=protected-access # turn it into a calculated input input._special_input = Calculated input.value = None def _evaluate_constructor_lambdas(node): """evaluates callable (lambda) value of a constructor""" for path, task in node.injected_inputs: input = getattr(node.task, path.target) if is_function(input.value): try: input.value = input.value(task) except CalculatedAccessException: node.task._runtime_evaluation_functions[ path.target ] = functools.partial( input.value, task ) # pylint: disable=protected-access input._special_input = Calculated input.value = None def _check_input(name, task: Task) -> Result: """Check if a given name is a valid input of a Task""" annotations = get_type_annotations(type(task)) if name not in annotations.keys() and not hasattr(task, name): return Result.Fail(f"'{name}' is no input of {type(task)}") if hasattr(task, name) and isinstance(getattr(task, name), TaskState): return Result.Fail(f"'{name}' of {type(task)} is a TaskState and not settable") if (hasattr(task, name) and not isinstance(getattr(task, name), Input)) or ( name in annotations.keys() and not _is_task_input(annotations[name]) ): return Result.Fail(f"'{name}' of {type(task)} is not a Input or subclassing it") return Result.Ok() def _wrap_root(root: Task): """wrap the run method of a task""" root._old_root_run = root.run # pylint: disable=protected-access root.run = types.MethodType(_root_run_wrapper, root) def _root_run_wrapper(self, *args, **kwargs): """wrapped run method, creates current working directories, calls the run method and cleanup after run. """ self.mosaic._startup() # pylint: disable=protected-access cwd = pathlib.Path.cwd() try: os.makedirs( self.mosaic.run_directory, exist_ok=True ) # pylint: disable=protected-access os.chdir(self.mosaic.run_directory) # pylint: disable=protected-access new_args = dict(**kwargs) result = self._old_root_run( *args, **new_args ) # pylint: disable=protected-access os.chdir(cwd) self.mosaic._cleanup() # pylint: disable=protected-access return result except Exception as error: os.chdir(cwd) self.mosaic._cleanup(error) # pylint: disable=protected-access raise error from error def _evaluate_runtime_lambdas(task): """evaluate the runtime lambdas of a task""" for name, function in task._runtime_evaluation_functions.items(): # pylint: disable=protected-access _set_input_or_wrap(task, name, function()) _try_validate_input(name, task) def _try_validate_input(name, task: Task): """try to validate a input of a Task, rises an exception if validation fails""" input = getattr(task, name) task_type = type(task) if hasattr(task_type, name): class_input = getattr(type(task), name) if hasattr(class_input, "validate"): if class_input.validate_func is not None: validation = class_input.validate_func(task, input.value) else: validation = input.validate() if isinstance(validation, ValidationError): path = concat(task.path, name) raise MosaicException( ErrorDefinition( error_type=ErrorType.INPUT_ERROR, message=f"Validation failed for '{path}' with value " f"'{input.value}' reason: {validation.message}", source=str(path), ) ) def _collect_results(task: CachableTask) -> CacheResult: """collects the results of a cachable tasks cache""" result = CacheResult(result=task.task_result, sub_results={}) stack = [] stack.extend(task.get_fields_of_type(CachableTask)) while stack: current_task = stack.pop() stack.extend(current_task.get_fields_of_type(CachableTask)) if not current_task.task_hash: continue result.sub_results[current_task.path.remove( task.path)] = current_task.task_hash return result def _inject_log_to_task_tool(task): """inject the task logger to a tool, during execution of a task the used tool should use the logger of the task""" for _, var in vars(task).items(): if isinstance(var, Tool): _inject_log_to_tools( tool=var, log=task._mosaic_logger ) # pylint: disable=protected-access def _inject_log_to_tools(tool: Tool, log: Logger): """inject log to a tool""" for name, annotation_type in get_type_annotations(type(tool)).items(): if annotation_type is Logger: setattr(tool, name, log) if issubclass(annotation_type, Tool): _inject_log_to_tools(getattr(tool, name), log) def _get_cached_result_and_logs(task: CachableTask): """get cached results and logs for a task, returns None if no cache is available""" with Cache(_get_cache_dir(task)) as cache: result: CacheResult = cache.get(task.task_hash) if result is None: return None, None logs = cache.get(task.task_hash + "_logs") if task._is_task_invalid(result, cache): # pylint: disable=protected-access cache.clear(task.task_hash) cache.clear(task.task_hash + "_logs") task.task_result = None return None, None return result.result, logs def _cache_log(handler: ListHandler, task: Task): """put the logs of a task into the cache""" if not handler.log_list: return with Cache(_get_cache_dir(task)) as cache: key = task.task_hash + "_logs" try: cache.put(key, handler.log_list) except Exception as error: raise AttributeError( f"{error} - Task inputs and results must be serializable using pickle package" ) from error def _inject_special_inputs(task, inputs): """Inject inputs passed to run method via inputs. Inputs that were already injected from top level are ignored""" inputs = inputs or {} # Check if all given inputs actually exists for name, _ in inputs.items(): path = Path(name) current_task = find_task(task, path) if not hasattr(current_task, path.target): raise InputException( f"Unexpected run inputs ({name} in {inputs}) for task {type(task)}" ) # Check if all calculated or mutable has been set in run() # Const can be set but it is not a must for name in task._special_inputs.keys(): path = Path(name) current_task = find_task(task, path) # ignore inputs that were injected from top level if path.target in current_task._injected_mutables: continue if name in inputs: result = _set_input_or_wrap( current_task, path.target, inputs[name], ) if result.failure: raise InputException(result.error) _try_validate_input(path.target, current_task) continue if hasattr(current_task, path.target): input = getattr(current_task, path.target) # only const or calculated can be missing if is_const(input): if input.value is not None: result = _set_input_or_wrap( current_task, path.target, input.value) if result.failure: raise InputException(result.error) _try_validate_input(path.target, current_task) continue else: result = _set_input_or_wrap( current_task, path.target, input.default, ) if result.failure: raise InputException(result.error) _try_validate_input(path.target, current_task) continue raise InputException( f"Missing input {name} in inputs for '{current_task.path}.{path.target}'. " "Maybe it has been set to Calculated but has not been set when calling the run() method." ) def _add_prefix_to_all_inputs(root: TaskNode, prefix): """adds a prefix to all inputs according the task tree""" tree.traverse(root, lambda node: _add_prefix_to_inputs(node, prefix)) def _add_prefix_to_inputs(node: TaskNode, prefix): """adds a given prefix to all inputs of a task""" for input_name in list(node.inputs.keys()): new_name = f"{prefix}.{input_name}" node.inputs[new_name] = node.inputs[input_name] del node.inputs[input_name]