#####################################################################
# protocol.py
#
# (c) Copyright 2023, Benjamin Parzella. All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#####################################################################
"""protocol base class."""
from __future__ import annotations
import abc
import logging
import queue
import random
import typing
from .block_send_info import BlockSendInfo
from .byte_queue import ByteQueue
from .events import EventProducer
from .protocol_dispatcher import ProtocolDispatcher
if typing.TYPE_CHECKING:
from secsgem.secs.functions.base import SecsStreamFunction
from .connection import Connection
from .message import Block, Message
from .settings import Settings
MessageT = typing.TypeVar("MessageT", bound="Message")
BlockT = typing.TypeVar("BlockT", bound="Block")
[docs]
class Protocol(abc.ABC, typing.Generic[MessageT, BlockT]): # pylint: disable=too-many-instance-attributes
"""Abstract base class for a protocol."""
message_type: type[MessageT]
def __init__(self, settings: Settings) -> None:
"""Initialize protocol base object.
Args:
settings: settings object
"""
super().__init__()
self._settings = settings
self._event_producer = EventProducer()
self._event_producer.targets += self
self._system_counter = random.randint(0, (2**32) - 1) # noqa: S311
self._logger = logging.getLogger(self.__module__ + "." + self.__class__.__name__)
self._communication_logger = logging.getLogger("communication")
self.__connection: Connection | None = None
self._response_queues: dict[int, queue.Queue[MessageT]] = {}
self._receive_buffer = ByteQueue()
self._send_queue: queue.Queue[BlockSendInfo] = queue.Queue()
self._incomplete_messages: dict[int, MessageT] = {}
self._thread = ProtocolDispatcher(
self._process_data,
self._dispatch_block,
self._settings,
)
@property
def _connection(self) -> Connection:
"""Get the connection object.
If no connection object is available, a new one is created.
Returns:
connection object
"""
if self.__connection is None:
self.__connection = self._settings.create_connection()
self.__connection.on_connected.register(self._on_connected)
self.__connection.on_data.register(self._on_connection_data_received)
self.__connection.on_disconnecting.register(self._on_disconnecting)
self.__connection.on_disconnected.register(self._on_disconnected)
return self.__connection
@abc.abstractmethod
def _on_connected(self, _: dict[str, typing.Any]):
"""Abstract method called when connection is established.
The arguemnt is a dictionary with the following keys
- source: connection object that triggered the event
"""
raise NotImplementedError("Protocol._on_connected missing implementation")
@abc.abstractmethod
def _on_disconnecting(self, _: dict[str, typing.Any]):
"""Abstract method called when connection is closing.
The arguemnt is a dictionary with the following keys
- source: connection object that triggered the event
"""
raise NotImplementedError("Protocol._on_disconnecting missing implementation")
@abc.abstractmethod
def _on_disconnected(self, _: dict[str, typing.Any]):
"""Abstract method called when connection is closed.
The arguemnt is a dictionary with the following keys
- source: connection object that triggered the event
"""
raise NotImplementedError("Protocol._on_disconnected missing implementation")
def _on_connection_data_received(self, data: dict[str, typing.Any]):
"""Method called when data is received by connection.
The `data` arguemnt is a dictionary with the following keys
- source: connection object that triggered the event
- data: received data
Args:
data: received data
"""
self._receive_buffer.append(data["data"])
self._thread.trigger_receiver()
def _process_data(self):
"""Process input and output data.
First, the send queue is processed, then the received data is processed.
"""
self._process_send_queue()
self._process_received_data()
@abc.abstractmethod
def _process_send_queue(self):
"""Process the send to communication queue."""
raise NotImplementedError("Protocol._process_send_queue missing implementation")
@abc.abstractmethod
def _process_received_data(self):
"""Process the receive from communication queue."""
raise NotImplementedError("Protocol._process_received_data missing implementation")
def _dispatch_block(self, source: Protocol, block: BlockT):
"""Dispatch block to message handler.
Args:
source: source of event
block: block to dispatch
"""
result = self._add_message_block(block)
if result is None:
return
try:
self._on_connection_message_received(source, result)
except Exception: # pylint: disable=broad-except
self._logger.exception("ignoring exception for on_connection_message_received handler")
@property
def events(self):
"""Property for event handling."""
return self._event_producer
[docs]
def get_next_system_counter(self) -> int:
"""Return the next System.
Returns:
System for the next command
"""
self._system_counter += 1
if self._system_counter > ((2**32) - 1):
self._system_counter = 0
return self._system_counter
[docs]
def enable(self):
"""Enable the connection."""
self._connection.enable()
[docs]
def disable(self):
"""Disable the connection."""
self._connection.disable()
[docs]
@abc.abstractmethod
def serialize_data(self) -> dict[str, typing.Any]:
"""Get protocol serialized data for debugging.
Returns:
data to serialize for this object
"""
raise NotImplementedError("Protocol.serialize_data missing implementation")
@abc.abstractmethod
def _on_connection_message_received(self, source: Protocol, message: MessageT):
"""Message received by connection.
Args:
source: source of event
message: received data message
"""
raise NotImplementedError("Protocol._on_connection_message_received missing implementation")
def _get_queue_for_system(self, system_id: int) -> queue.Queue:
"""Create a new queue to receive responses for a certain system.
Args:
system_id: system id to watch
Returns:
queue to receive responses with
"""
self._response_queues[system_id] = queue.Queue()
return self._response_queues[system_id]
def _remove_queue(self, system_id: int):
"""Remove queue for system id from list.
Args:
system_id: system id to remove
"""
del self._response_queues[system_id]
def _add_message_block(self, block: BlockT) -> MessageT | None:
"""Add a block, and get completed message if available.
Args:
block: block to add
Returns:
completed message or None if paket not complete
"""
if block.header.system not in self._incomplete_messages:
self._incomplete_messages[block.header.system] = self.message_type.from_block(block)
else:
self._incomplete_messages[block.header.system].blocks.append(block)
message = self._incomplete_messages[block.header.system]
if not message.complete:
return None
del self._incomplete_messages[block.header.system]
return message
@abc.abstractmethod
def _create_message_for_function(
self,
function: SecsStreamFunction,
system_id: int,
) -> Message:
"""Create a protocol specific message for a function.
Args:
function: function to create message for
system_id: system
Returns:
created message
"""
raise NotImplementedError("Protocol._create_message_for_function missing implementation")
[docs]
def send_message(self, message: Message) -> bool:
"""Send a message to the remote host.
Args:
message: message to be transmitted
Returns:
True if sending was successful
"""
for block in message.blocks:
block_send_info = BlockSendInfo(block.encode())
self._send_queue.put(block_send_info)
self._thread.trigger_receiver()
if not block_send_info.wait():
return False
return True
[docs]
def send_and_waitfor_response(self, function: SecsStreamFunction) -> Message | None:
"""Send the message and wait for the response.
Args:
function: message to be sent
Returns:
Message that was received
"""
system_id = self.get_next_system_counter()
response_queue = self._get_queue_for_system(system_id)
out_message = self._create_message_for_function(function, system_id)
self._communication_logger.info("> %s\n%s", out_message, function, extra=self._get_log_extra())
if not self.send_message(out_message):
self._logger.error("Sending message failed")
self._remove_queue(system_id)
return None
try:
response = response_queue.get(True, self._settings.timeouts.t3)
except queue.Empty:
response = None
self._remove_queue(system_id)
return response
[docs]
def send_response(self, function: SecsStreamFunction, system: int) -> bool:
"""Send response function for system.
Args:
function: function to be sent
system: system to reply to
Returns:
True if sending was successful
"""
out_message = self._create_message_for_function(function, system)
self._communication_logger.info("> %s\n%s", out_message, function, extra=self._get_log_extra())
return self.send_message(out_message)
[docs]
def send_stream_function(self, function: SecsStreamFunction) -> bool:
"""Send the message and wait for the response.
Args:
function: message to be sent
Returns:
True if successful
"""
out_message = self._create_message_for_function(function, self.get_next_system_counter())
self._communication_logger.info("> %s\n%s", out_message, function, extra=self._get_log_extra())
return self.send_message(out_message)
def __repr__(self) -> str:
"""Generate textual representation for an object of this class."""
return f"{self.__class__.__name__} {self.serialize_data()}"
@abc.abstractmethod
def _get_log_extra(self) -> dict[str, typing.Any]:
"""Get extra fields for logging."""
raise NotImplementedError("Protocol._get_log_extra missing implementation")