from typing import Union, List, Tuple
from abc import ABC, abstractmethod
from copy import deepcopy

import numpy as np
import networkx as nx

class UpdateMethod(ABC):
    def update(self):

class LoggerVariable(object):
    def __init__(
        attribute_name: str,
        history_update_method: UpdateMethod,
        best_update_method: UpdateMethod,
    ): = attribute_name = []
        self.history = []

        self.history_update_method = history_update_method
        self.best_update_method = best_update_method

    def update(self, new_value):

    def update_history(self, new_value):
        return self.history_update_method.update(self, "history", new_value)

    def update_best(self, new_value):
        return self.best_update_method.update(self, "best", new_value)

class AppendValue(UpdateMethod):
    def update(
        self, attribute_name: str, new_value: Union[list, int, float, str]
    ) -> None:
        get_var = getattr(self, attribute_name)
        setattr(self, attribute_name, get_var)

class ReplaceValue(UpdateMethod):
    def update(
        self, attribute_name: str, new_value: Union[list, int, float, str]
    ) -> None:
        setattr(self, attribute_name, [new_value])

class EmptyValue(UpdateMethod):
    def update(
        self, attribute_name: str, new_value: Union[list, int, float, str]
    ) -> None:
        setattr(self, attribute_name, [])

class IfLowerDo(UpdateMethod):
    def __init__(self, update_method: UpdateMethod):
        self._update_method = update_method

    def update(
        logger_variable: LoggerVariable,
        attribute_name: str,
        new_value: Union[int, float],
    ) -> None:
            old_value = getattr(logger_variable, attribute_name)[-1]
            if new_value < old_value:
                self._update_method.update(logger_variable, attribute_name, new_value)
        except IndexError:
            AppendValue.update(logger_variable, attribute_name, new_value)

class IfHigherDo(UpdateMethod):
    def __init__(self, update_method: UpdateMethod):
        self._update_method = update_method

    def update(
        logger_variable: LoggerVariable,
        attribute_name: str,
        new_value: Union[int, float],
    ) -> None:
            old_value = getattr(logger_variable, attribute_name)[-1]
            if new_value > old_value:
                self._update_method.update(logger_variable, attribute_name, new_value)
        except IndexError:
            AppendValue.update(logger_variable, attribute_name, new_value)

class LoggerVariableFactory(object):

    Creates Logger Variable Objects

    history_bool_mapping = {"True": AppendValue, "False": EmptyValue}
    best_string_mapping = {
        "HighestOnly": IfHigherDo(ReplaceValue),
        "HighestSoFar": IfHigherDo(AppendValue),
        "LowestOnly": IfLowerDo(ReplaceValue),
        "LowestSoFar": IfLowerDo(AppendValue),
        "Replace": ReplaceValue,
        "Append": AppendValue,

    def create_logger_variable(
        self, attribute_name: str, history_update_bool: bool, best_update_string: str
    ) -> LoggerVariable:
        history_update_method = self.history_bool_mapping[str(history_update_bool)]
        best_update_method = self.best_string_mapping[best_update_string]

        return LoggerVariable(attribute_name, history_update_method, best_update_method)

[docs]class Logger(object): def __init__(self, initialisation_variables: dict, logger_update_structure: dict): for attribute_name, attribute_properties in initialisation_variables.items(): self._create_variable( attribute_name=attribute_name, history_update_bool=attribute_properties["history_update_bool"], best_update_string=attribute_properties["best_update_string"], ) self._create_logger_update_structure( logger_update_structure["root_nodes"], logger_update_structure["best_update_structure"], ) def _create_variable( self, attribute_name: str, history_update_bool: bool, best_update_string: str ): """ If a variable was created was this method, ensure that the logger update strucuture is updated to include the new variable. Else the best value of the new variable will not be updated. """ setattr( self, attribute_name, LoggerVariableFactory.create_logger_variable( attribute_name, history_update_bool, best_update_string ), ) def _create_logger_update_structure( self, root_nodes: List[str], best_update_relations: Tuple[List[str]] ): """ This method creates an internal representation based on the way in which the Logger Variables should be updated using a NetworkX Directed Graph. Parameter --------- root_nodes: `list` Starting points of the update graphs. The name of the Logger Variable that needs to be updated first should be named in this list. best_update_relations: `list` If the updating of B depends on A, (A changes, Update B). This relationship should be encapsuled/represented with the following list: ['A', 'B']. """ self.best_update_structure = nx.DiGraph() self.root_nodes = root_nodes self.update_edges = best_update_relations @property def update_edges(self): return self._update_edges @update_edges.setter def update_edges(self, update_edges: Tuple[List[str]]): self.best_update_structure.add_edges_from(update_edges) self._update_edges = update_edges return self._update_edges @property def root_nodes(self): return self._root_nodes @root_nodes.setter def root_nodes(self, root_nodes: List[str]): self.best_update_structure.add_nodes_from(root_nodes) self._root_nodes = root_nodes return self._root_nodes
[docs] def log_variables(self, input_dict: dict): """ Updates all the histories of the logger variables. Best values of the logger variables are only updated according to the rules specified by its update structure. If the attribute is not in the update structure it isnt updated. input_dict: Should contain a dictionary with the key as the name of the attribute to be updated and its respective value. Note that the attribute should have been created beforehand. """ for each_key in input_dict.keys(): self._log_history(each_key, input_dict[each_key]) # Updates based on best_update_structure, the networkx graph. change_dict = dict() node_list = deepcopy(self.root_nodes) while len(node_list) != 0: mid_list = [] for each_node in node_list: change_bool = False # Checks if a nodes predecessor has been updated. If not all # of them were updated, the node is skipped. node_predecessor = self.best_update_structure.pred[each_node] updated_pred_count = np.sum( [ change_dict[each_precessor_node] for each_precessor_node in node_predecessor ] ) if len(node_predecessor) != updated_pred_count: continue # Update node if attribute name is in input dictionary and # its predecessors hsa been updated. if each_node in input_dict.keys(): logged_var = getattr(self, each_node) old_best = self._log_best(each_node, input_dict[each_node]) new_best = if type(new_best[0]) == np.ndarray: if not np.array_equal(new_best, old_best): change_bool = True elif new_best != old_best: change_bool = True # Keeps track of the nodes whose best values have changed change_dict[each_node] = change_bool # Retrieve the next set of nodes to update if the current node # was changed. if change_bool == True: mid_list.extend(self.best_update_structure.adj[each_node].keys()) node_list = list(set(mid_list))
def _log_history(self, attribute_name: str, attribute_value): attr_to_update = getattr(self, attribute_name) attr_to_update.update_history(attribute_value) def _log_best(self, attribute_name: str, attribute_value): attr_to_update = getattr(self, attribute_name) attr_to_update.update_best(attribute_value)