Source code for mosaic_orchestrator.tool_store

"""This module contains the mosaic-orchestrator framework core functionality"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Type, Union, Dict, List
from mosaic_orchestrator.mosaic_utils import MosaicException
from mosaic_orchestrator.tool import Tool


[docs]@dataclass class ToolHolder: """Simple holder class that marks used tools""" base_class: Type tool: Tool used: bool """When True the tool is used by at least on of the tasks."""
[docs]class ToolStore: """Holds all tools of a task hierarchy. Can be used to get a list of used tools""" def __init__(self): self._store: Dict[Type[Tool], ToolHolder] = {}
[docs] def put( self, base_class: Union[Type[Tool], str], tool: Tool, overwrite: bool = False ): """add tool to store""" if overwrite or base_class not in self._store: self._store[base_class] = ToolHolder(base_class, tool, used=False)
[docs] def set_used(self, tool_type: Type[Tool]): """set a tool to used""" self._store[tool_type].used = True
[docs] def used_tools(self) -> List[Tool]: """filter for used tools only""" return [holder.tool for holder in self._store.values() if holder.used]
[docs] def get_tool(self, tool_type: Type[Tool]) -> Tool: if tool_type not in self._store: raise MosaicException( f"Tool {tool_type} has not been registered or installed." ) tool_holder = self._store[tool_type] self.set_used(tool_type) return tool_holder.tool