Source code for mosaic_orchestrator.mosaic

"""This module contains the mosaic-orchestrator framework core functionality"""
from __future__ import annotations

import copy
from dataclasses import asdict, dataclass, fields, is_dataclass
import importlib
import logging
import os
import pathlib
import time
import types
import typing
from logging import Logger
from typing import (
    Any,
    Type,
    Union,
    Dict,
    List,
    Tuple,
    Optional,
)
import uuid

from mosaic_orchestrator.execution_tracker import ExecutionTracker

import pluggy

from mosaic_orchestrator import tree, hookspecs
from mosaic_orchestrator.asset import Asset
from mosaic_orchestrator.cachable_task import CachableTask
from mosaic_orchestrator.config import Config
from mosaic_orchestrator.mosaic_utils import (
    _set_input_or_wrap,
    find_node,
    _validate_input,
    _check_if_injectable,
    _check_input,
    _evaluate_callable_inputs,
    _add_prefix_to_all_inputs,
    _wrap_root,
    MosaicException,
    _inject_special_inputs,
    _evaluate_runtime_lambdas,
    _get_cached_result_and_logs,
    _inject_log_to_task_tool,
    _collect_results,
    _cache_log
)
from mosaic_orchestrator.pdk import PDK, PdkItem
from mosaic_orchestrator.result import (
    ValidationError,
    ErrorType,
    ErrorDefinition,
    Validation,
)
from mosaic_orchestrator.task import (
    ATask,
    TaskNode,
    Task,
    Calculated,
    Input,
    Output,
    is_calculated,
    CheckResult,
    is_const,
    Const,
    get_task_inputs,
)
from mosaic_orchestrator.tool import Tool
from mosaic_orchestrator.tree import is_path_string, Path, concat, prepend
from mosaic_orchestrator.utils import (
    get_type_annotations,
    get_all_inputs,
    ListHandler,
    get_type_of_class_field,
)
from mosaic_orchestrator.task_hierarchy_builder import TaskHierarchyBuilder, T
from mosaic_orchestrator.tool_store import ToolStore
from mosaic_orchestrator.protocols import (
    Validate,
    LifeCycleListener,
    Hashable,
)
from mosaic_orchestrator.result_container import ResultContainer
from mosaic_orchestrator.work_products import WorkProduct


[docs]class Mosaic(TaskHierarchyBuilder[T]): """This is the main entry point of the framework. Use `create` to start building an executable task tree.""" log_formatter = logging.Formatter( "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) """the formatter used for logging. It is used for stdout as well as log files. Can be set globally""" log_level = logging.INFO """the log level of all loggers. Can be set globally""" default_cache_directory = ".mosaic_orchestrator" default_work_directory = "mosaic_orchestrator.work"
[docs] @staticmethod def create( root: Type[T], caching: bool = True, plugin_enabled: bool = True ) -> TaskHierarchyBuilder[T]: """creates a builder for a task hierarchy given a root task. Args: root: the type of the desired root task. Notice that actual instantiation is done by the framework. caching: when False all caching is disabled for hierarchies build with this builder. plugin_enabled: Enables die plugin discovery for PDKs and tools. Returns: a `TaskHierarchyBuilder` for the given root type. Example: This is a simple example that shows how to build a simple mosaic_orchestrator task tree:: builder = Mosaic.create(root=MyTask)\ .with_pdk(MyPdkImpl())\ .register(AtoolInterface, MyToolImpl())\ .register(AnotherDependency, AnotherDependency())\ .with_inputs({ "x1": 5, "x2": 3, "t1.p5": "REF", }) if builder.check().success: instance = builder.build() instance.run() """ return Mosaic[T](root, caching, plugin_enabled)
[docs] @staticmethod def get_installed_pdks() -> List[PDK]: """lists all installed PDKs""" plugin_manager = pluggy.PluginManager("mosaic_orchestrator") plugin_manager.add_hookspecs(hookspecs) plugin_manager.load_setuptools_entrypoints("mosaic_orchestrator") return plugin_manager.hook.get_pdk()
[docs] @staticmethod def get_installed_tools() -> List[Tool]: """lists all installed Tools""" plugin_manager = pluggy.PluginManager("mosaic_orchestrator") plugin_manager.add_hookspecs(hookspecs) plugin_manager.load_setuptools_entrypoints("mosaic_orchestrator") return plugin_manager.hook.get_tool()
def __init__(self, root: Type[T], caching, plugin_enabled): self._configure_console_logging() self.root_type = root self.pdk = None self.execution_trackers = [] self.caching = caching self.plugin_enabled = plugin_enabled self._store = {} self.tools = ToolStore() self._callable_inputs: List[Tuple[Input, Task]] = [] self._check_result = CheckResult() self._task_tree = None self.root = None self._loggers = {} self.run_directory = pathlib.Path.joinpath( pathlib.Path.cwd(), Mosaic.default_work_directory ) self.cache_directory = pathlib.Path.joinpath( pathlib.Path.cwd(), Mosaic.default_cache_directory ) self._life_cycle_listener: List[LifeCycleListener] = [] self._validate_list: List[Validate] = [] self.plugin_manager = pluggy.PluginManager("mosaic_orchestrator") if plugin_enabled: self.plugin_manager.add_hookspecs(hookspecs) self.plugin_manager.load_setuptools_entrypoints( "mosaic_orchestrator")
[docs] def with_cfg(self, config: Union[Config, str]) -> TaskHierarchyBuilder[T]: if isinstance(config, str): mod_name, file_ext = os.path.splitext(os.path.split(config)[-1]) if file_ext.lower() == ".py": loader = importlib.machinery.SourceFileLoader(mod_name, config) mod = types.ModuleType(loader.name) loader.exec_module(mod) else: raise MosaicException( ErrorDefinition( ErrorType.GENERAL_ERROR, f"Configuration file '{config}' cannot " f"be loaded.", config, ) ) try: config = mod.MosaicConfig() except AttributeError as error: raise MosaicException( ErrorDefinition( ErrorType.GENERAL_ERROR, f"Configuration file '{config}' has " f"not a Config class named " f"'MosaicConfig'.", config, ) ) from error if config.get_pdk(): self.with_pdk(config.get_pdk()) self.with_inputs(config.get_inputs()) if config.get_run_directory(): self.with_run_directory(config.get_run_directory()) if config.get_cache_directory(): self.with_cache_directory(config.get_cache_directory()) for tool in config.get_tools(): self.register(key=type(tool), value=tool, overwrite=True) for key, value in config.get_objects_to_register(): self.register(key, value, overwrite=True) return self
[docs] def with_pdk(self, pdk: PDK) -> TaskHierarchyBuilder[T]: self.pdk = pdk self._build_protocol_lists(pdk) for tracker in self.execution_trackers: tracker.set_entry("pdk", pdk.name) return self
[docs] def with_run_directory(self, path: pathlib.Path) -> TaskHierarchyBuilder[T]: self.run_directory = path return self
[docs] def with_cache_directory(self, path: pathlib.Path) -> TaskHierarchyBuilder[T]: self.cache_directory = path return self
[docs] def register( self, key: Union[Type, str], value: any, overwrite: bool = False ) -> TaskHierarchyBuilder[T]: for tracker in self.execution_trackers: tracker.append_to_entry("installed_tools", str(value)) # Inject PDK in registered "things" if not isinstance(value, Input): for name, input_type in get_type_annotations(type(value)).items(): if input_type and issubclass(input_type, PDK): setattr(value, name, self.pdk) if isinstance(value, Tool): self.tools.put(key, value, overwrite) else: if key not in self._store or overwrite: self._store[key] = value else: raise MosaicException( ErrorDefinition( ErrorType.GENERAL_ERROR, f"{key} has already been registered! Maybe you have installed multiple variants in your " f"python environment (e.g. {type(value)} and {type(self._store[key])}). " "Please uninstall one or more to reduce it to a single tool installation.", value, ) ) return self
[docs] def with_inputs(self, inputs: Dict[str, object]) -> TaskHierarchyBuilder[T]: """Add inputs to the task. >>> Mosaic.create(root=ExampleTask).with_inputs({"width": 10.0}).build().run() 10.0 Args: inputs: Inputs to the task Returns: TaskHierarchyBuilder[T]: The builder instance """ self._store.update(inputs) return self
[docs] def build(self) -> T: """Build an instance of a task from the builder object. >>> instance = Mosaic.create(root=ExampleTask).build() >>> instance.run() 0.0 Raises: MosaicException: If an errors occures during check() Returns: T: An instance of the task """ result = self.check() if result.success is False: raise MosaicException(result.errors[0]) return self.root
[docs] def export_inputs_to_dict(self, skip_defaults: bool = True, skip_hidden: bool = True) -> dict: """Run check() and write all inputs into a dictionary that can be used to save it in any kind of format of your choice (e.g. JSON/YMAL). You can later load them again via `import_inputs_to_dict()`. >>> builder = Mosaic.create(root=ExampleTask).with_inputs({"width": 20}) >>> builder.export_inputs_to_dict(skip_defaults=False, skip_hidden=False) {'ExampleTask.width': 20} Args: skip_defaults (bool, optional): Skip inputs that have value==defaults. Defaults to True. skip_hidden (bool, optional): Skip hidden inputs. Defaults to True. Returns: dict: Dictionary with input path as keys and input values as values """ result = self.check() inputs = {} for single_input in result.inputs: if skip_defaults and single_input.input.value == single_input.input.default: continue if skip_hidden and single_input.input.hidden: continue if is_dataclass(single_input.input.value): # If init has been set to false on any of the fields of the # dataclass, then it will not be available in the constructor # and therefore should be dismissed in the output def handle_fields(dataclass_fields, tuple_input) -> dict: return {name: value for name, value in tuple_input if name in dataclass_fields and dataclass_fields[name].init} inputs[single_input.path] = asdict( single_input.input.value, dict_factory=lambda tup: handle_fields( single_input.input.value.__dataclass_fields__, tup) ) else: inputs[single_input.path] = single_input.input.value return inputs
[docs] def import_inputs_from_dict(self, input_dict: dict[str, Any]) -> TaskHierarchyBuilder[T]: """Import parameters from a dictionary that has been created via `export_inputs_from_dict`. It can be used to save and restore inputs of root tasks. >>> builder = Mosaic.create(root=ExampleTask).with_inputs({"width": 20.0}) >>> exported_dict = builder.export_inputs_to_dict(skip_defaults=False, skip_hidden=False) >>> builder = Mosaic.create(root=ExampleTask).import_inputs_from_dict(exported_dict) >>> result = builder.check() >>> result.inputs[0].input.value 20.0 Args: input_dict (dict[str, Any]): Dictionary in the form of path->obj representation Raises: MosaicException: If a path is not found MosaicException: Dataclass value cannot be initialized Returns: TaskHierarchyBuilder[T]: The builder instance """ result = self.check() inputs = result.inputs def find_input(path): return [x.input for x in inputs if x.path == path] def strip_root_task_from_path(path): return ".".join(path.split(".")[1:]) for path, parameter in input_dict.items(): if parameter is None: continue matched_inputs = find_input(path) if len(matched_inputs) == 0: raise MosaicException(f"Path {path} not found.") matched_input = matched_inputs[0] new_value = parameter if is_dataclass(matched_input.type): dataclass_fields = fields(matched_input.type) for field in dataclass_fields: if field.name in parameter and not field.init: del parameter[field.name] try: new_value = matched_input.type(**parameter) except Exception as error: raise MosaicException( f"Could not read {matched_input.type} from dictionary {parameter}: {error}") self._store.update({strip_root_task_from_path(path): new_value}) return self
[docs] def check(self) -> CheckResult: """Perform checks without running a task. Performed checks are: - Missing parameters - Validation of input parameters - Checking of tools availability - Checking of PDK completness (are all PDKItems available?) Returns: CheckResult: Container of check results """ self._check_result = CheckResult() if self.plugin_enabled: if not self.execution_trackers: self.__load_exection_trackers_from_plugins() if not self.pdk: self._load_pdk_from_plugins() self._load_tools_from_plugins() self._inject_by_type(self.pdk) self.root = self.root_type() self._task_tree = self._build_task_tree( task=self.root, path=Path(self.root_type.__name__) ) self._check_store_inputs() tree.traverse(tree=self._task_tree, on_each=self._inject) tree.traverse(tree=self._task_tree, on_each=self._set_defaults) for tool in self.tools.used_tools(): self._inject_pdk_items(tool) tree.traverse(tree=self._task_tree, on_each=self._evaluate_and_validate) for obj in self._validate_list: result = obj.validate() if isinstance(result, ValidationError): self._check_result.append_error( ErrorType.VALIDATION_ERROR, f"'{type(obj).__name__}' could not be validated - {result.message}", type(obj), ) _wrap_root(self.root) return self._check_result
def _load_pdk_from_plugins(self): """try to load a pdk from installed plugin, adds error if more than one PDK is found""" pdks = self.plugin_manager.hook.get_pdk() if len(pdks) == 1: self.with_pdk(pdks[0]) elif len(pdks) > 1: self._check_result.append_error( ErrorType.PDK_ERROR, f"More then one installed PDKs found in virtual environment: {pdks}", self, )
[docs] def register_execution_tracker(self, tracker : ExecutionTracker) -> TaskHierarchyBuilder[T]: """Register an execution tracker Args: tracker (ExecutionTracker): Execution tracker to add """ self.execution_trackers.append(tracker) self._build_protocol_lists(tracker) return self
def __load_exection_trackers_from_plugins(self): """try to load the execution tracker from the installed plugins""" execution_trackers : list[ExecutionTracker] = self.plugin_manager.hook.get_execution_tracker() for tracker in execution_trackers: self.register_execution_tracker(tracker) def _load_tools_from_plugins(self): """try to load tools from installed plugins""" tools = self.plugin_manager.hook.get_tool() for tool in tools: self.register(tool[0], tool[1], overwrite=False) def _build_task_tree( self, task: Task, path: Path, parent: Optional[Task] = None ) -> TaskNode: """build the task tree of a given task recursively""" task._path = path # pylint: disable=protected-access self._set_logger(task) self._inject_pdk_items(task) self._inject_by_type(task) task_inputs = get_task_inputs(task) dynamic_inputs = task.dynamic_inputs() if dynamic_inputs: for input_name in dynamic_inputs: setattr(task, input_name, dynamic_inputs[input_name]) task_inputs.update(dynamic_inputs) root = TaskNode( parent=parent, value=task, children=[], inputs=task_inputs, injected_inputs=[], ) dynamic_tasks = task.dynamic_tasks() if dynamic_tasks: for subtask_name, subtask in dynamic_tasks.items(): setattr(task, subtask_name, subtask) sub_tree = self._build_task_tree( task=subtask, path=Path([*path.steps, subtask_name]), parent=task ) root.children.append(sub_tree) _add_prefix_to_all_inputs(sub_tree, subtask_name) for subtask_attr_name in type(task).__dict__.keys(): subtask_attr = getattr(task, subtask_attr_name) if issubclass(type(subtask_attr), Task): sub_tree = self._build_task_tree( task=subtask_attr, path=Path([*path.steps, subtask_attr_name]), parent=task, ) root.children.append(sub_tree) _add_prefix_to_all_inputs(sub_tree, subtask_attr_name) return root def _inject(self, node: TaskNode): """wrap run method, handle calculated inputs and inject inputs into task""" node.task._mosaic = self # pylint: disable=protected-access self._wrap_run(node.task) self._handle_special_inputs(node) self._inject_inputs(node) def _set_logger(self, task: Task): """crate a new or set existing logger given task""" task._mosaic_logger = self._get_or_create_logger(task) # pylint: disable=protected-access def _inject_pdk_items(self, instance): """inject pdk items to a given instance""" source = type(instance) if hasattr(instance, "path"): source = instance.path for name, obj in get_all_inputs(instance).items(): # Allow assets to use pdk items if issubclass(type(obj), Asset): self._inject_pdk_items(obj) if issubclass(type(obj), PdkItem): if not self._check_pdk(source): continue result = self.pdk.get_item(obj.name, obj.version) if result.failure: supported_items = self.pdk.supported_items() or [] supported_items.sort(key=lambda item: item.name.lower()) supported_items_string = ' '.join(['('+item.name+' '+item.version+')' for item in supported_items]) self._check_result.append_error( ErrorType.MISSING_PDK_ITEM, f"PdkItem not available in provided PDK ('{obj.name}'," f" {obj.version}). Supported items: [{supported_items_string}]", source, ) continue item = result.value # making item conform to Hashable and dependant on pdk item.task_hash = self.pdk.task_hash # Check if they are not "related" if not issubclass(type(item), type(obj)): self._check_result.append_error( ErrorType.PDK_ERROR, f"Expected pdk item of type '{type(obj).__name__}' not compatible" f" with view from PDK '{type(item).__name__}'", type(instance), ) if not item._try_init(): # pylint: disable=protected-access self._check_result.append_error( ErrorType.PDK_ERROR, f"Could not validate pdk item '{type(item).__name__}'", type(instance), ) setattr(instance, name, result.value) def _check_pdk(self, source) -> bool: """check if pdk is available""" if self.pdk is None: self._check_result.append_error( ErrorType.PDK_ERROR, "PDK implementation is missing, " "please provide a valid implementation", source, ) return False return True def _inject_by_type(self, obj): """inject Logger, Tools, PDK, registered objects and input to given objects""" annotations = get_type_annotations(type(obj)) for input_name, annotation in annotations.items(): origin = typing.get_origin(annotation) if origin is not None: annotation = origin if annotation is Logger: if issubclass(type(obj), Task) and not hasattr(obj, input_name): setattr( obj, input_name, obj._mosaic_logger ) # pylint: disable=protected-access elif issubclass(annotation, Tool): tool_found = [] for ( tool ) in self.tools._store.values(): # pylint: disable=protected-access if issubclass(type(tool.tool), annotation): tool_found.append(tool.tool) if annotation not in self.tools._store: self._check_result.append_error( ErrorType.MISSING_TOOL, f"Tool is missing for {annotation}", type(obj), ) continue tool = self.tools._store[annotation].tool setattr(obj, input_name, tool) self._inject_by_type(tool) self.tools.set_used(annotation) elif annotation is Output: # If an output is just annotated -> create an instance setattr(obj, input_name, annotation()) continue elif annotation is PDK: if ( hasattr(obj, input_name) and getattr(obj, input_name) is not self.pdk ): self._check_result.append_error( ErrorType.PDK_ERROR, f"Instantiated PDK found in {type(obj)}" " - this is not the expected usage, please use " "builder function instead", type(obj), ) continue if self.pdk is None: self._check_result.append_error( ErrorType.PDK_ERROR, "PDK implementation is missing, please provide a valid " "implementation", type(obj), ) continue setattr(obj, input_name, self.pdk) elif annotation in self._store: input = self._store[annotation] setattr(obj, input_name, input) self._inject_by_type(input) elif ( not hasattr(obj, input_name) and not issubclass(annotation, Task) and not issubclass(annotation, Input) ): self._check_result.append_error( ErrorType.INPUT_ERROR, f"{annotation} implementation is missing, " f"please register a valid implementation", type(obj), ) def _build_input_list(self, node: TaskNode): """build input list of a task and append it to the check result""" for path_str in node.inputs.keys(): path = prepend(node.task.path.root, Path(path_str)) if not hasattr(node.task, path.target): self._check_result.append_error( ErrorType.MISSING_INPUT, f"Input missing: {path}", path ) continue task_input = getattr(node.task, path.target) if ( not is_calculated(task_input) and task_input.value is None and not is_const(task_input) ): self._check_result.append_error( ErrorType.MISSING_INPUT, f"Input missing: {path}", path ) if not is_const(task_input) and not is_calculated(task_input): self._check_result.append_input( str(path), task_input, node.task) def _get_or_create_logger(self, task: Task) -> Logger: """create a new logger or create the existing one""" if task.path in self._loggers: return self._loggers[task.path] logger = logging.getLogger(name=str(task.path)) logger.setLevel(Mosaic.log_level) filehandler = logging.FileHandler( filename=os.path.join( *self.run_directory.parts, *[*task.path.steps, "log"] ), delay=True, ) filehandler.setFormatter(Mosaic.log_formatter) logger.addHandler(filehandler) self._loggers[task.path] = logger return logger def _check_store_inputs(self): """check if inputs given to mosaic are valid and given path is correct""" for name, _ in self._store.items(): if isinstance(name, str): if is_path_string(name): path = Path(name) node_result = find_node( task_tree=self._task_tree, path=path) if node_result.failure: self._check_result.append_if_error( ErrorType.INPUT_ERROR, node_result, str(path) ) continue self._check_result.append_if_error( ErrorType.INPUT_ERROR, _check_input(path.target, node_result.value.task), str(path), ) else: self._check_result.append_if_error( ErrorType.INPUT_ERROR, _check_input(name, self.root), str(self.root.path), ) def _inject_inputs(self, node: TaskNode): """inject inputs form store in given task""" node.injected_inputs = [] for path_string in node.inputs.keys(): if path_string in self._store: self._inject_input_from_store(node, path_string) def _inject_input_from_store(self, node: TaskNode, path_string): """inject a named inputs form store""" path = Path(path_string) if hasattr(node.task, path.target): result = _check_if_injectable(node.task, path) if result.failure: self._check_result.append_error( ErrorType.INPUT_ERROR, result.error, str(path) ) return result = _set_input_or_wrap( instance=node.task, name=path.target, input=self._store[path_string] ) if result.failure: self._check_result.append_error( ErrorType.INPUT_ERROR, result.error, str(node.task.path) ) return node.injected_inputs.append((path, None)) self._add_to_relevant_inputs_of_ascendants(path, result) def _add_to_relevant_inputs_of_ascendants(self, path, result): """add input to task tree at given path""" current_path = path.parent() while current_path.steps: current_task = ( find_node(self._task_tree, current_path, check_input=False) .raise_on_fail() .task ) current_task._relevant_inputs.append( result.value ) # pylint: disable=protected-access current_path = current_path.parent() def _evaluate_and_validate(self, node: TaskNode): """evaluate and validate PDK, callable inputs, inputs, build input list and build protocol list""" self._validate_pdk(node) _evaluate_callable_inputs(node) self._validate_inputs(node) self._build_input_list(node) self._build_protocol_lists(node.task) def _validate_pdk(self, node: TaskNode): """calls validate_pdk of given task and handle error""" validation = node.task.validate_pdk(self.pdk) if isinstance(validation, ValidationError): self._check_result.append_error( ErrorType.PDK_ERROR, f"Task {node.task.path} could not validate PDK - {validation.message}", node.task.path, ) def _validate_inputs(self, node: TaskNode): """calls validate functions of all Inputs. Assumes that all inputs have been evaluated""" for name, input in get_task_inputs(node.task).items(): if input: # No validation possible if input is calculated # and nothing has been set as value yet if is_calculated(input): continue validation = _validate_input(node.task, input) if isinstance(validation, ValidationError): path = concat(node.task.path, name) self._check_result.append_error( ErrorType.INPUT_ERROR, f"Validation failed for '{path}' with value " f"'{input.value}' reason: {validation.message}", str(path), ) def _build_protocol_lists(self, obj): """add obj to validation_list and/or life_cycle_list if necessary""" if isinstance(obj, Validate) and not isinstance(obj, Input): self._validate_list.append(obj) if isinstance(obj, LifeCycleListener): self._life_cycle_listener.append(obj) for _, item in get_all_inputs(obj).items(): if isinstance(item, Validate) and not isinstance(item, Input): self._validate_list.append(item) if isinstance(item, LifeCycleListener): self._life_cycle_listener.append(item) def _startup(self): """calls the on_start method of each registered object implements LifeCycleListener""" for obj in self._life_cycle_listener: obj._on_start() # pylint: disable=protected-access def _cleanup(self, exception=None): """calls the on_end method of each registered object implements LifeCycleListener""" for obj in self._life_cycle_listener: obj._on_end(exception) # pylint: disable=protected-access def _handle_special_inputs(self, node: TaskNode): """handle special inputs like Calculated, Mutable or Const""" special_inputs = node.task._special_inputs for input_name, special_input in special_inputs.items(): # if a input has been set # del node.task._default_inputs[input_name] # pylint: disable=protected-access path = Path(input_name) result = find_node(node, path) if result.failure: self._check_result.append_if_error( ErrorType.INPUT_ERROR, result, path) continue target_node = result.value if not hasattr(target_node.task, path.target): # we have a type annotation that was set to some special thing # so we first have to create it on the instance # get type from type annotation or object input_type = get_type_of_class_field( instance=target_node.task, name=path.target ) origin = typing.get_origin(input_type) # get base type if origin is not None: input_type = origin new_input = input_type(default=None) if not is_calculated(special_input): new_input.value = special_input.value new_input._special_input = special_input setattr(target_node.task, path.target, new_input) # target_node.task._calculated_default_values[path.target] = new_input continue input = getattr(target_node.task, path.target) # target_node.task._calculated_default_values[path.target] = input if is_calculated(input): self._check_result.append_error( ErrorType.INPUT_ERROR, f"input {target_node.task.path}.{path.target} cannot be altered since it is set to Calculated", f"{target_node.task.path}.{path.target}", ) continue if is_const(input): self._check_result.append_error( ErrorType.INPUT_ERROR, f"{target_node.task.path}.{path.target} is already set to Const", f"{target_node.task.path}.{path.target}", ) continue # The normal case is: # Const(), Mutable() is set from above -> just extract the value and set it input._special_input = special_input # Only for calculated inputs -> do not use the .value property!! if is_calculated(special_input): pass elif hasattr(special_input, "value"): input.value = special_input.value elif special_input == Const: pass else: self._check_result.append_error( ErrorType.INPUT_ERROR, f"{target_node.task.path}.{path.target} has some unknown type", f"{target_node.task.path}.{path.target}", ) continue def _wrap_run(self, task: Task): """wrap run method of given task with _run_wrapper""" task._old_run = task.run # pylint: disable=protected-access task.run = types.MethodType(_run_wrapper, task) def _set_defaults(self, node: TaskNode): """set the inputs passed to constructor of task (via inputs) if they havent been injected""" for name, default_input in node.task._normal_inputs.items(): # pylint: disable=protected-access path = Path(name) result = find_node(task_tree=node, path=path) if result.failure: self._check_result.append_if_error( ErrorType.INPUT_ERROR, result, str(path) ) continue current_node = result.value injected_path_list = list( filter( lambda input_path: input_path[0].target == path.target, # pylint: disable=cell-var-from-loop current_node.injected_inputs, ) ) if not injected_path_list or injected_path_list[0][0].length < path.length: result = _set_input_or_wrap( instance=current_node.task, name=path.target, input=default_input ) if result.failure: self._check_result.append_error( ErrorType.INPUT_ERROR, result.error, str(current_node.task.path), ) continue if injected_path_list: current_node.injected_inputs.remove(injected_path_list[0]) current_node.injected_inputs.append((path, node.parent)) def _copy(self, new_id: str, task: ATask) -> ATask: """create a copy of a task""" new_task = type(task)() new_task._path = Path( [*task.path.steps[:-1], new_id] ) # pylint: disable=protected-access new_task._mosaic = self # pylint: disable=protected-access self._set_logger(new_task) self._inject_pdk_items(new_task) self._inject_by_type(new_task) self._wrap_run(new_task) for name, input in get_task_inputs(task).items(): setattr(new_task, name, copy.copy(input)) # new_task._calculated_inputs = copy.copy(task._calculated_inputs) # new_task._calculated_default_values = copy.copy(task._calculated_default_values) new_task._injected_mutables = copy.copy(task._injected_mutables) new_task._runtime_evaluation_functions = copy.copy( task._runtime_evaluation_functions ) new_task._normal_inputs = copy.copy(task._normal_inputs) new_task._special_inputs = copy.copy(task._special_inputs) new_task._all_inputs = copy.copy(task._all_inputs) return new_task def _configure_console_logging(self): """setup console logging""" console_handler = logging.StreamHandler() console_handler.setFormatter(Mosaic.log_formatter) logging.getLogger("").addHandler(console_handler)
def _run_wrapper( task, inputs: Dict[str, object] = None, *args, **kwargs ) -> Union[Any, WorkProduct]: """wrapped run method, creates current working directories, inject logging and handles caching""" # the inputs argument can be filled with a custom positional run argument # such as : run(my_arg: int) # when calling such a run method without keyword - run(1) - it would appear here as inputs is_inputs_real = inputs is None or isinstance(inputs, dict) if is_inputs_real: _inject_special_inputs(task, inputs) _evaluate_runtime_lambdas(task) working_dir = task.path.target if not os.path.exists(working_dir): os.makedirs(working_dir) os.chdir(working_dir) task.run_id = str(uuid.uuid4()) for tracker in task.mosaic.execution_trackers: tracker.start_task(task, is_root_task=task == task.mosaic.root) if task._mosaic.caching and isinstance( task, CachableTask ): # pylint: disable=protected-access task.task_hash = task.calc_hash(inputs, *args, **kwargs) cached_result, logs = _get_cached_result_and_logs(task) if cached_result is not None: _reproduce_cached_logs(logs, task) os.chdir("..") cached_result.apply_to_obj(task) task.task_result = cached_result task.cache_used = True # Track execution for tracker in task.mosaic.execution_trackers: tracker.end_task(task, is_root_task = (task == task.mosaic.root), cache_used=True) return cached_result.get_return_value() task.cache_used = False _inject_log_to_task_tool(task) list_handler = ListHandler() if isinstance(task, CachableTask): task._mosaic_logger.addHandler( list_handler) # pylint: disable=protected-access if is_inputs_real: result = task._old_run( *args, **kwargs) # pylint: disable=protected-access else: # inputs contains custom positional argument result = task._old_run( inputs, *args, **kwargs ) # pylint: disable=protected-access for tracker in task.mosaic.execution_trackers: tracker.end_task(task, is_root_task = (task == task.mosaic.root), cache_used=False) # Wrap in container # This container is saved in the cache and contains all outputs task_outputs = task.get_outputs() result_container = ResultContainer() result_container.update_container(task, task_outputs, result) if task._mosaic.caching and isinstance( task, CachableTask ): # pylint: disable=protected-access # Update hashes of outputs inside of the result container # since the task produced new outputs result_container._hash = result_container.calc_hash(task) task.task_result = result_container results = _collect_results(task) task._cache(results) # pylint: disable=protected-access task.is_invalid = False _cache_log(list_handler, task) os.chdir("..") return result_container.get_return_value() def _reproduce_cached_logs(logs, task: Task): """add cached logs to current task log""" if logs is not None: logger: Logger = task._mosaic_logger # pylint: disable=protected-access for record in logs: record.name = logger.name time_created = time.time() formatted_cached_timestamp = Mosaic.log_formatter.formatTime( record) record.created = time_created record.msecs = (time_created - int(time_created)) * 1000 record.msg = f"[CACHED {formatted_cached_timestamp}] {record.msg}" logger.callHandlers(record)