"""Container class for results of tasks.
A Task can return no, one or more outputs. They are stored inside of an instance
of this container class.
"""
from __future__ import annotations
from typing import Any, Union
from mosaic_orchestrator.task import ATask, Output
from mosaic_orchestrator.work_products import WorkProduct
[docs]class ResultContainer(WorkProduct):
def __init__(self):
self._hash = None
self._run_returned_value = None
[docs] def update_container(self, task: ATask, task_outputs: dict, run_return: Any):
"""Update this result container with the outputs of the task
Args:
task (ATask): The corresponding task
task_outputs (dict): Dictionary of outputs
run_return (Any): The return value of the run method
"""
self._run_returned_value = run_return
for name, output_type in task_outputs.items():
setattr(self, name, output_type) # copy over type
if name == "__return__":
self.__return__.value = Output(
type(self._run_returned_value), self._run_returned_value)
else:
getattr(self, name).value = getattr(task, name).value
# set it so that type is also checked
[docs] def get_return_value(self) -> Union[ResultContainer, Any]:
"""Get whatever the run() method has returned.
If it returned nothing, then return this result container instead.
Returns:
Union[ResultContainer, Any]: Returned value of run()
"""
if self._run_returned_value is not None:
return self._run_returned_value
return self
@property
def outputs(self):
"""Get all outputs in this result container
Yields:
tuple: name, output
"""
for name, output in self.__dict__.items():
if isinstance(output, Output):
yield name, output
if self._run_returned_value is not None:
yield "__run_return__", Output(type(self._run_returned_value), self._run_returned_value)
[docs] def apply_to_obj(self, obj: ATask):
"""Set the same attributes on a Task to restore a Task from the cache
Args:
obj (ATask): Task object
"""
for name, output in self.outputs:
setattr(obj, name, output)
[docs] def calc_hash(self, task: ATask) -> str:
"""Get a new hash of this work product.
Needed to check if the inputs have changed"""
outputs = list(self.outputs)
_hash = ""
for _, output in self.outputs:
if isinstance(output.value, WorkProduct):
_hash += output.value.calc_hash(task)
return _hash