Source code for secsgem.common.state_machine

#####################################################################
# state_machine.py
#
# (c) Copyright 2023, Benjamin Parzella. All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#####################################################################
"""State machine for connection state."""

from __future__ import annotations

import logging
import typing
from typing import Any

from .events import EventProducer

if typing.TYPE_CHECKING:
    import enum


[docs]class UnknownTransitionError(Exception): """Exception for unknown transition.""" def __init__(self, transition: str) -> None: """Initialize unknown transition exception. Args: transition: name of the transition """ super().__init__(f"Invalid transition: {transition}")
[docs]class WrongSourceStateError(Exception): """Exception for wrong source state for transition.""" def __init__(self, transition: str, expected: str, actual: str) -> None: """Initialize wrong source state exception. Args: transition: name of the transition expected: expected source state actual: actual source state """ super().__init__(f"Invalid source state for transition '{transition}': {actual} (expected {expected})")
[docs]class State: """State machine state class.""" def __init__( self, state: enum.Enum, name: str, parent: State | None = None, initial: bool = False, ) -> None: """Initialize state object. Args: state: state name: state name parent: parent state initial: is initial state """ self._state = state self._name = name self._parent = parent self._active = initial self._event_producer = EventProducer() @property def state(self) -> enum.Enum: """Get the connection state for this state.""" return self._state @property def name(self) -> str: """Get the name for this state.""" return self._name @property def events(self) -> EventProducer: """Property for event handling.""" return self._event_producer @property def parent(self) -> State | None: """Get parent state if available.""" return self._parent @property def active(self) -> bool: """Get if the state is active.""" return self._active
[docs] def enter(self, source: State | None): """Enter the state. Args: source: state to enter from. """ self.events.fire("enter", {}) self._active = True if self.parent is not None and (source is None or source.parent != self.parent): self.parent.enter(source.parent if source is not None else None)
[docs] def leave(self, destination: State | None): """Leave the state. Args: destination: state to enter from. """ self.events.fire("leave", {}) self._active = False if self.parent is not None and (destination is None or destination.parent != self.parent): self.parent.leave(destination.parent if destination is not None else None)
[docs]class Transition: """State machine transition class.""" def __init__( self, name: str, sources: State | list[State], destination: State, ) -> None: """Initialize transition object. Args: name: name of the transition sources: source states allowed for transition destination: destination state for transition """ self._name = name self._sources = sources if isinstance(sources, list) else [sources] self._destination = destination self._event_producer = EventProducer() @property def name(self) -> str: """Get transition name.""" return self._name @property def events(self) -> EventProducer: """Property for event handling.""" return self._event_producer @property def sources(self) -> list[State]: """Get the allowed source states for the transition.""" return self._sources @property def destination(self) -> State: """Get the destination state for the transition.""" return self._destination def __call__(self) -> Any: """Call the transition.""" self.events.fire("called", {})
[docs]class StateMachine: """Base state machine.""" def __init__(self) -> None: """Initialize state machine.""" self._current_state: State self._transitions: list[Transition] self._logger = logging.getLogger(self.__module__ + "." + self.__class__.__name__) @property def current(self) -> enum.Enum: """Get the current state enum.""" return self._current_state.state @property def current_state(self) -> State: """Get the current state.""" return self._current_state
[docs] def transition(self, name: str) -> Transition: """Get the object for a specific transition. Args: name: transition name """ value = next((transition for transition in self._transitions if transition.name == name), None) if value is None: raise UnknownTransitionError(name) return value
def _perform_transition(self, name: str) -> None: """Perform a transition. Args: name: transition name """ transition = self.transition(name) if self._current_state not in transition.sources: raise WrongSourceStateError( name, "/".join([state.name for state in transition.sources]), self._current_state.name, ) self._logger.debug("State change: %s >> %s", self._current_state.name, transition.destination.name) self._current_state.leave(transition.destination) old_state = self._current_state self._current_state = transition.destination transition.destination.enter(old_state) transition()