Source code for ampel.core.UnitLoader

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File:                Ampel-core/ampel/core/UnitLoader.py
# License:             BSD-3-Clause
# Author:              valery brinnel <firstname.lastname@gmail.com>
# Date:                07.10.2019
# Last Modified Date:  04.04.2023
# Last Modified By:    valery brinnel <firstname.lastname@gmail.com>

import os, sys
from importlib import import_module
from pathlib import Path
from hashlib import blake2b
from contextlib import contextmanager
from types import UnionType
from typing import Any, TypeVar, Union, overload, get_args, get_origin
from collections.abc import Iterator, Mapping
from copy import deepcopy

from ampel.types import ChannelId, check_class
from ampel.util.collections import ampel_iter
from ampel.util.freeze import recursive_unfreeze
from ampel.util.mappings import merge_dicts
from ampel.view.ReadOnlyDict import ReadOnlyDict
from ampel.base.AmpelUnit import AmpelUnit
from ampel.base.AuxUnitRegister import AuxUnitRegister
from ampel.base.LogicalUnit import LogicalUnit
from ampel.core.AmpelContext import AmpelContext
from ampel.core.ContextUnit import ContextUnit
from ampel.core.AmpelDB import AmpelDB
from ampel.model.UnitModel import UnitModel
from ampel.secret.Secret import Secret
from ampel.secret.AmpelVault import AmpelVault
from ampel.model.t3.AliasableModel import AliasableModel
from ampel.config.AmpelConfig import AmpelConfig
from ampel.log.AmpelLogger import AmpelLogger, LogFlag, VERBOSE
from ampel.log.handlers.ChanRecordBufHandler import ChanRecordBufHandler
from ampel.log.handlers.DefaultRecordBufferingHandler import DefaultRecordBufferingHandler
from ampel.util.hash import build_unsafe_dict_id


T = TypeVar('T', bound=AmpelUnit)
LT = TypeVar('LT', bound=LogicalUnit)
CT = TypeVar('CT', bound=ContextUnit)
pyv = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
env = ('conda_' + os.environ["CONDA_DEFAULT_ENV"]) if 'CONDA_DEFAULT_ENV' in os.environ else 'default'

[docs] class UnitLoader: def __init__(self, config: AmpelConfig, db: None | AmpelDB, provenance: bool = True, vault: None | AmpelVault = None ) -> None: """ :raises: ValueError in case bad arguments are provided """ if not isinstance(config, AmpelConfig): raise ValueError( f"First parameter must be an instance of " f"AmpelConfig (provided: {type(config)})" ) if provenance and not db: raise ValueError("Provenance tracking requires a database connection") self.db = db self.vault = vault self.config = config self.provenance = provenance self.unit_defs: dict = config._config['unit'] self.aliases: list[dict] = [config._config['alias'][f"t{el}"] for el in (0, 3, 1, 2)] self._dyn_register: None | dict[str, type[LogicalUnit] | type[ContextUnit]] = None # potentially updated by DevAmpelContext @overload def new_logical_unit(self, model: UnitModel, logger: AmpelLogger, *, sub_type: type[LT], **kwargs ) -> LT: ... @overload def new_logical_unit(self, model: UnitModel, logger: AmpelLogger, *, sub_type: None = ..., **kwargs ) -> LogicalUnit: ...
[docs] def new_logical_unit(self, model: UnitModel, logger: AmpelLogger, *, sub_type: None | type[LT] = None, **kwargs ) -> LT | LogicalUnit: """ Logical units require logger and resource as init parameters, additionaly to the potentialy defined custom parameters which will be provided as a union of the model config and the kwargs provided to this method (the latter having prevalance) :raises: ValueError is the unit defined in the model is unknown """ return self.new( model, unit_type = sub_type or LogicalUnit, logger = logger, resource = self.get_resources(model), **kwargs )
[docs] def new_safe_logical_unit(self, um: UnitModel, unit_type: type[LT], logger: AmpelLogger, _chan: None | ChannelId = None ) -> LT: """ Returns a logical unit with dedicated logger containing no db handler """ if logger.verbose: logger.log(VERBOSE, f"Instantiating unit {um.unit}") buf_hdlr = ChanRecordBufHandler(logger.level, _chan, {'unit': um.unit}) if _chan \ else DefaultRecordBufferingHandler(logger.level, {'unit': um.unit}) # Spawn unit instance unit = self.new_logical_unit( model = um, logger = AmpelLogger.get_logger( base_flag = (getattr(logger, 'base_flag', 0) & ~LogFlag.CORE) | LogFlag.UNIT, console = len(logger.handlers) == 1, # to be improved later handlers = [buf_hdlr] ), sub_type = unit_type ) setattr(unit, '_buf_hdlr', buf_hdlr) # Shortcut return unit
@overload def new_context_unit(self, model: UnitModel, context: AmpelContext, *, sub_type: type[CT], **kwargs ) -> CT: ... @overload def new_context_unit(self, model: UnitModel, context: AmpelContext, *, sub_type: None = ..., **kwargs ) -> ContextUnit: ...
[docs] def new_context_unit(self, model: UnitModel, context: AmpelContext, *, sub_type: None | type[CT] = None, **kwargs ) -> CT | ContextUnit: """ Context units require an AmpelContext instance as init parameters, additionaly to potentialy defined dedicated custom parameters. :raises: ValueError is the unit defined in the model is unknown """ return self.new( model, unit_type=sub_type or ContextUnit, context=context, **kwargs )
@overload def new(self, model: UnitModel, *, unit_type: type[T], **kwargs) -> T: ... @overload def new(self, model: UnitModel, *, unit_type: None = ..., **kwargs) -> AmpelUnit: ...
[docs] def new(self, model: UnitModel, *, unit_type: None | type[T] = None, **kwargs) -> AmpelUnit | T: """ Instantiate new object based on provided model and kwargs. :param 'unit_type': performs isinstance check and raise error on mismatch. Enables mypy/other static checks. :returns: unit instance, trace id (0 if not computable) """ if not isinstance(model, UnitModel): raise ValueError(f"Unexpected model: '{type(model)}'") provenance = kwargs.pop('_provenance', self.provenance) Klass = self.get_class_by_name(model.unit, unit_type) # type: ignore if unit_type: check_class(Klass, unit_type) init_config = self.get_init_config(model.config, model.override) unit = Klass( **self.resolve_secrets( Klass, init_config | kwargs | (model.secrets or {}) ) ) if isinstance(unit, (LogicalUnit, ContextUnit)): trace_id = None # potentially sync trace_ids with DB (Ampel_ext) if provenance: assert self.db trace_dict = { 'py': pyv, 'unit': model.unit, 'digest': self.get_digest(Klass), 'version': self.config.get(f"unit.{model.unit}.version", str, raise_exc=True) } if c := unit._get_trace_content(): trace_dict['config'] = c if deps := self.config.get(f"unit.{model.unit}.dependencies"): if not isinstance(deps, (list, tuple)): raise ValueError(f"Retrieved environment is not a list/tuple: {type(deps)}") envd = self.config.get(f"environment.{env}", dict, raise_exc=True) trace_dict['env'] = {k: envd[k] for k in deps} try: # Note: we could implement a hash collision detection mechanism here trace_id = build_unsafe_dict_id(trace_dict, ret=int) # Save trace id to external collection if trace_id not in self.db.trace_ids: trace_dict['_id'] = trace_id self.db.add_trace_id(trace_id, trace_dict) # Non-serializable content except Exception: trace_id = 0 # raise e unit._trace_id = trace_id # type: ignore[union-attr] if hasattr(unit, "post_init"): unit.post_init() # type: ignore[union-attr] return unit
@staticmethod def get_digest(Klass: type) -> str: try: return blake2b( Path(sys.modules[Klass.__module__].__file__).read_bytes() # type: ignore ).hexdigest()[:7] except Exception: return "unspecified" @overload def get_class_by_name(self, name: str, unit_type: type[T]) -> type[T]: ... @overload def get_class_by_name(self, name: str, unit_type: None = ...) -> type[AmpelUnit]: ...
[docs] def get_class_by_name(self, name: str, unit_type: None | type[T] = None) -> type[T | AmpelUnit]: """ Matches the parameter 'name' with the unit definitions defined in the ampel_config. This allows to retrieve the corresponding fully qualified name of the class and to load it. :param unit_type: - LogicalUnit or any sublcass of LogicalUnit - ContextUnit or any sublcass of ContextUnit - If None (auxiliary class), returned object will have type[Any] :raises: ValueError if unit cannot be found or loaded or if parent class is unrecognized """ if name in AuxUnitRegister._defs: return AuxUnitRegister.get_aux_class(name, sub_type=unit_type) if self._dyn_register and name in self._dyn_register: return self._dyn_register[name] if name in self.unit_defs: fqn = self.unit_defs[name]['fqn'] else: raise ValueError(f"Ampel unit not found: {name}") # Note: importlib.import_module caches internally imported modules return getattr(import_module(fqn), name)
[docs] def get_init_config(self, config: None | int | str | dict[str, Any] = None, override: None | dict[str, Any] = None, kwargs: None | dict[str, Any] = None, unfreeze: bool = True ) -> dict[str, Any]: """ :raises: ValueError if config alias is not found """ base_conf: dict[str, Any] = {} if isinstance(config, (dict, str)): base_conf = self.resolve_aliases(config) if base_conf is None: raise ValueError(f"Config alias {config} not found") # Hashed t2 unit configs elif isinstance(config, int): try: d = self.config.get_conf_id(config) # confid not found (obsolete or dynamically generated by isolated process) except Exception as e: assert self.db l = list(self.db.col_conf_ids.find({"_id": config})) if len(l) == 0: raise e del l[0]['_id'] d = l[0] base_conf = recursive_unfreeze(d) if (unfreeze and isinstance(d, ReadOnlyDict)) else d # save un-registered (in ampel conf but not in db) confid to external collection for posterity if self.provenance: assert self.db if config not in self.db.conf_ids: self.db.add_conf_id(config, base_conf) return merge_dicts([base_conf, override, kwargs]) or {}
[docs] def resolve_aliases(self, value): """ Recursively resolve aliases from config """ if isinstance(value, str): for adict in self.aliases: if value in adict: return self.resolve_aliases(adict[value]) return value elif isinstance(value, list): return [self.resolve_aliases(v) for v in value] elif isinstance(value, dict): return {k: self.resolve_aliases(v) for k, v in value.items()} return value
[docs] def resolve_secrets(self, unit_type: type[AmpelUnit], init_kwargs: dict[str, Any]) -> dict[str, Any]: """ Add a resolved Secret instance to init_kwargs for every Secret field of unit_type. """ for k, annotation in unit_type._annots.items(): # for unions, consider the first member that is not NoneType if get_origin(annotation) in (Union, UnionType): annotation = next((f for f in get_args(annotation) if f is not type(None)), type(None)) field_type = get_origin(annotation) or annotation if issubclass(type(field_type), type) and issubclass(field_type, Secret): default = False if isinstance(kwargs := init_kwargs.get(k), Mapping): v = field_type(**kwargs) elif k in unit_type._defaults: default = True v = deepcopy(unit_type._defaults[k]) else: # missing required field; will be caught in validation later continue ValueType = args[0] if (args := annotation.get_model_args()) else object if args: assert ValueType is not object if not self.vault: raise ValueError("No vault configured") if not self.vault.resolve_secret(v, ValueType): raise ValueError( f"Could not resolve {unit_type.__name__}.{k} as {getattr(ValueType, '__name__', '<untyped>')}" f" using {'default' if default else 'configured'} value {repr(v)}" ) init_kwargs[k] = v return init_kwargs
[docs] def get_resources(self, model: UnitModel) -> dict[str, Any]: """ Resources are defined using the static variable 'require' in ampel units -> example: catsHTM.default """ resources: dict[str, Any] = {} Klass = self.get_class_by_name(model.unit) # Load possibly required global resources for k in ampel_iter(getattr(Klass, 'require', [])): if k is None: continue # Global resource example: extcat if (resource := self.config.get(f'resource.{k}')) is None: raise ValueError(f"Global resource not available: {k}") resources[k] = resource return resources
[docs] @contextmanager def validate_unit_models(self) -> Iterator[None]: """ Enable validation for UnitModel instances """ from ampel.abstract.AbsProcessController import AbsProcessController @staticmethod # type: ignore[misc] def validate_unit(value: UnitModel) -> UnitModel: Unit = self.get_class_by_name(value.unit) if issubclass(Unit, AmpelUnit) and not issubclass(Unit, AbsProcessController): Unit.validate(self.get_init_config(value.config, value.override)) return value UnitModel.post_validate_hook = validate_unit AliasableModel._config = self.config try: yield finally: UnitModel.post_validate_hook = None AliasableModel._config = None
""" def internal_mypy_tests_uncomment_only_in_your_editor(self, model: UnitModel, context: AmpelContext, logger: AmpelLogger, sub_type: None | type[CT] = None, **kwargs ) -> None: # Interal: uncomment to check if mypy works adequately from ampel.abstract.AbsEventUnit import AbsEventUnit from ampel.abstract.AbsLightCurveT2Unit import AbsLightCurveT2Unit reveal_type(self.new(model)) reveal_type(self.new(model, bla=12)) reveal_type(self.new(model, unit_type = None)) reveal_type(self.new(model, unit_type=AbsLightCurveT2Unit)) reveal_type(self.new(model, unit_type=AbsLightCurveT2Unit, bla=12)) reveal_type(self.new(model, unit_type=AbsEventUnit)) reveal_type(self.new(model, unit_type=AbsEventUnit, bla=12)) reveal_type(self.new_logical_unit(model, logger)) reveal_type(self.new_logical_unit(model, logger, bla=12)) reveal_type(self.new_logical_unit(model, logger, sub_type = None)) reveal_type(self.new_logical_unit(model, logger, sub_type=AbsLightCurveT2Unit)) reveal_type(self.new_logical_unit(model, logger, sub_type = AbsLightCurveT2Unit, bla=12)) # Next two lines *should* fail reveal_type(self.new_logical_unit(model, logger, sub_type=AbsEventUnit)) reveal_type(self.new_logical_unit(model, logger, sub_type = AbsEventUnit, bla=12)) reveal_type(self.new_context_unit(model, context)) reveal_type(self.new_context_unit(model, context, bla=12)) reveal_type(self.new_context_unit(model, context, sub_type = None)) reveal_type(self.new_context_unit(model, context, sub_type = AbsEventUnit)) reveal_type(self.new_context_unit(model, context, sub_type = AbsEventUnit, bla=12)) # Next two lines *should* fail reveal_type(self.new_context_unit(model, context, sub_type = AbsLightCurveT2Unit)) reveal_type(self.new_context_unit(model, context, sub_type = AbsLightCurveT2Unit, bla=12)) """