Source code for mosaic_orchestrator.task

"""this module contains different types of tasks and some helper functions"""
from __future__ import annotations

import copy
from enum import Enum
import os.path
import pickle
import os
import shutil
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, is_dataclass, fields
from logging import Logger
from typing import (
    Dict,
    List,
    Tuple,
    Optional,
    Callable,
    Type,
    TypeVar,
    Generic,
    Protocol,
    Union,
    runtime_checkable,
    Any,
    get_origin,
)

from mosaic_orchestrator import tree
from mosaic_orchestrator.protocols import Hashable
from mosaic_orchestrator.work_products import WorkProduct
from mosaic_orchestrator.pdk import PDK, PdkItem
from mosaic_orchestrator.result import (
    Validation,
    ValidationError,
    ValidationSuccess,
    ErrorDefinition,
    ErrorType,
    Result,
)
from mosaic_orchestrator.tree import Path
from mosaic_orchestrator.utils import get_all_inputs, get_type_annotations


[docs]class InputException(Exception): """InputException is raised when a task input is not handled right, e.g. unexpected inputs in run()"""
TaskStateType = TypeVar("TaskStateType")
[docs]@dataclass class TaskState(Generic[TaskStateType]): """Generic class that encapsulates the state of a task. This class can be used similarly to a Input. It is used for caching and is intended to be used only from within a task. Hence, it can neither be modified from a parent task nor from the top level. """ value: Optional[TaskStateType] = None """the value that is encapsulated in this instance."""
[docs]class CalculatedAccessException(Exception): """raised if a calculated input"""
[docs]class Calculated: """Used to mark a input of a subtask as calculated by the current task. This means that the input is no longer settable from a higher hierarchy. Calculated inputs need to be passed to the run method of the subtask. Example: This example shows how to set Calculated inputs for a child task:: class Addition(Task): summand1 = Input() summand2 = Input() def run(self, *args, **kwargs): return self.summand1.value + self.summand2.value class MyTask(Task): t1 = Addition(inputs={"summand1": Calculated, "summand2": Calculated}) def run(self, *args, **kwargs): t1_result = self.t1.run(inputs={"summand1":2, "summand2":3}) # will be 5 ... """ pass
[docs]class Const: """Used to mark a input of a subtask as constant, i.e. not settable from a higher hierarchy. Example: When a value is passed to the constructor it is used for the target input. It is also possible to use the `Const` type as a marker to indicate that the default value of the target input shall be used als a constant:: class Addition(Task): summand1 = Input() summand2 = Input(default=0) def run(self, *args, **kwargs): return self.summand1.value + self.summand2.value class MyTask(Task): t1 = Addition(inputs={ "summand1": Const(2), "summand2": Const # will be fixed at default value 0 }) def run(self, *args, **kwargs): t1_result = self.t1.run() # will be 2 = (2 + 0) ... """ # pylint: disable=too-few-public-methods def __init__(self, value): self.value = value
[docs]def is_calculated(obj) -> bool: """Helper function do check if input is calculated""" return ( isinstance(obj, Calculated) or obj is Calculated or (isinstance(obj, Input) and obj._is_calculated) )
[docs]def is_mutable(obj) -> bool: """Helper function do check if input is mutable. A mutable input is a input that must be set when run() is being called but can be overwritten from above.""" return ( isinstance(obj, Mutable) or obj is Mutable or (isinstance(obj, Input) and obj._is_mutable) )
[docs]def is_const(obj) -> bool: """Helper function do check if input is constant""" return ( isinstance(obj, Const) or obj is Const or (isinstance(obj, Input) and obj._is_const) )
[docs]def is_calculated_or_mutable_or_const(obj) -> bool: """Helper function do check if input is calculated, mutable or constant""" return is_calculated(obj) or is_mutable(obj) or is_const(obj)
[docs]class Mutable: """Used to mark a input of a subtask as calculated by default but still mutable from a higher hierarchy. This means that the default value of this input is calculated at runtime (i.e. task execution). These default values need to be passed to the run method of the subtask. In general the behaviour is the same as `Calculated` inputs with the only difference that `Mutable` inputs can still be overridden from a higher hierarchy.""" @property def value(self): raise CalculatedAccessException
"""Task type"""
[docs]class Task(ABC): """Abstract class that represents a unit of work within the Mosaic framework. The task is the main abstraction of the Mosaic framework. It represents a (single-threaded) unit of work that can be reused in different places within an analog generator, much like a simple function in Python. To define a task this class must be subclassed and the run() method must be implemented. All work a task does is expected to take place within this method. Tasks are intended to be used hierarchically, i.e. a parent task can consist of multiple child tasks. This is done by declaring a (child) task as a class variable. The framework will automatically create instance variables of any child task. Therefore, child tasks should be accessed the same as other properties using the self reference. Inputs of child tasks can be set using the constructor. Example: This simple example shows how to define and relate a parent task to a child task:: class Addition(Task): summand1 = Input() summand2 = Input() def run(self, *args, **kwargs): return self.summand1.value + self.summand2.value class MyTask(Task): t1 = Addition(inputs={"summand1": 2, "summand2": 3}) t2 = Addition(inputs={"summand1": 1, "summand2": 1}) def run(self, *args, **kwargs): t1_result = self.t1.run() # will be 5 t2_result = self.t2.run() # will be 2 ... """ def __init__(self, inputs: Dict[str, object] = None): """Constructor method Args: inputs (:obj:`dict`, optional): a dictionary of inputs that are set from a parent task. """ self._all_inputs = inputs or {} self._special_inputs = self._get_special_inputs() self._normal_inputs = { name: value for name, value in self._all_inputs.items() if name not in self._special_inputs } self._runtime_evaluation_functions = {} self._path = None # is calculated during task tree creation self._mosaic_logger = None # injected self._mosaic = None # injected self._injected_mutables = [] self._relevant_inputs = [] self.task_result = None self.run_id = None # unique id that is generated before every run() self._create_members_from_static_tasks_and_inputs() @property def mosaic(self): """Get the mosaic instance Returns: TaskHierarchyBuilder: mosaic instance """ return self._mosaic def _get_special_inputs(self): """Return the special inputs that have been set from above""" return dict( filter( lambda input: is_calculated_or_mutable_or_const(input[1]), self._all_inputs.items(), ) ) def _create_members_from_static_tasks_and_inputs(self): """creates member variable of static class variables""" outputs = { name: getattr(self, name) for name in dir(self) if isinstance(getattr(self, name), Output) } for name, value in outputs.items(): setattr(self, name, copy.deepcopy(value)) for name, value in get_all_inputs(self).items(): if isinstance(value, (PdkItem, Task, Input, TaskState)): setattr(self, name, copy.deepcopy(value))
[docs] @abstractmethod def run(self, *args, **kwargs) -> Union[Any, WorkProduct]: """This method needs to be implemented when a new task is defined. The current working directory of this task depends on its position in the task hierarchy, i.e. its path. The directory is created automatically as soon as this method is called. Custom runtime arguments are allowed. """ raise NotImplementedError("run method not implemented for this task")
@property def path(self) -> Path: """:obj: `Path`: the unique path of this specific task instance starting from the root task.""" return self._path @property def log(self) -> Logger: """:obj: `Logger`: a task specific logger instance. Logs are saved in the run directory of each task as well as printed to stdout.""" return self._mosaic_logger @property def cwd(self) -> str: """path to the root of the current working directory (cwd)""" if self._mosaic is not None: return self._mosaic.run_directory else: return os.getcwd() def _on_end(self, exception: Exception = None): """called after run method has finished, custom cleanup can be done here"""
[docs] def copy(self, new_id: str) -> ATask: """Allows to copy a task dynamically during task execution. The created instance is added as a child of this task using the passed id. All inputs are copied from their current state. Args: new_id: the id of the new task. Must be unique within this task. Returns: The new task instance. """ return self._mosaic._copy(new_id, self) # pylint: disable=protected-access
[docs] def delete_run_dirs(self): """Deletes the run directory of this task.""" print(self.cwd) taskdir = os.path.join( os.getcwd(), *self._mosaic.run_directory.parts, *self.path.steps ) print(taskdir) shutil.rmtree(taskdir, ignore_errors=True)
[docs] def validate_pdk(self, pdk: PDK) -> Validation: """Can be overridden to write custom validation logic for the PDK. This method is called during the checking-phase, i.e. before the task tree is executed. Its intended use is to search the PDK for a specific functionality instead of a specific view. The default implementation returns `ValidationSuccess`. Args: pdk: The PDK object to validate. Returns: A Validation object indicating the result of the validation. `ValidationSuccess` in the default implementation. """ return ValidationSuccess
[docs] def get_fields_of_type(self, field_type: Type) -> List: """Return all fields of this task that are of type `field_type` or a subclass thereof.""" tasks = [] for name in dir(self): obj = getattr(self, name) if issubclass(type(obj), field_type): tasks.append(obj) return tasks
[docs] def dynamic_inputs(self) -> Dict[str, Input]: """Can be overridden to add dynamic Input e.g. from tools to the Task. Pdks and Tools are already available, other static defined Inputs not. Returns: A dict containing input names and Inputs to be added to the Task before the run() is called """
[docs] def dynamic_tasks(self) -> Dict[str, Task]: """Can be overridden to add dynamic Task e.g. from tools to the Task. Pdks and Tools are already available, other static defined Inputs not. Returns: A dict containing input names and Task to be added to the Task before the run() is called """
[docs] def get_outputs(self) -> Dict[str, Output]: """Get the outputs of the run() method. For this to work, the type annotation has to be set correctly. """ outputs = { name: getattr(self, name) for name in dir(self) if isinstance(getattr(self, name), Output) } result_type = get_type_annotations(self.run).get("return", None) if result_type is None: return outputs return {**outputs, "__return__": Output(result_type)}
ATask = TypeVar("ATask", bound=Task) TaskInputType = TypeVar("TaskInputType") """Type definition of Input""" TaskOutputType = TypeVar("TaskOutputType") """Type definition of Output"""
[docs]class Output(Generic[TaskOutputType]): def __init__( self, type: Optional[Type[TaskOutputType]] = Any, default: Optional[TaskOutputType] = None, validate: Optional[Callable[[TaskOutputType], Validation]] = None, description: Optional[str] = None, prompt: Optional[str] = None, transient: Optional[bool] = False, ) -> None: """Output declaration Args: type (Optional[Type[TaskOutputType]], optional): Type of the output. Defaults to Any. default (Optional[TaskOutputType], optional): Default value. Defaults to None. description (Optional[str], optional): Description of the output. Defaults to None. prompt (Optional[str], optional): Prompt string (used e.g. in the GUI). Defaults to None. transient (Optional[bool], optional): If True, ignore it to check the cache validity. Defaults to False. """ super().__init__() self.transient = transient self.prompt = prompt self.description = description self.default = default self.type = typing.get_origin(type) or type self.value = default def __setattr__(self, name, value): if ( name == "value" and value is not None and self.type != Any and self.type != Union and not isinstance(value, self.type) ): raise TypeError( f"Cannot assign {type(value)} to the output {self} of type {self.type}" ) self.__dict__[name] = value
[docs]class Input(Generic[TaskInputType]): """Represents a input of a task. Input values can be set from a higher hierarchy. This class may be subclassed to implement custom validation logic. All class variables in `Task` objects that are instances of this class are automatically converted to member variables by the framework. It is also possible to just annotate a variable (without assigning an actual value); in that case the input needs to be set from a higher hierarchy, i.e. it has no default value. Some subclasses for basic type checking are already implemented; see `FileInput`, `StringInput`, `ListInput`, ... Attributes: default: default value of this input. Defaults to None which is interpreted as a missing value. validate: an optional validation function that is applied to the input value once all task inputs have been set. description: an optional description of the input. transient: when True this input is ignored during caching. Example: A simple example showing different types of inputs:: class MyTask(CachableTask): # must be set from outside, i.e. has no default value p1: Input # a input with a default value and a description p2 = Input(default="a string", description="this is a string input") # use a lambda to access other inputs or dependencies of a task. The passed task # is a self reference. p3 = StringInput(default=lambda task: str(task.p1.value)) # inputs can provide custom validate logic via the constructor. The validate # function is called after all input values have been set. p4 = IntInput( default=0, validate=lambda task, value: Validation.of(value < 10, "value too large") ) # inputs marked as transient will be ignored when the state of a task is # determined in the caching process. p5 = Input(default=0, transient=True) """ def __init__( self, type: Optional[Type[TaskInputType]] = Any, default: Optional[TaskInputType] = None, validate: Optional[Callable[[ATask, TaskInputType], Validation]] = None, description: Optional[str] = None, prompt: Optional[str] = None, transient: Optional[bool] = False, hidden: Optional[bool] = False, ): self.type = type self._default = default self.value: Optional[TaskInputType] = default self.validate_func = validate self.description = description self.transient = transient self.prompt = prompt self.hidden = hidden # do not show on GUI self._special_input = None # set if e.g. calculated, etc. is used from top if ( self.type != Any and self.default is not None and not callable(self.default) and not isinstance(self.default, self.type) ): raise TypeError(f"Default value {default} does not match {type} in {self}.") def __getattribute__(self, __name: str) -> Any: if ( __name == "value" and self._is_calculated and super().__getattribute__(__name) is None ): # If there is an unset calculated input -> raise an access exception raise CalculatedAccessException() return super().__getattribute__(__name) @property def _is_const(self): return is_const(self._special_input) @property def _is_calculated(self): return is_calculated(self._special_input) @property def _is_mutable(self): return is_mutable(self._special_input) @property def _is_special(self): return self._special_input is not None @property def default(self) -> Optional[TaskInputType]: """default value""" return self._default
[docs] def task_hash(self) -> bytes: """default implementation of task hash""" return pickle.dumps(self.value)
def __eq__(self, other): return isinstance(other, Input) and self.value == other.value def __ne__(self, other): return not self.__eq__(other)
[docs] def validate(self) -> Validation: """overwrite this method to implement custom validation logic.""" if ( self.type != Any and self.value is not None and not isinstance(self.value, self.type) ): return ValidationError( f"Input validation of {self} failed. {self.value} is not of type {self.type}." ) return ValidationSuccess
[docs]@dataclass class TaskNode(tree.Node): """hepler struct to holding a Task and additional inputs for building the task tree""" value: Task inputs: dict injected_inputs: List[Tuple[Path, Optional[Task]]] """ list of paths with tasks of the origin of the default inputs (in the hierarchy). The origin can be None, which means it was injected from toplevel.""" parent: Optional[Task] = None @property def task(self) -> Task: """returning task of the node""" return self.value
[docs]class CheckResult: """Encapsulates the results of a check of the dependencies and validations of a task hierarchy. Attributes: success: when True no errors were found and the task tree is ready to build and run. inputs: a list of all settable inputs and their values. errors: a list of errors. See `ErrorType` for the different types of errors. """ def __init__(self, success: bool = True): self.success = success self.inputs: List[InputDescription] = [] self.errors: List[ErrorDefinition] = []
[docs] def append_if_error(self, error_type: ErrorType, result: Result, source): """appends the given error only if the given result is failure""" if result.failure: self.success = False self.errors.append(ErrorDefinition(error_type, result.error, source))
[docs] def append_error(self, error_type: ErrorType, message, source): """appends the given error""" self.success = False self.errors.append(ErrorDefinition(error_type, message, source))
[docs] def append_input(self, path: str, input: Input, parent_task: Task): """appends the given input""" self.inputs.append(InputDescription(path, input, parent_task))
[docs]@dataclass class InputDescription: """holds a Input and the path of that object in the task tree""" path: str input: Input parent_task: Task
[docs]def get_task_inputs(task: Task) -> Dict[str, Input]: """extracts all task input of a given Task, returning a dict of name and task input""" inputs = _filter_task_inputs(get_all_inputs(task)) annotations = get_type_annotations(type(task)) inputs.update(_get_not_set_inputs(annotations, task)) return inputs
def _filter_task_inputs(inputs): return dict(filter(lambda entry: isinstance(entry[1], Input), inputs.items())) def _get_not_set_inputs(annotations, task): return { name: None for name, ann_type in annotations.items() if _is_task_input(ann_type) and not hasattr(task, name) } def _is_task_input(annotation) -> bool: origin = typing.get_origin(annotation) if origin is not None: annotation = origin return issubclass(annotation, Input) def _get_cache_dir(task): cds: List[str] = [] for _ in range(0, task.path.length): cds.append("..") cache_dir = os.path.join( *cds, task._mosaic.cache_directory ) # pylint: disable=protected-access return cache_dir