Got db saving working

This commit is contained in:
Jonas Linter
2025-09-29 13:56:34 +02:00
parent 384fb2b558
commit 06739ebea9
21 changed files with 1188 additions and 830 deletions

View File

@@ -1,6 +1,7 @@
"""Entry point for alpine_bits_python package.""" """Entry point for alpine_bits_python package."""
from .main import main from .main import main
if __name__ == "__main__": if __name__ == "__main__":
print("running test main") print("running test main")
main() main()

View File

@@ -23,49 +23,65 @@ from xsdata_pydantic.bindings import XmlParser
class HttpStatusCode(IntEnum): class HttpStatusCode(IntEnum):
"""Allowed HTTP status codes for AlpineBits responses.""" """Allowed HTTP status codes for AlpineBits responses."""
OK = 200 OK = 200
BAD_REQUEST = 400 BAD_REQUEST = 400
UNAUTHORIZED = 401 UNAUTHORIZED = 401
INTERNAL_SERVER_ERROR = 500 INTERNAL_SERVER_ERROR = 500
class AlpineBitsActionName(Enum): class AlpineBitsActionName(Enum):
"""Enum for AlpineBits action names with capability and request name mappings.""" """Enum for AlpineBits action names with capability and request name mappings."""
# Format: (capability_name, actual_request_name) # Format: (capability_name, actual_request_name)
OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking") OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking")
OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests") OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests")
OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif") OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif")
OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = ("action_OTA_HotelResNotif_GuestRequests", OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = (
"OTA_HotelResNotif:GuestRequests") "action_OTA_HotelResNotif_GuestRequests",
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = ("action_OTA_HotelDescriptiveContentNotif_Inventory", "OTA_HotelResNotif:GuestRequests",
"OTA_HotelDescriptiveContentNotif:Inventory") )
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = ("action_OTA_HotelDescriptiveContentNotif_Info", OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = (
"OTA_HotelDescriptiveContentNotif:Info") "action_OTA_HotelDescriptiveContentNotif_Inventory",
OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = ("action_OTA_HotelDescriptiveInfo_Inventory", "OTA_HotelDescriptiveContentNotif:Inventory",
"OTA_HotelDescriptiveInfo:Inventory") )
OTA_HOTEL_DESCRIPTIVE_INFO_INFO = ("action_OTA_HotelDescriptiveInfo_Info", OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = (
"OTA_HotelDescriptiveInfo:Info") "action_OTA_HotelDescriptiveContentNotif_Info",
OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = ("action_OTA_HotelRatePlanNotif_RatePlans", "OTA_HotelDescriptiveContentNotif:Info",
"OTA_HotelRatePlanNotif:RatePlans") )
OTA_HOTEL_RATE_PLAN_BASE_RATES = ("action_OTA_HotelRatePlan_BaseRates", OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = (
"OTA_HotelRatePlan:BaseRates") "action_OTA_HotelDescriptiveInfo_Inventory",
"OTA_HotelDescriptiveInfo:Inventory",
)
OTA_HOTEL_DESCRIPTIVE_INFO_INFO = (
"action_OTA_HotelDescriptiveInfo_Info",
"OTA_HotelDescriptiveInfo:Info",
)
OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = (
"action_OTA_HotelRatePlanNotif_RatePlans",
"OTA_HotelRatePlanNotif:RatePlans",
)
OTA_HOTEL_RATE_PLAN_BASE_RATES = (
"action_OTA_HotelRatePlan_BaseRates",
"OTA_HotelRatePlan:BaseRates",
)
def __init__(self, capability_name: str, request_name: str): def __init__(self, capability_name: str, request_name: str):
self.capability_name = capability_name self.capability_name = capability_name
self.request_name = request_name self.request_name = request_name
@classmethod @classmethod
def get_by_capability_name(cls, capability_name: str) -> Optional['AlpineBitsActionName']: def get_by_capability_name(
cls, capability_name: str
) -> Optional["AlpineBitsActionName"]:
"""Get action enum by capability name.""" """Get action enum by capability name."""
for action in cls: for action in cls:
if action.capability_name == capability_name: if action.capability_name == capability_name:
return action return action
return None return None
@classmethod @classmethod
def get_by_request_name(cls, request_name: str) -> Optional['AlpineBitsActionName']: def get_by_request_name(cls, request_name: str) -> Optional["AlpineBitsActionName"]:
"""Get action enum by request name.""" """Get action enum by request name."""
for action in cls: for action in cls:
if action.request_name == request_name: if action.request_name == request_name:
@@ -75,22 +91,25 @@ class AlpineBitsActionName(Enum):
class Version(str, Enum): class Version(str, Enum):
"""Enum for AlpineBits versions.""" """Enum for AlpineBits versions."""
V2024_10 = "2024-10" V2024_10 = "2024-10"
V2022_10 = "2022-10" V2022_10 = "2022-10"
# Add other versions as needed # Add other versions as needed
@dataclass @dataclass
class AlpineBitsResponse: class AlpineBitsResponse:
"""Response data structure for AlpineBits actions.""" """Response data structure for AlpineBits actions."""
xml_content: str xml_content: str
status_code: HttpStatusCode = HttpStatusCode.OK status_code: HttpStatusCode = HttpStatusCode.OK
def __post_init__(self): def __post_init__(self):
"""Validate that status code is one of the allowed values.""" """Validate that status code is one of the allowed values."""
if self.status_code not in [200, 400, 401, 500]: if self.status_code not in [200, 400, 401, 500]:
raise ValueError(f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500") raise ValueError(
f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500"
)
# Abstract base class for AlpineBits Action # Abstract base class for AlpineBits Action
@@ -98,20 +117,24 @@ class AlpineBitsAction(ABC):
"""Abstract base class for handling AlpineBits actions.""" """Abstract base class for handling AlpineBits actions."""
name: AlpineBitsActionName name: AlpineBitsActionName
version: Version | list[Version] # list of versions in case action supports multiple versions version: (
Version | list[Version]
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: ) # list of versions in case action supports multiple versions
async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
""" """
Handle the incoming request XML and return response XML. Handle the incoming request XML and return response XML.
Default implementation returns "not implemented" error. Default implementation returns "not implemented" error.
Override this method in subclasses to provide actual functionality. Override this method in subclasses to provide actual functionality.
Args: Args:
action: The action to perform (e.g., "OTA_PingRQ") action: The action to perform (e.g., "OTA_PingRQ")
request_xml: The XML request body as string request_xml: The XML request body as string
version: The AlpineBits version version: The AlpineBits version
Returns: Returns:
AlpineBitsResponse with error or actual response AlpineBitsResponse with error or actual response
""" """
@@ -121,7 +144,7 @@ class AlpineBitsAction(ABC):
async def check_version_supported(self, version: Version) -> bool: async def check_version_supported(self, version: Version) -> bool:
""" """
Check if the action supports the given version. Check if the action supports the given version.
Args: Args:
version: The AlpineBits version to check version: The AlpineBits version to check
Returns: Returns:
@@ -130,103 +153,93 @@ class AlpineBitsAction(ABC):
if isinstance(self.version, list): if isinstance(self.version, list):
return version in self.version return version in self.version
return version == self.version return version == self.version
class ServerCapabilities: class ServerCapabilities:
""" """
Automatically discovers AlpineBitsAction implementations and generates capabilities. Automatically discovers AlpineBitsAction implementations and generates capabilities.
""" """
def __init__(self): def __init__(self):
self.action_registry: Dict[str, Type[AlpineBitsAction]] = {} self.action_registry: Dict[str, Type[AlpineBitsAction]] = {}
self._discover_actions() self._discover_actions()
self.capability_dict = None self.capability_dict = None
def _discover_actions(self): def _discover_actions(self):
"""Discover all AlpineBitsAction implementations in the current module.""" """Discover all AlpineBitsAction implementations in the current module."""
current_module = inspect.getmodule(self) current_module = inspect.getmodule(self)
for name, obj in inspect.getmembers(current_module): for name, obj in inspect.getmembers(current_module):
if (inspect.isclass(obj) and if (
issubclass(obj, AlpineBitsAction) and inspect.isclass(obj)
obj != AlpineBitsAction): and issubclass(obj, AlpineBitsAction)
and obj != AlpineBitsAction
):
# Check if this action is actually implemented (not just returning default) # Check if this action is actually implemented (not just returning default)
if self._is_action_implemented(obj): if self._is_action_implemented(obj):
action_instance = obj() action_instance = obj()
if hasattr(action_instance, 'name'): if hasattr(action_instance, "name"):
# Use capability name for the registry key # Use capability name for the registry key
self.action_registry[action_instance.name.capability_name] = obj self.action_registry[action_instance.name.capability_name] = obj
def _is_action_implemented(self, action_class: Type[AlpineBitsAction]) -> bool: def _is_action_implemented(self, action_class: Type[AlpineBitsAction]) -> bool:
""" """
Check if an action is actually implemented or just uses the default behavior. Check if an action is actually implemented or just uses the default behavior.
This is a simple check - in practice, you might want more sophisticated detection. This is a simple check - in practice, you might want more sophisticated detection.
""" """
# Check if the class has overridden the handle method # Check if the class has overridden the handle method
if 'handle' in action_class.__dict__: if "handle" in action_class.__dict__:
return True return True
return False return False
def create_capabilities_dict(self) -> None: def create_capabilities_dict(self) -> None:
""" """
Generate the capabilities dictionary based on discovered actions. Generate the capabilities dictionary based on discovered actions.
""" """
versions_dict = {} versions_dict = {}
for action_name, action_class in self.action_registry.items(): for action_name, action_class in self.action_registry.items():
action_instance = action_class() action_instance = action_class()
# Get supported versions for this action # Get supported versions for this action
if isinstance(action_instance.version, list): if isinstance(action_instance.version, list):
supported_versions = action_instance.version supported_versions = action_instance.version
else: else:
supported_versions = [action_instance.version] supported_versions = [action_instance.version]
# Add action to each supported version # Add action to each supported version
for version in supported_versions: for version in supported_versions:
version_str = version.value version_str = version.value
if version_str not in versions_dict: if version_str not in versions_dict:
versions_dict[version_str] = { versions_dict[version_str] = {"version": version_str, "actions": []}
"version": version_str,
"actions": []
}
action_dict = {"action": action_name} action_dict = {"action": action_name}
# Add supports field if the action has custom supports # Add supports field if the action has custom supports
if hasattr(action_instance, 'supports') and action_instance.supports: if hasattr(action_instance, "supports") and action_instance.supports:
action_dict["supports"] = action_instance.supports action_dict["supports"] = action_instance.supports
versions_dict[version_str]["actions"].append(action_dict) versions_dict[version_str]["actions"].append(action_dict)
self.capability_dict = {"versions": list(versions_dict.values())} self.capability_dict = {"versions": list(versions_dict.values())}
return None return None
def get_capabilities_dict(self) -> Dict: def get_capabilities_dict(self) -> Dict:
""" """
Get capabilities as a dictionary. Generates if not already created. Get capabilities as a dictionary. Generates if not already created.
""" """
if self.capability_dict is None: if self.capability_dict is None:
self.create_capabilities_dict() self.create_capabilities_dict()
return self.capability_dict return self.capability_dict
def get_capabilities_json(self) -> str: def get_capabilities_json(self) -> str:
"""Get capabilities as formatted JSON string.""" """Get capabilities as formatted JSON string."""
return json.dumps(self.get_capabilities_dict(), indent=2) return json.dumps(self.get_capabilities_dict(), indent=2)
def get_supported_actions(self) -> List[str]: def get_supported_actions(self) -> List[str]:
"""Get list of all supported action names.""" """Get list of all supported action names."""
return list(self.action_registry.keys()) return list(self.action_registry.keys())
@@ -234,22 +247,35 @@ class ServerCapabilities:
# Sample Action Implementations for demonstration # Sample Action Implementations for demonstration
class PingAction(AlpineBitsAction): class PingAction(AlpineBitsAction):
"""Implementation for OTA_Ping action (handshaking).""" """Implementation for OTA_Ping action (handshaking)."""
def __init__(self): def __init__(self):
self.name = AlpineBitsActionName.OTA_PING self.name = AlpineBitsActionName.OTA_PING
self.version = [Version.V2024_10, Version.V2022_10] # Supports multiple versions self.version = [
Version.V2024_10,
async def handle(self, action: str, request_xml: str, version: Version, server_capabilities: None | ServerCapabilities = None) -> AlpineBitsResponse: Version.V2022_10,
] # Supports multiple versions
async def handle(
self,
action: str,
request_xml: str,
version: Version,
server_capabilities: None | ServerCapabilities = None,
) -> AlpineBitsResponse:
"""Handle ping requests.""" """Handle ping requests."""
if request_xml is None: if request_xml is None:
return AlpineBitsResponse(f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST) return AlpineBitsResponse(
f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST
)
if server_capabilities is None: if server_capabilities is None:
return AlpineBitsResponse("Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR) return AlpineBitsResponse(
"Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR
)
# Parse the incoming request XML and extract EchoData # Parse the incoming request XML and extract EchoData
parser = XmlParser() parser = XmlParser()
@@ -259,54 +285,66 @@ class PingAction(AlpineBitsAction):
echo_data = json.loads(parsed_request.echo_data) echo_data = json.loads(parsed_request.echo_data)
except Exception as e: except Exception as e:
return AlpineBitsResponse(f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST) return AlpineBitsResponse(
f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST
)
# compare echo data with capabilities, create a dictionary containing the matching capabilities # compare echo data with capabilities, create a dictionary containing the matching capabilities
capabilities_dict = server_capabilities.get_capabilities_dict() capabilities_dict = server_capabilities.get_capabilities_dict()
matching_capabilities = {"versions": []} matching_capabilities = {"versions": []}
# Iterate through client's requested versions # Iterate through client's requested versions
for client_version in echo_data.get("versions", []): for client_version in echo_data.get("versions", []):
client_version_str = client_version.get("version", "") client_version_str = client_version.get("version", "")
# Find matching server version # Find matching server version
for server_version in capabilities_dict["versions"]: for server_version in capabilities_dict["versions"]:
if server_version["version"] == client_version_str: if server_version["version"] == client_version_str:
# Found a matching version, now find common actions # Found a matching version, now find common actions
matching_version = { matching_version = {"version": client_version_str, "actions": []}
"version": client_version_str,
"actions": []
}
# Get client's requested actions for this version # Get client's requested actions for this version
client_actions = {action.get("action", ""): action for action in client_version.get("actions", [])} client_actions = {
server_actions = {action.get("action", ""): action for action in server_version.get("actions", [])} action.get("action", ""): action
for action in client_version.get("actions", [])
}
server_actions = {
action.get("action", ""): action
for action in server_version.get("actions", [])
}
# Find common actions # Find common actions
for action_name in client_actions: for action_name in client_actions:
if action_name in server_actions: if action_name in server_actions:
# Use server's action definition (includes our supports) # Use server's action definition (includes our supports)
matching_version["actions"].append(server_actions[action_name]) matching_version["actions"].append(
server_actions[action_name]
)
# Only add version if there are common actions # Only add version if there are common actions
if matching_version["actions"]: if matching_version["actions"]:
matching_capabilities["versions"].append(matching_version) matching_capabilities["versions"].append(matching_version)
break break
# Debug print to see what we matched # Debug print to see what we matched
# Create successful ping response with matched capabilities # Create successful ping response with matched capabilities
capabilities_json = json.dumps(matching_capabilities, indent=2) capabilities_json = json.dumps(matching_capabilities, indent=2)
warning = OtaPingRs.Warnings.Warning(status=WarningStatus.ALPINEBITS_HANDSHAKE.value, type_value="11", content=[capabilities_json]) warning = OtaPingRs.Warnings.Warning(
status=WarningStatus.ALPINEBITS_HANDSHAKE.value,
type_value="11",
content=[capabilities_json],
)
warning_response = OtaPingRs.Warnings(warning=[warning]) warning_response = OtaPingRs.Warnings(warning=[warning])
response_ota_ping = OtaPingRs(version= "7.000", warnings=warning_response, echo_data=capabilities_json, success="") response_ota_ping = OtaPingRs(
version="7.000",
warnings=warning_response,
echo_data=capabilities_json,
success="",
)
config = SerializerConfig( config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8" pretty_print=True, xml_declaration=True, encoding="UTF-8"
@@ -314,34 +352,35 @@ class PingAction(AlpineBitsAction):
serializer = XmlSerializer(config=config) serializer = XmlSerializer(config=config)
response_xml = serializer.render(response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}) response_xml = serializer.render(
response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}
)
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class ReadAction(AlpineBitsAction): class ReadAction(AlpineBitsAction):
"""Implementation for OTA_Read action.""" """Implementation for OTA_Read action."""
def __init__(self): def __init__(self):
self.name = AlpineBitsActionName.OTA_READ self.name = AlpineBitsActionName.OTA_READ
self.version = [Version.V2024_10, Version.V2022_10] self.version = [Version.V2024_10, Version.V2022_10]
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""Handle read requests.""" """Handle read requests."""
response_xml = f'''<?xml version="1.0" encoding="UTF-8"?> response_xml = f"""<?xml version="1.0" encoding="UTF-8"?>
<OTA_ReadRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000"> <OTA_ReadRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000">
<Success/> <Success/>
<Data>Read operation successful for {version.value}</Data> <Data>Read operation successful for {version.value}</Data>
</OTA_ReadRS>''' </OTA_ReadRS>"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class HotelAvailNotifAction(AlpineBitsAction): class HotelAvailNotifAction(AlpineBitsAction):
"""Implementation for Hotel Availability Notification action with supports.""" """Implementation for Hotel Availability Notification action with supports."""
def __init__(self): def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_AVAIL_NOTIF self.name = AlpineBitsActionName.OTA_HOTEL_AVAIL_NOTIF
self.version = Version.V2022_10 self.version = Version.V2022_10
@@ -349,68 +388,68 @@ class HotelAvailNotifAction(AlpineBitsAction):
"OTA_HotelAvailNotif_accept_rooms", "OTA_HotelAvailNotif_accept_rooms",
"OTA_HotelAvailNotif_accept_categories", "OTA_HotelAvailNotif_accept_categories",
"OTA_HotelAvailNotif_accept_deltas", "OTA_HotelAvailNotif_accept_deltas",
"OTA_HotelAvailNotif_accept_BookingThreshold" "OTA_HotelAvailNotif_accept_BookingThreshold",
] ]
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""Handle hotel availability notifications.""" """Handle hotel availability notifications."""
response_xml = '''<?xml version="1.0" encoding="UTF-8"?> response_xml = """<?xml version="1.0" encoding="UTF-8"?>
<OTA_HotelAvailNotifRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000"> <OTA_HotelAvailNotifRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000">
<Success/> <Success/>
</OTA_HotelAvailNotifRS>''' </OTA_HotelAvailNotifRS>"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class GuestRequestsAction(AlpineBitsAction): class GuestRequestsAction(AlpineBitsAction):
"""Unimplemented action - will not appear in capabilities.""" """Unimplemented action - will not appear in capabilities."""
def __init__(self): def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS
self.version = Version.V2024_10 self.version = Version.V2024_10
# Note: This class doesn't override the handle method, so it won't be discovered # Note: This class doesn't override the handle method, so it won't be discovered
class AlpineBitsServer: class AlpineBitsServer:
""" """
Asynchronous AlpineBits server for handling hotel data exchange requests. Asynchronous AlpineBits server for handling hotel data exchange requests.
This server handles various OTA actions and implements the AlpineBits protocol This server handles various OTA actions and implements the AlpineBits protocol
for hotel data exchange. It maintains a registry of supported actions and for hotel data exchange. It maintains a registry of supported actions and
their capabilities, and can respond to handshake requests with its capabilities. their capabilities, and can respond to handshake requests with its capabilities.
""" """
def __init__(self): def __init__(self):
self.capabilities = ServerCapabilities() self.capabilities = ServerCapabilities()
self._action_instances = {} self._action_instances = {}
self._initialize_action_instances() self._initialize_action_instances()
def _initialize_action_instances(self): def _initialize_action_instances(self):
"""Initialize instances of all discovered action classes.""" """Initialize instances of all discovered action classes."""
for capability_name, action_class in self.capabilities.action_registry.items(): for capability_name, action_class in self.capabilities.action_registry.items():
self._action_instances[capability_name] = action_class() self._action_instances[capability_name] = action_class()
def get_capabilities(self) -> Dict: def get_capabilities(self) -> Dict:
"""Get server capabilities.""" """Get server capabilities."""
return self.capabilities.get_capabilities_dict() return self.capabilities.get_capabilities_dict()
def get_capabilities_json(self) -> str: def get_capabilities_json(self) -> str:
"""Get server capabilities as JSON.""" """Get server capabilities as JSON."""
return self.capabilities.get_capabilities_json() return self.capabilities.get_capabilities_json()
async def handle_request(self, request_action_name: str, request_xml: str, version: str = "2024-10") -> AlpineBitsResponse: async def handle_request(
self, request_action_name: str, request_xml: str, version: str = "2024-10"
) -> AlpineBitsResponse:
""" """
Handle an incoming AlpineBits request by routing to appropriate action handler. Handle an incoming AlpineBits request by routing to appropriate action handler.
Args: Args:
request_action_name: The action name from the request (e.g., "OTA_Read:GuestRequests") request_action_name: The action name from the request (e.g., "OTA_Read:GuestRequests")
request_xml: The XML request body request_xml: The XML request body
version: The AlpineBits version (defaults to "2024-10") version: The AlpineBits version (defaults to "2024-10")
Returns: Returns:
AlpineBitsResponse with the result AlpineBitsResponse with the result
""" """
@@ -419,52 +458,56 @@ class AlpineBitsServer:
version_enum = Version(version) version_enum = Version(version)
except ValueError: except ValueError:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unsupported version {version}", f"Error: Unsupported version {version}", HttpStatusCode.BAD_REQUEST
HttpStatusCode.BAD_REQUEST
) )
# Find the action by request name # Find the action by request name
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name) action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum: if not action_enum:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unknown action {request_action_name}", f"Error: Unknown action {request_action_name}",
HttpStatusCode.BAD_REQUEST HttpStatusCode.BAD_REQUEST,
) )
# Check if we have an implementation for this action # Check if we have an implementation for this action
capability_name = action_enum.capability_name capability_name = action_enum.capability_name
if capability_name not in self._action_instances: if capability_name not in self._action_instances:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Action {request_action_name} is not implemented", f"Error: Action {request_action_name} is not implemented",
HttpStatusCode.BAD_REQUEST HttpStatusCode.BAD_REQUEST,
) )
action_instance: AlpineBitsAction = self._action_instances[capability_name] action_instance: AlpineBitsAction = self._action_instances[capability_name]
# Check if the action supports the requested version # Check if the action supports the requested version
if not await action_instance.check_version_supported(version_enum): if not await action_instance.check_version_supported(version_enum):
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Action {request_action_name} does not support version {version}", f"Error: Action {request_action_name} does not support version {version}",
HttpStatusCode.BAD_REQUEST HttpStatusCode.BAD_REQUEST,
) )
# Handle the request # Handle the request
try: try:
# Special case for ping action - pass server capabilities # Special case for ping action - pass server capabilities
if capability_name == "action_OTA_Ping": if capability_name == "action_OTA_Ping":
return await action_instance.handle(request_action_name, request_xml, version_enum, self.capabilities) return await action_instance.handle(
request_action_name, request_xml, version_enum, self.capabilities
)
else: else:
return await action_instance.handle(request_action_name, request_xml, version_enum) return await action_instance.handle(
request_action_name, request_xml, version_enum
)
except Exception as e: except Exception as e:
print(f"Error handling request {request_action_name}: {str(e)}") print(f"Error handling request {request_action_name}: {str(e)}")
# print stack trace for debugging # print stack trace for debugging
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Internal server error while processing {request_action_name}: {str(e)}", f"Error: Internal server error while processing {request_action_name}: {str(e)}",
HttpStatusCode.INTERNAL_SERVER_ERROR HttpStatusCode.INTERNAL_SERVER_ERROR,
) )
def get_supported_request_names(self) -> List[str]: def get_supported_request_names(self) -> List[str]:
"""Get all supported request names (not capability names).""" """Get all supported request names (not capability names)."""
request_names = [] request_names = []
@@ -473,26 +516,28 @@ class AlpineBitsServer:
if action_enum: if action_enum:
request_names.append(action_enum.request_name) request_names.append(action_enum.request_name)
return sorted(request_names) return sorted(request_names)
def is_action_supported(self, request_action_name: str, version: str = None) -> bool: def is_action_supported(
self, request_action_name: str, version: str = None
) -> bool:
""" """
Check if a request action is supported. Check if a request action is supported.
Args: Args:
request_action_name: The request action name (e.g., "OTA_Read:GuestRequests") request_action_name: The request action name (e.g., "OTA_Read:GuestRequests")
version: Optional version to check version: Optional version to check
Returns: Returns:
True if supported, False otherwise True if supported, False otherwise
""" """
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name) action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum: if not action_enum:
return False return False
capability_name = action_enum.capability_name capability_name = action_enum.capability_name
if capability_name not in self._action_instances: if capability_name not in self._action_instances:
return False return False
if version: if version:
try: try:
version_enum = Version(version) version_enum = Version(version)
@@ -504,7 +549,7 @@ class AlpineBitsServer:
return action_instance.version == version_enum return action_instance.version == version_enum
except ValueError: except ValueError:
return False return False
return True return True
@@ -512,10 +557,10 @@ async def main():
"""Demonstrate the automatic capabilities discovery and request handling.""" """Demonstrate the automatic capabilities discovery and request handling."""
print("🚀 AlpineBits Server Capabilities Discovery & Request Handling Demo") print("🚀 AlpineBits Server Capabilities Discovery & Request Handling Demo")
print("=" * 70) print("=" * 70)
# Create server instance # Create server instance
server = AlpineBitsServer() server = AlpineBitsServer()
print("\n📋 Discovered Action Classes:") print("\n📋 Discovered Action Classes:")
print("-" * 30) print("-" * 30)
for capability_name, action_class in server.capabilities.action_registry.items(): for capability_name, action_class in server.capabilities.action_registry.items():
@@ -523,24 +568,26 @@ async def main():
request_name = action_enum.request_name if action_enum else "unknown" request_name = action_enum.request_name if action_enum else "unknown"
print(f"{capability_name} -> {action_class.__name__}") print(f"{capability_name} -> {action_class.__name__}")
print(f" Request name: {request_name}") print(f" Request name: {request_name}")
print(f"\n📊 Total Implemented Actions: {len(server.capabilities.get_supported_actions())}") print(
f"\n📊 Total Implemented Actions: {len(server.capabilities.get_supported_actions())}"
)
print("\n🔍 Generated Capabilities JSON:") print("\n🔍 Generated Capabilities JSON:")
print("-" * 30) print("-" * 30)
capabilities_json = server.get_capabilities_json() capabilities_json = server.get_capabilities_json()
print(capabilities_json) print(capabilities_json)
print("\n🎯 Supported Request Names:") print("\n🎯 Supported Request Names:")
print("-" * 30) print("-" * 30)
for request_name in server.get_supported_request_names(): for request_name in server.get_supported_request_names():
print(f"{request_name}") print(f"{request_name}")
print("\n🧪 Testing Request Handling:") print("\n🧪 Testing Request Handling:")
print("-" * 30) print("-" * 30)
test_xml = "<test>sample request</test>" test_xml = "<test>sample request</test>"
# Test different request formats # Test different request formats
test_cases = [ test_cases = [
("OTA_Ping:Handshaking", "2024-10"), ("OTA_Ping:Handshaking", "2024-10"),
@@ -548,16 +595,16 @@ async def main():
("OTA_Read:GuestRequests", "2022-10"), ("OTA_Read:GuestRequests", "2022-10"),
("OTA_HotelAvailNotif", "2024-10"), ("OTA_HotelAvailNotif", "2024-10"),
("UnknownAction", "2024-10"), ("UnknownAction", "2024-10"),
("OTA_Ping:Handshaking", "unsupported-version") ("OTA_Ping:Handshaking", "unsupported-version"),
] ]
for request_name, version in test_cases: for request_name, version in test_cases:
print(f"\n<EFBFBD> Testing: {request_name} (v{version})") print(f"\n<EFBFBD> Testing: {request_name} (v{version})")
# Check if supported first # Check if supported first
is_supported = server.is_action_supported(request_name, version) is_supported = server.is_action_supported(request_name, version)
print(f" Supported: {is_supported}") print(f" Supported: {is_supported}")
# Handle the request # Handle the request
response = await server.handle_request(request_name, test_xml, version) response = await server.handle_request(request_name, test_xml, version)
print(f" Status: {response.status_code}") print(f" Status: {response.status_code}")
@@ -565,9 +612,9 @@ async def main():
print(f" Response: {response.xml_content[:100]}...") print(f" Response: {response.xml_content[:100]}...")
else: else:
print(f" Response: {response.xml_content}") print(f" Response: {response.xml_content}")
print("\n✅ Demo completed successfully!") print("\n✅ Demo completed successfully!")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,18 +1,29 @@
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, APIRouter, Form, File, UploadFile from fastapi import (
FastAPI,
HTTPException,
BackgroundTasks,
Request,
Depends,
APIRouter,
Form,
File,
UploadFile,
)
from fastapi.concurrency import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic
from .config_loader import load_config from .config_loader import load_config
from fastapi.responses import HTMLResponse, PlainTextResponse, Response from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from .models import WixFormSubmission from .models import WixFormSubmission
from datetime import datetime, date, timezone from datetime import datetime, date, timezone
from .auth import validate_api_key, validate_wix_signature, generate_api_key from .auth import validate_api_key, validate_wix_signature, generate_api_key
from .rate_limit import ( from .rate_limit import (
limiter, limiter,
webhook_limiter, webhook_limiter,
custom_rate_limit_handler, custom_rate_limit_handler,
DEFAULT_RATE_LIMIT, DEFAULT_RATE_LIMIT,
WEBHOOK_RATE_LIMIT, WEBHOOK_RATE_LIMIT,
BURST_RATE_LIMIT BURST_RATE_LIMIT,
) )
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
import logging import logging
@@ -24,8 +35,14 @@ import gzip
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from .alpinebits_server import AlpineBitsServer, Version from .alpinebits_server import AlpineBitsServer, Version
import urllib.parse import urllib.parse
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from .db import get_async_session, Customer as DBCustomer, Reservation as DBReservation from .db import (
Base,
Customer as DBCustomer,
Reservation as DBReservation,
get_database_url,
)
# Configure logging # Configure logging
@@ -42,12 +59,36 @@ except Exception as e:
_LOGGER.error(f"Failed to load config: {str(e)}") _LOGGER.error(f"Failed to load config: {str(e)}")
config = {} config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup DB
DATABASE_URL = get_database_url(config)
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
app.state.engine = engine
app.state.async_sessionmaker = AsyncSessionLocal
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
_LOGGER.info("Database tables checked/created at startup.")
yield
# Optional: Dispose engine on shutdown
await engine.dispose()
async def get_async_session(request: Request):
async_sessionmaker = request.app.state.async_sessionmaker
async with async_sessionmaker() as session:
yield session
app = FastAPI( app = FastAPI(
title="Wix Form Handler API", title="Wix Form Handler API",
description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting", description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting",
version="1.0.0" version="1.0.0",
lifespan=lifespan
) )
# Create API router with /api prefix # Create API router with /api prefix
@@ -62,9 +103,9 @@ app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[ allow_origins=[
"https://*.wix.com", "https://*.wix.com",
"https://*.wixstatic.com", "https://*.wixstatic.com",
"http://localhost:3000", # For development "http://localhost:3000", # For development
"http://localhost:8000" # For local testing "http://localhost:8000", # For local testing
], ],
allow_credentials=True, allow_credentials=True,
allow_methods=["GET", "POST"], allow_methods=["GET", "POST"],
@@ -78,27 +119,39 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None:
Add your business logic here. Add your business logic here.
""" """
try: try:
_LOGGER.info(f"Processing form submission: {submission_data.get('submissionId')}") _LOGGER.info(
f"Processing form submission: {submission_data.get('submissionId')}"
)
# Example processing - you can replace this with your actual logic # Example processing - you can replace this with your actual logic
form_name = submission_data.get('formName') form_name = submission_data.get("formName")
contact_email = submission_data.get('contact', {}).get('email') if submission_data.get('contact') else None contact_email = (
submission_data.get("contact", {}).get("email")
if submission_data.get("contact")
else None
)
# Extract form fields # Extract form fields
form_fields = {k: v for k, v in submission_data.items() if k.startswith('field:')} form_fields = {
k: v for k, v in submission_data.items() if k.startswith("field:")
_LOGGER.info(f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}") }
_LOGGER.info(
f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}"
)
# Here you could: # Here you could:
# - Save to database # - Save to database
# - Send emails # - Send emails
# - Call external APIs # - Call external APIs
# - Process the data further # - Process the data further
except Exception as e: except Exception as e:
_LOGGER.error(f"Error processing form submission: {str(e)}") _LOGGER.error(f"Error processing form submission: {str(e)}")
@api_router.get("/") @api_router.get("/")
@limiter.limit(DEFAULT_RATE_LIMIT) @limiter.limit(DEFAULT_RATE_LIMIT)
async def root(request: Request): async def root(request: Request):
@@ -111,8 +164,8 @@ async def root(request: Request):
"rate_limits": { "rate_limits": {
"default": DEFAULT_RATE_LIMIT, "default": DEFAULT_RATE_LIMIT,
"webhook": WEBHOOK_RATE_LIMIT, "webhook": WEBHOOK_RATE_LIMIT,
"burst": BURST_RATE_LIMIT "burst": BURST_RATE_LIMIT,
} },
} }
@@ -126,11 +179,10 @@ async def health_check(request: Request):
"service": "wix-form-handler", "service": "wix-form-handler",
"version": "1.0.0", "version": "1.0.0",
"authentication": "enabled", "authentication": "enabled",
"rate_limiting": "enabled" "rate_limiting": "enabled",
} }
# Extracted business logic for handling Wix form submissions # Extracted business logic for handling Wix form submissions
async def process_wix_form_submission(request: Request, data: Dict[str, Any], db): async def process_wix_form_submission(request: Request, data: Dict[str, Any], db):
""" """
@@ -138,10 +190,9 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
""" """
timestamp = datetime.now().isoformat() timestamp = datetime.now().isoformat()
_LOGGER.info(f"Received Wix form data at {timestamp}") _LOGGER.info(f"Received Wix form data at {timestamp}")
#_LOGGER.info(f"Data keys: {list(data.keys())}") # _LOGGER.info(f"Data keys: {list(data.keys())}")
#_LOGGER.info(f"Full data: {json.dumps(data, indent=2)}") # _LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
log_entry = { log_entry = {
"timestamp": timestamp, "timestamp": timestamp,
"client_ip": request.client.host if request.client else "unknown", "client_ip": request.client.host if request.client else "unknown",
@@ -154,9 +205,13 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
if not os.path.exists(logs_dir): if not os.path.exists(logs_dir):
os.makedirs(logs_dir, mode=0o755, exist_ok=True) os.makedirs(logs_dir, mode=0o755, exist_ok=True)
stat_info = os.stat(logs_dir) stat_info = os.stat(logs_dir)
_LOGGER.info(f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}") _LOGGER.info(
f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}"
)
_LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}") _LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}")
log_filename = f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" log_filename = (
f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
with open(log_filename, "w", encoding="utf-8") as f: with open(log_filename, "w", encoding="utf-8") as f:
json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False) json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False)
file_stat = os.stat(log_filename) file_stat = os.stat(log_filename)
@@ -164,16 +219,10 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
_LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}") _LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}")
_LOGGER.info(f"Data logged to: {log_filename}") _LOGGER.info(f"Data logged to: {log_filename}")
data = data.get("data") # Handle nested "data" key if present data = data.get("data") # Handle nested "data" key if present
# save customer and reservation to DB # save customer and reservation to DB
contact_info = data.get("contact", {}) contact_info = data.get("contact", {})
first_name = contact_info.get("name", {}).get("first") first_name = contact_info.get("name", {}).get("first")
last_name = contact_info.get("name", {}).get("last") last_name = contact_info.get("name", {}).get("last")
@@ -193,10 +242,18 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
language = data.get("contact", {}).get("locale", "en")[:2] language = data.get("contact", {}).get("locale", "en")[:2]
# Dates # Dates
start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value") start_date = (
end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value") data.get("field:date_picker_a7c8")
or data.get("Anreisedatum")
or data.get("submissions", [{}])[1].get("value")
)
end_date = (
data.get("field:date_picker_7e65")
or data.get("Abreisedatum")
or data.get("submissions", [{}])[2].get("value")
)
# Room/guest info # Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2) num_adults = int(data.get("field:number_7cf5") or 2)
num_children = int(data.get("field:anzahl_kinder") or 0) num_children = int(data.get("field:anzahl_kinder") or 0)
children_ages = [] children_ages = []
@@ -258,7 +315,7 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
end_date=date.fromisoformat(end_date) if end_date else None, end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults, num_adults=num_adults,
num_children=num_children, num_children=num_children,
children_ages=','.join(str(a) for a in children_ages), children_ages=",".join(str(a) for a in children_ages),
offer=offer, offer=offer,
utm_comment=utm_comment, utm_comment=utm_comment,
created_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc),
@@ -277,23 +334,21 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
await db.commit() await db.commit()
await db.refresh(db_reservation) await db.refresh(db_reservation)
return { return {
"status": "success", "status": "success",
"message": "Wix form data received successfully", "message": "Wix form data received successfully",
"received_keys": list(data.keys()), "received_keys": list(data.keys()),
"data_logged_to": log_filename, "data_logged_to": log_filename,
"timestamp": timestamp, "timestamp": timestamp,
"process_info": log_entry["process_info"], "note": "No authentication required for this endpoint",
"note": "No authentication required for this endpoint"
} }
@api_router.post("/webhook/wix-form") @api_router.post("/webhook/wix-form")
@webhook_limiter.limit(WEBHOOK_RATE_LIMIT) @webhook_limiter.limit(WEBHOOK_RATE_LIMIT)
async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)): async def handle_wix_form(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
""" """
Unified endpoint to handle Wix form submissions (test and production). Unified endpoint to handle Wix form submissions (test and production).
No authentication required for this endpoint. No authentication required for this endpoint.
@@ -304,16 +359,19 @@ async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Dep
_LOGGER.error(f"Error in handle_wix_form: {str(e)}") _LOGGER.error(f"Error in handle_wix_form: {str(e)}")
# log stacktrace # log stacktrace
import traceback import traceback
traceback_str = traceback.format_exc() traceback_str = traceback.format_exc()
_LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}") _LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Error processing Wix form data: {str(e)}"
detail=f"Error processing Wix form data: {str(e)}"
) )
@api_router.post("/webhook/wix-form/test") @api_router.post("/webhook/wix-form/test")
@limiter.limit(DEFAULT_RATE_LIMIT) @limiter.limit(DEFAULT_RATE_LIMIT)
async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session=Depends(get_async_session)): async def handle_wix_form_test(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
""" """
Test endpoint to verify the API is working with raw JSON data. Test endpoint to verify the API is working with raw JSON data.
No authentication required for testing purposes. No authentication required for testing purposes.
@@ -323,40 +381,37 @@ async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session
except Exception as e: except Exception as e:
_LOGGER.error(f"Error in handle_wix_form_test: {str(e)}") _LOGGER.error(f"Error in handle_wix_form_test: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Error processing test data: {str(e)}"
detail=f"Error processing test data: {str(e)}"
) )
@api_router.post("/admin/generate-api-key") @api_router.post("/admin/generate-api-key")
@limiter.limit("5/hour") # Very restrictive for admin operations @limiter.limit("5/hour") # Very restrictive for admin operations
async def generate_new_api_key( async def generate_new_api_key(
request: Request, request: Request, admin_key: str = Depends(validate_api_key)
admin_key: str = Depends(validate_api_key)
): ):
""" """
Admin endpoint to generate new API keys. Admin endpoint to generate new API keys.
Requires admin API key and is heavily rate limited. Requires admin API key and is heavily rate limited.
""" """
if admin_key != "admin-key": if admin_key != "admin-key":
raise HTTPException( raise HTTPException(status_code=403, detail="Admin access required")
status_code=403,
detail="Admin access required"
)
new_key = generate_api_key() new_key = generate_api_key()
_LOGGER.info(f"Generated new API key (requested by: {admin_key})") _LOGGER.info(f"Generated new API key (requested by: {admin_key})")
return { return {
"status": "success", "status": "success",
"message": "New API key generated", "message": "New API key generated",
"api_key": new_key, "api_key": new_key,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"note": "Store this key securely - it won't be shown again" "note": "Store this key securely - it won't be shown again",
} }
async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(security_basic)) -> str: async def validate_basic_auth(
credentials: HTTPBasicCredentials = Depends(security_basic),
) -> str:
""" """
Validate basic authentication for AlpineBits protocol. Validate basic authentication for AlpineBits protocol.
Returns username if valid, raises HTTPException if not. Returns username if valid, raises HTTPException if not.
@@ -369,8 +424,11 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
headers={"WWW-Authenticate": "Basic"}, headers={"WWW-Authenticate": "Basic"},
) )
valid = False valid = False
for entry in config['alpine_bits_auth']: for entry in config["alpine_bits_auth"]:
if credentials.username == entry['username'] and credentials.password == entry['password']: if (
credentials.username == entry["username"]
and credentials.password == entry["password"]
):
valid = True valid = True
break break
if not valid: if not valid:
@@ -379,7 +437,9 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
detail="ERROR: Invalid credentials", detail="ERROR: Invalid credentials",
headers={"WWW-Authenticate": "Basic"}, headers={"WWW-Authenticate": "Basic"},
) )
_LOGGER.info(f"AlpineBits authentication successful for user: {credentials.username} (from config)") _LOGGER.info(
f"AlpineBits authentication successful for user: {credentials.username} (from config)"
)
return credentials.username return credentials.username
@@ -390,10 +450,9 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
""" """
if "multipart/form-data" not in content_type: if "multipart/form-data" not in content_type:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="ERROR: Content-Type must be multipart/form-data"
detail="ERROR: Content-Type must be multipart/form-data"
) )
# Extract boundary # Extract boundary
boundary = None boundary = None
for part in content_type.split(";"): for part in content_type.split(";"):
@@ -401,62 +460,56 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
if part.startswith("boundary="): if part.startswith("boundary="):
boundary = part.split("=", 1)[1].strip('"') boundary = part.split("=", 1)[1].strip('"')
break break
if not boundary: if not boundary:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="ERROR: Missing boundary in multipart/form-data"
detail="ERROR: Missing boundary in multipart/form-data"
) )
# Simple multipart parsing # Simple multipart parsing
parts = body.split(f"--{boundary}".encode()) parts = body.split(f"--{boundary}".encode())
data = {} data = {}
for part in parts: for part in parts:
if not part.strip() or part.strip() == b"--": if not part.strip() or part.strip() == b"--":
continue continue
# Split headers and content # Split headers and content
if b"\r\n\r\n" in part: if b"\r\n\r\n" in part:
headers_section, content = part.split(b"\r\n\r\n", 1) headers_section, content = part.split(b"\r\n\r\n", 1)
content = content.rstrip(b"\r\n") content = content.rstrip(b"\r\n")
# Parse Content-Disposition header # Parse Content-Disposition header
headers = headers_section.decode('utf-8', errors='ignore') headers = headers_section.decode("utf-8", errors="ignore")
name = None name = None
for line in headers.split('\n'): for line in headers.split("\n"):
if 'Content-Disposition' in line and 'name=' in line: if "Content-Disposition" in line and "name=" in line:
# Extract name parameter # Extract name parameter
for param in line.split(';'): for param in line.split(";"):
param = param.strip() param = param.strip()
if param.startswith('name='): if param.startswith("name="):
name = param.split('=', 1)[1].strip('"') name = param.split("=", 1)[1].strip('"')
break break
if name: if name:
# Handle file uploads or text content # Handle file uploads or text content
if content.startswith(b'<'): if content.startswith(b"<"):
# Likely XML content # Likely XML content
data[name] = content.decode('utf-8', errors='ignore') data[name] = content.decode("utf-8", errors="ignore")
else: else:
data[name] = content.decode('utf-8', errors='ignore') data[name] = content.decode("utf-8", errors="ignore")
return data return data
@api_router.post("/alpinebits/server-2024-10") @api_router.post("/alpinebits/server-2024-10")
@limiter.limit("60/minute") @limiter.limit("60/minute")
async def alpinebits_server_handshake( async def alpinebits_server_handshake(
request: Request, request: Request, username: str = Depends(validate_basic_auth)
username: str = Depends(validate_basic_auth)
): ):
""" """
AlpineBits server endpoint implementing the handshake protocol. AlpineBits server endpoint implementing the handshake protocol.
This endpoint handles: This endpoint handles:
- Protocol version negotiation via X-AlpineBits-ClientProtocolVersion header - Protocol version negotiation via X-AlpineBits-ClientProtocolVersion header
- Client identification via X-AlpineBits-ClientID header (optional) - Client identification via X-AlpineBits-ClientID header (optional)
@@ -464,62 +517,67 @@ async def alpinebits_server_handshake(
- Gzip compression support - Gzip compression support
- Proper error handling with HTTP status codes - Proper error handling with HTTP status codes
- Handshaking action processing - Handshaking action processing
Authentication: HTTP Basic Auth required Authentication: HTTP Basic Auth required
Content-Type: multipart/form-data Content-Type: multipart/form-data
Compression: gzip supported (check X-AlpineBits-Server-Accept-Encoding) Compression: gzip supported (check X-AlpineBits-Server-Accept-Encoding)
""" """
try: try:
# Check required headers # Check required headers
client_protocol_version = request.headers.get("X-AlpineBits-ClientProtocolVersion") client_protocol_version = request.headers.get(
"X-AlpineBits-ClientProtocolVersion"
)
if not client_protocol_version: if not client_protocol_version:
# Server concludes client speaks a protocol version preceding 2013-04 # Server concludes client speaks a protocol version preceding 2013-04
client_protocol_version = "pre-2013-04" client_protocol_version = "pre-2013-04"
_LOGGER.info("No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04") _LOGGER.info(
"No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04"
)
else: else:
_LOGGER.info(f"Client protocol version: {client_protocol_version}") _LOGGER.info(f"Client protocol version: {client_protocol_version}")
# Optional client ID # Optional client ID
client_id = request.headers.get("X-AlpineBits-ClientID") client_id = request.headers.get("X-AlpineBits-ClientID")
if client_id: if client_id:
_LOGGER.info(f"Client ID: {client_id}") _LOGGER.info(f"Client ID: {client_id}")
# Check content encoding # Check content encoding
content_encoding = request.headers.get("Content-Encoding") content_encoding = request.headers.get("Content-Encoding")
is_compressed = content_encoding == "gzip" is_compressed = content_encoding == "gzip"
if is_compressed: if is_compressed:
_LOGGER.info("Request is gzip compressed") _LOGGER.info("Request is gzip compressed")
# Get content type before processing # Get content type before processing
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
_LOGGER.info(f"Content-Type: {content_type}") _LOGGER.info(f"Content-Type: {content_type}")
_LOGGER.info(f"Content-Encoding: {content_encoding}") _LOGGER.info(f"Content-Encoding: {content_encoding}")
# Get request body # Get request body
body = await request.body() body = await request.body()
# Decompress if needed # Decompress if needed
if is_compressed: if is_compressed:
try: try:
body = gzip.decompress(body) body = gzip.decompress(body)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"ERROR: Failed to decompress gzip content: {str(e)}" detail=f"ERROR: Failed to decompress gzip content: {str(e)}",
) )
# Check content type (after decompression) # Check content type (after decompression)
if "multipart/form-data" not in content_type and "application/x-www-form-urlencoded" not in content_type: if (
"multipart/form-data" not in content_type
and "application/x-www-form-urlencoded" not in content_type
):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded" detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
) )
# Parse multipart data # Parse multipart data
if "multipart/form-data" in content_type: if "multipart/form-data" in content_type:
try: try:
@@ -527,7 +585,7 @@ async def alpinebits_server_handshake(
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}" detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}",
) )
elif "application/x-www-form-urlencoded" in content_type: elif "application/x-www-form-urlencoded" in content_type:
# Parse as urlencoded # Parse as urlencoded
@@ -535,75 +593,59 @@ async def alpinebits_server_handshake(
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded" detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
) )
# Check for required action parameter # Check for required action parameter
action = form_data.get("action") action = form_data.get("action")
if not action: if not action:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="ERROR: Missing required 'action' parameter"
detail="ERROR: Missing required 'action' parameter") )
_LOGGER.info(f"AlpineBits action: {action}") _LOGGER.info(f"AlpineBits action: {action}")
# Get optional request XML # Get optional request XML
request_xml = form_data.get("request") request_xml = form_data.get("request")
server = AlpineBitsServer() server = AlpineBitsServer()
version = Version.V2024_10 version = Version.V2024_10
# Create successful handshake response # Create successful handshake response
response = await server.handle_request(action, request_xml, version) response = await server.handle_request(action, request_xml, version)
response_xml = response.xml_content response_xml = response.xml_content
# Set response headers indicating server capabilities # Set response headers indicating server capabilities
headers = { headers = {
"Content-Type": "application/xml; charset=utf-8", "Content-Type": "application/xml; charset=utf-8",
"X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support "X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support
"X-AlpineBits-Server-Version": "2024-10" "X-AlpineBits-Server-Version": "2024-10",
} }
return Response(
content=response_xml,
status_code=response.status_code,
headers=headers
)
return Response(
content=response_xml, status_code=response.status_code, headers=headers
)
except HTTPException: except HTTPException:
# Re-raise HTTP exceptions (auth errors, etc.) # Re-raise HTTP exceptions (auth errors, etc.)
raise raise
except Exception as e: except Exception as e:
_LOGGER.error(f"Error in AlpineBits handshake: {str(e)}") _LOGGER.error(f"Error in AlpineBits handshake: {str(e)}")
raise HTTPException( raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
status_code=500,
detail=f"Internal server error: {str(e)}"
)
@api_router.get("/admin/stats") @api_router.get("/admin/stats")
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def get_api_stats( async def get_api_stats(request: Request, admin_key: str = Depends(validate_api_key)):
request: Request,
admin_key: str = Depends(validate_api_key)
):
""" """
Admin endpoint to get API usage statistics. Admin endpoint to get API usage statistics.
Requires admin API key. Requires admin API key.
""" """
if admin_key != "admin-key": if admin_key != "admin-key":
raise HTTPException( raise HTTPException(status_code=403, detail="Admin access required")
status_code=403,
detail="Admin access required"
)
# In a real application, you'd fetch this from your database/monitoring system # In a real application, you'd fetch this from your database/monitoring system
return { return {
"status": "success", "status": "success",
@@ -611,9 +653,9 @@ async def get_api_stats(
"uptime": "Available in production deployment", "uptime": "Available in production deployment",
"total_requests": "Available with monitoring setup", "total_requests": "Available with monitoring setup",
"active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]), "active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]),
"rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory" "rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory",
}, },
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat(),
} }
@@ -629,11 +671,12 @@ async def landing_page():
try: try:
# Get the path to the HTML file # Get the path to the HTML file
import os import os
html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html") html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html")
with open(html_path, "r", encoding="utf-8") as f: with open(html_path, "r", encoding="utf-8") as f:
html_content = f.read() html_content = f.read()
return HTMLResponse(content=html_content, status_code=200) return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError: except FileNotFoundError:
# Fallback if HTML file is not found # Fallback if HTML file is not found
@@ -660,4 +703,5 @@ async def landing_page():
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -21,7 +21,7 @@ security = HTTPBearer()
API_KEYS = { API_KEYS = {
# Example API keys - replace with your own secure keys # Example API keys - replace with your own secure keys
"wix-webhook-key": "sk_live_your_secure_api_key_here", "wix-webhook-key": "sk_live_your_secure_api_key_here",
"admin-key": "sk_admin_your_admin_key_here" "admin-key": "sk_admin_your_admin_key_here",
} }
# Load API keys from environment if available # Load API keys from environment if available
@@ -36,19 +36,21 @@ def generate_api_key() -> str:
return f"sk_live_{secrets.token_urlsafe(32)}" return f"sk_live_{secrets.token_urlsafe(32)}"
def validate_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: def validate_api_key(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> str:
""" """
Validate API key from Authorization header. Validate API key from Authorization header.
Expected format: Authorization: Bearer your_api_key_here Expected format: Authorization: Bearer your_api_key_here
""" """
token = credentials.credentials token = credentials.credentials
# Check if the token is in our valid API keys # Check if the token is in our valid API keys
for key_name, valid_key in API_KEYS.items(): for key_name, valid_key in API_KEYS.items():
if secrets.compare_digest(token, valid_key): if secrets.compare_digest(token, valid_key):
logger.info(f"Valid API key used: {key_name}") logger.info(f"Valid API key used: {key_name}")
return key_name return key_name
logger.warning(f"Invalid API key attempted: {token[:10]}...") logger.warning(f"Invalid API key attempted: {token[:10]}...")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -64,19 +66,17 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool:
""" """
if not signature or not secret: if not signature or not secret:
return False return False
try: try:
# Remove 'sha256=' prefix if present # Remove 'sha256=' prefix if present
if signature.startswith('sha256='): if signature.startswith("sha256="):
signature = signature[7:] signature = signature[7:]
# Calculate expected signature # Calculate expected signature
expected_signature = hmac.new( expected_signature = hmac.new(
secret.encode('utf-8'), secret.encode("utf-8"), payload, hashlib.sha256
payload,
hashlib.sha256
).hexdigest() ).hexdigest()
# Compare signatures securely # Compare signatures securely
return secrets.compare_digest(signature, expected_signature) return secrets.compare_digest(signature, expected_signature)
except Exception as e: except Exception as e:
@@ -86,21 +86,21 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool:
class APIKeyAuth: class APIKeyAuth:
"""Simple API key authentication class""" """Simple API key authentication class"""
def __init__(self, api_keys: dict): def __init__(self, api_keys: dict):
self.api_keys = api_keys self.api_keys = api_keys
def authenticate(self, api_key: str) -> Optional[str]: def authenticate(self, api_key: str) -> Optional[str]:
"""Authenticate an API key and return the key name if valid""" """Authenticate an API key and return the key name if valid"""
for key_name, valid_key in self.api_keys.items(): for key_name, valid_key in self.api_keys.items():
if secrets.compare_digest(api_key, valid_key): if secrets.compare_digest(api_key, valid_key):
return key_name return key_name
return None return None
def add_key(self, name: str, key: str): def add_key(self, name: str, key: str):
"""Add a new API key""" """Add a new API key"""
self.api_keys[name] = key self.api_keys[name] = key
def remove_key(self, name: str): def remove_key(self, name: str):
"""Remove an API key""" """Remove an API key"""
if name in self.api_keys: if name in self.api_keys:
@@ -108,4 +108,4 @@ class APIKeyAuth:
# Initialize auth system # Initialize auth system
auth_system = APIKeyAuth(API_KEYS) auth_system = APIKeyAuth(API_KEYS)

View File

@@ -1,4 +1,3 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -16,37 +15,45 @@ from annotatedyaml.loader import (
from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid
# --- Voluptuous schemas --- # --- Voluptuous schemas ---
database_schema = Schema({ database_schema = Schema({Required("url"): str}, extra=PREVENT_EXTRA)
Required('url'): str
}, extra=PREVENT_EXTRA)
hotel_auth_schema = Schema(
hotel_auth_schema = Schema({ {
Required("hotel_id"): str, Required("hotel_id"): str,
Required("hotel_name"): str, Required("hotel_name"): str,
Required("username"): str, Required("username"): str,
Required("password"): str Required("password"): str,
}, extra=PREVENT_EXTRA) },
extra=PREVENT_EXTRA,
basic_auth_schema = Schema(
All([hotel_auth_schema], Length(min=1))
) )
config_schema = Schema({ basic_auth_schema = Schema(All([hotel_auth_schema], Length(min=1)))
Required('database'): database_schema,
Required('alpine_bits_auth'): basic_auth_schema
}, extra=PREVENT_EXTRA)
DEFAULT_CONFIG_FILE = 'config.yaml' config_schema = Schema(
{
Required("database"): database_schema,
Required("alpine_bits_auth"): basic_auth_schema,
},
extra=PREVENT_EXTRA,
)
DEFAULT_CONFIG_FILE = "config.yaml"
class Config: class Config:
def __init__(self, config_folder: str | Path = None, config_name: str = DEFAULT_CONFIG_FILE, testing_mode: bool = False): def __init__(
self,
config_folder: str | Path = None,
config_name: str = DEFAULT_CONFIG_FILE,
testing_mode: bool = False,
):
if config_folder is None: if config_folder is None:
config_folder = os.environ.get('ALPINEBITS_CONFIG_DIR') config_folder = os.environ.get("ALPINEBITS_CONFIG_DIR")
if not config_folder: if not config_folder:
config_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../config')) config_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../config")
)
if isinstance(config_folder, str): if isinstance(config_folder, str):
config_folder = Path(config_folder) config_folder = Path(config_folder)
self.config_folder = config_folder self.config_folder = config_folder
@@ -61,8 +68,8 @@ class Config:
validated = config_schema(stuff) validated = config_schema(stuff)
except MultipleInvalid as e: except MultipleInvalid as e:
raise ValueError(f"Config validation error: {e}") raise ValueError(f"Config validation error: {e}")
self.database = validated['database'] self.database = validated["database"]
self.basic_auth = validated['alpine_bits_auth'] self.basic_auth = validated["alpine_bits_auth"]
self.config = validated self.config = validated
def get(self, key, default=None): def get(self, key, default=None):
@@ -70,19 +77,20 @@ class Config:
@property @property
def db_url(self) -> str: def db_url(self) -> str:
return self.database['url'] return self.database["url"]
@property @property
def hotel_id(self) -> str: def hotel_id(self) -> str:
return self.basic_auth['hotel_id'] return self.basic_auth["hotel_id"]
@property @property
def hotel_name(self) -> str: def hotel_name(self) -> str:
return self.basic_auth['hotel_name'] return self.basic_auth["hotel_name"]
@property @property
def users(self) -> List[Dict[str, str]]: def users(self) -> List[Dict[str, str]]:
return self.basic_auth['users'] return self.basic_auth["users"]
# For backward compatibility # For backward compatibility
def load_config(): def load_config():

View File

@@ -5,27 +5,24 @@ import os
Base = declarative_base() Base = declarative_base()
# Async SQLAlchemy setup # Async SQLAlchemy setup
def get_database_url(config=None): def get_database_url(config=None):
db_url = None db_url = None
if config and 'database' in config and 'url' in config['database']: if config and "database" in config and "url" in config["database"]:
db_url = config['database']['url'] db_url = config["database"]["url"]
if not db_url: if not db_url:
db_url = os.environ.get('DATABASE_URL') db_url = os.environ.get("DATABASE_URL")
if not db_url: if not db_url:
db_url = 'sqlite+aiosqlite:///alpinebits.db' db_url = "sqlite+aiosqlite:///alpinebits.db"
return db_url return db_url
DATABASE_URL = get_database_url()
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_async_session():
async with AsyncSessionLocal() as session:
yield session
class Customer(Base): class Customer(Base):
__tablename__ = 'customers' __tablename__ = "customers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
given_name = Column(String) given_name = Column(String)
contact_id = Column(String, unique=True) contact_id = Column(String, unique=True)
@@ -42,13 +39,14 @@ class Customer(Base):
birth_date = Column(String) birth_date = Column(String)
language = Column(String) language = Column(String)
address_catalog = Column(Boolean) # Added for XML address_catalog = Column(Boolean) # Added for XML
name_title = Column(String) # Added for XML name_title = Column(String) # Added for XML
reservations = relationship('Reservation', back_populates='customer') reservations = relationship("Reservation", back_populates="customer")
class Reservation(Base): class Reservation(Base):
__tablename__ = 'reservations' __tablename__ = "reservations"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
customer_id = Column(Integer, ForeignKey('customers.id')) customer_id = Column(Integer, ForeignKey("customers.id"))
form_id = Column(String, unique=True) form_id = Column(String, unique=True)
start_date = Column(Date) start_date = Column(Date)
end_date = Column(Date) end_date = Column(Date)
@@ -70,16 +68,14 @@ class Reservation(Base):
# Add hotel_code and hotel_name for XML # Add hotel_code and hotel_name for XML
hotel_code = Column(String) hotel_code = Column(String)
hotel_name = Column(String) hotel_name = Column(String)
customer = relationship('Customer', back_populates='reservations') customer = relationship("Customer", back_populates="reservations")
class HashedCustomer(Base): class HashedCustomer(Base):
__tablename__ = 'hashed_customers' __tablename__ = "hashed_customers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
customer_id = Column(Integer) customer_id = Column(Integer)
hashed_email = Column(String) hashed_email = Column(String)
hashed_phone = Column(String) hashed_phone = Column(String)
hashed_name = Column(String) hashed_name = Column(String)
redacted_at = Column(DateTime) redacted_at = Column(DateTime)

View File

@@ -15,11 +15,16 @@ from .simplified_access import (
HotelReservationIdData, HotelReservationIdData,
PhoneTechType, PhoneTechType,
AlpineBitsFactory, AlpineBitsFactory,
OtaMessageType OtaMessageType,
) )
# DB and config # DB and config
from .db import Customer as DBCustomer, Reservation as DBReservation, HashedCustomer, get_async_session from .db import (
Customer as DBCustomer,
Reservation as DBReservation,
HashedCustomer,
get_async_session,
)
from .config_loader import load_config from .config_loader import load_config
import hashlib import hashlib
import json import json
@@ -29,8 +34,8 @@ import asyncio
from alpine_bits_python import db from alpine_bits_python import db
async def main():
async def main():
print("🚀 Starting AlpineBits XML generation script...") print("🚀 Starting AlpineBits XML generation script...")
# Load config (yaml, annotatedyaml) # Load config (yaml, annotatedyaml)
config = load_config() config = load_config()
@@ -40,9 +45,9 @@ async def main():
print(json.dumps(config, indent=2)) print(json.dumps(config, indent=2))
# Ensure SQLite DB file exists if using SQLite # Ensure SQLite DB file exists if using SQLite
db_url = config.get('database', {}).get('url', '') db_url = config.get("database", {}).get("url", "")
if db_url.startswith('sqlite+aiosqlite:///'): if db_url.startswith("sqlite+aiosqlite:///"):
db_path = db_url.replace('sqlite+aiosqlite:///', '') db_path = db_url.replace("sqlite+aiosqlite:///", "")
db_path = os.path.abspath(db_path) db_path = os.path.abspath(db_path)
db_dir = os.path.dirname(db_path) db_dir = os.path.dirname(db_path)
if not os.path.exists(db_dir): if not os.path.exists(db_dir):
@@ -54,15 +59,17 @@ async def main():
# # Ensure DB schema is created (async) # # Ensure DB schema is created (async)
from .db import engine, Base from .db import engine, Base
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
async for db in get_async_session(): async for db in get_async_session():
# Load data from JSON file # Load data from JSON file
json_path = os.path.join(os.path.dirname(__file__), '../../test_data/wix_test_data_20250928_132611.json') json_path = os.path.join(
with open(json_path, 'r', encoding='utf-8') as f: os.path.dirname(__file__),
"../../test_data/wix_test_data_20250928_132611.json",
)
with open(json_path, "r", encoding="utf-8") as f:
wix_data = json.load(f) wix_data = json.load(f)
data = wix_data["data"]["data"] data = wix_data["data"]["data"]
@@ -85,8 +92,16 @@ async def main():
language = data.get("contact", {}).get("locale", "en")[:2] language = data.get("contact", {}).get("locale", "en")[:2]
# Dates # Dates
start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value") start_date = (
end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value") data.get("field:date_picker_a7c8")
or data.get("Anreisedatum")
or data.get("submissions", [{}])[1].get("value")
)
end_date = (
data.get("field:date_picker_7e65")
or data.get("Abreisedatum")
or data.get("submissions", [{}])[2].get("value")
)
# Room/guest info # Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2) num_adults = int(data.get("field:number_7cf5") or 2)
@@ -100,7 +115,7 @@ async def main():
children_ages.append(age) children_ages.append(age)
except ValueError: except ValueError:
logging.warning(f"Invalid age value for {k}: {data[k]}") logging.warning(f"Invalid age value for {k}: {data[k]}")
# UTM and offer # UTM and offer
utm_fields = [ utm_fields = [
("utm_Source", "utm_source"), ("utm_Source", "utm_source"),
@@ -147,7 +162,7 @@ async def main():
end_date=date.fromisoformat(end_date) if end_date else None, end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults, num_adults=num_adults,
num_children=num_children, num_children=num_children,
children_ages=','.join(str(a) for a in children_ages), children_ages=",".join(str(a) for a in children_ages),
offer=offer, offer=offer,
utm_comment=utm_comment, utm_comment=utm_comment,
created_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc),
@@ -177,9 +192,19 @@ async def main():
def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
from .simplified_access import CustomerData, GuestCountsFactory, HotelReservationIdData, AlpineBitsFactory, OtaMessageType, CommentData, CommentsData, CommentListItemData from .simplified_access import (
CustomerData,
GuestCountsFactory,
HotelReservationIdData,
AlpineBitsFactory,
OtaMessageType,
CommentData,
CommentsData,
CommentListItemData,
)
from .generated import alpinebits as ab from .generated import alpinebits as ab
from datetime import datetime, timezone from datetime import datetime, timezone
# Prepare data for XML # Prepare data for XML
phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else [] phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else []
customer_data = CustomerData( customer_data = CustomerData(
@@ -200,11 +225,15 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
language=customer.language, language=customer.language,
) )
alpine_bits_factory = AlpineBitsFactory() alpine_bits_factory = AlpineBitsFactory()
res_guests = alpine_bits_factory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) res_guests = alpine_bits_factory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
# Guest counts # Guest counts
children_ages = [int(a) for a in reservation.children_ages.split(",") if a] children_ages = [int(a) for a in reservation.children_ages.split(",") if a]
guest_counts = GuestCountsFactory.create_retrieve_guest_counts(reservation.num_adults, children_ages) guest_counts = GuestCountsFactory.create_retrieve_guest_counts(
reservation.num_adults, children_ages
)
# UniqueID # UniqueID
unique_id = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId( unique_id = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId(
@@ -214,11 +243,13 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
# TimeSpan # TimeSpan
time_span = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan( time_span = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan(
start=reservation.start_date.isoformat() if reservation.start_date else None, start=reservation.start_date.isoformat() if reservation.start_date else None,
end=reservation.end_date.isoformat() if reservation.end_date else None end=reservation.end_date.isoformat() if reservation.end_date else None,
) )
room_stay = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay( room_stay = (
time_span=time_span, ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay(
guest_counts=guest_counts, time_span=time_span,
guest_counts=guest_counts,
)
) )
room_stays = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays( room_stays = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays(
room_stay=[room_stay], room_stay=[room_stay],
@@ -231,7 +262,9 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
res_id_source=None, res_id_source=None,
res_id_source_context="99tales", res_id_source_context="99tales",
) )
hotel_res_id = alpine_bits_factory.create(hotel_res_id_data, OtaMessageType.RETRIEVE) hotel_res_id = alpine_bits_factory.create(
hotel_res_id_data, OtaMessageType.RETRIEVE
)
hotel_res_ids = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds( hotel_res_ids = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds(
hotel_reservation_id=[hotel_res_id] hotel_reservation_id=[hotel_res_id]
) )
@@ -244,31 +277,37 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
offer_comment = CommentData( offer_comment = CommentData(
name=ab.CommentName2.ADDITIONAL_INFO, name=ab.CommentName2.ADDITIONAL_INFO,
text="Angebot/Offerta", text="Angebot/Offerta",
list_items=[CommentListItemData( list_items=[
value=reservation.offer, CommentListItemData(
language=customer.language, value=reservation.offer,
list_item="1", language=customer.language,
)], list_item="1",
)
],
) )
comment = None comment = None
if reservation.user_comment: if reservation.user_comment:
comment = CommentData( comment = CommentData(
name=ab.CommentName2.CUSTOMER_COMMENT, name=ab.CommentName2.CUSTOMER_COMMENT,
text=reservation.user_comment, text=reservation.user_comment,
list_items=[CommentListItemData( list_items=[
value="Landing page comment", CommentListItemData(
language=customer.language, value="Landing page comment",
list_item="1", language=customer.language,
)], list_item="1",
)
],
) )
comments = [offer_comment, comment] if comment else [offer_comment] comments = [offer_comment, comment] if comment else [offer_comment]
comments_data = CommentsData(comments=comments) comments_data = CommentsData(comments=comments)
comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE) comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE)
res_global_info = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo( res_global_info = (
hotel_reservation_ids=hotel_res_ids, ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo(
basic_property_info=basic_property_info, hotel_reservation_ids=hotel_res_ids,
comments=comments_xml, basic_property_info=basic_property_info,
comments=comments_xml,
)
) )
hotel_reservation = ab.OtaResRetrieveRs.ReservationsList.HotelReservation( hotel_reservation = ab.OtaResRetrieveRs.ReservationsList.HotelReservation(
@@ -293,6 +332,7 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("✅ Pydantic validation successful!") print("✅ Pydantic validation successful!")
from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.formats.dataclass.serializers.config import SerializerConfig
from xsdata_pydantic.bindings import XmlSerializer from xsdata_pydantic.bindings import XmlSerializer
config = SerializerConfig( config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8" pretty_print=True, xml_declaration=True, encoding="UTF-8"
) )
@@ -306,15 +346,18 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("\n📄 Generated XML:") print("\n📄 Generated XML:")
print(xml_string) print(xml_string)
from xsdata_pydantic.bindings import XmlParser from xsdata_pydantic.bindings import XmlParser
parser = XmlParser() parser = XmlParser()
with open("output.xml", "r", encoding="utf-8") as infile: with open("output.xml", "r", encoding="utf-8") as infile:
xml_content = infile.read() xml_content = infile.read()
parsed_result = parser.from_string(xml_content, ab.OtaResRetrieveRs) parsed_result = parser.from_string(xml_content, ab.OtaResRetrieveRs)
print("✅ Round-trip validation successful!") print("✅ Round-trip validation successful!")
print(f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}") print(
f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}"
)
except Exception as e: except Exception as e:
print(f"❌ Validation/Serialization failed: {e}") print(f"❌ Validation/Serialization failed: {e}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -5,18 +5,23 @@ from datetime import datetime
class AlpineBitsHandshakeRequest(BaseModel): class AlpineBitsHandshakeRequest(BaseModel):
"""Model for AlpineBits handshake request data""" """Model for AlpineBits handshake request data"""
action: str = Field(..., description="Action parameter, typically 'OTA_Ping:Handshaking'")
action: str = Field(
..., description="Action parameter, typically 'OTA_Ping:Handshaking'"
)
request_xml: Optional[str] = Field(None, description="XML request document") request_xml: Optional[str] = Field(None, description="XML request document")
class ContactName(BaseModel): class ContactName(BaseModel):
"""Contact name structure""" """Contact name structure"""
first: Optional[str] = None first: Optional[str] = None
last: Optional[str] = None last: Optional[str] = None
class ContactAddress(BaseModel): class ContactAddress(BaseModel):
"""Contact address structure""" """Contact address structure"""
street: Optional[str] = None street: Optional[str] = None
city: Optional[str] = None city: Optional[str] = None
state: Optional[str] = None state: Optional[str] = None
@@ -26,6 +31,7 @@ class ContactAddress(BaseModel):
class Contact(BaseModel): class Contact(BaseModel):
"""Contact information from Wix form""" """Contact information from Wix form"""
name: Optional[ContactName] = None name: Optional[ContactName] = None
email: Optional[str] = None email: Optional[str] = None
locale: Optional[str] = None locale: Optional[str] = None
@@ -43,12 +49,14 @@ class Contact(BaseModel):
class SubmissionPdf(BaseModel): class SubmissionPdf(BaseModel):
"""PDF submission structure""" """PDF submission structure"""
url: Optional[str] = None url: Optional[str] = None
filename: Optional[str] = None filename: Optional[str] = None
class WixFormSubmission(BaseModel): class WixFormSubmission(BaseModel):
"""Model for Wix form submission data""" """Model for Wix form submission data"""
formName: str formName: str
submissions: List[Dict[str, Any]] = Field(default_factory=list) submissions: List[Dict[str, Any]] = Field(default_factory=list)
submissionTime: str submissionTime: str
@@ -59,7 +67,7 @@ class WixFormSubmission(BaseModel):
submissionPdf: Optional[SubmissionPdf] = None submissionPdf: Optional[SubmissionPdf] = None
formId: str formId: str
contact: Optional[Contact] = None contact: Optional[Contact] = None
# Dynamic form fields - these will capture all field:* entries # Dynamic form fields - these will capture all field:* entries
class Config: class Config:
extra = "allow" # Allow additional fields not defined in the model extra = "allow" # Allow additional fields not defined in the model

View File

@@ -11,11 +11,12 @@ logger = logging.getLogger(__name__)
# Rate limiting configuration # Rate limiting configuration
DEFAULT_RATE_LIMIT = "10/minute" # 10 requests per minute per IP DEFAULT_RATE_LIMIT = "10/minute" # 10 requests per minute per IP
WEBHOOK_RATE_LIMIT = "60/minute" # 60 webhook requests per minute per IP WEBHOOK_RATE_LIMIT = "60/minute" # 60 webhook requests per minute per IP
BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
# Redis configuration for distributed rate limiting (optional) # Redis configuration for distributed rate limiting (optional)
REDIS_URL = os.getenv("REDIS_URL", None) REDIS_URL = os.getenv("REDIS_URL", None)
def get_remote_address_with_forwarded(request: Request): def get_remote_address_with_forwarded(request: Request):
""" """
Get client IP address, considering forwarded headers from proxies/load balancers Get client IP address, considering forwarded headers from proxies/load balancers
@@ -25,11 +26,11 @@ def get_remote_address_with_forwarded(request: Request):
if forwarded_for: if forwarded_for:
# Take the first IP in the chain # Take the first IP in the chain
return forwarded_for.split(",")[0].strip() return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP") real_ip = request.headers.get("X-Real-IP")
if real_ip: if real_ip:
return real_ip return real_ip
# Fallback to direct connection IP # Fallback to direct connection IP
return get_remote_address(request) return get_remote_address(request)
@@ -39,14 +40,16 @@ if REDIS_URL:
# Use Redis for distributed rate limiting (recommended for production) # Use Redis for distributed rate limiting (recommended for production)
try: try:
import redis import redis
redis_client = redis.from_url(REDIS_URL) redis_client = redis.from_url(REDIS_URL)
limiter = Limiter( limiter = Limiter(
key_func=get_remote_address_with_forwarded, key_func=get_remote_address_with_forwarded, storage_uri=REDIS_URL
storage_uri=REDIS_URL
) )
logger.info("Rate limiting initialized with Redis backend") logger.info("Rate limiting initialized with Redis backend")
except Exception as e: except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}. Using in-memory rate limiting.") logger.warning(
f"Failed to connect to Redis: {e}. Using in-memory rate limiting."
)
limiter = Limiter(key_func=get_remote_address_with_forwarded) limiter = Limiter(key_func=get_remote_address_with_forwarded)
else: else:
# Use in-memory rate limiting (fine for single instance) # Use in-memory rate limiting (fine for single instance)
@@ -65,7 +68,7 @@ def get_api_key_identifier(request: Request) -> str:
api_key = auth_header[7:] # Remove "Bearer " prefix api_key = auth_header[7:] # Remove "Bearer " prefix
# Use first 10 chars of API key as identifier (don't log full key) # Use first 10 chars of API key as identifier (don't log full key)
return f"api_key:{api_key[:10]}" return f"api_key:{api_key[:10]}"
# Fallback to IP address # Fallback to IP address
return f"ip:{get_remote_address_with_forwarded(request)}" return f"ip:{get_remote_address_with_forwarded(request)}"
@@ -77,10 +80,10 @@ def api_key_rate_limit_key(request: Request):
# Rate limiting decorators for different endpoint types # Rate limiting decorators for different endpoint types
webhook_limiter = Limiter( webhook_limiter = Limiter(
key_func=api_key_rate_limit_key, key_func=api_key_rate_limit_key, storage_uri=REDIS_URL if REDIS_URL else None
storage_uri=REDIS_URL if REDIS_URL else None
) )
# Custom rate limit exceeded handler # Custom rate limit exceeded handler
def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded): def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Custom handler for rate limit exceeded""" """Custom handler for rate limit exceeded"""
@@ -88,11 +91,11 @@ def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
f"Rate limit exceeded for {get_remote_address_with_forwarded(request)}: " f"Rate limit exceeded for {get_remote_address_with_forwarded(request)}: "
f"{exc.detail}" f"{exc.detail}"
) )
response = _rate_limit_exceeded_handler(request, exc) response = _rate_limit_exceeded_handler(request, exc)
# Add custom headers # Add custom headers
response.headers["X-RateLimit-Limit"] = str(exc.retry_after) response.headers["X-RateLimit-Limit"] = str(exc.retry_after)
response.headers["X-RateLimit-Retry-After"] = str(exc.retry_after) response.headers["X-RateLimit-Retry-After"] = str(exc.retry_after)
return response return response

View File

@@ -1,7 +1,2 @@
def parse_form(form: dict): def parse_form(form: dict):
pass pass

View File

@@ -2,14 +2,21 @@
""" """
Startup script for the Wix Form Handler API Startup script for the Wix Form Handler API
""" """
import os
import uvicorn import uvicorn
from .api import app from .api import app
if __name__ == "__main__": if __name__ == "__main__":
db_path = "alpinebits.db" # Adjust path if needed
if os.path.exists(db_path):
os.remove(db_path)
print(f"Deleted database file: {db_path}")
uvicorn.run( uvicorn.run(
"alpine_bits_python.api:app", "alpine_bits_python.api:app",
host="0.0.0.0", host="0.0.0.0",
port=8080, port=8080,
reload=True, # Enable auto-reload during development reload=True, # Enable auto-reload during development
log_level="info" log_level="info",
) )

View File

@@ -2,6 +2,7 @@
""" """
Configuration and setup script for the Wix Form Handler API Configuration and setup script for the Wix Form Handler API
""" """
import os import os
import sys import sys
import secrets import secrets
@@ -11,80 +12,83 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from alpine_bits_python.auth import generate_api_key from alpine_bits_python.auth import generate_api_key
def generate_secure_keys(): def generate_secure_keys():
"""Generate secure API keys for the application""" """Generate secure API keys for the application"""
print("🔐 Generating Secure API Keys") print("🔐 Generating Secure API Keys")
print("=" * 50) print("=" * 50)
# Generate API keys # Generate API keys
wix_api_key = generate_api_key() wix_api_key = generate_api_key()
admin_api_key = generate_api_key() admin_api_key = generate_api_key()
webhook_secret = secrets.token_urlsafe(32) webhook_secret = secrets.token_urlsafe(32)
print(f"🔑 Wix Webhook API Key: {wix_api_key}") print(f"🔑 Wix Webhook API Key: {wix_api_key}")
print(f"🔐 Admin API Key: {admin_api_key}") print(f"🔐 Admin API Key: {admin_api_key}")
print(f"🔒 Webhook Secret: {webhook_secret}") print(f"🔒 Webhook Secret: {webhook_secret}")
print("\n📋 Environment Variables") print("\n📋 Environment Variables")
print("-" * 30) print("-" * 30)
print(f"export WIX_API_KEY='{wix_api_key}'") print(f"export WIX_API_KEY='{wix_api_key}'")
print(f"export ADMIN_API_KEY='{admin_api_key}'") print(f"export ADMIN_API_KEY='{admin_api_key}'")
print(f"export WIX_WEBHOOK_SECRET='{webhook_secret}'") print(f"export WIX_WEBHOOK_SECRET='{webhook_secret}'")
print(f"export REDIS_URL='redis://localhost:6379' # Optional for production") print(f"export REDIS_URL='redis://localhost:6379' # Optional for production")
print("\n🔧 .env File Content") print("\n🔧 .env File Content")
print("-" * 20) print("-" * 20)
print(f"WIX_API_KEY={wix_api_key}") print(f"WIX_API_KEY={wix_api_key}")
print(f"ADMIN_API_KEY={admin_api_key}") print(f"ADMIN_API_KEY={admin_api_key}")
print(f"WIX_WEBHOOK_SECRET={webhook_secret}") print(f"WIX_WEBHOOK_SECRET={webhook_secret}")
print("REDIS_URL=redis://localhost:6379") print("REDIS_URL=redis://localhost:6379")
# Optionally write to .env file # Optionally write to .env file
create_env = input("\n❓ Create .env file? (y/n): ").lower().strip() create_env = input("\n❓ Create .env file? (y/n): ").lower().strip()
if create_env == 'y': if create_env == "y":
# Create .env in the project root (two levels up from scripts) # Create .env in the project root (two levels up from scripts)
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') env_path = os.path.join(
with open(env_path, 'w') as f: os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env"
)
with open(env_path, "w") as f:
f.write(f"WIX_API_KEY={wix_api_key}\n") f.write(f"WIX_API_KEY={wix_api_key}\n")
f.write(f"ADMIN_API_KEY={admin_api_key}\n") f.write(f"ADMIN_API_KEY={admin_api_key}\n")
f.write(f"WIX_WEBHOOK_SECRET={webhook_secret}\n") f.write(f"WIX_WEBHOOK_SECRET={webhook_secret}\n")
f.write("REDIS_URL=redis://localhost:6379\n") f.write("REDIS_URL=redis://localhost:6379\n")
print(f"✅ .env file created at {env_path}!") print(f"✅ .env file created at {env_path}!")
print("⚠️ Add .env to your .gitignore file!") print("⚠️ Add .env to your .gitignore file!")
print("\n🌐 Wix Configuration") print("\n🌐 Wix Configuration")
print("-" * 20) print("-" * 20)
print("1. In your Wix site, go to Settings > Webhooks") print("1. In your Wix site, go to Settings > Webhooks")
print("2. Add webhook URL: https://yourdomain.com/webhook/wix-form") print("2. Add webhook URL: https://yourdomain.com/webhook/wix-form")
print("3. Add custom header: Authorization: Bearer " + wix_api_key) print("3. Add custom header: Authorization: Bearer " + wix_api_key)
print("4. Optionally configure webhook signature with the secret above") print("4. Optionally configure webhook signature with the secret above")
return { return {
'wix_api_key': wix_api_key, "wix_api_key": wix_api_key,
'admin_api_key': admin_api_key, "admin_api_key": admin_api_key,
'webhook_secret': webhook_secret "webhook_secret": webhook_secret,
} }
def check_security_setup(): def check_security_setup():
"""Check current security configuration""" """Check current security configuration"""
print("🔍 Security Configuration Check") print("🔍 Security Configuration Check")
print("=" * 40) print("=" * 40)
# Check environment variables # Check environment variables
wix_key = os.getenv('WIX_API_KEY') wix_key = os.getenv("WIX_API_KEY")
admin_key = os.getenv('ADMIN_API_KEY') admin_key = os.getenv("ADMIN_API_KEY")
webhook_secret = os.getenv('WIX_WEBHOOK_SECRET') webhook_secret = os.getenv("WIX_WEBHOOK_SECRET")
redis_url = os.getenv('REDIS_URL') redis_url = os.getenv("REDIS_URL")
print("Environment Variables:") print("Environment Variables:")
print(f" WIX_API_KEY: {'✅ Set' if wix_key else '❌ Not set'}") print(f" WIX_API_KEY: {'✅ Set' if wix_key else '❌ Not set'}")
print(f" ADMIN_API_KEY: {'✅ Set' if admin_key else '❌ Not set'}") print(f" ADMIN_API_KEY: {'✅ Set' if admin_key else '❌ Not set'}")
print(f" WIX_WEBHOOK_SECRET: {'✅ Set' if webhook_secret else '❌ Not set'}") print(f" WIX_WEBHOOK_SECRET: {'✅ Set' if webhook_secret else '❌ Not set'}")
print(f" REDIS_URL: {'✅ Set' if redis_url else '⚠️ Optional (using in-memory)'}") print(f" REDIS_URL: {'✅ Set' if redis_url else '⚠️ Optional (using in-memory)'}")
# Security recommendations # Security recommendations
print("\n🛡️ Security Recommendations:") print("\n🛡️ Security Recommendations:")
if not wix_key: if not wix_key:
@@ -94,19 +98,19 @@ def check_security_setup():
print(" ⚠️ WIX_API_KEY should be longer for better security") print(" ⚠️ WIX_API_KEY should be longer for better security")
else: else:
print(" ✅ WIX_API_KEY looks secure") print(" ✅ WIX_API_KEY looks secure")
if not admin_key: if not admin_key:
print(" ❌ Set ADMIN_API_KEY environment variable") print(" ❌ Set ADMIN_API_KEY environment variable")
elif wix_key and admin_key == wix_key: elif wix_key and admin_key == wix_key:
print(" ❌ Admin and Wix keys should be different") print(" ❌ Admin and Wix keys should be different")
else: else:
print(" ✅ ADMIN_API_KEY configured") print(" ✅ ADMIN_API_KEY configured")
if not webhook_secret: if not webhook_secret:
print(" ⚠️ Consider setting WIX_WEBHOOK_SECRET for signature validation") print(" ⚠️ Consider setting WIX_WEBHOOK_SECRET for signature validation")
else: else:
print(" ✅ Webhook signature validation enabled") print(" ✅ Webhook signature validation enabled")
print("\n🚀 Production Checklist:") print("\n🚀 Production Checklist:")
print(" - Use HTTPS in production") print(" - Use HTTPS in production")
print(" - Set up Redis for distributed rate limiting") print(" - Set up Redis for distributed rate limiting")
@@ -118,12 +122,14 @@ def check_security_setup():
if __name__ == "__main__": if __name__ == "__main__":
print("🔐 Wix Form Handler API - Security Setup") print("🔐 Wix Form Handler API - Security Setup")
print("=" * 50) print("=" * 50)
choice = input("Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): ").strip() choice = input(
"Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): "
).strip()
if choice == "1": if choice == "1":
generate_secure_keys() generate_secure_keys()
elif choice == "2": elif choice == "2":
check_security_setup() check_security_setup()
else: else:
print("Invalid choice. Please run again and choose 1 or 2.") print("Invalid choice. Please run again and choose 1 or 2.")

View File

@@ -2,6 +2,7 @@
""" """
Test script for the Secure Wix Form Handler API Test script for the Secure Wix Form Handler API
""" """
import asyncio import asyncio
import aiohttp import aiohttp
import json import json
@@ -30,7 +31,7 @@ SAMPLE_WIX_DATA = {
"submissionsLink": "https://www.wix.app/forms/test-form/submissions", "submissionsLink": "https://www.wix.app/forms/test-form/submissions",
"submissionPdf": { "submissionPdf": {
"url": "https://example.com/submission.pdf", "url": "https://example.com/submission.pdf",
"filename": "submission.pdf" "filename": "submission.pdf",
}, },
"formId": "test-form-789", "formId": "test-form-789",
"field:email_5139": "test@example.com", "field:email_5139": "test@example.com",
@@ -43,10 +44,7 @@ SAMPLE_WIX_DATA = {
"field:alter_kind_4": "12", "field:alter_kind_4": "12",
"field:long_answer_3524": "This is a long answer field with more details about the inquiry.", "field:long_answer_3524": "This is a long answer field with more details about the inquiry.",
"contact": { "contact": {
"name": { "name": {"first": "John", "last": "Doe"},
"first": "John",
"last": "Doe"
},
"email": "test@example.com", "email": "test@example.com",
"locale": "de", "locale": "de",
"company": "Test Company", "company": "Test Company",
@@ -57,29 +55,29 @@ SAMPLE_WIX_DATA = {
"street": "Test Street 123", "street": "Test Street 123",
"city": "Test City", "city": "Test City",
"country": "Germany", "country": "Germany",
"postalCode": "12345" "postalCode": "12345",
}, },
"jobTitle": "Manager", "jobTitle": "Manager",
"phone": "+1234567890", "phone": "+1234567890",
"createdDate": "2024-03-20T10:00:00.000Z", "createdDate": "2024-03-20T10:00:00.000Z",
"updatedDate": "2024-03-20T10:30:00.000Z" "updatedDate": "2024-03-20T10:30:00.000Z",
} },
} }
async def test_api(): async def test_api():
"""Test the API endpoints with authentication""" """Test the API endpoints with authentication"""
headers_with_auth = { headers_with_auth = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {TEST_API_KEY}" "Authorization": f"Bearer {TEST_API_KEY}",
} }
admin_headers = { admin_headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {ADMIN_API_KEY}" "Authorization": f"Bearer {ADMIN_API_KEY}",
} }
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Test health endpoint (no auth required) # Test health endpoint (no auth required)
print("1. Testing health endpoint (no auth)...") print("1. Testing health endpoint (no auth)...")
@@ -89,7 +87,7 @@ async def test_api():
print(f" ✅ Health check: {response.status} - {result.get('status')}") print(f" ✅ Health check: {response.status} - {result.get('status')}")
except Exception as e: except Exception as e:
print(f" ❌ Health check failed: {e}") print(f" ❌ Health check failed: {e}")
# Test root endpoint (no auth required) # Test root endpoint (no auth required)
print("\n2. Testing root endpoint (no auth)...") print("\n2. Testing root endpoint (no auth)...")
try: try:
@@ -98,87 +96,94 @@ async def test_api():
print(f" ✅ Root: {response.status} - {result.get('message')}") print(f" ✅ Root: {response.status} - {result.get('message')}")
except Exception as e: except Exception as e:
print(f" ❌ Root endpoint failed: {e}") print(f" ❌ Root endpoint failed: {e}")
# Test webhook endpoint without auth (should fail) # Test webhook endpoint without auth (should fail)
print("\n3. Testing webhook endpoint WITHOUT auth (should fail)...") print("\n3. Testing webhook endpoint WITHOUT auth (should fail)...")
try: try:
async with session.post( async with session.post(
f"{BASE_URL}/api/webhook/wix-form", f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA, json=SAMPLE_WIX_DATA,
headers={"Content-Type": "application/json"} headers={"Content-Type": "application/json"},
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 401: if response.status == 401:
print(f" ✅ Correctly rejected: {response.status} - {result.get('detail')}") print(
f" ✅ Correctly rejected: {response.status} - {result.get('detail')}"
)
else: else:
print(f" ❌ Unexpected response: {response.status} - {result}") print(f" ❌ Unexpected response: {response.status} - {result}")
except Exception as e: except Exception as e:
print(f" ❌ Test failed: {e}") print(f" ❌ Test failed: {e}")
# Test webhook endpoint with valid auth # Test webhook endpoint with valid auth
print("\n4. Testing webhook endpoint WITH valid auth...") print("\n4. Testing webhook endpoint WITH valid auth...")
try: try:
async with session.post( async with session.post(
f"{BASE_URL}/api/webhook/wix-form", f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA, json=SAMPLE_WIX_DATA,
headers=headers_with_auth headers=headers_with_auth,
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 200: if response.status == 200:
print(f" ✅ Webhook success: {response.status} - {result.get('status')}") print(
f" ✅ Webhook success: {response.status} - {result.get('status')}"
)
else: else:
print(f" ❌ Webhook failed: {response.status} - {result}") print(f" ❌ Webhook failed: {response.status} - {result}")
except Exception as e: except Exception as e:
print(f" ❌ Webhook test failed: {e}") print(f" ❌ Webhook test failed: {e}")
# Test test endpoint with auth # Test test endpoint with auth
print("\n5. Testing simple test endpoint WITH auth...") print("\n5. Testing simple test endpoint WITH auth...")
try: try:
async with session.post( async with session.post(
f"{BASE_URL}/api/webhook/wix-form/test", f"{BASE_URL}/api/webhook/wix-form/test",
json={"test": "data", "timestamp": datetime.now().isoformat()}, json={"test": "data", "timestamp": datetime.now().isoformat()},
headers=headers_with_auth headers=headers_with_auth,
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 200: if response.status == 200:
print(f" ✅ Test endpoint: {response.status} - {result.get('status')}") print(
f" ✅ Test endpoint: {response.status} - {result.get('status')}"
)
else: else:
print(f" ❌ Test endpoint failed: {response.status} - {result}") print(f" ❌ Test endpoint failed: {response.status} - {result}")
except Exception as e: except Exception as e:
print(f" ❌ Test endpoint failed: {e}") print(f" ❌ Test endpoint failed: {e}")
# Test rate limiting by making multiple rapid requests # Test rate limiting by making multiple rapid requests
print("\n6. Testing rate limiting (making 5 rapid requests)...") print("\n6. Testing rate limiting (making 5 rapid requests)...")
rate_limit_test_count = 0 rate_limit_test_count = 0
for i in range(5): for i in range(5):
try: try:
async with session.get( async with session.get(f"{BASE_URL}/api/health") as response:
f"{BASE_URL}/api/health"
) as response:
if response.status == 200: if response.status == 200:
rate_limit_test_count += 1 rate_limit_test_count += 1
elif response.status == 429: elif response.status == 429:
print(f" ✅ Rate limit triggered on request {i+1}") print(f" ✅ Rate limit triggered on request {i + 1}")
break break
except Exception as e: except Exception as e:
print(f" ❌ Rate limit test failed: {e}") print(f" ❌ Rate limit test failed: {e}")
break break
if rate_limit_test_count == 5: if rate_limit_test_count == 5:
print(" No rate limit reached (normal for low request volume)") print(" No rate limit reached (normal for low request volume)")
# Test admin endpoint (if admin key is configured) # Test admin endpoint (if admin key is configured)
print("\n7. Testing admin stats endpoint...") print("\n7. Testing admin stats endpoint...")
try: try:
async with session.get( async with session.get(
f"{BASE_URL}/api/admin/stats", f"{BASE_URL}/api/admin/stats", headers=admin_headers
headers=admin_headers
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 200: if response.status == 200:
print(f" ✅ Admin stats: {response.status} - {result.get('status')}") print(
f" ✅ Admin stats: {response.status} - {result.get('status')}"
)
elif response.status == 401: elif response.status == 401:
print(f" ⚠️ Admin access denied (API key not configured): {result.get('detail')}") print(
f" ⚠️ Admin access denied (API key not configured): {result.get('detail')}"
)
else: else:
print(f" ❌ Admin endpoint failed: {response.status} - {result}") print(f" ❌ Admin endpoint failed: {response.status} - {result}")
except Exception as e: except Exception as e:
@@ -189,12 +194,18 @@ if __name__ == "__main__":
print("🔒 Testing Secure Wix Form Handler API...") print("🔒 Testing Secure Wix Form Handler API...")
print("=" * 60) print("=" * 60)
print("📍 API URL:", BASE_URL) print("📍 API URL:", BASE_URL)
print("🔑 Using API Key:", TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY) print(
print("🔐 Using Admin Key:", ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY) "🔑 Using API Key:",
TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY,
)
print(
"🔐 Using Admin Key:",
ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY,
)
print("=" * 60) print("=" * 60)
print("Make sure the API is running with: python3 run_api.py") print("Make sure the API is running with: python3 run_api.py")
print("-" * 60) print("-" * 60)
try: try:
asyncio.run(test_api()) asyncio.run(test_api())
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -207,4 +218,4 @@ if __name__ == "__main__":
print("3. Add Authorization header: Bearer your_api_key") print("3. Add Authorization header: Bearer your_api_key")
except Exception as e: except Exception as e:
print(f"\n❌ Error testing API: {e}") print(f"\n❌ Error testing API: {e}")
print("Make sure the API server is running!") print("Make sure the API server is running!")

View File

@@ -15,15 +15,26 @@ NotifHotelReservationId = OtaHotelResNotifRq.HotelReservations.HotelReservation.
RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId
# Define type aliases for Comments types # Define type aliases for Comments types
NotifComments = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments NotifComments = (
RetrieveComments = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments
NotifComment = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment )
RetrieveComment = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment RetrieveComments = (
OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments
)
NotifComment = (
OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment
)
RetrieveComment = (
OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment
)
# type aliases for GuestCounts # type aliases for GuestCounts
NotifGuestCounts = OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts NotifGuestCounts = (
RetrieveGuestCounts = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts
)
RetrieveGuestCounts = (
OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts
)
# phonetechtype enum 1,3,5 voice, fax, mobile # phonetechtype enum 1,3,5 voice, fax, mobile
@@ -36,12 +47,13 @@ class PhoneTechType(Enum):
# Enum to specify which OTA message type to use # Enum to specify which OTA message type to use
class OtaMessageType(Enum): class OtaMessageType(Enum):
NOTIF = "notification" # For OtaHotelResNotifRq NOTIF = "notification" # For OtaHotelResNotifRq
RETRIEVE = "retrieve" # For OtaResRetrieveRs RETRIEVE = "retrieve" # For OtaResRetrieveRs
@dataclass @dataclass
class KidsAgeData: class KidsAgeData:
"""Data class to hold information about children's ages.""" """Data class to hold information about children's ages."""
ages: list[int] ages: list[int]
@@ -77,9 +89,10 @@ class CustomerData:
class GuestCountsFactory: class GuestCountsFactory:
@staticmethod @staticmethod
def create_notif_guest_counts(adults: int, kids: Optional[list[int]] = None) -> NotifGuestCounts: def create_notif_guest_counts(
adults: int, kids: Optional[list[int]] = None
) -> NotifGuestCounts:
""" """
Create a GuestCounts object for OtaHotelResNotifRq. Create a GuestCounts object for OtaHotelResNotifRq.
:param adults: Number of adults :param adults: Number of adults
@@ -89,18 +102,23 @@ class GuestCountsFactory:
return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts) return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts)
@staticmethod @staticmethod
def create_retrieve_guest_counts(adults: int, kids: Optional[list[int]] = None) -> RetrieveGuestCounts: def create_retrieve_guest_counts(
adults: int, kids: Optional[list[int]] = None
) -> RetrieveGuestCounts:
""" """
Create a GuestCounts object for OtaResRetrieveRs. Create a GuestCounts object for OtaResRetrieveRs.
:param adults: Number of adults :param adults: Number of adults
:param kids: List of ages for each kid (optional) :param kids: List of ages for each kid (optional)
:return: GuestCounts instance :return: GuestCounts instance
""" """
return GuestCountsFactory._create_guest_counts(adults, kids, RetrieveGuestCounts) return GuestCountsFactory._create_guest_counts(
adults, kids, RetrieveGuestCounts
)
@staticmethod @staticmethod
def _create_guest_counts(adults: int, kids: Optional[list[int]], guest_counts_class: type) -> Any: def _create_guest_counts(
adults: int, kids: Optional[list[int]], guest_counts_class: type
) -> Any:
""" """
Internal method to create a GuestCounts object of the specified type. Internal method to create a GuestCounts object of the specified type.
:param adults: Number of adults :param adults: Number of adults
@@ -356,9 +374,10 @@ class HotelReservationIdFactory:
) )
@dataclass @dataclass
class CommentListItemData: class CommentListItemData:
"""Simple data class to hold comment list item information.""" """Simple data class to hold comment list item information."""
value: str # The text content of the list item value: str # The text content of the list item
list_item: str # Numeric identifier (pattern: [0-9]+) list_item: str # Numeric identifier (pattern: [0-9]+)
language: str # Two-letter language code (pattern: [a-z][a-z]) language: str # Two-letter language code (pattern: [a-z][a-z])
@@ -367,6 +386,7 @@ class CommentListItemData:
@dataclass @dataclass
class CommentData: class CommentData:
"""Simple data class to hold comment information without nested type constraints.""" """Simple data class to hold comment information without nested type constraints."""
name: CommentName2 # Required: "included services", "customer comment", "additional info" name: CommentName2 # Required: "included services", "customer comment", "additional info"
text: Optional[str] = None # Optional text content text: Optional[str] = None # Optional text content
list_items: list[CommentListItemData] = None # Optional list items list_items: list[CommentListItemData] = None # Optional list items
@@ -379,6 +399,7 @@ class CommentData:
@dataclass @dataclass
class CommentsData: class CommentsData:
"""Simple data class to hold multiple comments (1-3 max).""" """Simple data class to hold multiple comments (1-3 max)."""
comments: list[CommentData] = None # 1-3 comments maximum comments: list[CommentData] = None # 1-3 comments maximum
def __post_init__(self): def __post_init__(self):
@@ -388,21 +409,23 @@ class CommentsData:
class CommentFactory: class CommentFactory:
"""Factory class to create Comment instances for both OtaHotelResNotifRq and OtaResRetrieveRs.""" """Factory class to create Comment instances for both OtaHotelResNotifRq and OtaResRetrieveRs."""
@staticmethod @staticmethod
def create_notif_comments(data: CommentsData) -> NotifComments: def create_notif_comments(data: CommentsData) -> NotifComments:
"""Create Comments for OtaHotelResNotifRq.""" """Create Comments for OtaHotelResNotifRq."""
return CommentFactory._create_comments(NotifComments, NotifComment, data) return CommentFactory._create_comments(NotifComments, NotifComment, data)
@staticmethod @staticmethod
def create_retrieve_comments(data: CommentsData) -> RetrieveComments: def create_retrieve_comments(data: CommentsData) -> RetrieveComments:
"""Create Comments for OtaResRetrieveRs.""" """Create Comments for OtaResRetrieveRs."""
return CommentFactory._create_comments(RetrieveComments, RetrieveComment, data) return CommentFactory._create_comments(RetrieveComments, RetrieveComment, data)
@staticmethod @staticmethod
def _create_comments(comments_class: type, comment_class: type, data: CommentsData) -> Any: def _create_comments(
comments_class: type, comment_class: type, data: CommentsData
) -> Any:
"""Internal method to create comments of the specified type.""" """Internal method to create comments of the specified type."""
comments_list = [] comments_list = []
for comment_data in data.comments: for comment_data in data.comments:
# Create list items # Create list items
@@ -411,55 +434,53 @@ class CommentFactory:
list_item = comment_class.ListItem( list_item = comment_class.ListItem(
value=item_data.value, value=item_data.value,
list_item=item_data.list_item, list_item=item_data.list_item,
language=item_data.language language=item_data.language,
) )
list_items.append(list_item) list_items.append(list_item)
# Create comment # Create comment
comment = comment_class( comment = comment_class(
name=comment_data.name, name=comment_data.name, text=comment_data.text, list_item=list_items
text=comment_data.text,
list_item=list_items
) )
comments_list.append(comment) comments_list.append(comment)
# Create comments container # Create comments container
return comments_class(comment=comments_list) return comments_class(comment=comments_list)
@staticmethod @staticmethod
def from_notif_comments(comments: NotifComments) -> CommentsData: def from_notif_comments(comments: NotifComments) -> CommentsData:
"""Convert NotifComments back to CommentsData.""" """Convert NotifComments back to CommentsData."""
return CommentFactory._comments_to_data(comments) return CommentFactory._comments_to_data(comments)
@staticmethod @staticmethod
def from_retrieve_comments(comments: RetrieveComments) -> CommentsData: def from_retrieve_comments(comments: RetrieveComments) -> CommentsData:
"""Convert RetrieveComments back to CommentsData.""" """Convert RetrieveComments back to CommentsData."""
return CommentFactory._comments_to_data(comments) return CommentFactory._comments_to_data(comments)
@staticmethod @staticmethod
def _comments_to_data(comments: Any) -> CommentsData: def _comments_to_data(comments: Any) -> CommentsData:
"""Internal method to convert any comments type to CommentsData.""" """Internal method to convert any comments type to CommentsData."""
comments_data_list = [] comments_data_list = []
for comment in comments.comment: for comment in comments.comment:
# Extract list items # Extract list items
list_items_data = [] list_items_data = []
if comment.list_item: if comment.list_item:
for list_item in comment.list_item: for list_item in comment.list_item:
list_items_data.append(CommentListItemData( list_items_data.append(
value=list_item.value, CommentListItemData(
list_item=list_item.list_item, value=list_item.value,
language=list_item.language list_item=list_item.list_item,
)) language=list_item.language,
)
)
# Extract comment data # Extract comment data
comment_data = CommentData( comment_data = CommentData(
name=comment.name, name=comment.name, text=comment.text, list_items=list_items_data
text=comment.text,
list_items=list_items_data
) )
comments_data_list.append(comment_data) comments_data_list.append(comment_data)
return CommentsData(comments=comments_data_list) return CommentsData(comments=comments_data_list)
@@ -529,16 +550,19 @@ class ResGuestFactory:
class AlpineBitsFactory: class AlpineBitsFactory:
"""Unified factory class for creating AlpineBits objects with a simple interface.""" """Unified factory class for creating AlpineBits objects with a simple interface."""
@staticmethod @staticmethod
def create(data: Union[CustomerData, HotelReservationIdData, CommentsData], message_type: OtaMessageType) -> Any: def create(
data: Union[CustomerData, HotelReservationIdData, CommentsData],
message_type: OtaMessageType,
) -> Any:
""" """
Create an AlpineBits object based on the data type and message type. Create an AlpineBits object based on the data type and message type.
Args: Args:
data: The data object (CustomerData, HotelReservationIdData, CommentsData, etc.) data: The data object (CustomerData, HotelReservationIdData, CommentsData, etc.)
message_type: Whether to create for NOTIF or RETRIEVE message types message_type: Whether to create for NOTIF or RETRIEVE message types
Returns: Returns:
The appropriate AlpineBits object based on the data type and message type The appropriate AlpineBits object based on the data type and message type
""" """
@@ -547,31 +571,35 @@ class AlpineBitsFactory:
return CustomerFactory.create_notif_customer(data) return CustomerFactory.create_notif_customer(data)
else: else:
return CustomerFactory.create_retrieve_customer(data) return CustomerFactory.create_retrieve_customer(data)
elif isinstance(data, HotelReservationIdData): elif isinstance(data, HotelReservationIdData):
if message_type == OtaMessageType.NOTIF: if message_type == OtaMessageType.NOTIF:
return HotelReservationIdFactory.create_notif_hotel_reservation_id(data) return HotelReservationIdFactory.create_notif_hotel_reservation_id(data)
else: else:
return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(data) return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
data
)
elif isinstance(data, CommentsData): elif isinstance(data, CommentsData):
if message_type == OtaMessageType.NOTIF: if message_type == OtaMessageType.NOTIF:
return CommentFactory.create_notif_comments(data) return CommentFactory.create_notif_comments(data)
else: else:
return CommentFactory.create_retrieve_comments(data) return CommentFactory.create_retrieve_comments(data)
else: else:
raise ValueError(f"Unsupported data type: {type(data)}") raise ValueError(f"Unsupported data type: {type(data)}")
@staticmethod @staticmethod
def create_res_guests(customer_data: CustomerData, message_type: OtaMessageType) -> Union[NotifResGuests, RetrieveResGuests]: def create_res_guests(
customer_data: CustomerData, message_type: OtaMessageType
) -> Union[NotifResGuests, RetrieveResGuests]:
""" """
Create a complete ResGuests structure with a primary customer. Create a complete ResGuests structure with a primary customer.
Args: Args:
customer_data: The customer data customer_data: The customer data
message_type: Whether to create for NOTIF or RETRIEVE message types message_type: Whether to create for NOTIF or RETRIEVE message types
Returns: Returns:
The appropriate ResGuests object The appropriate ResGuests object
""" """
@@ -579,43 +607,45 @@ class AlpineBitsFactory:
return ResGuestFactory.create_notif_res_guests(customer_data) return ResGuestFactory.create_notif_res_guests(customer_data)
else: else:
return ResGuestFactory.create_retrieve_res_guests(customer_data) return ResGuestFactory.create_retrieve_res_guests(customer_data)
@staticmethod @staticmethod
def extract_data(obj: Any) -> Union[CustomerData, HotelReservationIdData, CommentsData]: def extract_data(
obj: Any,
) -> Union[CustomerData, HotelReservationIdData, CommentsData]:
""" """
Extract data from an AlpineBits object back to a simple data class. Extract data from an AlpineBits object back to a simple data class.
Args: Args:
obj: The AlpineBits object to extract data from obj: The AlpineBits object to extract data from
Returns: Returns:
The appropriate data object The appropriate data object
""" """
# Check if it's a Customer object # Check if it's a Customer object
if hasattr(obj, 'person_name') and hasattr(obj.person_name, 'given_name'): if hasattr(obj, "person_name") and hasattr(obj.person_name, "given_name"):
if isinstance(obj, NotifCustomer): if isinstance(obj, NotifCustomer):
return CustomerFactory.from_notif_customer(obj) return CustomerFactory.from_notif_customer(obj)
elif isinstance(obj, RetrieveCustomer): elif isinstance(obj, RetrieveCustomer):
return CustomerFactory.from_retrieve_customer(obj) return CustomerFactory.from_retrieve_customer(obj)
# Check if it's a HotelReservationId object # Check if it's a HotelReservationId object
elif hasattr(obj, 'res_id_type'): elif hasattr(obj, "res_id_type"):
if isinstance(obj, NotifHotelReservationId): if isinstance(obj, NotifHotelReservationId):
return HotelReservationIdFactory.from_notif_hotel_reservation_id(obj) return HotelReservationIdFactory.from_notif_hotel_reservation_id(obj)
elif isinstance(obj, RetrieveHotelReservationId): elif isinstance(obj, RetrieveHotelReservationId):
return HotelReservationIdFactory.from_retrieve_hotel_reservation_id(obj) return HotelReservationIdFactory.from_retrieve_hotel_reservation_id(obj)
# Check if it's a Comments object # Check if it's a Comments object
elif hasattr(obj, 'comment'): elif hasattr(obj, "comment"):
if isinstance(obj, NotifComments): if isinstance(obj, NotifComments):
return CommentFactory.from_notif_comments(obj) return CommentFactory.from_notif_comments(obj)
elif isinstance(obj, RetrieveComments): elif isinstance(obj, RetrieveComments):
return CommentFactory.from_retrieve_comments(obj) return CommentFactory.from_retrieve_comments(obj)
# Check if it's a ResGuests object # Check if it's a ResGuests object
elif hasattr(obj, 'res_guest'): elif hasattr(obj, "res_guest"):
return ResGuestFactory.extract_primary_customer(obj) return ResGuestFactory.extract_primary_customer(obj)
else: else:
raise ValueError(f"Unsupported object type: {type(obj)}") raise ValueError(f"Unsupported object type: {type(obj)}")
@@ -733,70 +763,74 @@ if __name__ == "__main__":
# Verify roundtrip conversion # Verify roundtrip conversion
print("Roundtrip conversion successful:", customer_data == extracted_data) print("Roundtrip conversion successful:", customer_data == extracted_data)
print("\n--- Unified AlpineBitsFactory Examples ---") print("\n--- Unified AlpineBitsFactory Examples ---")
# Much simpler approach - single factory with enum parameter! # Much simpler approach - single factory with enum parameter!
print("=== Customer Creation ===") print("=== Customer Creation ===")
notif_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF) notif_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF)
retrieve_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE) retrieve_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE)
print("Created customers using unified factory") print("Created customers using unified factory")
print("=== HotelReservationId Creation ===") print("=== HotelReservationId Creation ===")
reservation_id_data = HotelReservationIdData( reservation_id_data = HotelReservationIdData(
res_id_type="123", res_id_type="123", res_id_value="RESERVATION-456", res_id_source="HOTEL_SYSTEM"
res_id_value="RESERVATION-456",
res_id_source="HOTEL_SYSTEM"
) )
notif_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.NOTIF) notif_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.NOTIF)
retrieve_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.RETRIEVE) retrieve_res_id = AlpineBitsFactory.create(
reservation_id_data, OtaMessageType.RETRIEVE
)
print("Created reservation IDs using unified factory") print("Created reservation IDs using unified factory")
print("=== Comments Creation ===") print("=== Comments Creation ===")
comments_data = CommentsData(comments=[ comments_data = CommentsData(
CommentData( comments=[
name=CommentName2.CUSTOMER_COMMENT, CommentData(
text="This is a customer comment about the reservation", name=CommentName2.CUSTOMER_COMMENT,
list_items=[ text="This is a customer comment about the reservation",
CommentListItemData( list_items=[
value="Special dietary requirements: vegetarian", CommentListItemData(
list_item="1", value="Special dietary requirements: vegetarian",
language="en" list_item="1",
), language="en",
CommentListItemData( ),
value="Late arrival expected", CommentListItemData(
list_item="2", value="Late arrival expected", list_item="2", language="en"
language="en" ),
) ],
] ),
), CommentData(
CommentData( name=CommentName2.ADDITIONAL_INFO,
name=CommentName2.ADDITIONAL_INFO, text="Additional information about the stay",
text="Additional information about the stay" ),
) ]
]) )
notif_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.NOTIF) notif_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.NOTIF)
retrieve_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.RETRIEVE) retrieve_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.RETRIEVE)
print("Created comments using unified factory") print("Created comments using unified factory")
print("=== ResGuests Creation ===") print("=== ResGuests Creation ===")
notif_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF) notif_res_guests = AlpineBitsFactory.create_res_guests(
retrieve_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) customer_data, OtaMessageType.NOTIF
)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
print("Created ResGuests using unified factory") print("Created ResGuests using unified factory")
print("=== Data Extraction ===") print("=== Data Extraction ===")
# Extract data back using unified interface # Extract data back using unified interface
extracted_customer_data = AlpineBitsFactory.extract_data(notif_customer) extracted_customer_data = AlpineBitsFactory.extract_data(notif_customer)
extracted_res_id_data = AlpineBitsFactory.extract_data(notif_res_id) extracted_res_id_data = AlpineBitsFactory.extract_data(notif_res_id)
extracted_comments_data = AlpineBitsFactory.extract_data(retrieve_comments) extracted_comments_data = AlpineBitsFactory.extract_data(retrieve_comments)
extracted_from_res_guests = AlpineBitsFactory.extract_data(retrieve_res_guests) extracted_from_res_guests = AlpineBitsFactory.extract_data(retrieve_res_guests)
print("Data extraction successful:") print("Data extraction successful:")
print("- Customer roundtrip:", customer_data == extracted_customer_data) print("- Customer roundtrip:", customer_data == extracted_customer_data)
print("- ReservationId roundtrip:", reservation_id_data == extracted_res_id_data) print("- ReservationId roundtrip:", reservation_id_data == extracted_res_id_data)
print("- Comments roundtrip:", comments_data == extracted_comments_data) print("- Comments roundtrip:", comments_data == extracted_comments_data)
print("- ResGuests roundtrip:", customer_data == extracted_from_res_guests) print("- ResGuests roundtrip:", customer_data == extracted_from_res_guests)
print("\n--- Comparison with old approach ---") print("\n--- Comparison with old approach ---")
print("Old way required multiple imports and knowing specific factory methods") print("Old way required multiple imports and knowing specific factory methods")
print("New way: single import, single factory, enum parameter to specify type!") print("New way: single import, single factory, enum parameter to specify type!")

View File

@@ -1 +1 @@
"""Utility functions for alpine_bits_python.""" """Utility functions for alpine_bits_python."""

View File

@@ -1,5 +1,6 @@
"""Entry point for util package.""" """Entry point for util package."""
from .handshake_util import main from .handshake_util import main
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,26 +2,22 @@ from ..generated.alpinebits import OtaPingRq, OtaPingRs
from xsdata_pydantic.bindings import XmlParser from xsdata_pydantic.bindings import XmlParser
def main(): def main():
# test parsing a ping request sample # test parsing a ping request sample
path = "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml" path = (
"AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml"
)
with open( with open(path, "r", encoding="utf-8") as f:
path, "r", encoding="utf-8") as f:
xml = f.read() xml = f.read()
# Parse the XML into the request object # Parse the XML into the request object
# Test parsing back # Test parsing back
parser = XmlParser() parser = XmlParser()
parsed_result = parser.from_string(xml, OtaPingRs) parsed_result = parser.from_string(xml, OtaPingRs)
print(parsed_result.echo_data) print(parsed_result.echo_data)
@@ -34,19 +30,14 @@ def main():
print(warning.content[0]) print(warning.content[0])
# save json in echo_data to file with indents # save json in echo_data to file with indents
output_path = "echo_data_response.json" output_path = "echo_data_response.json"
with open(output_path, "w", encoding="utf-8") as out_f: with open(output_path, "w", encoding="utf-8") as out_f:
import json import json
json.dump(json.loads(parsed_result.echo_data), out_f, indent=4) json.dump(json.loads(parsed_result.echo_data), out_f, indent=4)
print(f"Saved echo_data json to {output_path}") print(f"Saved echo_data json to {output_path}")
if __name__ == "__main__": if __name__ == "__main__":
main()
main()

View File

@@ -2,12 +2,13 @@
""" """
Convenience launcher for the Wix Form Handler API Convenience launcher for the Wix Form Handler API
""" """
import os import os
import subprocess import subprocess
# Change to src directory # Change to src directory
src_dir = os.path.join(os.path.dirname(__file__), 'src/alpine_bits_python') src_dir = os.path.join(os.path.dirname(__file__), "src/alpine_bits_python")
# Run the API using uv # Run the API using uv
if __name__ == "__main__": if __name__ == "__main__":
subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")]) subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")])

View File

@@ -5,57 +5,63 @@ discovers implemented vs unimplemented actions.
""" """
from alpine_bits_python.alpinebits_server import ( from alpine_bits_python.alpinebits_server import (
ServerCapabilities, ServerCapabilities,
AlpineBitsAction, AlpineBitsAction,
AlpineBitsActionName, AlpineBitsActionName,
Version, Version,
AlpineBitsResponse, AlpineBitsResponse,
HttpStatusCode HttpStatusCode,
) )
import asyncio import asyncio
class NewImplementedAction(AlpineBitsAction): class NewImplementedAction(AlpineBitsAction):
"""A new action that IS implemented.""" """A new action that IS implemented."""
def __init__(self): def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_INFO_INFO self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_INFO_INFO
self.version = Version.V2024_10 self.version = Version.V2024_10
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""This action is implemented.""" """This action is implemented."""
return AlpineBitsResponse("Implemented!", HttpStatusCode.OK) return AlpineBitsResponse("Implemented!", HttpStatusCode.OK)
class NewUnimplementedAction(AlpineBitsAction): class NewUnimplementedAction(AlpineBitsAction):
"""A new action that is NOT implemented (no handle override).""" """A new action that is NOT implemented (no handle override)."""
def __init__(self): def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO
self.version = Version.V2024_10 self.version = Version.V2024_10
# Notice: No handle method override - will use default "not implemented" # Notice: No handle method override - will use default "not implemented"
async def main(): async def main():
print("🔍 Testing Action Discovery Logic") print("🔍 Testing Action Discovery Logic")
print("=" * 50) print("=" * 50)
# Create capabilities and see what gets discovered # Create capabilities and see what gets discovered
capabilities = ServerCapabilities() capabilities = ServerCapabilities()
print("📋 Actions found by discovery:") print("📋 Actions found by discovery:")
for action_name in capabilities.get_supported_actions(): for action_name in capabilities.get_supported_actions():
print(f"{action_name}") print(f"{action_name}")
print(f"\n📊 Total discovered: {len(capabilities.get_supported_actions())}") print(f"\n📊 Total discovered: {len(capabilities.get_supported_actions())}")
# Test the new implemented action # Test the new implemented action
implemented_action = NewImplementedAction() implemented_action = NewImplementedAction()
result = await implemented_action.handle("test", "<xml/>", Version.V2024_10) result = await implemented_action.handle("test", "<xml/>", Version.V2024_10)
print(f"\n🟢 NewImplementedAction result: {result.xml_content}") print(f"\n🟢 NewImplementedAction result: {result.xml_content}")
# Test the unimplemented action (should use default behavior) # Test the unimplemented action (should use default behavior)
unimplemented_action = NewUnimplementedAction() unimplemented_action = NewUnimplementedAction()
result = await unimplemented_action.handle("test", "<xml/>", Version.V2024_10) result = await unimplemented_action.handle("test", "<xml/>", Version.V2024_10)
print(f"🔴 NewUnimplementedAction result: {result.xml_content}") print(f"🔴 NewUnimplementedAction result: {result.xml_content}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -4,11 +4,11 @@ import sys
import os import os
# Add the src directory to the path so we can import our modules # Add the src directory to the path so we can import our modules
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from simplified_access import ( from simplified_access import (
CustomerData, CustomerData,
CustomerFactory, CustomerFactory,
ResGuestFactory, ResGuestFactory,
HotelReservationIdData, HotelReservationIdData,
HotelReservationIdFactory, HotelReservationIdFactory,
@@ -20,7 +20,7 @@ from simplified_access import (
NotifResGuests, NotifResGuests,
RetrieveResGuests, RetrieveResGuests,
NotifHotelReservationId, NotifHotelReservationId,
RetrieveHotelReservationId RetrieveHotelReservationId,
) )
@@ -35,7 +35,7 @@ def sample_customer_data():
phone_numbers=[ phone_numbers=[
("+1234567890", PhoneTechType.MOBILE), ("+1234567890", PhoneTechType.MOBILE),
("+0987654321", PhoneTechType.VOICE), ("+0987654321", PhoneTechType.VOICE),
("+1111111111", None) ("+1111111111", None),
], ],
email_address="john.doe@example.com", email_address="john.doe@example.com",
email_newsletter=True, email_newsletter=True,
@@ -46,17 +46,14 @@ def sample_customer_data():
address_catalog=False, address_catalog=False,
gender="Male", gender="Male",
birth_date="1980-01-01", birth_date="1980-01-01",
language="en" language="en",
) )
@pytest.fixture @pytest.fixture
def minimal_customer_data(): def minimal_customer_data():
"""Fixture providing minimal customer data (only required fields).""" """Fixture providing minimal customer data (only required fields)."""
return CustomerData( return CustomerData(given_name="Jane", surname="Smith")
given_name="Jane",
surname="Smith"
)
@pytest.fixture @pytest.fixture
@@ -66,21 +63,19 @@ def sample_hotel_reservation_id_data():
res_id_type="123", res_id_type="123",
res_id_value="RESERVATION-456", res_id_value="RESERVATION-456",
res_id_source="HOTEL_SYSTEM", res_id_source="HOTEL_SYSTEM",
res_id_source_context="BOOKING_ENGINE" res_id_source_context="BOOKING_ENGINE",
) )
@pytest.fixture @pytest.fixture
def minimal_hotel_reservation_id_data(): def minimal_hotel_reservation_id_data():
"""Fixture providing minimal hotel reservation ID data (only required fields).""" """Fixture providing minimal hotel reservation ID data (only required fields)."""
return HotelReservationIdData( return HotelReservationIdData(res_id_type="999")
res_id_type="999"
)
class TestCustomerData: class TestCustomerData:
"""Test the CustomerData dataclass.""" """Test the CustomerData dataclass."""
def test_customer_data_creation_full(self, sample_customer_data): def test_customer_data_creation_full(self, sample_customer_data):
"""Test creating CustomerData with all fields.""" """Test creating CustomerData with all fields."""
assert sample_customer_data.given_name == "John" assert sample_customer_data.given_name == "John"
@@ -89,7 +84,7 @@ class TestCustomerData:
assert sample_customer_data.email_address == "john.doe@example.com" assert sample_customer_data.email_address == "john.doe@example.com"
assert sample_customer_data.email_newsletter is True assert sample_customer_data.email_newsletter is True
assert len(sample_customer_data.phone_numbers) == 3 assert len(sample_customer_data.phone_numbers) == 3
def test_customer_data_creation_minimal(self, minimal_customer_data): def test_customer_data_creation_minimal(self, minimal_customer_data):
"""Test creating CustomerData with only required fields.""" """Test creating CustomerData with only required fields."""
assert minimal_customer_data.given_name == "Jane" assert minimal_customer_data.given_name == "Jane"
@@ -97,7 +92,7 @@ class TestCustomerData:
assert minimal_customer_data.phone_numbers == [] assert minimal_customer_data.phone_numbers == []
assert minimal_customer_data.email_address is None assert minimal_customer_data.email_address is None
assert minimal_customer_data.address_line is None assert minimal_customer_data.address_line is None
def test_phone_numbers_default_initialization(self): def test_phone_numbers_default_initialization(self):
"""Test that phone_numbers gets initialized to empty list.""" """Test that phone_numbers gets initialized to empty list."""
customer_data = CustomerData(given_name="Test", surname="User") customer_data = CustomerData(given_name="Test", surname="User")
@@ -106,54 +101,56 @@ class TestCustomerData:
class TestCustomerFactory: class TestCustomerFactory:
"""Test the CustomerFactory class.""" """Test the CustomerFactory class."""
def test_create_notif_customer_full(self, sample_customer_data): def test_create_notif_customer_full(self, sample_customer_data):
"""Test creating a NotifCustomer with full data.""" """Test creating a NotifCustomer with full data."""
customer = CustomerFactory.create_notif_customer(sample_customer_data) customer = CustomerFactory.create_notif_customer(sample_customer_data)
assert isinstance(customer, NotifCustomer) assert isinstance(customer, NotifCustomer)
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe" assert customer.person_name.surname == "Doe"
assert customer.person_name.name_prefix == "Mr." assert customer.person_name.name_prefix == "Mr."
assert customer.person_name.name_title == "Jr." assert customer.person_name.name_title == "Jr."
# Check telephone # Check telephone
assert len(customer.telephone) == 3 assert len(customer.telephone) == 3
assert customer.telephone[0].phone_number == "+1234567890" assert customer.telephone[0].phone_number == "+1234567890"
assert customer.telephone[0].phone_tech_type == "5" # MOBILE assert customer.telephone[0].phone_tech_type == "5" # MOBILE
assert customer.telephone[1].phone_tech_type == "1" # VOICE assert customer.telephone[1].phone_tech_type == "1" # VOICE
assert customer.telephone[2].phone_tech_type is None assert customer.telephone[2].phone_tech_type is None
# Check email # Check email
assert customer.email.value == "john.doe@example.com" assert customer.email.value == "john.doe@example.com"
assert customer.email.remark == "newsletter:yes" assert customer.email.remark == "newsletter:yes"
# Check address # Check address
assert customer.address.address_line == "123 Main Street" assert customer.address.address_line == "123 Main Street"
assert customer.address.city_name == "Anytown" assert customer.address.city_name == "Anytown"
assert customer.address.postal_code == "12345" assert customer.address.postal_code == "12345"
assert customer.address.country_name.code == "US" assert customer.address.country_name.code == "US"
assert customer.address.remark == "catalog:no" assert customer.address.remark == "catalog:no"
# Check other attributes # Check other attributes
assert customer.gender == "Male" assert customer.gender == "Male"
assert customer.birth_date == "1980-01-01" assert customer.birth_date == "1980-01-01"
assert customer.language == "en" assert customer.language == "en"
def test_create_retrieve_customer_full(self, sample_customer_data): def test_create_retrieve_customer_full(self, sample_customer_data):
"""Test creating a RetrieveCustomer with full data.""" """Test creating a RetrieveCustomer with full data."""
customer = CustomerFactory.create_retrieve_customer(sample_customer_data) customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
assert isinstance(customer, RetrieveCustomer) assert isinstance(customer, RetrieveCustomer)
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe" assert customer.person_name.surname == "Doe"
# Same structure as NotifCustomer, so we don't need to test all fields again # Same structure as NotifCustomer, so we don't need to test all fields again
def test_create_customer_minimal(self, minimal_customer_data): def test_create_customer_minimal(self, minimal_customer_data):
"""Test creating customers with minimal data.""" """Test creating customers with minimal data."""
notif_customer = CustomerFactory.create_notif_customer(minimal_customer_data) notif_customer = CustomerFactory.create_notif_customer(minimal_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(minimal_customer_data) retrieve_customer = CustomerFactory.create_retrieve_customer(
minimal_customer_data
)
for customer in [notif_customer, retrieve_customer]: for customer in [notif_customer, retrieve_customer]:
assert customer.person_name.given_name == "Jane" assert customer.person_name.given_name == "Jane"
assert customer.person_name.surname == "Smith" assert customer.person_name.surname == "Smith"
@@ -165,73 +162,97 @@ class TestCustomerFactory:
assert customer.gender is None assert customer.gender is None
assert customer.birth_date is None assert customer.birth_date is None
assert customer.language is None assert customer.language is None
def test_email_newsletter_options(self): def test_email_newsletter_options(self):
"""Test different email newsletter options.""" """Test different email newsletter options."""
# Newsletter yes # Newsletter yes
data_yes = CustomerData(given_name="Test", surname="User", data_yes = CustomerData(
email_address="test@example.com", email_newsletter=True) given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=True,
)
customer = CustomerFactory.create_notif_customer(data_yes) customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.email.remark == "newsletter:yes" assert customer.email.remark == "newsletter:yes"
# Newsletter no # Newsletter no
data_no = CustomerData(given_name="Test", surname="User", data_no = CustomerData(
email_address="test@example.com", email_newsletter=False) given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=False,
)
customer = CustomerFactory.create_notif_customer(data_no) customer = CustomerFactory.create_notif_customer(data_no)
assert customer.email.remark == "newsletter:no" assert customer.email.remark == "newsletter:no"
# Newsletter not specified # Newsletter not specified
data_none = CustomerData(given_name="Test", surname="User", data_none = CustomerData(
email_address="test@example.com", email_newsletter=None) given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=None,
)
customer = CustomerFactory.create_notif_customer(data_none) customer = CustomerFactory.create_notif_customer(data_none)
assert customer.email.remark is None assert customer.email.remark is None
def test_address_catalog_options(self): def test_address_catalog_options(self):
"""Test different address catalog options.""" """Test different address catalog options."""
# Catalog no # Catalog no
data_no = CustomerData(given_name="Test", surname="User", data_no = CustomerData(
address_line="123 Street", address_catalog=False) given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=False,
)
customer = CustomerFactory.create_notif_customer(data_no) customer = CustomerFactory.create_notif_customer(data_no)
assert customer.address.remark == "catalog:no" assert customer.address.remark == "catalog:no"
# Catalog yes # Catalog yes
data_yes = CustomerData(given_name="Test", surname="User", data_yes = CustomerData(
address_line="123 Street", address_catalog=True) given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=True,
)
customer = CustomerFactory.create_notif_customer(data_yes) customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.address.remark == "catalog:yes" assert customer.address.remark == "catalog:yes"
# Catalog not specified # Catalog not specified
data_none = CustomerData(given_name="Test", surname="User", data_none = CustomerData(
address_line="123 Street", address_catalog=None) given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=None,
)
customer = CustomerFactory.create_notif_customer(data_none) customer = CustomerFactory.create_notif_customer(data_none)
assert customer.address.remark is None assert customer.address.remark is None
def test_from_notif_customer_roundtrip(self, sample_customer_data): def test_from_notif_customer_roundtrip(self, sample_customer_data):
"""Test converting NotifCustomer back to CustomerData.""" """Test converting NotifCustomer back to CustomerData."""
customer = CustomerFactory.create_notif_customer(sample_customer_data) customer = CustomerFactory.create_notif_customer(sample_customer_data)
converted_data = CustomerFactory.from_notif_customer(customer) converted_data = CustomerFactory.from_notif_customer(customer)
assert converted_data == sample_customer_data assert converted_data == sample_customer_data
def test_from_retrieve_customer_roundtrip(self, sample_customer_data): def test_from_retrieve_customer_roundtrip(self, sample_customer_data):
"""Test converting RetrieveCustomer back to CustomerData.""" """Test converting RetrieveCustomer back to CustomerData."""
customer = CustomerFactory.create_retrieve_customer(sample_customer_data) customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
converted_data = CustomerFactory.from_retrieve_customer(customer) converted_data = CustomerFactory.from_retrieve_customer(customer)
assert converted_data == sample_customer_data assert converted_data == sample_customer_data
def test_phone_tech_type_conversion(self): def test_phone_tech_type_conversion(self):
"""Test that PhoneTechType enum values are properly converted.""" """Test that PhoneTechType enum values are properly converted."""
data = CustomerData( data = CustomerData(
given_name="Test", given_name="Test",
surname="User", surname="User",
phone_numbers=[ phone_numbers=[
("+1111111111", PhoneTechType.VOICE), ("+1111111111", PhoneTechType.VOICE),
("+2222222222", PhoneTechType.FAX), ("+2222222222", PhoneTechType.FAX),
("+3333333333", PhoneTechType.MOBILE) ("+3333333333", PhoneTechType.MOBILE),
] ],
) )
customer = CustomerFactory.create_notif_customer(data) customer = CustomerFactory.create_notif_customer(data)
assert customer.telephone[0].phone_tech_type == "1" # VOICE assert customer.telephone[0].phone_tech_type == "1" # VOICE
assert customer.telephone[1].phone_tech_type == "3" # FAX assert customer.telephone[1].phone_tech_type == "3" # FAX
@@ -240,15 +261,21 @@ class TestCustomerFactory:
class TestHotelReservationIdData: class TestHotelReservationIdData:
"""Test the HotelReservationIdData dataclass.""" """Test the HotelReservationIdData dataclass."""
def test_hotel_reservation_id_data_creation_full(self, sample_hotel_reservation_id_data): def test_hotel_reservation_id_data_creation_full(
self, sample_hotel_reservation_id_data
):
"""Test creating HotelReservationIdData with all fields.""" """Test creating HotelReservationIdData with all fields."""
assert sample_hotel_reservation_id_data.res_id_type == "123" assert sample_hotel_reservation_id_data.res_id_type == "123"
assert sample_hotel_reservation_id_data.res_id_value == "RESERVATION-456" assert sample_hotel_reservation_id_data.res_id_value == "RESERVATION-456"
assert sample_hotel_reservation_id_data.res_id_source == "HOTEL_SYSTEM" assert sample_hotel_reservation_id_data.res_id_source == "HOTEL_SYSTEM"
assert sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE" assert (
sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE"
def test_hotel_reservation_id_data_creation_minimal(self, minimal_hotel_reservation_id_data): )
def test_hotel_reservation_id_data_creation_minimal(
self, minimal_hotel_reservation_id_data
):
"""Test creating HotelReservationIdData with only required fields.""" """Test creating HotelReservationIdData with only required fields."""
assert minimal_hotel_reservation_id_data.res_id_type == "999" assert minimal_hotel_reservation_id_data.res_id_type == "999"
assert minimal_hotel_reservation_id_data.res_id_value is None assert minimal_hotel_reservation_id_data.res_id_value is None
@@ -258,124 +285,158 @@ class TestHotelReservationIdData:
class TestHotelReservationIdFactory: class TestHotelReservationIdFactory:
"""Test the HotelReservationIdFactory class.""" """Test the HotelReservationIdFactory class."""
def test_create_notif_hotel_reservation_id_full(self, sample_hotel_reservation_id_data): def test_create_notif_hotel_reservation_id_full(
self, sample_hotel_reservation_id_data
):
"""Test creating a NotifHotelReservationId with full data.""" """Test creating a NotifHotelReservationId with full data."""
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
sample_hotel_reservation_id_data
)
assert isinstance(reservation_id, NotifHotelReservationId) assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_value == "RESERVATION-456"
assert reservation_id.res_id_source == "HOTEL_SYSTEM" assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE" assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
def test_create_retrieve_hotel_reservation_id_full(self, sample_hotel_reservation_id_data): def test_create_retrieve_hotel_reservation_id_full(
self, sample_hotel_reservation_id_data
):
"""Test creating a RetrieveHotelReservationId with full data.""" """Test creating a RetrieveHotelReservationId with full data."""
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
sample_hotel_reservation_id_data
)
assert isinstance(reservation_id, RetrieveHotelReservationId) assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_value == "RESERVATION-456"
assert reservation_id.res_id_source == "HOTEL_SYSTEM" assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE" assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
def test_create_hotel_reservation_id_minimal(self, minimal_hotel_reservation_id_data): def test_create_hotel_reservation_id_minimal(
self, minimal_hotel_reservation_id_data
):
"""Test creating hotel reservation IDs with minimal data.""" """Test creating hotel reservation IDs with minimal data."""
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(minimal_hotel_reservation_id_data) notif_reservation_id = (
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(minimal_hotel_reservation_id_data) HotelReservationIdFactory.create_notif_hotel_reservation_id(
minimal_hotel_reservation_id_data
)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
minimal_hotel_reservation_id_data
)
)
for reservation_id in [notif_reservation_id, retrieve_reservation_id]: for reservation_id in [notif_reservation_id, retrieve_reservation_id]:
assert reservation_id.res_id_type == "999" assert reservation_id.res_id_type == "999"
assert reservation_id.res_id_value is None assert reservation_id.res_id_value is None
assert reservation_id.res_id_source is None assert reservation_id.res_id_source is None
assert reservation_id.res_id_source_context is None assert reservation_id.res_id_source_context is None
def test_from_notif_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data): def test_from_notif_hotel_reservation_id_roundtrip(
self, sample_hotel_reservation_id_data
):
"""Test converting NotifHotelReservationId back to HotelReservationIdData.""" """Test converting NotifHotelReservationId back to HotelReservationIdData."""
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(reservation_id) sample_hotel_reservation_id_data
)
converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(
reservation_id
)
assert converted_data == sample_hotel_reservation_id_data assert converted_data == sample_hotel_reservation_id_data
def test_from_retrieve_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data): def test_from_retrieve_hotel_reservation_id_roundtrip(
self, sample_hotel_reservation_id_data
):
"""Test converting RetrieveHotelReservationId back to HotelReservationIdData.""" """Test converting RetrieveHotelReservationId back to HotelReservationIdData."""
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(reservation_id) sample_hotel_reservation_id_data
)
converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
reservation_id
)
assert converted_data == sample_hotel_reservation_id_data assert converted_data == sample_hotel_reservation_id_data
class TestResGuestFactory: class TestResGuestFactory:
"""Test the ResGuestFactory class.""" """Test the ResGuestFactory class."""
def test_create_notif_res_guests(self, sample_customer_data): def test_create_notif_res_guests(self, sample_customer_data):
"""Test creating NotifResGuests structure.""" """Test creating NotifResGuests structure."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
assert isinstance(res_guests, NotifResGuests) assert isinstance(res_guests, NotifResGuests)
# Navigate down the nested structure # Navigate down the nested structure
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe" assert customer.person_name.surname == "Doe"
assert customer.email.value == "john.doe@example.com" assert customer.email.value == "john.doe@example.com"
def test_create_retrieve_res_guests(self, sample_customer_data): def test_create_retrieve_res_guests(self, sample_customer_data):
"""Test creating RetrieveResGuests structure.""" """Test creating RetrieveResGuests structure."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
assert isinstance(res_guests, RetrieveResGuests) assert isinstance(res_guests, RetrieveResGuests)
# Navigate down the nested structure # Navigate down the nested structure
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe" assert customer.person_name.surname == "Doe"
assert customer.email.value == "john.doe@example.com" assert customer.email.value == "john.doe@example.com"
def test_create_res_guests_minimal(self, minimal_customer_data): def test_create_res_guests_minimal(self, minimal_customer_data):
"""Test creating ResGuests with minimal customer data.""" """Test creating ResGuests with minimal customer data."""
notif_res_guests = ResGuestFactory.create_notif_res_guests(minimal_customer_data) notif_res_guests = ResGuestFactory.create_notif_res_guests(
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(minimal_customer_data) minimal_customer_data
)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
minimal_customer_data
)
for res_guests in [notif_res_guests, retrieve_res_guests]: for res_guests in [notif_res_guests, retrieve_res_guests]:
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "Jane" assert customer.person_name.given_name == "Jane"
assert customer.person_name.surname == "Smith" assert customer.person_name.surname == "Smith"
assert customer.email is None assert customer.email is None
assert customer.address is None assert customer.address is None
def test_extract_primary_customer_notif(self, sample_customer_data): def test_extract_primary_customer_notif(self, sample_customer_data):
"""Test extracting primary customer from NotifResGuests.""" """Test extracting primary customer from NotifResGuests."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests) extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data assert extracted_data == sample_customer_data
def test_extract_primary_customer_retrieve(self, sample_customer_data): def test_extract_primary_customer_retrieve(self, sample_customer_data):
"""Test extracting primary customer from RetrieveResGuests.""" """Test extracting primary customer from RetrieveResGuests."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests) extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data assert extracted_data == sample_customer_data
def test_roundtrip_conversion_notif(self, sample_customer_data): def test_roundtrip_conversion_notif(self, sample_customer_data):
"""Test complete roundtrip: CustomerData -> NotifResGuests -> CustomerData.""" """Test complete roundtrip: CustomerData -> NotifResGuests -> CustomerData."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests) extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data assert extracted_data == sample_customer_data
def test_roundtrip_conversion_retrieve(self, sample_customer_data): def test_roundtrip_conversion_retrieve(self, sample_customer_data):
"""Test complete roundtrip: CustomerData -> RetrieveResGuests -> CustomerData.""" """Test complete roundtrip: CustomerData -> RetrieveResGuests -> CustomerData."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests) extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data assert extracted_data == sample_customer_data
class TestPhoneTechType: class TestPhoneTechType:
"""Test the PhoneTechType enum.""" """Test the PhoneTechType enum."""
def test_enum_values(self): def test_enum_values(self):
"""Test that enum values are correct.""" """Test that enum values are correct."""
assert PhoneTechType.VOICE.value == "1" assert PhoneTechType.VOICE.value == "1"
@@ -385,95 +446,121 @@ class TestPhoneTechType:
class TestAlpineBitsFactory: class TestAlpineBitsFactory:
"""Test the unified AlpineBitsFactory class.""" """Test the unified AlpineBitsFactory class."""
def test_create_customer_notif(self, sample_customer_data): def test_create_customer_notif(self, sample_customer_data):
"""Test creating customer using unified factory for NOTIF.""" """Test creating customer using unified factory for NOTIF."""
customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF) customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF)
assert isinstance(customer, NotifCustomer) assert isinstance(customer, NotifCustomer)
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe" assert customer.person_name.surname == "Doe"
def test_create_customer_retrieve(self, sample_customer_data): def test_create_customer_retrieve(self, sample_customer_data):
"""Test creating customer using unified factory for RETRIEVE.""" """Test creating customer using unified factory for RETRIEVE."""
customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE) customer = AlpineBitsFactory.create(
sample_customer_data, OtaMessageType.RETRIEVE
)
assert isinstance(customer, RetrieveCustomer) assert isinstance(customer, RetrieveCustomer)
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe" assert customer.person_name.surname == "Doe"
def test_create_hotel_reservation_id_notif(self, sample_hotel_reservation_id_data): def test_create_hotel_reservation_id_notif(self, sample_hotel_reservation_id_data):
"""Test creating hotel reservation ID using unified factory for NOTIF.""" """Test creating hotel reservation ID using unified factory for NOTIF."""
reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF) reservation_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.NOTIF
)
assert isinstance(reservation_id, NotifHotelReservationId) assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_value == "RESERVATION-456"
def test_create_hotel_reservation_id_retrieve(self, sample_hotel_reservation_id_data): def test_create_hotel_reservation_id_retrieve(
self, sample_hotel_reservation_id_data
):
"""Test creating hotel reservation ID using unified factory for RETRIEVE.""" """Test creating hotel reservation ID using unified factory for RETRIEVE."""
reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE) reservation_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
)
assert isinstance(reservation_id, RetrieveHotelReservationId) assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_value == "RESERVATION-456"
def test_create_res_guests_notif(self, sample_customer_data): def test_create_res_guests_notif(self, sample_customer_data):
"""Test creating ResGuests using unified factory for NOTIF.""" """Test creating ResGuests using unified factory for NOTIF."""
res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF) res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.NOTIF
)
assert isinstance(res_guests, NotifResGuests) assert isinstance(res_guests, NotifResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
def test_create_res_guests_retrieve(self, sample_customer_data): def test_create_res_guests_retrieve(self, sample_customer_data):
"""Test creating ResGuests using unified factory for RETRIEVE.""" """Test creating ResGuests using unified factory for RETRIEVE."""
res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE) res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.RETRIEVE
)
assert isinstance(res_guests, RetrieveResGuests) assert isinstance(res_guests, RetrieveResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
def test_extract_data_from_customer(self, sample_customer_data): def test_extract_data_from_customer(self, sample_customer_data):
"""Test extracting data from customer objects.""" """Test extracting data from customer objects."""
# Create both types and extract data back # Create both types and extract data back
notif_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF) notif_customer = AlpineBitsFactory.create(
retrieve_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE) sample_customer_data, OtaMessageType.NOTIF
)
retrieve_customer = AlpineBitsFactory.create(
sample_customer_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_customer) notif_extracted = AlpineBitsFactory.extract_data(notif_customer)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_customer) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_customer)
assert notif_extracted == sample_customer_data assert notif_extracted == sample_customer_data
assert retrieve_extracted == sample_customer_data assert retrieve_extracted == sample_customer_data
def test_extract_data_from_hotel_reservation_id(self, sample_hotel_reservation_id_data): def test_extract_data_from_hotel_reservation_id(
self, sample_hotel_reservation_id_data
):
"""Test extracting data from hotel reservation ID objects.""" """Test extracting data from hotel reservation ID objects."""
# Create both types and extract data back # Create both types and extract data back
notif_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF) notif_res_id = AlpineBitsFactory.create(
retrieve_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE) sample_hotel_reservation_id_data, OtaMessageType.NOTIF
)
retrieve_res_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_res_id) notif_extracted = AlpineBitsFactory.extract_data(notif_res_id)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_id) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_id)
assert notif_extracted == sample_hotel_reservation_id_data assert notif_extracted == sample_hotel_reservation_id_data
assert retrieve_extracted == sample_hotel_reservation_id_data assert retrieve_extracted == sample_hotel_reservation_id_data
def test_extract_data_from_res_guests(self, sample_customer_data): def test_extract_data_from_res_guests(self, sample_customer_data):
"""Test extracting data from ResGuests objects.""" """Test extracting data from ResGuests objects."""
# Create both types and extract data back # Create both types and extract data back
notif_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF) notif_res_guests = AlpineBitsFactory.create_res_guests(
retrieve_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE) sample_customer_data, OtaMessageType.NOTIF
)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_res_guests) notif_extracted = AlpineBitsFactory.extract_data(notif_res_guests)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_guests) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_guests)
assert notif_extracted == sample_customer_data assert notif_extracted == sample_customer_data
assert retrieve_extracted == sample_customer_data assert retrieve_extracted == sample_customer_data
def test_unsupported_data_type_error(self): def test_unsupported_data_type_error(self):
"""Test that unsupported data types raise ValueError.""" """Test that unsupported data types raise ValueError."""
with pytest.raises(ValueError, match="Unsupported data type"): with pytest.raises(ValueError, match="Unsupported data type"):
AlpineBitsFactory.create("invalid_data", OtaMessageType.NOTIF) AlpineBitsFactory.create("invalid_data", OtaMessageType.NOTIF)
def test_unsupported_object_type_error(self): def test_unsupported_object_type_error(self):
"""Test that unsupported object types raise ValueError in extract_data.""" """Test that unsupported object types raise ValueError in extract_data."""
with pytest.raises(ValueError, match="Unsupported object type"): with pytest.raises(ValueError, match="Unsupported object type"):
AlpineBitsFactory.extract_data("invalid_object") AlpineBitsFactory.extract_data("invalid_object")
def test_complete_workflow_with_unified_factory(self): def test_complete_workflow_with_unified_factory(self):
"""Test a complete workflow using only the unified factory.""" """Test a complete workflow using only the unified factory."""
# Original data # Original data
@@ -481,34 +568,47 @@ class TestAlpineBitsFactory:
given_name="Unified", given_name="Unified",
surname="Factory", surname="Factory",
email_address="unified@factory.com", email_address="unified@factory.com",
phone_numbers=[("+1234567890", PhoneTechType.MOBILE)] phone_numbers=[("+1234567890", PhoneTechType.MOBILE)],
) )
reservation_data = HotelReservationIdData( reservation_data = HotelReservationIdData(
res_id_type="999", res_id_type="999", res_id_value="UNIFIED-TEST"
res_id_value="UNIFIED-TEST"
) )
# Create using unified factory # Create using unified factory
customer_notif = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF) customer_notif = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF)
customer_retrieve = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE) customer_retrieve = AlpineBitsFactory.create(
customer_data, OtaMessageType.RETRIEVE
)
res_id_notif = AlpineBitsFactory.create(reservation_data, OtaMessageType.NOTIF) res_id_notif = AlpineBitsFactory.create(reservation_data, OtaMessageType.NOTIF)
res_id_retrieve = AlpineBitsFactory.create(reservation_data, OtaMessageType.RETRIEVE) res_id_retrieve = AlpineBitsFactory.create(
reservation_data, OtaMessageType.RETRIEVE
res_guests_notif = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF) )
res_guests_retrieve = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE)
res_guests_notif = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.NOTIF
)
res_guests_retrieve = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
# Extract everything back # Extract everything back
extracted_customer_from_notif = AlpineBitsFactory.extract_data(customer_notif) extracted_customer_from_notif = AlpineBitsFactory.extract_data(customer_notif)
extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(customer_retrieve) extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(
customer_retrieve
)
extracted_res_id_from_notif = AlpineBitsFactory.extract_data(res_id_notif) extracted_res_id_from_notif = AlpineBitsFactory.extract_data(res_id_notif)
extracted_res_id_from_retrieve = AlpineBitsFactory.extract_data(res_id_retrieve) extracted_res_id_from_retrieve = AlpineBitsFactory.extract_data(res_id_retrieve)
extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(res_guests_notif) extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(
extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(res_guests_retrieve) res_guests_notif
)
extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(
res_guests_retrieve
)
# Verify everything matches # Verify everything matches
assert extracted_customer_from_notif == customer_data assert extracted_customer_from_notif == customer_data
assert extracted_customer_from_retrieve == customer_data assert extracted_customer_from_retrieve == customer_data
@@ -520,37 +620,72 @@ class TestAlpineBitsFactory:
class TestIntegration: class TestIntegration:
"""Integration tests combining both factories.""" """Integration tests combining both factories."""
def test_both_factories_produce_same_customer_data(self, sample_customer_data): def test_both_factories_produce_same_customer_data(self, sample_customer_data):
"""Test that both factories can work with the same customer data.""" """Test that both factories can work with the same customer data."""
# Create using CustomerFactory # Create using CustomerFactory
notif_customer = CustomerFactory.create_notif_customer(sample_customer_data) notif_customer = CustomerFactory.create_notif_customer(sample_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(sample_customer_data) retrieve_customer = CustomerFactory.create_retrieve_customer(
sample_customer_data
)
# Create using ResGuestFactory and extract customers # Create using ResGuestFactory and extract customers
notif_res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) notif_res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
sample_customer_data
notif_from_res_guests = notif_res_guests.res_guest.profiles.profile_info.profile.customer )
retrieve_from_res_guests = retrieve_res_guests.res_guest.profiles.profile_info.profile.customer
notif_from_res_guests = (
notif_res_guests.res_guest.profiles.profile_info.profile.customer
)
retrieve_from_res_guests = (
retrieve_res_guests.res_guest.profiles.profile_info.profile.customer
)
# Compare customer names (structure should be identical) # Compare customer names (structure should be identical)
assert notif_customer.person_name.given_name == notif_from_res_guests.person_name.given_name assert (
assert notif_customer.person_name.surname == notif_from_res_guests.person_name.surname notif_customer.person_name.given_name
assert retrieve_customer.person_name.given_name == retrieve_from_res_guests.person_name.given_name == notif_from_res_guests.person_name.given_name
assert retrieve_customer.person_name.surname == retrieve_from_res_guests.person_name.surname )
assert (
def test_hotel_reservation_id_factories_produce_same_data(self, sample_hotel_reservation_id_data): notif_customer.person_name.surname
== notif_from_res_guests.person_name.surname
)
assert (
retrieve_customer.person_name.given_name
== retrieve_from_res_guests.person_name.given_name
)
assert (
retrieve_customer.person_name.surname
== retrieve_from_res_guests.person_name.surname
)
def test_hotel_reservation_id_factories_produce_same_data(
self, sample_hotel_reservation_id_data
):
"""Test that both HotelReservationId factories produce equivalent results.""" """Test that both HotelReservationId factories produce equivalent results."""
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) notif_reservation_id = (
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) HotelReservationIdFactory.create_notif_hotel_reservation_id(
sample_hotel_reservation_id_data
)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
sample_hotel_reservation_id_data
)
)
# Both should have the same field values # Both should have the same field values
assert notif_reservation_id.res_id_type == retrieve_reservation_id.res_id_type assert notif_reservation_id.res_id_type == retrieve_reservation_id.res_id_type
assert notif_reservation_id.res_id_value == retrieve_reservation_id.res_id_value assert notif_reservation_id.res_id_value == retrieve_reservation_id.res_id_value
assert notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source assert (
assert notif_reservation_id.res_id_source_context == retrieve_reservation_id.res_id_source_context notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source
)
assert (
notif_reservation_id.res_id_source_context
== retrieve_reservation_id.res_id_source_context
)
def test_complex_customer_workflow(self): def test_complex_customer_workflow(self):
"""Test a complex workflow with multiple operations.""" """Test a complex workflow with multiple operations."""
# Create original data # Create original data
@@ -559,7 +694,7 @@ class TestIntegration:
surname="Johnson", surname="Johnson",
phone_numbers=[ phone_numbers=[
("+1555123456", PhoneTechType.MOBILE), ("+1555123456", PhoneTechType.MOBILE),
("+1555654321", PhoneTechType.VOICE) ("+1555654321", PhoneTechType.VOICE),
], ],
email_address="alice.johnson@company.com", email_address="alice.johnson@company.com",
email_newsletter=False, email_newsletter=False,
@@ -569,22 +704,24 @@ class TestIntegration:
country_code="CA", country_code="CA",
address_catalog=True, address_catalog=True,
gender="Female", gender="Female",
language="fr" language="fr",
) )
# Create ResGuests for both types # Create ResGuests for both types
notif_res_guests = ResGuestFactory.create_notif_res_guests(original_data) notif_res_guests = ResGuestFactory.create_notif_res_guests(original_data)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(original_data) retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(original_data)
# Extract data back from both # Extract data back from both
notif_extracted = ResGuestFactory.extract_primary_customer(notif_res_guests) notif_extracted = ResGuestFactory.extract_primary_customer(notif_res_guests)
retrieve_extracted = ResGuestFactory.extract_primary_customer(retrieve_res_guests) retrieve_extracted = ResGuestFactory.extract_primary_customer(
retrieve_res_guests
)
# All should be equal # All should be equal
assert original_data == notif_extracted assert original_data == notif_extracted
assert original_data == retrieve_extracted assert original_data == retrieve_extracted
assert notif_extracted == retrieve_extracted assert notif_extracted == retrieve_extracted
def test_complex_hotel_reservation_id_workflow(self): def test_complex_hotel_reservation_id_workflow(self):
"""Test a complex workflow with HotelReservationId operations.""" """Test a complex workflow with HotelReservationId operations."""
# Create original reservation ID data # Create original reservation ID data
@@ -592,18 +729,30 @@ class TestIntegration:
res_id_type="456", res_id_type="456",
res_id_value="COMPLEX-RESERVATION-789", res_id_value="COMPLEX-RESERVATION-789",
res_id_source="INTEGRATION_SYSTEM", res_id_source="INTEGRATION_SYSTEM",
res_id_source_context="API_CALL" res_id_source_context="API_CALL",
) )
# Create HotelReservationId for both types # Create HotelReservationId for both types
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data) notif_reservation_id = (
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(original_data) HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
original_data
)
)
# Extract data back from both # Extract data back from both
notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(notif_reservation_id) notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(
retrieve_extracted = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(retrieve_reservation_id) notif_reservation_id
)
retrieve_extracted = (
HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
retrieve_reservation_id
)
)
# All should be equal # All should be equal
assert original_data == notif_extracted assert original_data == notif_extracted
assert original_data == retrieve_extracted assert original_data == retrieve_extracted
assert notif_extracted == retrieve_extracted assert notif_extracted == retrieve_extracted

View File

@@ -6,24 +6,31 @@ Test the handshake functionality with the real AlpineBits sample file.
import asyncio import asyncio
from alpine_bits_python.alpinebits_server import AlpineBitsServer from alpine_bits_python.alpinebits_server import AlpineBitsServer
async def main(): async def main():
print("🔄 Testing AlpineBits Handshake with Sample File") print("🔄 Testing AlpineBits Handshake with Sample File")
print("=" * 60) print("=" * 60)
# Create server instance # Create server instance
server = AlpineBitsServer() server = AlpineBitsServer()
# Read the sample handshake request # Read the sample handshake request
with open("AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml", "r") as f: with open(
"AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml",
"r",
) as f:
ping_request_xml = f.read() ping_request_xml = f.read()
print("📤 Sending handshake request...") print("📤 Sending handshake request...")
# Handle the ping request # Handle the ping request
response = await server.handle_request("OTA_Ping:Handshaking", ping_request_xml, "2024-10") response = await server.handle_request(
"OTA_Ping:Handshaking", ping_request_xml, "2024-10"
)
print(f"\n📥 Response Status: {response.status_code}") print(f"\n📥 Response Status: {response.status_code}")
print(f"📄 Response XML:\n{response.xml_content}") print(f"📄 Response XML:\n{response.xml_content}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())