Source code for mosaic_orchestrator.result_container

"""Container class for results of tasks.
A Task can return no, one or more outputs. They are stored inside of an instance
of this container class.
"""

from __future__ import annotations
from typing import Any, Union
from mosaic_orchestrator.task import ATask, Output
from mosaic_orchestrator.work_products import WorkProduct


[docs]class ResultContainer(WorkProduct): def __init__(self): self._hash = None self._run_returned_value = None
[docs] def update_container(self, task: ATask, task_outputs: dict, run_return: Any): """Update this result container with the outputs of the task Args: task (ATask): The corresponding task task_outputs (dict): Dictionary of outputs run_return (Any): The return value of the run method """ self._run_returned_value = run_return for name, output_type in task_outputs.items(): setattr(self, name, output_type) # copy over type if name == "__return__": self.__return__.value = Output( type(self._run_returned_value), self._run_returned_value) else: getattr(self, name).value = getattr(task, name).value
# set it so that type is also checked
[docs] def get_return_value(self) -> Union[ResultContainer, Any]: """Get whatever the run() method has returned. If it returned nothing, then return this result container instead. Returns: Union[ResultContainer, Any]: Returned value of run() """ if self._run_returned_value is not None: return self._run_returned_value return self
@property def outputs(self): """Get all outputs in this result container Yields: tuple: name, output """ for name, output in self.__dict__.items(): if isinstance(output, Output): yield name, output if self._run_returned_value is not None: yield "__run_return__", Output(type(self._run_returned_value), self._run_returned_value)
[docs] def apply_to_obj(self, obj: ATask): """Set the same attributes on a Task to restore a Task from the cache Args: obj (ATask): Task object """ for name, output in self.outputs: setattr(obj, name, output)
[docs] def calc_hash(self, task: ATask) -> str: """Get a new hash of this work product. Needed to check if the inputs have changed""" outputs = list(self.outputs) _hash = "" for _, output in self.outputs: if isinstance(output.value, WorkProduct): _hash += output.value.calc_hash(task) return _hash