"""This module contains the mosaic-orchestrator framework core functionality"""
from __future__ import annotations
import copy
from dataclasses import asdict, dataclass, fields, is_dataclass
import importlib
import logging
import os
import pathlib
import time
import types
import typing
from logging import Logger
from typing import (
Any,
Type,
Union,
Dict,
List,
Tuple,
Optional,
)
import uuid
from mosaic_orchestrator.execution_tracker import ExecutionTracker
import pluggy
from mosaic_orchestrator import tree, hookspecs
from mosaic_orchestrator.asset import Asset
from mosaic_orchestrator.cachable_task import CachableTask
from mosaic_orchestrator.config import Config
from mosaic_orchestrator.mosaic_utils import (
_set_input_or_wrap,
find_node,
_validate_input,
_check_if_injectable,
_check_input,
_evaluate_callable_inputs,
_add_prefix_to_all_inputs,
_wrap_root,
MosaicException,
_inject_special_inputs,
_evaluate_runtime_lambdas,
_get_cached_result_and_logs,
_inject_log_to_task_tool,
_collect_results,
_cache_log
)
from mosaic_orchestrator.pdk import PDK, PdkItem
from mosaic_orchestrator.result import (
ValidationError,
ErrorType,
ErrorDefinition,
Validation,
)
from mosaic_orchestrator.task import (
ATask,
TaskNode,
Task,
Calculated,
Input,
Output,
is_calculated,
CheckResult,
is_const,
Const,
get_task_inputs,
)
from mosaic_orchestrator.tool import Tool
from mosaic_orchestrator.tree import is_path_string, Path, concat, prepend
from mosaic_orchestrator.utils import (
get_type_annotations,
get_all_inputs,
ListHandler,
get_type_of_class_field,
)
from mosaic_orchestrator.task_hierarchy_builder import TaskHierarchyBuilder, T
from mosaic_orchestrator.tool_store import ToolStore
from mosaic_orchestrator.protocols import (
Validate,
LifeCycleListener,
Hashable,
)
from mosaic_orchestrator.result_container import ResultContainer
from mosaic_orchestrator.work_products import WorkProduct
[docs]class Mosaic(TaskHierarchyBuilder[T]):
"""This is the main entry point of the framework. Use `create` to start building an executable
task tree."""
log_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
"""the formatter used for logging. It is used for stdout as well as log files. Can be set globally"""
log_level = logging.INFO
"""the log level of all loggers. Can be set globally"""
default_cache_directory = ".mosaic_orchestrator"
default_work_directory = "mosaic_orchestrator.work"
[docs] @staticmethod
def create(
root: Type[T], caching: bool = True, plugin_enabled: bool = True
) -> TaskHierarchyBuilder[T]:
"""creates a builder for a task hierarchy given a root task.
Args:
root: the type of the desired root task. Notice that actual instantiation is done by the
framework.
caching: when False all caching is disabled for hierarchies build with this builder.
plugin_enabled: Enables die plugin discovery for PDKs and tools.
Returns:
a `TaskHierarchyBuilder` for the given root type.
Example:
This is a simple example that shows how to build a simple mosaic_orchestrator task tree::
builder = Mosaic.create(root=MyTask)\
.with_pdk(MyPdkImpl())\
.register(AtoolInterface, MyToolImpl())\
.register(AnotherDependency, AnotherDependency())\
.with_inputs({
"x1": 5,
"x2": 3,
"t1.p5": "REF",
})
if builder.check().success:
instance = builder.build()
instance.run()
"""
return Mosaic[T](root, caching, plugin_enabled)
[docs] @staticmethod
def get_installed_pdks() -> List[PDK]:
"""lists all installed PDKs"""
plugin_manager = pluggy.PluginManager("mosaic_orchestrator")
plugin_manager.add_hookspecs(hookspecs)
plugin_manager.load_setuptools_entrypoints("mosaic_orchestrator")
return plugin_manager.hook.get_pdk()
def __init__(self, root: Type[T], caching, plugin_enabled):
self._configure_console_logging()
self.root_type = root
self.pdk = None
self.execution_trackers = []
self.caching = caching
self.plugin_enabled = plugin_enabled
self._store = {}
self.tools = ToolStore()
self._callable_inputs: List[Tuple[Input, Task]] = []
self._check_result = CheckResult()
self._task_tree = None
self.root = None
self._loggers = {}
self.run_directory = pathlib.Path.joinpath(
pathlib.Path.cwd(), Mosaic.default_work_directory
)
self.cache_directory = pathlib.Path.joinpath(
pathlib.Path.cwd(), Mosaic.default_cache_directory
)
self._life_cycle_listener: List[LifeCycleListener] = []
self._validate_list: List[Validate] = []
self.plugin_manager = pluggy.PluginManager("mosaic_orchestrator")
if plugin_enabled:
self.plugin_manager.add_hookspecs(hookspecs)
self.plugin_manager.load_setuptools_entrypoints(
"mosaic_orchestrator")
[docs] def with_cfg(self, config: Union[Config, str]) -> TaskHierarchyBuilder[T]:
if isinstance(config, str):
mod_name, file_ext = os.path.splitext(os.path.split(config)[-1])
if file_ext.lower() == ".py":
loader = importlib.machinery.SourceFileLoader(mod_name, config)
mod = types.ModuleType(loader.name)
loader.exec_module(mod)
else:
raise MosaicException(
ErrorDefinition(
ErrorType.GENERAL_ERROR,
f"Configuration file '{config}' cannot " f"be loaded.",
config,
)
)
try:
config = mod.MosaicConfig()
except AttributeError as error:
raise MosaicException(
ErrorDefinition(
ErrorType.GENERAL_ERROR,
f"Configuration file '{config}' has "
f"not a Config class named "
f"'MosaicConfig'.",
config,
)
) from error
if config.get_pdk():
self.with_pdk(config.get_pdk())
self.with_inputs(config.get_inputs())
if config.get_run_directory():
self.with_run_directory(config.get_run_directory())
if config.get_cache_directory():
self.with_cache_directory(config.get_cache_directory())
for tool in config.get_tools():
self.register(key=type(tool), value=tool, overwrite=True)
for key, value in config.get_objects_to_register():
self.register(key, value, overwrite=True)
return self
[docs] def with_pdk(self, pdk: PDK) -> TaskHierarchyBuilder[T]:
self.pdk = pdk
self._build_protocol_lists(pdk)
for tracker in self.execution_trackers:
tracker.set_entry("pdk", pdk.name)
return self
[docs] def with_run_directory(self, path: pathlib.Path) -> TaskHierarchyBuilder[T]:
self.run_directory = path
return self
[docs] def with_cache_directory(self, path: pathlib.Path) -> TaskHierarchyBuilder[T]:
self.cache_directory = path
return self
[docs] def register(
self, key: Union[Type, str], value: any, overwrite: bool = False
) -> TaskHierarchyBuilder[T]:
for tracker in self.execution_trackers:
tracker.append_to_entry("installed_tools", str(value))
# Inject PDK in registered "things"
if not isinstance(value, Input):
for name, input_type in get_type_annotations(type(value)).items():
if input_type and issubclass(input_type, PDK):
setattr(value, name, self.pdk)
if isinstance(value, Tool):
self.tools.put(key, value, overwrite)
else:
if key not in self._store or overwrite:
self._store[key] = value
else:
raise MosaicException(
ErrorDefinition(
ErrorType.GENERAL_ERROR,
f"{key} has already been registered! Maybe you have installed multiple variants in your "
f"python environment (e.g. {type(value)} and {type(self._store[key])}). "
"Please uninstall one or more to reduce it to a single tool installation.",
value,
)
)
return self
[docs] def build(self) -> T:
"""Build an instance of a task from the builder object.
>>> instance = Mosaic.create(root=ExampleTask).build()
>>> instance.run()
0.0
Raises:
MosaicException: If an errors occures during check()
Returns:
T: An instance of the task
"""
result = self.check()
if result.success is False:
raise MosaicException(result.errors[0])
return self.root
[docs] def check(self) -> CheckResult:
"""Perform checks without running a task. Performed checks are:
- Missing parameters
- Validation of input parameters
- Checking of tools availability
- Checking of PDK completness (are all PDKItems available?)
Returns:
CheckResult: Container of check results
"""
self._check_result = CheckResult()
if self.plugin_enabled:
if not self.execution_trackers:
self.__load_exection_trackers_from_plugins()
if not self.pdk:
self._load_pdk_from_plugins()
self._load_tools_from_plugins()
self._inject_by_type(self.pdk)
self.root = self.root_type()
self._task_tree = self._build_task_tree(
task=self.root, path=Path(self.root_type.__name__)
)
self._check_store_inputs()
tree.traverse(tree=self._task_tree, on_each=self._inject)
tree.traverse(tree=self._task_tree, on_each=self._set_defaults)
for tool in self.tools.used_tools():
self._inject_pdk_items(tool)
tree.traverse(tree=self._task_tree,
on_each=self._evaluate_and_validate)
for obj in self._validate_list:
result = obj.validate()
if isinstance(result, ValidationError):
self._check_result.append_error(
ErrorType.VALIDATION_ERROR,
f"'{type(obj).__name__}' could not be validated - {result.message}",
type(obj),
)
_wrap_root(self.root)
return self._check_result
def _load_pdk_from_plugins(self):
"""try to load a pdk from installed plugin, adds error if more than one PDK is found"""
pdks = self.plugin_manager.hook.get_pdk()
if len(pdks) == 1:
self.with_pdk(pdks[0])
elif len(pdks) > 1:
self._check_result.append_error(
ErrorType.PDK_ERROR,
f"More then one installed PDKs found in virtual environment: {pdks}",
self,
)
[docs] def register_execution_tracker(self, tracker : ExecutionTracker) -> TaskHierarchyBuilder[T]:
"""Register an execution tracker
Args:
tracker (ExecutionTracker): Execution tracker to add
"""
self.execution_trackers.append(tracker)
self._build_protocol_lists(tracker)
return self
def __load_exection_trackers_from_plugins(self):
"""try to load the execution tracker from the installed plugins"""
execution_trackers : list[ExecutionTracker] = self.plugin_manager.hook.get_execution_tracker()
for tracker in execution_trackers:
self.register_execution_tracker(tracker)
def _load_tools_from_plugins(self):
"""try to load tools from installed plugins"""
tools = self.plugin_manager.hook.get_tool()
for tool in tools:
self.register(tool[0], tool[1], overwrite=False)
def _build_task_tree(
self, task: Task, path: Path, parent: Optional[Task] = None
) -> TaskNode:
"""build the task tree of a given task recursively"""
task._path = path # pylint: disable=protected-access
self._set_logger(task)
self._inject_pdk_items(task)
self._inject_by_type(task)
task_inputs = get_task_inputs(task)
dynamic_inputs = task.dynamic_inputs()
if dynamic_inputs:
for input_name in dynamic_inputs:
setattr(task, input_name, dynamic_inputs[input_name])
task_inputs.update(dynamic_inputs)
root = TaskNode(
parent=parent,
value=task,
children=[],
inputs=task_inputs,
injected_inputs=[],
)
dynamic_tasks = task.dynamic_tasks()
if dynamic_tasks:
for subtask_name, subtask in dynamic_tasks.items():
setattr(task, subtask_name, subtask)
sub_tree = self._build_task_tree(
task=subtask, path=Path([*path.steps, subtask_name]), parent=task
)
root.children.append(sub_tree)
_add_prefix_to_all_inputs(sub_tree, subtask_name)
for subtask_attr_name in type(task).__dict__.keys():
subtask_attr = getattr(task, subtask_attr_name)
if issubclass(type(subtask_attr), Task):
sub_tree = self._build_task_tree(
task=subtask_attr,
path=Path([*path.steps, subtask_attr_name]),
parent=task,
)
root.children.append(sub_tree)
_add_prefix_to_all_inputs(sub_tree, subtask_attr_name)
return root
def _inject(self, node: TaskNode):
"""wrap run method, handle calculated inputs and inject inputs into task"""
node.task._mosaic = self # pylint: disable=protected-access
self._wrap_run(node.task)
self._handle_special_inputs(node)
self._inject_inputs(node)
def _set_logger(self, task: Task):
"""crate a new or set existing logger given task"""
task._mosaic_logger = self._get_or_create_logger(task)
# pylint: disable=protected-access
def _inject_pdk_items(self, instance):
"""inject pdk items to a given instance"""
source = type(instance)
if hasattr(instance, "path"):
source = instance.path
for name, obj in get_all_inputs(instance).items():
# Allow assets to use pdk items
if issubclass(type(obj), Asset):
self._inject_pdk_items(obj)
if issubclass(type(obj), PdkItem):
if not self._check_pdk(source):
continue
result = self.pdk.get_item(obj.name, obj.version)
if result.failure:
supported_items = self.pdk.supported_items() or []
supported_items.sort(key=lambda item: item.name.lower())
supported_items_string = ' '.join(['('+item.name+' '+item.version+')' for item in supported_items])
self._check_result.append_error(
ErrorType.MISSING_PDK_ITEM,
f"PdkItem not available in provided PDK ('{obj.name}',"
f" {obj.version}). Supported items: [{supported_items_string}]",
source,
)
continue
item = result.value
# making item conform to Hashable and dependant on pdk
item.task_hash = self.pdk.task_hash
# Check if they are not "related"
if not issubclass(type(item), type(obj)):
self._check_result.append_error(
ErrorType.PDK_ERROR,
f"Expected pdk item of type '{type(obj).__name__}' not compatible"
f" with view from PDK '{type(item).__name__}'",
type(instance),
)
if not item._try_init(): # pylint: disable=protected-access
self._check_result.append_error(
ErrorType.PDK_ERROR,
f"Could not validate pdk item '{type(item).__name__}'",
type(instance),
)
setattr(instance, name, result.value)
def _check_pdk(self, source) -> bool:
"""check if pdk is available"""
if self.pdk is None:
self._check_result.append_error(
ErrorType.PDK_ERROR,
"PDK implementation is missing, "
"please provide a valid implementation",
source,
)
return False
return True
def _inject_by_type(self, obj):
"""inject Logger, Tools, PDK, registered objects and input to given objects"""
annotations = get_type_annotations(type(obj))
for input_name, annotation in annotations.items():
origin = typing.get_origin(annotation)
if origin is not None:
annotation = origin
if annotation is Logger:
if issubclass(type(obj), Task) and not hasattr(obj, input_name):
setattr(
obj, input_name, obj._mosaic_logger
) # pylint: disable=protected-access
elif issubclass(annotation, Tool):
tool_found = []
for (
tool
) in self.tools._store.values(): # pylint: disable=protected-access
if issubclass(type(tool.tool), annotation):
tool_found.append(tool.tool)
if annotation not in self.tools._store:
self._check_result.append_error(
ErrorType.MISSING_TOOL,
f"Tool is missing for {annotation}",
type(obj),
)
continue
tool = self.tools._store[annotation].tool
setattr(obj, input_name, tool)
self._inject_by_type(tool)
self.tools.set_used(annotation)
elif annotation is Output:
# If an output is just annotated -> create an instance
setattr(obj, input_name, annotation())
continue
elif annotation is PDK:
if (
hasattr(obj, input_name)
and getattr(obj, input_name) is not self.pdk
):
self._check_result.append_error(
ErrorType.PDK_ERROR,
f"Instantiated PDK found in {type(obj)}"
" - this is not the expected usage, please use "
"builder function instead",
type(obj),
)
continue
if self.pdk is None:
self._check_result.append_error(
ErrorType.PDK_ERROR,
"PDK implementation is missing, please provide a valid "
"implementation",
type(obj),
)
continue
setattr(obj, input_name, self.pdk)
elif annotation in self._store:
input = self._store[annotation]
setattr(obj, input_name, input)
self._inject_by_type(input)
elif (
not hasattr(obj, input_name)
and not issubclass(annotation, Task)
and not issubclass(annotation, Input)
):
self._check_result.append_error(
ErrorType.INPUT_ERROR,
f"{annotation} implementation is missing, "
f"please register a valid implementation",
type(obj),
)
def _build_input_list(self, node: TaskNode):
"""build input list of a task and append it to the check result"""
for path_str in node.inputs.keys():
path = prepend(node.task.path.root, Path(path_str))
if not hasattr(node.task, path.target):
self._check_result.append_error(
ErrorType.MISSING_INPUT, f"Input missing: {path}", path
)
continue
task_input = getattr(node.task, path.target)
if (
not is_calculated(task_input)
and task_input.value is None
and not is_const(task_input)
):
self._check_result.append_error(
ErrorType.MISSING_INPUT, f"Input missing: {path}", path
)
if not is_const(task_input) and not is_calculated(task_input):
self._check_result.append_input(
str(path), task_input, node.task)
def _get_or_create_logger(self, task: Task) -> Logger:
"""create a new logger or create the existing one"""
if task.path in self._loggers:
return self._loggers[task.path]
logger = logging.getLogger(name=str(task.path))
logger.setLevel(Mosaic.log_level)
filehandler = logging.FileHandler(
filename=os.path.join(
*self.run_directory.parts, *[*task.path.steps, "log"]
),
delay=True,
)
filehandler.setFormatter(Mosaic.log_formatter)
logger.addHandler(filehandler)
self._loggers[task.path] = logger
return logger
def _check_store_inputs(self):
"""check if inputs given to mosaic are valid and given path is correct"""
for name, _ in self._store.items():
if isinstance(name, str):
if is_path_string(name):
path = Path(name)
node_result = find_node(
task_tree=self._task_tree, path=path)
if node_result.failure:
self._check_result.append_if_error(
ErrorType.INPUT_ERROR, node_result, str(path)
)
continue
self._check_result.append_if_error(
ErrorType.INPUT_ERROR,
_check_input(path.target, node_result.value.task),
str(path),
)
else:
self._check_result.append_if_error(
ErrorType.INPUT_ERROR,
_check_input(name, self.root),
str(self.root.path),
)
def _inject_inputs(self, node: TaskNode):
"""inject inputs form store in given task"""
node.injected_inputs = []
for path_string in node.inputs.keys():
if path_string in self._store:
self._inject_input_from_store(node, path_string)
def _inject_input_from_store(self, node: TaskNode, path_string):
"""inject a named inputs form store"""
path = Path(path_string)
if hasattr(node.task, path.target):
result = _check_if_injectable(node.task, path)
if result.failure:
self._check_result.append_error(
ErrorType.INPUT_ERROR, result.error, str(path)
)
return
result = _set_input_or_wrap(
instance=node.task, name=path.target, input=self._store[path_string]
)
if result.failure:
self._check_result.append_error(
ErrorType.INPUT_ERROR, result.error, str(node.task.path)
)
return
node.injected_inputs.append((path, None))
self._add_to_relevant_inputs_of_ascendants(path, result)
def _add_to_relevant_inputs_of_ascendants(self, path, result):
"""add input to task tree at given path"""
current_path = path.parent()
while current_path.steps:
current_task = (
find_node(self._task_tree, current_path, check_input=False)
.raise_on_fail()
.task
)
current_task._relevant_inputs.append(
result.value
) # pylint: disable=protected-access
current_path = current_path.parent()
def _evaluate_and_validate(self, node: TaskNode):
"""evaluate and validate PDK, callable inputs, inputs, build input list and build protocol list"""
self._validate_pdk(node)
_evaluate_callable_inputs(node)
self._validate_inputs(node)
self._build_input_list(node)
self._build_protocol_lists(node.task)
def _validate_pdk(self, node: TaskNode):
"""calls validate_pdk of given task and handle error"""
validation = node.task.validate_pdk(self.pdk)
if isinstance(validation, ValidationError):
self._check_result.append_error(
ErrorType.PDK_ERROR,
f"Task {node.task.path} could not validate PDK - {validation.message}",
node.task.path,
)
def _validate_inputs(self, node: TaskNode):
"""calls validate functions of all Inputs. Assumes that all inputs have been
evaluated"""
for name, input in get_task_inputs(node.task).items():
if input:
# No validation possible if input is calculated
# and nothing has been set as value yet
if is_calculated(input):
continue
validation = _validate_input(node.task, input)
if isinstance(validation, ValidationError):
path = concat(node.task.path, name)
self._check_result.append_error(
ErrorType.INPUT_ERROR,
f"Validation failed for '{path}' with value "
f"'{input.value}' reason: {validation.message}",
str(path),
)
def _build_protocol_lists(self, obj):
"""add obj to validation_list and/or life_cycle_list if necessary"""
if isinstance(obj, Validate) and not isinstance(obj, Input):
self._validate_list.append(obj)
if isinstance(obj, LifeCycleListener):
self._life_cycle_listener.append(obj)
for _, item in get_all_inputs(obj).items():
if isinstance(item, Validate) and not isinstance(item, Input):
self._validate_list.append(item)
if isinstance(item, LifeCycleListener):
self._life_cycle_listener.append(item)
def _startup(self):
"""calls the on_start method of each registered object implements LifeCycleListener"""
for obj in self._life_cycle_listener:
obj._on_start() # pylint: disable=protected-access
def _cleanup(self, exception=None):
"""calls the on_end method of each registered object implements LifeCycleListener"""
for obj in self._life_cycle_listener:
obj._on_end(exception) # pylint: disable=protected-access
def _handle_special_inputs(self, node: TaskNode):
"""handle special inputs like Calculated, Mutable or Const"""
special_inputs = node.task._special_inputs
for input_name, special_input in special_inputs.items():
# if a input has been set
# del node.task._default_inputs[input_name]
# pylint: disable=protected-access
path = Path(input_name)
result = find_node(node, path)
if result.failure:
self._check_result.append_if_error(
ErrorType.INPUT_ERROR, result, path)
continue
target_node = result.value
if not hasattr(target_node.task, path.target):
# we have a type annotation that was set to some special thing
# so we first have to create it on the instance
# get type from type annotation or object
input_type = get_type_of_class_field(
instance=target_node.task, name=path.target
)
origin = typing.get_origin(input_type) # get base type
if origin is not None:
input_type = origin
new_input = input_type(default=None)
if not is_calculated(special_input):
new_input.value = special_input.value
new_input._special_input = special_input
setattr(target_node.task, path.target, new_input)
# target_node.task._calculated_default_values[path.target] = new_input
continue
input = getattr(target_node.task, path.target)
# target_node.task._calculated_default_values[path.target] = input
if is_calculated(input):
self._check_result.append_error(
ErrorType.INPUT_ERROR,
f"input {target_node.task.path}.{path.target} cannot be altered since it is set to Calculated",
f"{target_node.task.path}.{path.target}",
)
continue
if is_const(input):
self._check_result.append_error(
ErrorType.INPUT_ERROR,
f"{target_node.task.path}.{path.target} is already set to Const",
f"{target_node.task.path}.{path.target}",
)
continue
# The normal case is:
# Const(), Mutable() is set from above -> just extract the value and set it
input._special_input = special_input
# Only for calculated inputs -> do not use the .value property!!
if is_calculated(special_input):
pass
elif hasattr(special_input, "value"):
input.value = special_input.value
elif special_input == Const:
pass
else:
self._check_result.append_error(
ErrorType.INPUT_ERROR,
f"{target_node.task.path}.{path.target} has some unknown type",
f"{target_node.task.path}.{path.target}",
)
continue
def _wrap_run(self, task: Task):
"""wrap run method of given task with _run_wrapper"""
task._old_run = task.run # pylint: disable=protected-access
task.run = types.MethodType(_run_wrapper, task)
def _set_defaults(self, node: TaskNode):
"""set the inputs passed to constructor of task (via inputs)
if they havent been injected"""
for name, default_input in node.task._normal_inputs.items():
# pylint: disable=protected-access
path = Path(name)
result = find_node(task_tree=node, path=path)
if result.failure:
self._check_result.append_if_error(
ErrorType.INPUT_ERROR, result, str(path)
)
continue
current_node = result.value
injected_path_list = list(
filter(
lambda input_path: input_path[0].target == path.target,
# pylint: disable=cell-var-from-loop
current_node.injected_inputs,
)
)
if not injected_path_list or injected_path_list[0][0].length < path.length:
result = _set_input_or_wrap(
instance=current_node.task, name=path.target, input=default_input
)
if result.failure:
self._check_result.append_error(
ErrorType.INPUT_ERROR,
result.error,
str(current_node.task.path),
)
continue
if injected_path_list:
current_node.injected_inputs.remove(injected_path_list[0])
current_node.injected_inputs.append((path, node.parent))
def _copy(self, new_id: str, task: ATask) -> ATask:
"""create a copy of a task"""
new_task = type(task)()
new_task._path = Path(
[*task.path.steps[:-1], new_id]
) # pylint: disable=protected-access
new_task._mosaic = self # pylint: disable=protected-access
self._set_logger(new_task)
self._inject_pdk_items(new_task)
self._inject_by_type(new_task)
self._wrap_run(new_task)
for name, input in get_task_inputs(task).items():
setattr(new_task, name, copy.copy(input))
# new_task._calculated_inputs = copy.copy(task._calculated_inputs)
# new_task._calculated_default_values = copy.copy(task._calculated_default_values)
new_task._injected_mutables = copy.copy(task._injected_mutables)
new_task._runtime_evaluation_functions = copy.copy(
task._runtime_evaluation_functions
)
new_task._normal_inputs = copy.copy(task._normal_inputs)
new_task._special_inputs = copy.copy(task._special_inputs)
new_task._all_inputs = copy.copy(task._all_inputs)
return new_task
def _configure_console_logging(self):
"""setup console logging"""
console_handler = logging.StreamHandler()
console_handler.setFormatter(Mosaic.log_formatter)
logging.getLogger("").addHandler(console_handler)
def _run_wrapper(
task, inputs: Dict[str, object] = None, *args, **kwargs
) -> Union[Any, WorkProduct]:
"""wrapped run method, creates current working directories, inject logging and handles caching"""
# the inputs argument can be filled with a custom positional run argument
# such as : run(my_arg: int)
# when calling such a run method without keyword - run(1) - it would appear here as inputs
is_inputs_real = inputs is None or isinstance(inputs, dict)
if is_inputs_real:
_inject_special_inputs(task, inputs)
_evaluate_runtime_lambdas(task)
working_dir = task.path.target
if not os.path.exists(working_dir):
os.makedirs(working_dir)
os.chdir(working_dir)
task.run_id = str(uuid.uuid4())
for tracker in task.mosaic.execution_trackers:
tracker.start_task(task, is_root_task=task == task.mosaic.root)
if task._mosaic.caching and isinstance(
task, CachableTask
): # pylint: disable=protected-access
task.task_hash = task.calc_hash(inputs, *args, **kwargs)
cached_result, logs = _get_cached_result_and_logs(task)
if cached_result is not None:
_reproduce_cached_logs(logs, task)
os.chdir("..")
cached_result.apply_to_obj(task)
task.task_result = cached_result
task.cache_used = True
# Track execution
for tracker in task.mosaic.execution_trackers:
tracker.end_task(task, is_root_task = (task == task.mosaic.root), cache_used=True)
return cached_result.get_return_value()
task.cache_used = False
_inject_log_to_task_tool(task)
list_handler = ListHandler()
if isinstance(task, CachableTask):
task._mosaic_logger.addHandler(
list_handler) # pylint: disable=protected-access
if is_inputs_real:
result = task._old_run(
*args, **kwargs) # pylint: disable=protected-access
else:
# inputs contains custom positional argument
result = task._old_run(
inputs, *args, **kwargs
) # pylint: disable=protected-access
for tracker in task.mosaic.execution_trackers:
tracker.end_task(task, is_root_task = (task == task.mosaic.root), cache_used=False)
# Wrap in container
# This container is saved in the cache and contains all outputs
task_outputs = task.get_outputs()
result_container = ResultContainer()
result_container.update_container(task, task_outputs, result)
if task._mosaic.caching and isinstance(
task, CachableTask
): # pylint: disable=protected-access
# Update hashes of outputs inside of the result container
# since the task produced new outputs
result_container._hash = result_container.calc_hash(task)
task.task_result = result_container
results = _collect_results(task)
task._cache(results) # pylint: disable=protected-access
task.is_invalid = False
_cache_log(list_handler, task)
os.chdir("..")
return result_container.get_return_value()
def _reproduce_cached_logs(logs, task: Task):
"""add cached logs to current task log"""
if logs is not None:
logger: Logger = task._mosaic_logger # pylint: disable=protected-access
for record in logs:
record.name = logger.name
time_created = time.time()
formatted_cached_timestamp = Mosaic.log_formatter.formatTime(
record)
record.created = time_created
record.msecs = (time_created - int(time_created)) * 1000
record.msg = f"[CACHED {formatted_cached_timestamp}] {record.msg}"
logger.callHandlers(record)