Files
Eric F d398a6ced2 Add extracted tools: CitrineOS, OpenOCPP, ShapeShifter
- CitrineOS core extracted (CSMS OCPP 2.0.1)
- OpenOCPP extracted (firmware OCPP 1.6J/2.0.1)
- ShapeShifter library installed (pip install -e)
- ShapeShifter specification extracted
- EVerest extracted

TODO updated with progress
2026-06-08 00:38:27 -04:00

311 lines
12 KiB
Python

import re
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from time import sleep
import uvicorn
from fastapi import FastAPI, Response
from fastapi.exceptions import HTTPException
from fastapi_xml import XmlAppResponse, XmlRoute
from .. import transport
from ..client import client_map
from ..exceptions import (
FunctionalException,
InvalidMessageException,
InvalidSenderException,
TransportException,
)
from ..logging import logger
from ..uftp import (
AcceptedRejected,
PayloadMessage,
SignedMessage,
TestMessage,
TestMessageResponse,
UsefRole,
request_response_map,
)
class ShapeshifterService():
"""
Basis for all Shapeshifter Services. Defines the web service, the
message ingestion, the post-processing queue mechanics, and
threading and context options.
"""
sender_domain = None
sender_role = None
acceptable_messages = []
num_inbound_threads = 10
num_outbound_threads = 10
def __init__(
self,
sender_domain,
signing_key,
key_lookup_function=None,
endpoint_lookup_function=None,
oauth_lookup_function=None,
host: str = "0.0.0.0",
port: int = 8080,
path: str = "/shapeshifter/api/v3/message",
version: str = "3.1.0"
):
"""
:param sender_domain: our sender domain (FQDN) that the recipient uses to look us up.
:param signing_key: the private singing key that we use to sign outgoing messages.
:param key_lookup_function: A callable that takes a (sender_domain, sender_role)
pair and returns a verify_key (str or bytes).
Omit parameter to use DNS for key lookup.
:param key_lookup_function: A callable that takes a (sender_domain, sender_role)
pair and returns a full endpoint URL (str).
Omit parameter to use DNS for endpoint lookup.
:param oauth_lookup_function: A callable that takes a (sender_domain, sender_role)
pair and returns in instance of shapeshifter_uftp.OAuthClient
if OAuth authentication is required.
:param host: the host to bind the server to (usually 127.0.0.1 or 0.0.0.0)
:param port: the port to bind the server to (default: 8080)
:param path: the URL path that the server listens on (default: /shapeshifter/api/v3/message)
"""
if version not in ("3.0.0", "3.1.0"):
raise ValueError(f"'version' must be one of '3.0.0' or '3.1.0', not {version}")
self.version = version
# Set the sender domain, which is used
# to identify us to the other party.
self.sender_domain = sender_domain
# The signing key is used to sign outgoing messages. The
# corresponding public key should be published via DNS or
# given to the recipient out-of-band.
self.signing_key = signing_key
# The key lookup method is used to look up keys for the other
# party. If omitted, use DNS lookups using well-known DNS names.
self.key_lookup_function = key_lookup_function or transport.get_key
# The endpoint lookup method is used to look up the endpoint
# that we send all messages to. If omitted, use DNS lookups
# using well-known DNS names.
self.endpoint_lookup_function = endpoint_lookup_function or transport.get_endpoint
# The OAuth lookup function is used to get the OAuth instance
# used to authenticate outgoing requests.
self.oauth_lookup_function = oauth_lookup_function
# The FastAPI web app takes care of routing messages to the
# (one) endpoint, and by virtue of FastAPI-XML convert the
# python-friendly objects into XML and vice versa.
self.app = FastAPI(default_response_class=XmlAppResponse)
self.app.router.route_class = XmlRoute
self.app.router.add_api_route(
path,
endpoint=self._receive_message,
response_model=None,
methods=["POST"],
status_code=200,
)
# The web server hosts the FastAPI application and takes care
# of the HTTP transport.
config = uvicorn.Config(app=self.app, host=host, port=port)
self.server = uvicorn.Server(config)
self.server_thread = None
# Create an inbound executor worker
self.inbound_executor = ThreadPoolExecutor(max_workers=self.num_inbound_threads)
self.outbound_executor = ThreadPoolExecutor(max_workers=self.num_outbound_threads)
def run(self):
"""
Start the web server that hosts the FastAPI application. Other
participants can now send messages to us.
"""
# Start the service and start accepting incoming requests.
self.server.run()
def run_in_thread(self):
"""
Run the service in a background thread.
"""
self.server_thread = Thread(target=self.run)
self.server_thread.start()
while not self.server.started:
sleep(0.1)
def stop(self):
"""
Stop the service if it was running in a separate thread.
"""
self.server.should_exit = True
if self.server_thread and self.server_thread.is_alive():
self.server_thread.join()
self.server_thread = None
# ------------------------------------------------------------ #
# Message handling and processing methods, internal to the #
# Shapeshifter UFTP implementation. #
# ------------------------------------------------------------ #
def _receive_message(self, message: SignedMessage) -> Response:
"""
The default entrypoint for the route. This will unpack the
message and validate the signature. It will thes pass the
PayloadMessago to the pre processing function for a
response.
"""
logger.info(f"Got a request: {message}")
# Get the public key that is used to decrypt the message
signing_key = self.key_lookup_function(
message.sender_domain, message.sender_role.value
)
logger.debug(f"The signing key is {signing_key}")
# Unseal the message, returning an error if required
try:
unsealed_message = transport.unseal_message(message.body, signing_key)
# Verify that the sender_domain inside the message is the
# same as the sender_domain of the SignedMessage
# envelope
if unsealed_message.sender_domain != message.sender_domain:
logger.warning(
"Received a message with mismatching sender_domain in the "
f"SignedMessage envelope ({message.sender_domain}) and the "
f"inner PayloadMessage ({unsealed_message.sender_domain})"
)
raise InvalidSenderException()
if type(unsealed_message) not in self.acceptable_messages:
logger.warning(
f"Received a misdirected message of type {unsealed_message.__class__.__name__} "
f"from {unsealed_message.sender_domain}.")
raise InvalidMessageException(unsealed_message)
except TransportException as err:
logger.warning(f"The original transport error is {err.__class__.__name__}: {err}")
raise HTTPException(err.http_status_code) from err
except FunctionalException as err:
self.outbound_executor.submit(self._reject_message, message, unsealed_message, err.rejection_reason)
else:
# If the initial checks passed, process the message in the
# user-defined pipeline.
self.inbound_executor.submit(self._process_message, unsealed_message, message.sender_role)
return Response(status_code=200)
def _process_message(self, message: PayloadMessage, sender_role: UsefRole):
"""
Find the relevant post-processing method to handle the message
outside of the request context, and run it.
"""
process_method_name = f"process_{snake_case(message.__class__.__name__)}"
process_method = getattr(self, process_method_name)
try:
if isinstance(message, TestMessage):
process_method(message, sender_role)
else:
process_method(message)
except Exception as err: # pylint: disable=broad-exception-caught
logger.error(
f"An error occurred during the post-processing of a {message.__class__.__name__} message."
f"{err.__class__.__name__}: {err}"
)
def _get_client(self, recipient_domain: str, recipient_role: UsefRole, version: str = "3.1.0"):
"""
Method to get a relevant client to communicate to the
indicated participant.
"""
client_cls = client_map[(self.sender_role, recipient_role)]
recipient_endpoint = self.endpoint_lookup_function(recipient_domain, recipient_role)
recipient_signing_key = self.key_lookup_function(recipient_domain, recipient_role)
oauth_client = self.oauth_lookup_function(recipient_domain, recipient_role) if self.oauth_lookup_function else None
return client_cls(
sender_domain = self.sender_domain,
signing_key = self.signing_key,
recipient_domain = recipient_domain,
recipient_endpoint = recipient_endpoint,
recipient_signing_key = recipient_signing_key,
oauth_client = oauth_client,
version=version,
)
def _reject_message(self, message, unsealed_message, reason):
"""
Send a rejection to the sending party.
"""
if type(unsealed_message) not in request_response_map:
return
client = self._get_client(message.sender_domain, message.sender_role, unsealed_message.version)
response_type = request_response_map[type(unsealed_message)]
response_id_field = snake_case(type(unsealed_message).__name__) + "_message_id"
message_contents = {
"recipient_domain": message.sender_domain,
"conversation_id": unsealed_message.conversation_id,
"result": AcceptedRejected.REJECTED,
"rejection_reason": reason,
response_id_field: unsealed_message.message_id
}
response_message = response_type(**message_contents)
client._send_message(response_message)
def __enter__(self):
"""
Context-manager method that allows an instance of this class to be
used in a temporary context.
Starts the uvicorn server in a separate thread and waits for it
to be started. Useful in testing scenarios.
Usage:
with MyShapeshifterServiceSubClass(...) as server:
# your test code here, server exists cleanly when leaving
# the 'with' block.
"""
self.run_in_thread()
return self
def __exit__(self, *args, **kwargs):
"""
Tell uvicorn we should exit and wait for the thread to finish.
"""
self.stop()
# ------------------------------------------------------------ #
# Common messages to all parties #
# ------------------------------------------------------------ #
def process_test_message(self, message: TestMessage, sender_role: UsefRole):
logger.info(
"Received a TestMessage, will respond with TestMessageResponse. "
f"Implement the process_test_message method on your {self.__class__.__qualname__} "
f"to implement custom behavior. The message was: {message}"
)
response = TestMessageResponse(conversation_id=message.conversation_id)
client = self._get_client(message.sender_domain, sender_role, version=message.version or self.version)
client.send_test_message_response(response)
def process_test_message_response(self, message: TestMessage):
logger.info(f"Received a TestMessageResponse: {message}")
def snake_case(text):
"""
Convert text from CamelCase to snake_case.
"""
return re.sub(r"(.)([A-Z][a-z])", r"\1_\2", text).lower()