Files
alpinebits_python/src/alpine_bits_python/alpinebits_server.py

780 lines
27 KiB
Python
Raw Blame History

"""
AlpineBits Server for handling hotel data exchange.
This module provides an asynchronous AlpineBits server that can handle various
OTA (OpenTravel Alliance) actions for hotel data exchange. Currently implements
handshaking functionality with configurable supported actions and capabilities.
"""
import asyncio
from datetime import datetime
import difflib
import json
import inspect
import re
from typing import Dict, List, Optional, Any, Union, Tuple, Type, override
from xml.etree import ElementTree as ET
from dataclasses import dataclass
from enum import Enum, IntEnum
from alpine_bits_python.alpine_bits_helpers import PhoneTechType, create_xml_from_db
from .generated.alpinebits import OtaPingRq, OtaPingRs, WarningStatus, OtaReadRq
from xsdata_pydantic.bindings import XmlSerializer
from xsdata.formats.dataclass.serializers.config import SerializerConfig
from abc import ABC, abstractmethod
from xsdata_pydantic.bindings import XmlParser
import logging
from .db import Reservation, Customer
from sqlalchemy import select
from sqlalchemy.orm import joinedload
# Configure logging
logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__)
class HttpStatusCode(IntEnum):
"""Allowed HTTP status codes for AlpineBits responses."""
OK = 200
BAD_REQUEST = 400
UNAUTHORIZED = 401
INTERNAL_SERVER_ERROR = 500
class AlpineBitsActionName(Enum):
"""Enum for AlpineBits action names with capability and request name mappings."""
# Format: (capability_name, actual_request_name)
OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking")
OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests")
OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif")
OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = (
"action_OTA_HotelResNotif_GuestRequests",
"OTA_HotelResNotif:GuestRequests",
)
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = (
"action_OTA_HotelDescriptiveContentNotif_Inventory",
"OTA_HotelDescriptiveContentNotif:Inventory",
)
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = (
"action_OTA_HotelDescriptiveContentNotif_Info",
"OTA_HotelDescriptiveContentNotif:Info",
)
OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = (
"action_OTA_HotelDescriptiveInfo_Inventory",
"OTA_HotelDescriptiveInfo:Inventory",
)
OTA_HOTEL_DESCRIPTIVE_INFO_INFO = (
"action_OTA_HotelDescriptiveInfo_Info",
"OTA_HotelDescriptiveInfo:Info",
)
OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = (
"action_OTA_HotelRatePlanNotif_RatePlans",
"OTA_HotelRatePlanNotif:RatePlans",
)
OTA_HOTEL_RATE_PLAN_BASE_RATES = (
"action_OTA_HotelRatePlan_BaseRates",
"OTA_HotelRatePlan:BaseRates",
)
def __init__(self, capability_name: str, request_name: str):
self.capability_name = capability_name
self.request_name = request_name
@classmethod
def get_by_capability_name(
cls, capability_name: str
) -> Optional["AlpineBitsActionName"]:
"""Get action enum by capability name."""
for action in cls:
if action.capability_name == capability_name:
return action
return None
@classmethod
def get_by_request_name(cls, request_name: str) -> Optional["AlpineBitsActionName"]:
"""Get action enum by request name."""
for action in cls:
if action.request_name == request_name:
return action
return None
class Version(str, Enum):
"""Enum for AlpineBits versions."""
V2024_10 = "2024-10"
V2022_10 = "2022-10"
# Add other versions as needed
class AlpineBitsClientInfo:
"""Wrapper for username, password, client_id"""
def __init__(self, username: str, password: str, client_id: str | None = None):
self.username = username
self.password = password
self.client_id = client_id
@dataclass
class AlpineBitsResponse:
"""Response data structure for AlpineBits actions."""
xml_content: str
status_code: HttpStatusCode = HttpStatusCode.OK
def __post_init__(self):
"""Validate that status code is one of the allowed values."""
if self.status_code not in [200, 400, 401, 500]:
raise ValueError(
f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500"
)
# Abstract base class for AlpineBits Action
class AlpineBitsAction(ABC):
"""Abstract base class for handling AlpineBits actions."""
name: AlpineBitsActionName
version: (
Version | list[Version]
) # list of versions in case action supports multiple versions
async def handle(
self,
action: str,
request_xml: str,
version: Version,
client_info: AlpineBitsClientInfo,
dbsession=None,
server_capabilities=None,
) -> AlpineBitsResponse:
"""
Handle the incoming request XML and return response XML.
Default implementation returns "not implemented" error.
Override this method in subclasses to provide actual functionality.
Args:
action: The action to perform (e.g., "OTA_PingRQ")
request_xml: The XML request body as string
version: The AlpineBits version
Returns:
AlpineBitsResponse with error or actual response
"""
return_string = f"Error: Action {action} not implemented"
return AlpineBitsResponse(return_string, HttpStatusCode.BAD_REQUEST)
async def check_version_supported(self, version: Version) -> bool:
"""
Check if the action supports the given version.
Args:
version: The AlpineBits version to check
Returns:
True if supported, False otherwise
"""
if isinstance(self.version, list):
return version in self.version
return version == self.version
class ServerCapabilities:
"""
Automatically discovers AlpineBitsAction implementations and generates capabilities.
"""
def __init__(self):
self.action_registry: Dict[str, Type[AlpineBitsAction]] = {}
self._discover_actions()
self.capability_dict = None
def _discover_actions(self):
"""Discover all AlpineBitsAction implementations in the current module."""
current_module = inspect.getmodule(self)
for name, obj in inspect.getmembers(current_module):
if (
inspect.isclass(obj)
and issubclass(obj, AlpineBitsAction)
and obj != AlpineBitsAction
):
# Check if this action is actually implemented (not just returning default)
if self._is_action_implemented(obj):
action_instance = obj()
if hasattr(action_instance, "name"):
# Use capability name for the registry key
self.action_registry[action_instance.name.capability_name] = obj
def _is_action_implemented(self, action_class: Type[AlpineBitsAction]) -> bool:
"""
Check if an action is actually implemented or just uses the default behavior.
This is a simple check - in practice, you might want more sophisticated detection.
"""
# Check if the class has overridden the handle method
if "handle" in action_class.__dict__:
return True
return False
def create_capabilities_dict(self) -> None:
"""
Generate the capabilities dictionary based on discovered actions.
"""
versions_dict = {}
for action_name, action_class in self.action_registry.items():
action_instance = action_class()
# Get supported versions for this action
if isinstance(action_instance.version, list):
supported_versions = action_instance.version
else:
supported_versions = [action_instance.version]
# Add action to each supported version
for version in supported_versions:
version_str = version.value
if version_str not in versions_dict:
versions_dict[version_str] = {"version": version_str, "actions": []}
action_dict = {"action": action_name}
# Add supports field if the action has custom supports
if hasattr(action_instance, "supports") and action_instance.supports:
action_dict["supports"] = action_instance.supports
versions_dict[version_str]["actions"].append(action_dict)
self.capability_dict = {"versions": list(versions_dict.values())}
return None
def get_capabilities_dict(self) -> Dict:
"""
Get capabilities as a dictionary. Generates if not already created.
"""
if self.capability_dict is None:
self.create_capabilities_dict()
return self.capability_dict
def get_capabilities_json(self) -> str:
"""Get capabilities as formatted JSON string."""
return json.dumps(self.get_capabilities_dict(), indent=2)
def get_supported_actions(self) -> List[str]:
"""Get list of all supported action names."""
return list(self.action_registry.keys())
# Sample Action Implementations for demonstration
class PingAction(AlpineBitsAction):
"""Implementation for OTA_Ping action (handshaking)."""
def __init__(self, config: Dict = {}):
self.name = AlpineBitsActionName.OTA_PING
self.version = [
Version.V2024_10,
Version.V2022_10,
] # Supports multiple versions
self.config = config
@override
async def handle(
self,
action: str,
request_xml: str,
version: Version,
client_info: AlpineBitsClientInfo,
server_capabilities: None | ServerCapabilities = None,
) -> AlpineBitsResponse:
"""Handle ping requests."""
if request_xml is None:
return AlpineBitsResponse(
f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST
)
if server_capabilities is None:
return AlpineBitsResponse(
"Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR
)
# Parse the incoming request XML and extract EchoData
parser = XmlParser()
try:
parsed_request = parser.from_string(request_xml, OtaPingRq)
echo_data = json.loads(parsed_request.echo_data)
except Exception as e:
return AlpineBitsResponse(
f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST
)
# compare echo data with capabilities, create a dictionary containing the matching capabilities
capabilities_dict = server_capabilities.get_capabilities_dict()
_LOGGER.info(f"Capabilities Dict: {capabilities_dict}")
matching_capabilities = {"versions": []}
# Iterate through client's requested versions
for client_version in echo_data.get("versions", []):
client_version_str = client_version.get("version", "")
# Find matching server version
for server_version in capabilities_dict["versions"]:
if server_version["version"] == client_version_str:
# Found a matching version, now find common actions
matching_version = {"version": client_version_str, "actions": []}
# Get client's requested actions for this version
client_actions = {
action.get("action", ""): action
for action in client_version.get("actions", [])
}
server_actions = {
action.get("action", ""): action
for action in server_version.get("actions", [])
}
# Find common actions
for action_name in client_actions:
if action_name in server_actions:
# Use server's action definition (includes our supports)
matching_version["actions"].append(
server_actions[action_name]
)
# Only add version if there are common actions
if matching_version["actions"]:
matching_capabilities["versions"].append(matching_version)
break
# Debug print to see what we matched
# Create successful ping response with matched capabilities
capabilities_json = json.dumps(matching_capabilities, indent=2)
warning = OtaPingRs.Warnings.Warning(
status=WarningStatus.ALPINEBITS_HANDSHAKE,
type_value="11",
content=[capabilities_json],
)
warning_response = OtaPingRs.Warnings(warning=[warning])
all_capabilities = server_capabilities.get_capabilities_json()
response_ota_ping = OtaPingRs(
version="7.000",
warnings=warning_response,
echo_data=all_capabilities,
success="",
)
config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8"
)
serializer = XmlSerializer(config=config)
response_xml = serializer.render(
response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}
)
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
def strip_control_chars(s):
# Remove all control characters (ASCII < 32 and DEL)
return re.sub(r"[\x00-\x1F\x7F]", "", s)
def validate_hotel_authentication(
username: str, password: str, hotelid: str, config: Dict
) -> bool:
"""Validate hotel authentication based on username, password, and hotel ID.
Example config
alpine_bits_auth:
- hotel_id: "123"
hotel_name: "Frangart Inn"
username: "alice"
password: !secret ALICE_PASSWORD
"""
if not config or "alpine_bits_auth" not in config:
return False
auth_list = config["alpine_bits_auth"]
for auth in auth_list:
if (
auth.get("hotel_id") == hotelid
and auth.get("username") == username
and auth.get("password") == password
):
return True
return False
# look for hotelid in config
class ReadAction(AlpineBitsAction):
"""Implementation for OTA_Read action."""
def __init__(self, config: Dict = {}):
self.name = AlpineBitsActionName.OTA_READ
self.version = [Version.V2024_10, Version.V2022_10]
self.config = config
async def handle(
self,
action: str,
request_xml: str,
version: Version,
client_info: AlpineBitsClientInfo,
dbsession=None,
server_capabilities=None,
) -> AlpineBitsResponse:
"""Handle read requests."""
clean_action = strip_control_chars(str(action)).strip()
clean_expected = strip_control_chars(self.name.value[1]).strip()
if clean_action != clean_expected:
return AlpineBitsResponse(
f"Error: Invalid action {action}, expected {self.name.value[1]}",
HttpStatusCode.BAD_REQUEST,
)
if dbsession is None:
return AlpineBitsResponse(
"Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR
)
read_request = XmlParser().from_string(request_xml, OtaReadRq)
hotel_read_request = read_request.read_requests.hotel_read_request
hotelid = hotel_read_request.hotel_code
hotelname = hotel_read_request.hotel_name
if hotelname is None:
hotelname = "unknown"
if hotelid is None:
return AlpineBitsResponse(
f"Error: Unauthorized Read Request. No target hotel specified. Check credentials",
HttpStatusCode.UNAUTHORIZED,
)
if not validate_hotel_authentication(client_info.username, client_info.password, hotelid, self.config):
return AlpineBitsResponse(
f"Error: Unauthorized Read Request for this specific hotel {hotelname}. Check credentials",
HttpStatusCode.UNAUTHORIZED,
)
start_date = None
if hotel_read_request.selection_criteria is not None:
start_date = datetime.fromisoformat(
hotel_read_request.selection_criteria.start
)
# query all reservations for this hotel from the database, where start_date is greater than or equal to the given start_date
stmt = (
select(Reservation, Customer)
.join(Customer, Reservation.customer_id == Customer.id)
.filter(Reservation.hotel_code == hotelid)
)
if start_date:
stmt = stmt.filter(Reservation.start_date >= start_date)
result = await dbsession.execute(stmt)
reservation_customer_pairs: list[tuple[Reservation, Customer]] = (
result.all()
) # List of (Reservation, Customer) tuples
_LOGGER.info(
f"Querying reservations and customers for hotel {hotelid} from database"
)
for reservation, customer in reservation_customer_pairs:
_LOGGER.info(
f"Reservation: {reservation.id}, Customer: {customer.given_name}"
)
res_retrive_rs = create_xml_from_db(reservation_customer_pairs)
config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8"
)
serializer = XmlSerializer(config=config)
response_xml = serializer.render(
res_retrive_rs, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}
)
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class NotifReportReadAction(AlpineBitsAction):
"""Necessary for read action to follow specification. Clients need to report acknowledgements"""
def __init__(self, config: Dict = {}):
self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS
self.version = [Version.V2024_10, Version.V2022_10]
self.config = config
async def handle(
self,
action: str,
request_xml: str,
version: Version,
dbsession=None,
username=None,
password=None,
) -> AlpineBitsResponse:
"""Handle read requests."""
return AlpineBitsResponse(
f"Error: Action {action} not implemented", HttpStatusCode.BAD_REQUEST
)
class GuestRequestsAction(AlpineBitsAction):
"""Unimplemented action - will not appear in capabilities."""
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS
self.version = Version.V2024_10
# Note: This class doesn't override the handle method, so it won't be discovered
class AlpineBitsServer:
"""
Asynchronous AlpineBits server for handling hotel data exchange requests.
This server handles various OTA actions and implements the AlpineBits protocol
for hotel data exchange. It maintains a registry of supported actions and
their capabilities, and can respond to handshake requests with its capabilities.
"""
def __init__(self, config: Dict = None):
self.capabilities = ServerCapabilities()
self._action_instances = {}
self.config = config
self._initialize_action_instances()
def _initialize_action_instances(self):
"""Initialize instances of all discovered action classes."""
for capability_name, action_class in self.capabilities.action_registry.items():
self._action_instances[capability_name] = action_class(config=self.config)
def get_capabilities(self) -> Dict:
"""Get server capabilities."""
return self.capabilities.get_capabilities_dict()
def get_capabilities_json(self) -> str:
"""Get server capabilities as JSON."""
return self.capabilities.get_capabilities_json()
async def handle_request(
self,
request_action_name: str,
request_xml: str,
client_info: AlpineBitsClientInfo,
version: str = "2024-10",
dbsession=None,
) -> AlpineBitsResponse:
"""
Handle an incoming AlpineBits request by routing to appropriate action handler.
Args:
request_action_name: The action name from the request (e.g., "OTA_Read:GuestRequests")
request_xml: The XML request body
version: The AlpineBits version (defaults to "2024-10")
Returns:
AlpineBitsResponse with the result
"""
# Convert string version to enum
try:
version_enum = Version(version)
except ValueError:
return AlpineBitsResponse(
f"Error: Unsupported version {version}", HttpStatusCode.BAD_REQUEST
)
# Find the action by request name
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum:
return AlpineBitsResponse(
f"Error: Unknown action {request_action_name}",
HttpStatusCode.BAD_REQUEST,
)
# Check if we have an implementation for this action
capability_name = action_enum.capability_name
if capability_name not in self._action_instances:
return AlpineBitsResponse(
f"Error: Action {request_action_name} is not implemented",
HttpStatusCode.BAD_REQUEST,
)
action_instance: AlpineBitsAction = self._action_instances[capability_name]
# Check if the action supports the requested version
if not await action_instance.check_version_supported(version_enum):
return AlpineBitsResponse(
f"Error: Action {request_action_name} does not support version {version}",
HttpStatusCode.BAD_REQUEST,
)
# Handle the request
try:
# Special case for ping action - pass server capabilities
if capability_name == "action_OTA_Ping":
return await action_instance.handle(
action=request_action_name, request_xml=request_xml, version=version_enum, server_capabilities=self.capabilities, client_info=client_info
)
else:
return await action_instance.handle(
action=request_action_name,
request_xml=request_xml,
version=version_enum,
dbsession=dbsession,
client_info=client_info,
)
except Exception as e:
print(f"Error handling request {request_action_name}: {str(e)}")
# print stack trace for debugging
import traceback
traceback.print_exc()
return AlpineBitsResponse(
f"Error: Internal server error while processing {request_action_name}: {str(e)}",
HttpStatusCode.INTERNAL_SERVER_ERROR,
)
def get_supported_request_names(self) -> List[str]:
"""Get all supported request names (not capability names)."""
request_names = []
for capability_name in self._action_instances.keys():
action_enum = AlpineBitsActionName.get_by_capability_name(capability_name)
if action_enum:
request_names.append(action_enum.request_name)
return sorted(request_names)
def is_action_supported(
self, request_action_name: str, version: str | None = None
) -> bool:
"""
Check if a request action is supported.
Args:
request_action_name: The request action name (e.g., "OTA_Read:GuestRequests")
version: Optional version to check
Returns:
True if supported, False otherwise
"""
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum:
return False
capability_name = action_enum.capability_name
if capability_name not in self._action_instances:
return False
if version:
try:
version_enum = Version(version)
action_instance = self._action_instances[capability_name]
# This would need to be async, but for simplicity we'll just check if version exists
if isinstance(action_instance.version, list):
return version_enum in action_instance.version
else:
return action_instance.version == version_enum
except ValueError:
return False
return True
async def main():
"""Demonstrate the automatic capabilities discovery and request handling."""
print("🚀 AlpineBits Server Capabilities Discovery & Request Handling Demo")
print("=" * 70)
# Create server instance
server = AlpineBitsServer()
print("\n📋 Discovered Action Classes:")
print("-" * 30)
for capability_name, action_class in server.capabilities.action_registry.items():
action_enum = AlpineBitsActionName.get_by_capability_name(capability_name)
request_name = action_enum.request_name if action_enum else "unknown"
print(f"{capability_name} -> {action_class.__name__}")
print(f" Request name: {request_name}")
print(
f"\n📊 Total Implemented Actions: {len(server.capabilities.get_supported_actions())}"
)
print("\n🔍 Generated Capabilities JSON:")
print("-" * 30)
capabilities_json = server.get_capabilities_json()
print(capabilities_json)
print("\n🎯 Supported Request Names:")
print("-" * 30)
for request_name in server.get_supported_request_names():
print(f"{request_name}")
print("\n🧪 Testing Request Handling:")
print("-" * 30)
test_xml = "<test>sample request</test>"
# Test different request formats
test_cases = [
("OTA_Ping:Handshaking", "2024-10"),
("OTA_Read:GuestRequests", "2024-10"),
("OTA_Read:GuestRequests", "2022-10"),
("OTA_HotelAvailNotif", "2024-10"),
("UnknownAction", "2024-10"),
("OTA_Ping:Handshaking", "unsupported-version"),
]
for request_name, version in test_cases:
print(f"\n<EFBFBD> Testing: {request_name} (v{version})")
# Check if supported first
is_supported = server.is_action_supported(request_name, version)
print(f" Supported: {is_supported}")
# Handle the request
response = await server.handle_request(request_name, test_xml, version)
print(f" Status: {response.status_code}")
if len(response.xml_content) > 100:
print(f" Response: {response.xml_content[:100]}...")
else:
print(f" Response: {response.xml_content}")
print("\n✅ Demo completed successfully!")
if __name__ == "__main__":
asyncio.run(main())