Got db saving working

This commit is contained in:
Jonas Linter
2025-09-29 13:56:34 +02:00
parent 384fb2b558
commit 06739ebea9
21 changed files with 1188 additions and 830 deletions

View File

@@ -1,4 +1,5 @@
"""Entry point for alpine_bits_python package.""" """Entry point for alpine_bits_python package."""
from .main import main from .main import main
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -23,13 +23,13 @@ from xsdata_pydantic.bindings import XmlParser
class HttpStatusCode(IntEnum): class HttpStatusCode(IntEnum):
"""Allowed HTTP status codes for AlpineBits responses.""" """Allowed HTTP status codes for AlpineBits responses."""
OK = 200 OK = 200
BAD_REQUEST = 400 BAD_REQUEST = 400
UNAUTHORIZED = 401 UNAUTHORIZED = 401
INTERNAL_SERVER_ERROR = 500 INTERNAL_SERVER_ERROR = 500
class AlpineBitsActionName(Enum): class AlpineBitsActionName(Enum):
"""Enum for AlpineBits action names with capability and request name mappings.""" """Enum for AlpineBits action names with capability and request name mappings."""
@@ -37,27 +37,43 @@ class AlpineBitsActionName(Enum):
OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking") OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking")
OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests") OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests")
OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif") OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif")
OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = ("action_OTA_HotelResNotif_GuestRequests", OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = (
"OTA_HotelResNotif:GuestRequests") "action_OTA_HotelResNotif_GuestRequests",
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = ("action_OTA_HotelDescriptiveContentNotif_Inventory", "OTA_HotelResNotif:GuestRequests",
"OTA_HotelDescriptiveContentNotif:Inventory") )
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = ("action_OTA_HotelDescriptiveContentNotif_Info", OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = (
"OTA_HotelDescriptiveContentNotif:Info") "action_OTA_HotelDescriptiveContentNotif_Inventory",
OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = ("action_OTA_HotelDescriptiveInfo_Inventory", "OTA_HotelDescriptiveContentNotif:Inventory",
"OTA_HotelDescriptiveInfo:Inventory") )
OTA_HOTEL_DESCRIPTIVE_INFO_INFO = ("action_OTA_HotelDescriptiveInfo_Info", OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = (
"OTA_HotelDescriptiveInfo:Info") "action_OTA_HotelDescriptiveContentNotif_Info",
OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = ("action_OTA_HotelRatePlanNotif_RatePlans", "OTA_HotelDescriptiveContentNotif:Info",
"OTA_HotelRatePlanNotif:RatePlans") )
OTA_HOTEL_RATE_PLAN_BASE_RATES = ("action_OTA_HotelRatePlan_BaseRates", OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = (
"OTA_HotelRatePlan:BaseRates") "action_OTA_HotelDescriptiveInfo_Inventory",
"OTA_HotelDescriptiveInfo:Inventory",
)
OTA_HOTEL_DESCRIPTIVE_INFO_INFO = (
"action_OTA_HotelDescriptiveInfo_Info",
"OTA_HotelDescriptiveInfo:Info",
)
OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = (
"action_OTA_HotelRatePlanNotif_RatePlans",
"OTA_HotelRatePlanNotif:RatePlans",
)
OTA_HOTEL_RATE_PLAN_BASE_RATES = (
"action_OTA_HotelRatePlan_BaseRates",
"OTA_HotelRatePlan:BaseRates",
)
def __init__(self, capability_name: str, request_name: str): def __init__(self, capability_name: str, request_name: str):
self.capability_name = capability_name self.capability_name = capability_name
self.request_name = request_name self.request_name = request_name
@classmethod @classmethod
def get_by_capability_name(cls, capability_name: str) -> Optional['AlpineBitsActionName']: def get_by_capability_name(
cls, capability_name: str
) -> Optional["AlpineBitsActionName"]:
"""Get action enum by capability name.""" """Get action enum by capability name."""
for action in cls: for action in cls:
if action.capability_name == capability_name: if action.capability_name == capability_name:
@@ -65,7 +81,7 @@ class AlpineBitsActionName(Enum):
return None return None
@classmethod @classmethod
def get_by_request_name(cls, request_name: str) -> Optional['AlpineBitsActionName']: def get_by_request_name(cls, request_name: str) -> Optional["AlpineBitsActionName"]:
"""Get action enum by request name.""" """Get action enum by request name."""
for action in cls: for action in cls:
if action.request_name == request_name: if action.request_name == request_name:
@@ -75,22 +91,25 @@ class AlpineBitsActionName(Enum):
class Version(str, Enum): class Version(str, Enum):
"""Enum for AlpineBits versions.""" """Enum for AlpineBits versions."""
V2024_10 = "2024-10" V2024_10 = "2024-10"
V2022_10 = "2022-10" V2022_10 = "2022-10"
# Add other versions as needed # Add other versions as needed
@dataclass @dataclass
class AlpineBitsResponse: class AlpineBitsResponse:
"""Response data structure for AlpineBits actions.""" """Response data structure for AlpineBits actions."""
xml_content: str xml_content: str
status_code: HttpStatusCode = HttpStatusCode.OK status_code: HttpStatusCode = HttpStatusCode.OK
def __post_init__(self): def __post_init__(self):
"""Validate that status code is one of the allowed values.""" """Validate that status code is one of the allowed values."""
if self.status_code not in [200, 400, 401, 500]: if self.status_code not in [200, 400, 401, 500]:
raise ValueError(f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500") raise ValueError(
f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500"
)
# Abstract base class for AlpineBits Action # Abstract base class for AlpineBits Action
@@ -98,9 +117,13 @@ class AlpineBitsAction(ABC):
"""Abstract base class for handling AlpineBits actions.""" """Abstract base class for handling AlpineBits actions."""
name: AlpineBitsActionName name: AlpineBitsActionName
version: Version | list[Version] # list of versions in case action supports multiple versions version: (
Version | list[Version]
) # list of versions in case action supports multiple versions
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
""" """
Handle the incoming request XML and return response XML. Handle the incoming request XML and return response XML.
@@ -132,10 +155,6 @@ class AlpineBitsAction(ABC):
return version == self.version return version == self.version
class ServerCapabilities: class ServerCapabilities:
""" """
Automatically discovers AlpineBitsAction implementations and generates capabilities. Automatically discovers AlpineBitsAction implementations and generates capabilities.
@@ -151,14 +170,15 @@ class ServerCapabilities:
current_module = inspect.getmodule(self) current_module = inspect.getmodule(self)
for name, obj in inspect.getmembers(current_module): for name, obj in inspect.getmembers(current_module):
if (inspect.isclass(obj) and if (
issubclass(obj, AlpineBitsAction) and inspect.isclass(obj)
obj != AlpineBitsAction): and issubclass(obj, AlpineBitsAction)
and obj != AlpineBitsAction
):
# Check if this action is actually implemented (not just returning default) # Check if this action is actually implemented (not just returning default)
if self._is_action_implemented(obj): if self._is_action_implemented(obj):
action_instance = obj() action_instance = obj()
if hasattr(action_instance, 'name'): if hasattr(action_instance, "name"):
# Use capability name for the registry key # Use capability name for the registry key
self.action_registry[action_instance.name.capability_name] = obj self.action_registry[action_instance.name.capability_name] = obj
@@ -168,11 +188,10 @@ class ServerCapabilities:
This is a simple check - in practice, you might want more sophisticated detection. This is a simple check - in practice, you might want more sophisticated detection.
""" """
# Check if the class has overridden the handle method # Check if the class has overridden the handle method
if 'handle' in action_class.__dict__: if "handle" in action_class.__dict__:
return True return True
return False return False
def create_capabilities_dict(self) -> None: def create_capabilities_dict(self) -> None:
""" """
Generate the capabilities dictionary based on discovered actions. Generate the capabilities dictionary based on discovered actions.
@@ -194,26 +213,20 @@ class ServerCapabilities:
version_str = version.value version_str = version.value
if version_str not in versions_dict: if version_str not in versions_dict:
versions_dict[version_str] = { versions_dict[version_str] = {"version": version_str, "actions": []}
"version": version_str,
"actions": []
}
action_dict = {"action": action_name} action_dict = {"action": action_name}
# Add supports field if the action has custom supports # Add supports field if the action has custom supports
if hasattr(action_instance, 'supports') and action_instance.supports: if hasattr(action_instance, "supports") and action_instance.supports:
action_dict["supports"] = action_instance.supports action_dict["supports"] = action_instance.supports
versions_dict[version_str]["actions"].append(action_dict) versions_dict[version_str]["actions"].append(action_dict)
self.capability_dict = {"versions": list(versions_dict.values())} self.capability_dict = {"versions": list(versions_dict.values())}
return None return None
def get_capabilities_dict(self) -> Dict: def get_capabilities_dict(self) -> Dict:
""" """
Get capabilities as a dictionary. Generates if not already created. Get capabilities as a dictionary. Generates if not already created.
@@ -234,22 +247,35 @@ class ServerCapabilities:
# Sample Action Implementations for demonstration # Sample Action Implementations for demonstration
class PingAction(AlpineBitsAction): class PingAction(AlpineBitsAction):
"""Implementation for OTA_Ping action (handshaking).""" """Implementation for OTA_Ping action (handshaking)."""
def __init__(self): def __init__(self):
self.name = AlpineBitsActionName.OTA_PING self.name = AlpineBitsActionName.OTA_PING
self.version = [Version.V2024_10, Version.V2022_10] # Supports multiple versions self.version = [
Version.V2024_10,
Version.V2022_10,
] # Supports multiple versions
async def handle(self, action: str, request_xml: str, version: Version, server_capabilities: None | ServerCapabilities = None) -> AlpineBitsResponse: async def handle(
self,
action: str,
request_xml: str,
version: Version,
server_capabilities: None | ServerCapabilities = None,
) -> AlpineBitsResponse:
"""Handle ping requests.""" """Handle ping requests."""
if request_xml is None: if request_xml is None:
return AlpineBitsResponse(f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST) return AlpineBitsResponse(
f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST
)
if server_capabilities is None: if server_capabilities is None:
return AlpineBitsResponse("Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR) return AlpineBitsResponse(
"Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR
)
# Parse the incoming request XML and extract EchoData # Parse the incoming request XML and extract EchoData
parser = XmlParser() parser = XmlParser()
@@ -259,7 +285,9 @@ class PingAction(AlpineBitsAction):
echo_data = json.loads(parsed_request.echo_data) echo_data = json.loads(parsed_request.echo_data)
except Exception as e: except Exception as e:
return AlpineBitsResponse(f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST) return AlpineBitsResponse(
f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST
)
# compare echo data with capabilities, create a dictionary containing the matching capabilities # compare echo data with capabilities, create a dictionary containing the matching capabilities
capabilities_dict = server_capabilities.get_capabilities_dict() capabilities_dict = server_capabilities.get_capabilities_dict()
@@ -273,20 +301,25 @@ class PingAction(AlpineBitsAction):
for server_version in capabilities_dict["versions"]: for server_version in capabilities_dict["versions"]:
if server_version["version"] == client_version_str: if server_version["version"] == client_version_str:
# Found a matching version, now find common actions # Found a matching version, now find common actions
matching_version = { matching_version = {"version": client_version_str, "actions": []}
"version": client_version_str,
"actions": []
}
# Get client's requested actions for this version # Get client's requested actions for this version
client_actions = {action.get("action", ""): action for action in client_version.get("actions", [])} client_actions = {
server_actions = {action.get("action", ""): action for action in server_version.get("actions", [])} action.get("action", ""): action
for action in client_version.get("actions", [])
}
server_actions = {
action.get("action", ""): action
for action in server_version.get("actions", [])
}
# Find common actions # Find common actions
for action_name in client_actions: for action_name in client_actions:
if action_name in server_actions: if action_name in server_actions:
# Use server's action definition (includes our supports) # Use server's action definition (includes our supports)
matching_version["actions"].append(server_actions[action_name]) matching_version["actions"].append(
server_actions[action_name]
)
# Only add version if there are common actions # Only add version if there are common actions
if matching_version["actions"]: if matching_version["actions"]:
@@ -298,15 +331,20 @@ class PingAction(AlpineBitsAction):
# Create successful ping response with matched capabilities # Create successful ping response with matched capabilities
capabilities_json = json.dumps(matching_capabilities, indent=2) capabilities_json = json.dumps(matching_capabilities, indent=2)
warning = OtaPingRs.Warnings.Warning(status=WarningStatus.ALPINEBITS_HANDSHAKE.value, type_value="11", content=[capabilities_json]) warning = OtaPingRs.Warnings.Warning(
status=WarningStatus.ALPINEBITS_HANDSHAKE.value,
type_value="11",
content=[capabilities_json],
)
warning_response = OtaPingRs.Warnings(warning=[warning]) warning_response = OtaPingRs.Warnings(warning=[warning])
response_ota_ping = OtaPingRs(version= "7.000", warnings=warning_response, echo_data=capabilities_json, success="") response_ota_ping = OtaPingRs(
version="7.000",
warnings=warning_response,
echo_data=capabilities_json,
success="",
)
config = SerializerConfig( config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8" pretty_print=True, xml_declaration=True, encoding="UTF-8"
@@ -314,10 +352,9 @@ class PingAction(AlpineBitsAction):
serializer = XmlSerializer(config=config) serializer = XmlSerializer(config=config)
response_xml = serializer.render(response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}) response_xml = serializer.render(
response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}
)
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
@@ -329,13 +366,15 @@ class ReadAction(AlpineBitsAction):
self.name = AlpineBitsActionName.OTA_READ self.name = AlpineBitsActionName.OTA_READ
self.version = [Version.V2024_10, Version.V2022_10] self.version = [Version.V2024_10, Version.V2022_10]
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""Handle read requests.""" """Handle read requests."""
response_xml = f'''<?xml version="1.0" encoding="UTF-8"?> response_xml = f"""<?xml version="1.0" encoding="UTF-8"?>
<OTA_ReadRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000"> <OTA_ReadRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000">
<Success/> <Success/>
<Data>Read operation successful for {version.value}</Data> <Data>Read operation successful for {version.value}</Data>
</OTA_ReadRS>''' </OTA_ReadRS>"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
@@ -349,15 +388,17 @@ class HotelAvailNotifAction(AlpineBitsAction):
"OTA_HotelAvailNotif_accept_rooms", "OTA_HotelAvailNotif_accept_rooms",
"OTA_HotelAvailNotif_accept_categories", "OTA_HotelAvailNotif_accept_categories",
"OTA_HotelAvailNotif_accept_deltas", "OTA_HotelAvailNotif_accept_deltas",
"OTA_HotelAvailNotif_accept_BookingThreshold" "OTA_HotelAvailNotif_accept_BookingThreshold",
] ]
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""Handle hotel availability notifications.""" """Handle hotel availability notifications."""
response_xml = '''<?xml version="1.0" encoding="UTF-8"?> response_xml = """<?xml version="1.0" encoding="UTF-8"?>
<OTA_HotelAvailNotifRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000"> <OTA_HotelAvailNotifRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000">
<Success/> <Success/>
</OTA_HotelAvailNotifRS>''' </OTA_HotelAvailNotifRS>"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
@@ -371,10 +412,6 @@ class GuestRequestsAction(AlpineBitsAction):
# Note: This class doesn't override the handle method, so it won't be discovered # Note: This class doesn't override the handle method, so it won't be discovered
class AlpineBitsServer: class AlpineBitsServer:
""" """
Asynchronous AlpineBits server for handling hotel data exchange requests. Asynchronous AlpineBits server for handling hotel data exchange requests.
@@ -402,7 +439,9 @@ class AlpineBitsServer:
"""Get server capabilities as JSON.""" """Get server capabilities as JSON."""
return self.capabilities.get_capabilities_json() return self.capabilities.get_capabilities_json()
async def handle_request(self, request_action_name: str, request_xml: str, version: str = "2024-10") -> AlpineBitsResponse: async def handle_request(
self, request_action_name: str, request_xml: str, version: str = "2024-10"
) -> AlpineBitsResponse:
""" """
Handle an incoming AlpineBits request by routing to appropriate action handler. Handle an incoming AlpineBits request by routing to appropriate action handler.
@@ -419,8 +458,7 @@ class AlpineBitsServer:
version_enum = Version(version) version_enum = Version(version)
except ValueError: except ValueError:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unsupported version {version}", f"Error: Unsupported version {version}", HttpStatusCode.BAD_REQUEST
HttpStatusCode.BAD_REQUEST
) )
# Find the action by request name # Find the action by request name
@@ -428,7 +466,7 @@ class AlpineBitsServer:
if not action_enum: if not action_enum:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unknown action {request_action_name}", f"Error: Unknown action {request_action_name}",
HttpStatusCode.BAD_REQUEST HttpStatusCode.BAD_REQUEST,
) )
# Check if we have an implementation for this action # Check if we have an implementation for this action
@@ -436,7 +474,7 @@ class AlpineBitsServer:
if capability_name not in self._action_instances: if capability_name not in self._action_instances:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Action {request_action_name} is not implemented", f"Error: Action {request_action_name} is not implemented",
HttpStatusCode.BAD_REQUEST HttpStatusCode.BAD_REQUEST,
) )
action_instance: AlpineBitsAction = self._action_instances[capability_name] action_instance: AlpineBitsAction = self._action_instances[capability_name]
@@ -445,24 +483,29 @@ class AlpineBitsServer:
if not await action_instance.check_version_supported(version_enum): if not await action_instance.check_version_supported(version_enum):
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Action {request_action_name} does not support version {version}", f"Error: Action {request_action_name} does not support version {version}",
HttpStatusCode.BAD_REQUEST HttpStatusCode.BAD_REQUEST,
) )
# Handle the request # Handle the request
try: try:
# Special case for ping action - pass server capabilities # Special case for ping action - pass server capabilities
if capability_name == "action_OTA_Ping": if capability_name == "action_OTA_Ping":
return await action_instance.handle(request_action_name, request_xml, version_enum, self.capabilities) return await action_instance.handle(
request_action_name, request_xml, version_enum, self.capabilities
)
else: else:
return await action_instance.handle(request_action_name, request_xml, version_enum) return await action_instance.handle(
request_action_name, request_xml, version_enum
)
except Exception as e: except Exception as e:
print(f"Error handling request {request_action_name}: {str(e)}") print(f"Error handling request {request_action_name}: {str(e)}")
# print stack trace for debugging # print stack trace for debugging
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Internal server error while processing {request_action_name}: {str(e)}", f"Error: Internal server error while processing {request_action_name}: {str(e)}",
HttpStatusCode.INTERNAL_SERVER_ERROR HttpStatusCode.INTERNAL_SERVER_ERROR,
) )
def get_supported_request_names(self) -> List[str]: def get_supported_request_names(self) -> List[str]:
@@ -474,7 +517,9 @@ class AlpineBitsServer:
request_names.append(action_enum.request_name) request_names.append(action_enum.request_name)
return sorted(request_names) return sorted(request_names)
def is_action_supported(self, request_action_name: str, version: str = None) -> bool: def is_action_supported(
self, request_action_name: str, version: str = None
) -> bool:
""" """
Check if a request action is supported. Check if a request action is supported.
@@ -524,7 +569,9 @@ async def main():
print(f"{capability_name} -> {action_class.__name__}") print(f"{capability_name} -> {action_class.__name__}")
print(f" Request name: {request_name}") print(f" Request name: {request_name}")
print(f"\n📊 Total Implemented Actions: {len(server.capabilities.get_supported_actions())}") print(
f"\n📊 Total Implemented Actions: {len(server.capabilities.get_supported_actions())}"
)
print("\n🔍 Generated Capabilities JSON:") print("\n🔍 Generated Capabilities JSON:")
print("-" * 30) print("-" * 30)
@@ -548,7 +595,7 @@ async def main():
("OTA_Read:GuestRequests", "2022-10"), ("OTA_Read:GuestRequests", "2022-10"),
("OTA_HotelAvailNotif", "2024-10"), ("OTA_HotelAvailNotif", "2024-10"),
("UnknownAction", "2024-10"), ("UnknownAction", "2024-10"),
("OTA_Ping:Handshaking", "unsupported-version") ("OTA_Ping:Handshaking", "unsupported-version"),
] ]
for request_name, version in test_cases: for request_name, version in test_cases:

View File

@@ -1,4 +1,15 @@
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, APIRouter, Form, File, UploadFile from fastapi import (
FastAPI,
HTTPException,
BackgroundTasks,
Request,
Depends,
APIRouter,
Form,
File,
UploadFile,
)
from fastapi.concurrency import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic
from .config_loader import load_config from .config_loader import load_config
@@ -12,7 +23,7 @@ from .rate_limit import (
custom_rate_limit_handler, custom_rate_limit_handler,
DEFAULT_RATE_LIMIT, DEFAULT_RATE_LIMIT,
WEBHOOK_RATE_LIMIT, WEBHOOK_RATE_LIMIT,
BURST_RATE_LIMIT BURST_RATE_LIMIT,
) )
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
import logging import logging
@@ -24,8 +35,14 @@ import gzip
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from .alpinebits_server import AlpineBitsServer, Version from .alpinebits_server import AlpineBitsServer, Version
import urllib.parse import urllib.parse
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from .db import get_async_session, Customer as DBCustomer, Reservation as DBReservation from .db import (
Base,
Customer as DBCustomer,
Reservation as DBReservation,
get_database_url,
)
# Configure logging # Configure logging
@@ -42,12 +59,36 @@ except Exception as e:
_LOGGER.error(f"Failed to load config: {str(e)}") _LOGGER.error(f"Failed to load config: {str(e)}")
config = {} config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup DB
DATABASE_URL = get_database_url(config)
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
app.state.engine = engine
app.state.async_sessionmaker = AsyncSessionLocal
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
_LOGGER.info("Database tables checked/created at startup.")
yield
# Optional: Dispose engine on shutdown
await engine.dispose()
async def get_async_session(request: Request):
async_sessionmaker = request.app.state.async_sessionmaker
async with async_sessionmaker() as session:
yield session
app = FastAPI( app = FastAPI(
title="Wix Form Handler API", title="Wix Form Handler API",
description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting", description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting",
version="1.0.0" version="1.0.0",
lifespan=lifespan
) )
# Create API router with /api prefix # Create API router with /api prefix
@@ -64,7 +105,7 @@ app.add_middleware(
"https://*.wix.com", "https://*.wix.com",
"https://*.wixstatic.com", "https://*.wixstatic.com",
"http://localhost:3000", # For development "http://localhost:3000", # For development
"http://localhost:8000" # For local testing "http://localhost:8000", # For local testing
], ],
allow_credentials=True, allow_credentials=True,
allow_methods=["GET", "POST"], allow_methods=["GET", "POST"],
@@ -78,16 +119,26 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None:
Add your business logic here. Add your business logic here.
""" """
try: try:
_LOGGER.info(f"Processing form submission: {submission_data.get('submissionId')}") _LOGGER.info(
f"Processing form submission: {submission_data.get('submissionId')}"
)
# Example processing - you can replace this with your actual logic # Example processing - you can replace this with your actual logic
form_name = submission_data.get('formName') form_name = submission_data.get("formName")
contact_email = submission_data.get('contact', {}).get('email') if submission_data.get('contact') else None contact_email = (
submission_data.get("contact", {}).get("email")
if submission_data.get("contact")
else None
)
# Extract form fields # Extract form fields
form_fields = {k: v for k, v in submission_data.items() if k.startswith('field:')} form_fields = {
k: v for k, v in submission_data.items() if k.startswith("field:")
}
_LOGGER.info(f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}") _LOGGER.info(
f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}"
)
# Here you could: # Here you could:
# - Save to database # - Save to database
@@ -99,6 +150,8 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None:
_LOGGER.error(f"Error processing form submission: {str(e)}") _LOGGER.error(f"Error processing form submission: {str(e)}")
@api_router.get("/") @api_router.get("/")
@limiter.limit(DEFAULT_RATE_LIMIT) @limiter.limit(DEFAULT_RATE_LIMIT)
async def root(request: Request): async def root(request: Request):
@@ -111,8 +164,8 @@ async def root(request: Request):
"rate_limits": { "rate_limits": {
"default": DEFAULT_RATE_LIMIT, "default": DEFAULT_RATE_LIMIT,
"webhook": WEBHOOK_RATE_LIMIT, "webhook": WEBHOOK_RATE_LIMIT,
"burst": BURST_RATE_LIMIT "burst": BURST_RATE_LIMIT,
} },
} }
@@ -126,11 +179,10 @@ async def health_check(request: Request):
"service": "wix-form-handler", "service": "wix-form-handler",
"version": "1.0.0", "version": "1.0.0",
"authentication": "enabled", "authentication": "enabled",
"rate_limiting": "enabled" "rate_limiting": "enabled",
} }
# Extracted business logic for handling Wix form submissions # Extracted business logic for handling Wix form submissions
async def process_wix_form_submission(request: Request, data: Dict[str, Any], db): async def process_wix_form_submission(request: Request, data: Dict[str, Any], db):
""" """
@@ -138,10 +190,9 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
""" """
timestamp = datetime.now().isoformat() timestamp = datetime.now().isoformat()
_LOGGER.info(f"Received Wix form data at {timestamp}") _LOGGER.info(f"Received Wix form data at {timestamp}")
#_LOGGER.info(f"Data keys: {list(data.keys())}") # _LOGGER.info(f"Data keys: {list(data.keys())}")
#_LOGGER.info(f"Full data: {json.dumps(data, indent=2)}") # _LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
log_entry = { log_entry = {
"timestamp": timestamp, "timestamp": timestamp,
"client_ip": request.client.host if request.client else "unknown", "client_ip": request.client.host if request.client else "unknown",
@@ -154,9 +205,13 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
if not os.path.exists(logs_dir): if not os.path.exists(logs_dir):
os.makedirs(logs_dir, mode=0o755, exist_ok=True) os.makedirs(logs_dir, mode=0o755, exist_ok=True)
stat_info = os.stat(logs_dir) stat_info = os.stat(logs_dir)
_LOGGER.info(f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}") _LOGGER.info(
f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}"
)
_LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}") _LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}")
log_filename = f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" log_filename = (
f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
with open(log_filename, "w", encoding="utf-8") as f: with open(log_filename, "w", encoding="utf-8") as f:
json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False) json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False)
file_stat = os.stat(log_filename) file_stat = os.stat(log_filename)
@@ -164,16 +219,10 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
_LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}") _LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}")
_LOGGER.info(f"Data logged to: {log_filename}") _LOGGER.info(f"Data logged to: {log_filename}")
data = data.get("data") # Handle nested "data" key if present data = data.get("data") # Handle nested "data" key if present
# save customer and reservation to DB # save customer and reservation to DB
contact_info = data.get("contact", {}) contact_info = data.get("contact", {})
first_name = contact_info.get("name", {}).get("first") first_name = contact_info.get("name", {}).get("first")
last_name = contact_info.get("name", {}).get("last") last_name = contact_info.get("name", {}).get("last")
@@ -193,8 +242,16 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
language = data.get("contact", {}).get("locale", "en")[:2] language = data.get("contact", {}).get("locale", "en")[:2]
# Dates # Dates
start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value") start_date = (
end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value") data.get("field:date_picker_a7c8")
or data.get("Anreisedatum")
or data.get("submissions", [{}])[1].get("value")
)
end_date = (
data.get("field:date_picker_7e65")
or data.get("Abreisedatum")
or data.get("submissions", [{}])[2].get("value")
)
# Room/guest info # Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2) num_adults = int(data.get("field:number_7cf5") or 2)
@@ -258,7 +315,7 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
end_date=date.fromisoformat(end_date) if end_date else None, end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults, num_adults=num_adults,
num_children=num_children, num_children=num_children,
children_ages=','.join(str(a) for a in children_ages), children_ages=",".join(str(a) for a in children_ages),
offer=offer, offer=offer,
utm_comment=utm_comment, utm_comment=utm_comment,
created_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc),
@@ -277,23 +334,21 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
await db.commit() await db.commit()
await db.refresh(db_reservation) await db.refresh(db_reservation)
return { return {
"status": "success", "status": "success",
"message": "Wix form data received successfully", "message": "Wix form data received successfully",
"received_keys": list(data.keys()), "received_keys": list(data.keys()),
"data_logged_to": log_filename, "data_logged_to": log_filename,
"timestamp": timestamp, "timestamp": timestamp,
"process_info": log_entry["process_info"], "note": "No authentication required for this endpoint",
"note": "No authentication required for this endpoint"
} }
@api_router.post("/webhook/wix-form") @api_router.post("/webhook/wix-form")
@webhook_limiter.limit(WEBHOOK_RATE_LIMIT) @webhook_limiter.limit(WEBHOOK_RATE_LIMIT)
async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)): async def handle_wix_form(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
""" """
Unified endpoint to handle Wix form submissions (test and production). Unified endpoint to handle Wix form submissions (test and production).
No authentication required for this endpoint. No authentication required for this endpoint.
@@ -304,16 +359,19 @@ async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Dep
_LOGGER.error(f"Error in handle_wix_form: {str(e)}") _LOGGER.error(f"Error in handle_wix_form: {str(e)}")
# log stacktrace # log stacktrace
import traceback import traceback
traceback_str = traceback.format_exc() traceback_str = traceback.format_exc()
_LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}") _LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Error processing Wix form data: {str(e)}"
detail=f"Error processing Wix form data: {str(e)}"
) )
@api_router.post("/webhook/wix-form/test") @api_router.post("/webhook/wix-form/test")
@limiter.limit(DEFAULT_RATE_LIMIT) @limiter.limit(DEFAULT_RATE_LIMIT)
async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session=Depends(get_async_session)): async def handle_wix_form_test(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
""" """
Test endpoint to verify the API is working with raw JSON data. Test endpoint to verify the API is working with raw JSON data.
No authentication required for testing purposes. No authentication required for testing purposes.
@@ -323,26 +381,21 @@ async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session
except Exception as e: except Exception as e:
_LOGGER.error(f"Error in handle_wix_form_test: {str(e)}") _LOGGER.error(f"Error in handle_wix_form_test: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail=f"Error processing test data: {str(e)}"
detail=f"Error processing test data: {str(e)}"
) )
@api_router.post("/admin/generate-api-key") @api_router.post("/admin/generate-api-key")
@limiter.limit("5/hour") # Very restrictive for admin operations @limiter.limit("5/hour") # Very restrictive for admin operations
async def generate_new_api_key( async def generate_new_api_key(
request: Request, request: Request, admin_key: str = Depends(validate_api_key)
admin_key: str = Depends(validate_api_key)
): ):
""" """
Admin endpoint to generate new API keys. Admin endpoint to generate new API keys.
Requires admin API key and is heavily rate limited. Requires admin API key and is heavily rate limited.
""" """
if admin_key != "admin-key": if admin_key != "admin-key":
raise HTTPException( raise HTTPException(status_code=403, detail="Admin access required")
status_code=403,
detail="Admin access required"
)
new_key = generate_api_key() new_key = generate_api_key()
_LOGGER.info(f"Generated new API key (requested by: {admin_key})") _LOGGER.info(f"Generated new API key (requested by: {admin_key})")
@@ -352,11 +405,13 @@ async def generate_new_api_key(
"message": "New API key generated", "message": "New API key generated",
"api_key": new_key, "api_key": new_key,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"note": "Store this key securely - it won't be shown again" "note": "Store this key securely - it won't be shown again",
} }
async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(security_basic)) -> str: async def validate_basic_auth(
credentials: HTTPBasicCredentials = Depends(security_basic),
) -> str:
""" """
Validate basic authentication for AlpineBits protocol. Validate basic authentication for AlpineBits protocol.
Returns username if valid, raises HTTPException if not. Returns username if valid, raises HTTPException if not.
@@ -369,8 +424,11 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
headers={"WWW-Authenticate": "Basic"}, headers={"WWW-Authenticate": "Basic"},
) )
valid = False valid = False
for entry in config['alpine_bits_auth']: for entry in config["alpine_bits_auth"]:
if credentials.username == entry['username'] and credentials.password == entry['password']: if (
credentials.username == entry["username"]
and credentials.password == entry["password"]
):
valid = True valid = True
break break
if not valid: if not valid:
@@ -379,7 +437,9 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
detail="ERROR: Invalid credentials", detail="ERROR: Invalid credentials",
headers={"WWW-Authenticate": "Basic"}, headers={"WWW-Authenticate": "Basic"},
) )
_LOGGER.info(f"AlpineBits authentication successful for user: {credentials.username} (from config)") _LOGGER.info(
f"AlpineBits authentication successful for user: {credentials.username} (from config)"
)
return credentials.username return credentials.username
@@ -390,8 +450,7 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
""" """
if "multipart/form-data" not in content_type: if "multipart/form-data" not in content_type:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="ERROR: Content-Type must be multipart/form-data"
detail="ERROR: Content-Type must be multipart/form-data"
) )
# Extract boundary # Extract boundary
@@ -404,8 +463,7 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
if not boundary: if not boundary:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="ERROR: Missing boundary in multipart/form-data"
detail="ERROR: Missing boundary in multipart/form-data"
) )
# Simple multipart parsing # Simple multipart parsing
@@ -422,37 +480,32 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
content = content.rstrip(b"\r\n") content = content.rstrip(b"\r\n")
# Parse Content-Disposition header # Parse Content-Disposition header
headers = headers_section.decode('utf-8', errors='ignore') headers = headers_section.decode("utf-8", errors="ignore")
name = None name = None
for line in headers.split('\n'): for line in headers.split("\n"):
if 'Content-Disposition' in line and 'name=' in line: if "Content-Disposition" in line and "name=" in line:
# Extract name parameter # Extract name parameter
for param in line.split(';'): for param in line.split(";"):
param = param.strip() param = param.strip()
if param.startswith('name='): if param.startswith("name="):
name = param.split('=', 1)[1].strip('"') name = param.split("=", 1)[1].strip('"')
break break
if name: if name:
# Handle file uploads or text content # Handle file uploads or text content
if content.startswith(b'<'): if content.startswith(b"<"):
# Likely XML content # Likely XML content
data[name] = content.decode('utf-8', errors='ignore') data[name] = content.decode("utf-8", errors="ignore")
else: else:
data[name] = content.decode('utf-8', errors='ignore') data[name] = content.decode("utf-8", errors="ignore")
return data return data
@api_router.post("/alpinebits/server-2024-10") @api_router.post("/alpinebits/server-2024-10")
@limiter.limit("60/minute") @limiter.limit("60/minute")
async def alpinebits_server_handshake( async def alpinebits_server_handshake(
request: Request, request: Request, username: str = Depends(validate_basic_auth)
username: str = Depends(validate_basic_auth)
): ):
""" """
AlpineBits server endpoint implementing the handshake protocol. AlpineBits server endpoint implementing the handshake protocol.
@@ -471,11 +524,15 @@ async def alpinebits_server_handshake(
""" """
try: try:
# Check required headers # Check required headers
client_protocol_version = request.headers.get("X-AlpineBits-ClientProtocolVersion") client_protocol_version = request.headers.get(
"X-AlpineBits-ClientProtocolVersion"
)
if not client_protocol_version: if not client_protocol_version:
# Server concludes client speaks a protocol version preceding 2013-04 # Server concludes client speaks a protocol version preceding 2013-04
client_protocol_version = "pre-2013-04" client_protocol_version = "pre-2013-04"
_LOGGER.info("No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04") _LOGGER.info(
"No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04"
)
else: else:
_LOGGER.info(f"Client protocol version: {client_protocol_version}") _LOGGER.info(f"Client protocol version: {client_protocol_version}")
@@ -503,21 +560,22 @@ async def alpinebits_server_handshake(
# Decompress if needed # Decompress if needed
if is_compressed: if is_compressed:
try: try:
body = gzip.decompress(body) body = gzip.decompress(body)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"ERROR: Failed to decompress gzip content: {str(e)}" detail=f"ERROR: Failed to decompress gzip content: {str(e)}",
) )
# Check content type (after decompression) # Check content type (after decompression)
if "multipart/form-data" not in content_type and "application/x-www-form-urlencoded" not in content_type: if (
"multipart/form-data" not in content_type
and "application/x-www-form-urlencoded" not in content_type
):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded" detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
) )
# Parse multipart data # Parse multipart data
@@ -527,7 +585,7 @@ async def alpinebits_server_handshake(
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}" detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}",
) )
elif "application/x-www-form-urlencoded" in content_type: elif "application/x-www-form-urlencoded" in content_type:
# Parse as urlencoded # Parse as urlencoded
@@ -535,29 +593,25 @@ async def alpinebits_server_handshake(
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded" detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
) )
# Check for required action parameter # Check for required action parameter
action = form_data.get("action") action = form_data.get("action")
if not action: if not action:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="ERROR: Missing required 'action' parameter"
detail="ERROR: Missing required 'action' parameter") )
_LOGGER.info(f"AlpineBits action: {action}") _LOGGER.info(f"AlpineBits action: {action}")
# Get optional request XML # Get optional request XML
request_xml = form_data.get("request") request_xml = form_data.get("request")
server = AlpineBitsServer() server = AlpineBitsServer()
version = Version.V2024_10 version = Version.V2024_10
# Create successful handshake response # Create successful handshake response
response = await server.handle_request(action, request_xml, version) response = await server.handle_request(action, request_xml, version)
@@ -567,42 +621,30 @@ async def alpinebits_server_handshake(
headers = { headers = {
"Content-Type": "application/xml; charset=utf-8", "Content-Type": "application/xml; charset=utf-8",
"X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support "X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support
"X-AlpineBits-Server-Version": "2024-10" "X-AlpineBits-Server-Version": "2024-10",
} }
return Response( return Response(
content=response_xml, content=response_xml, status_code=response.status_code, headers=headers
status_code=response.status_code,
headers=headers
) )
except HTTPException: except HTTPException:
# Re-raise HTTP exceptions (auth errors, etc.) # Re-raise HTTP exceptions (auth errors, etc.)
raise raise
except Exception as e: except Exception as e:
_LOGGER.error(f"Error in AlpineBits handshake: {str(e)}") _LOGGER.error(f"Error in AlpineBits handshake: {str(e)}")
raise HTTPException( raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
status_code=500,
detail=f"Internal server error: {str(e)}"
)
@api_router.get("/admin/stats") @api_router.get("/admin/stats")
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def get_api_stats( async def get_api_stats(request: Request, admin_key: str = Depends(validate_api_key)):
request: Request,
admin_key: str = Depends(validate_api_key)
):
""" """
Admin endpoint to get API usage statistics. Admin endpoint to get API usage statistics.
Requires admin API key. Requires admin API key.
""" """
if admin_key != "admin-key": if admin_key != "admin-key":
raise HTTPException( raise HTTPException(status_code=403, detail="Admin access required")
status_code=403,
detail="Admin access required"
)
# In a real application, you'd fetch this from your database/monitoring system # In a real application, you'd fetch this from your database/monitoring system
return { return {
@@ -611,9 +653,9 @@ async def get_api_stats(
"uptime": "Available in production deployment", "uptime": "Available in production deployment",
"total_requests": "Available with monitoring setup", "total_requests": "Available with monitoring setup",
"active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]), "active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]),
"rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory" "rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory",
}, },
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat(),
} }
@@ -629,6 +671,7 @@ async def landing_page():
try: try:
# Get the path to the HTML file # Get the path to the HTML file
import os import os
html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html") html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html")
with open(html_path, "r", encoding="utf-8") as f: with open(html_path, "r", encoding="utf-8") as f:
@@ -660,4 +703,5 @@ async def landing_page():
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -21,7 +21,7 @@ security = HTTPBearer()
API_KEYS = { API_KEYS = {
# Example API keys - replace with your own secure keys # Example API keys - replace with your own secure keys
"wix-webhook-key": "sk_live_your_secure_api_key_here", "wix-webhook-key": "sk_live_your_secure_api_key_here",
"admin-key": "sk_admin_your_admin_key_here" "admin-key": "sk_admin_your_admin_key_here",
} }
# Load API keys from environment if available # Load API keys from environment if available
@@ -36,7 +36,9 @@ def generate_api_key() -> str:
return f"sk_live_{secrets.token_urlsafe(32)}" return f"sk_live_{secrets.token_urlsafe(32)}"
def validate_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: def validate_api_key(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> str:
""" """
Validate API key from Authorization header. Validate API key from Authorization header.
Expected format: Authorization: Bearer your_api_key_here Expected format: Authorization: Bearer your_api_key_here
@@ -67,14 +69,12 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool:
try: try:
# Remove 'sha256=' prefix if present # Remove 'sha256=' prefix if present
if signature.startswith('sha256='): if signature.startswith("sha256="):
signature = signature[7:] signature = signature[7:]
# Calculate expected signature # Calculate expected signature
expected_signature = hmac.new( expected_signature = hmac.new(
secret.encode('utf-8'), secret.encode("utf-8"), payload, hashlib.sha256
payload,
hashlib.sha256
).hexdigest() ).hexdigest()
# Compare signatures securely # Compare signatures securely

View File

@@ -1,4 +1,3 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -16,37 +15,45 @@ from annotatedyaml.loader import (
from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid
# --- Voluptuous schemas --- # --- Voluptuous schemas ---
database_schema = Schema({ database_schema = Schema({Required("url"): str}, extra=PREVENT_EXTRA)
Required('url'): str
}, extra=PREVENT_EXTRA)
hotel_auth_schema = Schema(
hotel_auth_schema = Schema({ {
Required("hotel_id"): str, Required("hotel_id"): str,
Required("hotel_name"): str, Required("hotel_name"): str,
Required("username"): str, Required("username"): str,
Required("password"): str Required("password"): str,
}, extra=PREVENT_EXTRA) },
extra=PREVENT_EXTRA,
basic_auth_schema = Schema(
All([hotel_auth_schema], Length(min=1))
) )
config_schema = Schema({ basic_auth_schema = Schema(All([hotel_auth_schema], Length(min=1)))
Required('database'): database_schema,
Required('alpine_bits_auth'): basic_auth_schema
}, extra=PREVENT_EXTRA)
DEFAULT_CONFIG_FILE = 'config.yaml' config_schema = Schema(
{
Required("database"): database_schema,
Required("alpine_bits_auth"): basic_auth_schema,
},
extra=PREVENT_EXTRA,
)
DEFAULT_CONFIG_FILE = "config.yaml"
class Config: class Config:
def __init__(self, config_folder: str | Path = None, config_name: str = DEFAULT_CONFIG_FILE, testing_mode: bool = False): def __init__(
self,
config_folder: str | Path = None,
config_name: str = DEFAULT_CONFIG_FILE,
testing_mode: bool = False,
):
if config_folder is None: if config_folder is None:
config_folder = os.environ.get('ALPINEBITS_CONFIG_DIR') config_folder = os.environ.get("ALPINEBITS_CONFIG_DIR")
if not config_folder: if not config_folder:
config_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../config')) config_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../config")
)
if isinstance(config_folder, str): if isinstance(config_folder, str):
config_folder = Path(config_folder) config_folder = Path(config_folder)
self.config_folder = config_folder self.config_folder = config_folder
@@ -61,8 +68,8 @@ class Config:
validated = config_schema(stuff) validated = config_schema(stuff)
except MultipleInvalid as e: except MultipleInvalid as e:
raise ValueError(f"Config validation error: {e}") raise ValueError(f"Config validation error: {e}")
self.database = validated['database'] self.database = validated["database"]
self.basic_auth = validated['alpine_bits_auth'] self.basic_auth = validated["alpine_bits_auth"]
self.config = validated self.config = validated
def get(self, key, default=None): def get(self, key, default=None):
@@ -70,19 +77,20 @@ class Config:
@property @property
def db_url(self) -> str: def db_url(self) -> str:
return self.database['url'] return self.database["url"]
@property @property
def hotel_id(self) -> str: def hotel_id(self) -> str:
return self.basic_auth['hotel_id'] return self.basic_auth["hotel_id"]
@property @property
def hotel_name(self) -> str: def hotel_name(self) -> str:
return self.basic_auth['hotel_name'] return self.basic_auth["hotel_name"]
@property @property
def users(self) -> List[Dict[str, str]]: def users(self) -> List[Dict[str, str]]:
return self.basic_auth['users'] return self.basic_auth["users"]
# For backward compatibility # For backward compatibility
def load_config(): def load_config():

View File

@@ -5,27 +5,24 @@ import os
Base = declarative_base() Base = declarative_base()
# Async SQLAlchemy setup # Async SQLAlchemy setup
def get_database_url(config=None): def get_database_url(config=None):
db_url = None db_url = None
if config and 'database' in config and 'url' in config['database']: if config and "database" in config and "url" in config["database"]:
db_url = config['database']['url'] db_url = config["database"]["url"]
if not db_url: if not db_url:
db_url = os.environ.get('DATABASE_URL') db_url = os.environ.get("DATABASE_URL")
if not db_url: if not db_url:
db_url = 'sqlite+aiosqlite:///alpinebits.db' db_url = "sqlite+aiosqlite:///alpinebits.db"
return db_url return db_url
DATABASE_URL = get_database_url()
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_async_session():
async with AsyncSessionLocal() as session:
yield session
class Customer(Base): class Customer(Base):
__tablename__ = 'customers' __tablename__ = "customers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
given_name = Column(String) given_name = Column(String)
contact_id = Column(String, unique=True) contact_id = Column(String, unique=True)
@@ -43,12 +40,13 @@ class Customer(Base):
language = Column(String) language = Column(String)
address_catalog = Column(Boolean) # Added for XML address_catalog = Column(Boolean) # Added for XML
name_title = Column(String) # Added for XML name_title = Column(String) # Added for XML
reservations = relationship('Reservation', back_populates='customer') reservations = relationship("Reservation", back_populates="customer")
class Reservation(Base): class Reservation(Base):
__tablename__ = 'reservations' __tablename__ = "reservations"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
customer_id = Column(Integer, ForeignKey('customers.id')) customer_id = Column(Integer, ForeignKey("customers.id"))
form_id = Column(String, unique=True) form_id = Column(String, unique=True)
start_date = Column(Date) start_date = Column(Date)
end_date = Column(Date) end_date = Column(Date)
@@ -70,16 +68,14 @@ class Reservation(Base):
# Add hotel_code and hotel_name for XML # Add hotel_code and hotel_name for XML
hotel_code = Column(String) hotel_code = Column(String)
hotel_name = Column(String) hotel_name = Column(String)
customer = relationship('Customer', back_populates='reservations') customer = relationship("Customer", back_populates="reservations")
class HashedCustomer(Base): class HashedCustomer(Base):
__tablename__ = 'hashed_customers' __tablename__ = "hashed_customers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
customer_id = Column(Integer) customer_id = Column(Integer)
hashed_email = Column(String) hashed_email = Column(String)
hashed_phone = Column(String) hashed_phone = Column(String)
hashed_name = Column(String) hashed_name = Column(String)
redacted_at = Column(DateTime) redacted_at = Column(DateTime)

View File

@@ -15,11 +15,16 @@ from .simplified_access import (
HotelReservationIdData, HotelReservationIdData,
PhoneTechType, PhoneTechType,
AlpineBitsFactory, AlpineBitsFactory,
OtaMessageType OtaMessageType,
) )
# DB and config # DB and config
from .db import Customer as DBCustomer, Reservation as DBReservation, HashedCustomer, get_async_session from .db import (
Customer as DBCustomer,
Reservation as DBReservation,
HashedCustomer,
get_async_session,
)
from .config_loader import load_config from .config_loader import load_config
import hashlib import hashlib
import json import json
@@ -29,8 +34,8 @@ import asyncio
from alpine_bits_python import db from alpine_bits_python import db
async def main():
async def main():
print("🚀 Starting AlpineBits XML generation script...") print("🚀 Starting AlpineBits XML generation script...")
# Load config (yaml, annotatedyaml) # Load config (yaml, annotatedyaml)
config = load_config() config = load_config()
@@ -40,9 +45,9 @@ async def main():
print(json.dumps(config, indent=2)) print(json.dumps(config, indent=2))
# Ensure SQLite DB file exists if using SQLite # Ensure SQLite DB file exists if using SQLite
db_url = config.get('database', {}).get('url', '') db_url = config.get("database", {}).get("url", "")
if db_url.startswith('sqlite+aiosqlite:///'): if db_url.startswith("sqlite+aiosqlite:///"):
db_path = db_url.replace('sqlite+aiosqlite:///', '') db_path = db_url.replace("sqlite+aiosqlite:///", "")
db_path = os.path.abspath(db_path) db_path = os.path.abspath(db_path)
db_dir = os.path.dirname(db_path) db_dir = os.path.dirname(db_path)
if not os.path.exists(db_dir): if not os.path.exists(db_dir):
@@ -54,15 +59,17 @@ async def main():
# # Ensure DB schema is created (async) # # Ensure DB schema is created (async)
from .db import engine, Base from .db import engine, Base
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
async for db in get_async_session(): async for db in get_async_session():
# Load data from JSON file # Load data from JSON file
json_path = os.path.join(os.path.dirname(__file__), '../../test_data/wix_test_data_20250928_132611.json') json_path = os.path.join(
with open(json_path, 'r', encoding='utf-8') as f: os.path.dirname(__file__),
"../../test_data/wix_test_data_20250928_132611.json",
)
with open(json_path, "r", encoding="utf-8") as f:
wix_data = json.load(f) wix_data = json.load(f)
data = wix_data["data"]["data"] data = wix_data["data"]["data"]
@@ -85,8 +92,16 @@ async def main():
language = data.get("contact", {}).get("locale", "en")[:2] language = data.get("contact", {}).get("locale", "en")[:2]
# Dates # Dates
start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value") start_date = (
end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value") data.get("field:date_picker_a7c8")
or data.get("Anreisedatum")
or data.get("submissions", [{}])[1].get("value")
)
end_date = (
data.get("field:date_picker_7e65")
or data.get("Abreisedatum")
or data.get("submissions", [{}])[2].get("value")
)
# Room/guest info # Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2) num_adults = int(data.get("field:number_7cf5") or 2)
@@ -147,7 +162,7 @@ async def main():
end_date=date.fromisoformat(end_date) if end_date else None, end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults, num_adults=num_adults,
num_children=num_children, num_children=num_children,
children_ages=','.join(str(a) for a in children_ages), children_ages=",".join(str(a) for a in children_ages),
offer=offer, offer=offer,
utm_comment=utm_comment, utm_comment=utm_comment,
created_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc),
@@ -177,9 +192,19 @@ async def main():
def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
from .simplified_access import CustomerData, GuestCountsFactory, HotelReservationIdData, AlpineBitsFactory, OtaMessageType, CommentData, CommentsData, CommentListItemData from .simplified_access import (
CustomerData,
GuestCountsFactory,
HotelReservationIdData,
AlpineBitsFactory,
OtaMessageType,
CommentData,
CommentsData,
CommentListItemData,
)
from .generated import alpinebits as ab from .generated import alpinebits as ab
from datetime import datetime, timezone from datetime import datetime, timezone
# Prepare data for XML # Prepare data for XML
phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else [] phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else []
customer_data = CustomerData( customer_data = CustomerData(
@@ -200,11 +225,15 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
language=customer.language, language=customer.language,
) )
alpine_bits_factory = AlpineBitsFactory() alpine_bits_factory = AlpineBitsFactory()
res_guests = alpine_bits_factory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) res_guests = alpine_bits_factory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
# Guest counts # Guest counts
children_ages = [int(a) for a in reservation.children_ages.split(",") if a] children_ages = [int(a) for a in reservation.children_ages.split(",") if a]
guest_counts = GuestCountsFactory.create_retrieve_guest_counts(reservation.num_adults, children_ages) guest_counts = GuestCountsFactory.create_retrieve_guest_counts(
reservation.num_adults, children_ages
)
# UniqueID # UniqueID
unique_id = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId( unique_id = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId(
@@ -214,12 +243,14 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
# TimeSpan # TimeSpan
time_span = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan( time_span = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan(
start=reservation.start_date.isoformat() if reservation.start_date else None, start=reservation.start_date.isoformat() if reservation.start_date else None,
end=reservation.end_date.isoformat() if reservation.end_date else None end=reservation.end_date.isoformat() if reservation.end_date else None,
) )
room_stay = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay( room_stay = (
ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay(
time_span=time_span, time_span=time_span,
guest_counts=guest_counts, guest_counts=guest_counts,
) )
)
room_stays = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays( room_stays = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays(
room_stay=[room_stay], room_stay=[room_stay],
) )
@@ -231,7 +262,9 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
res_id_source=None, res_id_source=None,
res_id_source_context="99tales", res_id_source_context="99tales",
) )
hotel_res_id = alpine_bits_factory.create(hotel_res_id_data, OtaMessageType.RETRIEVE) hotel_res_id = alpine_bits_factory.create(
hotel_res_id_data, OtaMessageType.RETRIEVE
)
hotel_res_ids = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds( hotel_res_ids = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds(
hotel_reservation_id=[hotel_res_id] hotel_reservation_id=[hotel_res_id]
) )
@@ -244,32 +277,38 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
offer_comment = CommentData( offer_comment = CommentData(
name=ab.CommentName2.ADDITIONAL_INFO, name=ab.CommentName2.ADDITIONAL_INFO,
text="Angebot/Offerta", text="Angebot/Offerta",
list_items=[CommentListItemData( list_items=[
CommentListItemData(
value=reservation.offer, value=reservation.offer,
language=customer.language, language=customer.language,
list_item="1", list_item="1",
)], )
],
) )
comment = None comment = None
if reservation.user_comment: if reservation.user_comment:
comment = CommentData( comment = CommentData(
name=ab.CommentName2.CUSTOMER_COMMENT, name=ab.CommentName2.CUSTOMER_COMMENT,
text=reservation.user_comment, text=reservation.user_comment,
list_items=[CommentListItemData( list_items=[
CommentListItemData(
value="Landing page comment", value="Landing page comment",
language=customer.language, language=customer.language,
list_item="1", list_item="1",
)], )
],
) )
comments = [offer_comment, comment] if comment else [offer_comment] comments = [offer_comment, comment] if comment else [offer_comment]
comments_data = CommentsData(comments=comments) comments_data = CommentsData(comments=comments)
comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE) comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE)
res_global_info = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo( res_global_info = (
ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo(
hotel_reservation_ids=hotel_res_ids, hotel_reservation_ids=hotel_res_ids,
basic_property_info=basic_property_info, basic_property_info=basic_property_info,
comments=comments_xml, comments=comments_xml,
) )
)
hotel_reservation = ab.OtaResRetrieveRs.ReservationsList.HotelReservation( hotel_reservation = ab.OtaResRetrieveRs.ReservationsList.HotelReservation(
create_date_time=datetime.now(timezone.utc).isoformat(), create_date_time=datetime.now(timezone.utc).isoformat(),
@@ -293,6 +332,7 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("✅ Pydantic validation successful!") print("✅ Pydantic validation successful!")
from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.formats.dataclass.serializers.config import SerializerConfig
from xsdata_pydantic.bindings import XmlSerializer from xsdata_pydantic.bindings import XmlSerializer
config = SerializerConfig( config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8" pretty_print=True, xml_declaration=True, encoding="UTF-8"
) )
@@ -306,15 +346,18 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("\n📄 Generated XML:") print("\n📄 Generated XML:")
print(xml_string) print(xml_string)
from xsdata_pydantic.bindings import XmlParser from xsdata_pydantic.bindings import XmlParser
parser = XmlParser() parser = XmlParser()
with open("output.xml", "r", encoding="utf-8") as infile: with open("output.xml", "r", encoding="utf-8") as infile:
xml_content = infile.read() xml_content = infile.read()
parsed_result = parser.from_string(xml_content, ab.OtaResRetrieveRs) parsed_result = parser.from_string(xml_content, ab.OtaResRetrieveRs)
print("✅ Round-trip validation successful!") print("✅ Round-trip validation successful!")
print(f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}") print(
f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}"
)
except Exception as e: except Exception as e:
print(f"❌ Validation/Serialization failed: {e}") print(f"❌ Validation/Serialization failed: {e}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -5,18 +5,23 @@ from datetime import datetime
class AlpineBitsHandshakeRequest(BaseModel): class AlpineBitsHandshakeRequest(BaseModel):
"""Model for AlpineBits handshake request data""" """Model for AlpineBits handshake request data"""
action: str = Field(..., description="Action parameter, typically 'OTA_Ping:Handshaking'")
action: str = Field(
..., description="Action parameter, typically 'OTA_Ping:Handshaking'"
)
request_xml: Optional[str] = Field(None, description="XML request document") request_xml: Optional[str] = Field(None, description="XML request document")
class ContactName(BaseModel): class ContactName(BaseModel):
"""Contact name structure""" """Contact name structure"""
first: Optional[str] = None first: Optional[str] = None
last: Optional[str] = None last: Optional[str] = None
class ContactAddress(BaseModel): class ContactAddress(BaseModel):
"""Contact address structure""" """Contact address structure"""
street: Optional[str] = None street: Optional[str] = None
city: Optional[str] = None city: Optional[str] = None
state: Optional[str] = None state: Optional[str] = None
@@ -26,6 +31,7 @@ class ContactAddress(BaseModel):
class Contact(BaseModel): class Contact(BaseModel):
"""Contact information from Wix form""" """Contact information from Wix form"""
name: Optional[ContactName] = None name: Optional[ContactName] = None
email: Optional[str] = None email: Optional[str] = None
locale: Optional[str] = None locale: Optional[str] = None
@@ -43,12 +49,14 @@ class Contact(BaseModel):
class SubmissionPdf(BaseModel): class SubmissionPdf(BaseModel):
"""PDF submission structure""" """PDF submission structure"""
url: Optional[str] = None url: Optional[str] = None
filename: Optional[str] = None filename: Optional[str] = None
class WixFormSubmission(BaseModel): class WixFormSubmission(BaseModel):
"""Model for Wix form submission data""" """Model for Wix form submission data"""
formName: str formName: str
submissions: List[Dict[str, Any]] = Field(default_factory=list) submissions: List[Dict[str, Any]] = Field(default_factory=list)
submissionTime: str submissionTime: str

View File

@@ -16,6 +16,7 @@ BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
# Redis configuration for distributed rate limiting (optional) # Redis configuration for distributed rate limiting (optional)
REDIS_URL = os.getenv("REDIS_URL", None) REDIS_URL = os.getenv("REDIS_URL", None)
def get_remote_address_with_forwarded(request: Request): def get_remote_address_with_forwarded(request: Request):
""" """
Get client IP address, considering forwarded headers from proxies/load balancers Get client IP address, considering forwarded headers from proxies/load balancers
@@ -39,14 +40,16 @@ if REDIS_URL:
# Use Redis for distributed rate limiting (recommended for production) # Use Redis for distributed rate limiting (recommended for production)
try: try:
import redis import redis
redis_client = redis.from_url(REDIS_URL) redis_client = redis.from_url(REDIS_URL)
limiter = Limiter( limiter = Limiter(
key_func=get_remote_address_with_forwarded, key_func=get_remote_address_with_forwarded, storage_uri=REDIS_URL
storage_uri=REDIS_URL
) )
logger.info("Rate limiting initialized with Redis backend") logger.info("Rate limiting initialized with Redis backend")
except Exception as e: except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}. Using in-memory rate limiting.") logger.warning(
f"Failed to connect to Redis: {e}. Using in-memory rate limiting."
)
limiter = Limiter(key_func=get_remote_address_with_forwarded) limiter = Limiter(key_func=get_remote_address_with_forwarded)
else: else:
# Use in-memory rate limiting (fine for single instance) # Use in-memory rate limiting (fine for single instance)
@@ -77,10 +80,10 @@ def api_key_rate_limit_key(request: Request):
# Rate limiting decorators for different endpoint types # Rate limiting decorators for different endpoint types
webhook_limiter = Limiter( webhook_limiter = Limiter(
key_func=api_key_rate_limit_key, key_func=api_key_rate_limit_key, storage_uri=REDIS_URL if REDIS_URL else None
storage_uri=REDIS_URL if REDIS_URL else None
) )
# Custom rate limit exceeded handler # Custom rate limit exceeded handler
def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded): def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Custom handler for rate limit exceeded""" """Custom handler for rate limit exceeded"""

View File

@@ -1,7 +1,2 @@
def parse_form(form: dict): def parse_form(form: dict):
pass pass

View File

@@ -2,14 +2,21 @@
""" """
Startup script for the Wix Form Handler API Startup script for the Wix Form Handler API
""" """
import os
import uvicorn import uvicorn
from .api import app from .api import app
if __name__ == "__main__": if __name__ == "__main__":
db_path = "alpinebits.db" # Adjust path if needed
if os.path.exists(db_path):
os.remove(db_path)
print(f"Deleted database file: {db_path}")
uvicorn.run( uvicorn.run(
"alpine_bits_python.api:app", "alpine_bits_python.api:app",
host="0.0.0.0", host="0.0.0.0",
port=8080, port=8080,
reload=True, # Enable auto-reload during development reload=True, # Enable auto-reload during development
log_level="info" log_level="info",
) )

View File

@@ -2,6 +2,7 @@
""" """
Configuration and setup script for the Wix Form Handler API Configuration and setup script for the Wix Form Handler API
""" """
import os import os
import sys import sys
import secrets import secrets
@@ -11,6 +12,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from alpine_bits_python.auth import generate_api_key from alpine_bits_python.auth import generate_api_key
def generate_secure_keys(): def generate_secure_keys():
"""Generate secure API keys for the application""" """Generate secure API keys for the application"""
@@ -42,10 +44,12 @@ def generate_secure_keys():
# Optionally write to .env file # Optionally write to .env file
create_env = input("\n❓ Create .env file? (y/n): ").lower().strip() create_env = input("\n❓ Create .env file? (y/n): ").lower().strip()
if create_env == 'y': if create_env == "y":
# Create .env in the project root (two levels up from scripts) # Create .env in the project root (two levels up from scripts)
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') env_path = os.path.join(
with open(env_path, 'w') as f: os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env"
)
with open(env_path, "w") as f:
f.write(f"WIX_API_KEY={wix_api_key}\n") f.write(f"WIX_API_KEY={wix_api_key}\n")
f.write(f"ADMIN_API_KEY={admin_api_key}\n") f.write(f"ADMIN_API_KEY={admin_api_key}\n")
f.write(f"WIX_WEBHOOK_SECRET={webhook_secret}\n") f.write(f"WIX_WEBHOOK_SECRET={webhook_secret}\n")
@@ -61,9 +65,9 @@ def generate_secure_keys():
print("4. Optionally configure webhook signature with the secret above") print("4. Optionally configure webhook signature with the secret above")
return { return {
'wix_api_key': wix_api_key, "wix_api_key": wix_api_key,
'admin_api_key': admin_api_key, "admin_api_key": admin_api_key,
'webhook_secret': webhook_secret "webhook_secret": webhook_secret,
} }
@@ -74,10 +78,10 @@ def check_security_setup():
print("=" * 40) print("=" * 40)
# Check environment variables # Check environment variables
wix_key = os.getenv('WIX_API_KEY') wix_key = os.getenv("WIX_API_KEY")
admin_key = os.getenv('ADMIN_API_KEY') admin_key = os.getenv("ADMIN_API_KEY")
webhook_secret = os.getenv('WIX_WEBHOOK_SECRET') webhook_secret = os.getenv("WIX_WEBHOOK_SECRET")
redis_url = os.getenv('REDIS_URL') redis_url = os.getenv("REDIS_URL")
print("Environment Variables:") print("Environment Variables:")
print(f" WIX_API_KEY: {'✅ Set' if wix_key else '❌ Not set'}") print(f" WIX_API_KEY: {'✅ Set' if wix_key else '❌ Not set'}")
@@ -119,7 +123,9 @@ if __name__ == "__main__":
print("🔐 Wix Form Handler API - Security Setup") print("🔐 Wix Form Handler API - Security Setup")
print("=" * 50) print("=" * 50)
choice = input("Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): ").strip() choice = input(
"Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): "
).strip()
if choice == "1": if choice == "1":
generate_secure_keys() generate_secure_keys()

View File

@@ -2,6 +2,7 @@
""" """
Test script for the Secure Wix Form Handler API Test script for the Secure Wix Form Handler API
""" """
import asyncio import asyncio
import aiohttp import aiohttp
import json import json
@@ -30,7 +31,7 @@ SAMPLE_WIX_DATA = {
"submissionsLink": "https://www.wix.app/forms/test-form/submissions", "submissionsLink": "https://www.wix.app/forms/test-form/submissions",
"submissionPdf": { "submissionPdf": {
"url": "https://example.com/submission.pdf", "url": "https://example.com/submission.pdf",
"filename": "submission.pdf" "filename": "submission.pdf",
}, },
"formId": "test-form-789", "formId": "test-form-789",
"field:email_5139": "test@example.com", "field:email_5139": "test@example.com",
@@ -43,10 +44,7 @@ SAMPLE_WIX_DATA = {
"field:alter_kind_4": "12", "field:alter_kind_4": "12",
"field:long_answer_3524": "This is a long answer field with more details about the inquiry.", "field:long_answer_3524": "This is a long answer field with more details about the inquiry.",
"contact": { "contact": {
"name": { "name": {"first": "John", "last": "Doe"},
"first": "John",
"last": "Doe"
},
"email": "test@example.com", "email": "test@example.com",
"locale": "de", "locale": "de",
"company": "Test Company", "company": "Test Company",
@@ -57,13 +55,13 @@ SAMPLE_WIX_DATA = {
"street": "Test Street 123", "street": "Test Street 123",
"city": "Test City", "city": "Test City",
"country": "Germany", "country": "Germany",
"postalCode": "12345" "postalCode": "12345",
}, },
"jobTitle": "Manager", "jobTitle": "Manager",
"phone": "+1234567890", "phone": "+1234567890",
"createdDate": "2024-03-20T10:00:00.000Z", "createdDate": "2024-03-20T10:00:00.000Z",
"updatedDate": "2024-03-20T10:30:00.000Z" "updatedDate": "2024-03-20T10:30:00.000Z",
} },
} }
@@ -72,12 +70,12 @@ async def test_api():
headers_with_auth = { headers_with_auth = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {TEST_API_KEY}" "Authorization": f"Bearer {TEST_API_KEY}",
} }
admin_headers = { admin_headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {ADMIN_API_KEY}" "Authorization": f"Bearer {ADMIN_API_KEY}",
} }
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -105,11 +103,13 @@ async def test_api():
async with session.post( async with session.post(
f"{BASE_URL}/api/webhook/wix-form", f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA, json=SAMPLE_WIX_DATA,
headers={"Content-Type": "application/json"} headers={"Content-Type": "application/json"},
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 401: if response.status == 401:
print(f" ✅ Correctly rejected: {response.status} - {result.get('detail')}") print(
f" ✅ Correctly rejected: {response.status} - {result.get('detail')}"
)
else: else:
print(f" ❌ Unexpected response: {response.status} - {result}") print(f" ❌ Unexpected response: {response.status} - {result}")
except Exception as e: except Exception as e:
@@ -121,11 +121,13 @@ async def test_api():
async with session.post( async with session.post(
f"{BASE_URL}/api/webhook/wix-form", f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA, json=SAMPLE_WIX_DATA,
headers=headers_with_auth headers=headers_with_auth,
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 200: if response.status == 200:
print(f" ✅ Webhook success: {response.status} - {result.get('status')}") print(
f" ✅ Webhook success: {response.status} - {result.get('status')}"
)
else: else:
print(f" ❌ Webhook failed: {response.status} - {result}") print(f" ❌ Webhook failed: {response.status} - {result}")
except Exception as e: except Exception as e:
@@ -137,11 +139,13 @@ async def test_api():
async with session.post( async with session.post(
f"{BASE_URL}/api/webhook/wix-form/test", f"{BASE_URL}/api/webhook/wix-form/test",
json={"test": "data", "timestamp": datetime.now().isoformat()}, json={"test": "data", "timestamp": datetime.now().isoformat()},
headers=headers_with_auth headers=headers_with_auth,
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 200: if response.status == 200:
print(f" ✅ Test endpoint: {response.status} - {result.get('status')}") print(
f" ✅ Test endpoint: {response.status} - {result.get('status')}"
)
else: else:
print(f" ❌ Test endpoint failed: {response.status} - {result}") print(f" ❌ Test endpoint failed: {response.status} - {result}")
except Exception as e: except Exception as e:
@@ -152,13 +156,11 @@ async def test_api():
rate_limit_test_count = 0 rate_limit_test_count = 0
for i in range(5): for i in range(5):
try: try:
async with session.get( async with session.get(f"{BASE_URL}/api/health") as response:
f"{BASE_URL}/api/health"
) as response:
if response.status == 200: if response.status == 200:
rate_limit_test_count += 1 rate_limit_test_count += 1
elif response.status == 429: elif response.status == 429:
print(f" ✅ Rate limit triggered on request {i+1}") print(f" ✅ Rate limit triggered on request {i + 1}")
break break
except Exception as e: except Exception as e:
print(f" ❌ Rate limit test failed: {e}") print(f" ❌ Rate limit test failed: {e}")
@@ -171,14 +173,17 @@ async def test_api():
print("\n7. Testing admin stats endpoint...") print("\n7. Testing admin stats endpoint...")
try: try:
async with session.get( async with session.get(
f"{BASE_URL}/api/admin/stats", f"{BASE_URL}/api/admin/stats", headers=admin_headers
headers=admin_headers
) as response: ) as response:
result = await response.json() result = await response.json()
if response.status == 200: if response.status == 200:
print(f" ✅ Admin stats: {response.status} - {result.get('status')}") print(
f" ✅ Admin stats: {response.status} - {result.get('status')}"
)
elif response.status == 401: elif response.status == 401:
print(f" ⚠️ Admin access denied (API key not configured): {result.get('detail')}") print(
f" ⚠️ Admin access denied (API key not configured): {result.get('detail')}"
)
else: else:
print(f" ❌ Admin endpoint failed: {response.status} - {result}") print(f" ❌ Admin endpoint failed: {response.status} - {result}")
except Exception as e: except Exception as e:
@@ -189,8 +194,14 @@ if __name__ == "__main__":
print("🔒 Testing Secure Wix Form Handler API...") print("🔒 Testing Secure Wix Form Handler API...")
print("=" * 60) print("=" * 60)
print("📍 API URL:", BASE_URL) print("📍 API URL:", BASE_URL)
print("🔑 Using API Key:", TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY) print(
print("🔐 Using Admin Key:", ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY) "🔑 Using API Key:",
TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY,
)
print(
"🔐 Using Admin Key:",
ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY,
)
print("=" * 60) print("=" * 60)
print("Make sure the API is running with: python3 run_api.py") print("Make sure the API is running with: python3 run_api.py")
print("-" * 60) print("-" * 60)

View File

@@ -15,15 +15,26 @@ NotifHotelReservationId = OtaHotelResNotifRq.HotelReservations.HotelReservation.
RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId
# Define type aliases for Comments types # Define type aliases for Comments types
NotifComments = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments NotifComments = (
RetrieveComments = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments
NotifComment = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment )
RetrieveComment = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment RetrieveComments = (
OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments
)
NotifComment = (
OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment
)
RetrieveComment = (
OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment
)
# type aliases for GuestCounts # type aliases for GuestCounts
NotifGuestCounts = OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts NotifGuestCounts = (
RetrieveGuestCounts = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts
)
RetrieveGuestCounts = (
OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts
)
# phonetechtype enum 1,3,5 voice, fax, mobile # phonetechtype enum 1,3,5 voice, fax, mobile
@@ -42,6 +53,7 @@ class OtaMessageType(Enum):
@dataclass @dataclass
class KidsAgeData: class KidsAgeData:
"""Data class to hold information about children's ages.""" """Data class to hold information about children's ages."""
ages: list[int] ages: list[int]
@@ -77,9 +89,10 @@ class CustomerData:
class GuestCountsFactory: class GuestCountsFactory:
@staticmethod @staticmethod
def create_notif_guest_counts(adults: int, kids: Optional[list[int]] = None) -> NotifGuestCounts: def create_notif_guest_counts(
adults: int, kids: Optional[list[int]] = None
) -> NotifGuestCounts:
""" """
Create a GuestCounts object for OtaHotelResNotifRq. Create a GuestCounts object for OtaHotelResNotifRq.
:param adults: Number of adults :param adults: Number of adults
@@ -89,18 +102,23 @@ class GuestCountsFactory:
return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts) return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts)
@staticmethod @staticmethod
def create_retrieve_guest_counts(adults: int, kids: Optional[list[int]] = None) -> RetrieveGuestCounts: def create_retrieve_guest_counts(
adults: int, kids: Optional[list[int]] = None
) -> RetrieveGuestCounts:
""" """
Create a GuestCounts object for OtaResRetrieveRs. Create a GuestCounts object for OtaResRetrieveRs.
:param adults: Number of adults :param adults: Number of adults
:param kids: List of ages for each kid (optional) :param kids: List of ages for each kid (optional)
:return: GuestCounts instance :return: GuestCounts instance
""" """
return GuestCountsFactory._create_guest_counts(adults, kids, RetrieveGuestCounts) return GuestCountsFactory._create_guest_counts(
adults, kids, RetrieveGuestCounts
)
@staticmethod @staticmethod
def _create_guest_counts(adults: int, kids: Optional[list[int]], guest_counts_class: type) -> Any: def _create_guest_counts(
adults: int, kids: Optional[list[int]], guest_counts_class: type
) -> Any:
""" """
Internal method to create a GuestCounts object of the specified type. Internal method to create a GuestCounts object of the specified type.
:param adults: Number of adults :param adults: Number of adults
@@ -359,6 +377,7 @@ class HotelReservationIdFactory:
@dataclass @dataclass
class CommentListItemData: class CommentListItemData:
"""Simple data class to hold comment list item information.""" """Simple data class to hold comment list item information."""
value: str # The text content of the list item value: str # The text content of the list item
list_item: str # Numeric identifier (pattern: [0-9]+) list_item: str # Numeric identifier (pattern: [0-9]+)
language: str # Two-letter language code (pattern: [a-z][a-z]) language: str # Two-letter language code (pattern: [a-z][a-z])
@@ -367,6 +386,7 @@ class CommentListItemData:
@dataclass @dataclass
class CommentData: class CommentData:
"""Simple data class to hold comment information without nested type constraints.""" """Simple data class to hold comment information without nested type constraints."""
name: CommentName2 # Required: "included services", "customer comment", "additional info" name: CommentName2 # Required: "included services", "customer comment", "additional info"
text: Optional[str] = None # Optional text content text: Optional[str] = None # Optional text content
list_items: list[CommentListItemData] = None # Optional list items list_items: list[CommentListItemData] = None # Optional list items
@@ -379,6 +399,7 @@ class CommentData:
@dataclass @dataclass
class CommentsData: class CommentsData:
"""Simple data class to hold multiple comments (1-3 max).""" """Simple data class to hold multiple comments (1-3 max)."""
comments: list[CommentData] = None # 1-3 comments maximum comments: list[CommentData] = None # 1-3 comments maximum
def __post_init__(self): def __post_init__(self):
@@ -400,7 +421,9 @@ class CommentFactory:
return CommentFactory._create_comments(RetrieveComments, RetrieveComment, data) return CommentFactory._create_comments(RetrieveComments, RetrieveComment, data)
@staticmethod @staticmethod
def _create_comments(comments_class: type, comment_class: type, data: CommentsData) -> Any: def _create_comments(
comments_class: type, comment_class: type, data: CommentsData
) -> Any:
"""Internal method to create comments of the specified type.""" """Internal method to create comments of the specified type."""
comments_list = [] comments_list = []
@@ -411,15 +434,13 @@ class CommentFactory:
list_item = comment_class.ListItem( list_item = comment_class.ListItem(
value=item_data.value, value=item_data.value,
list_item=item_data.list_item, list_item=item_data.list_item,
language=item_data.language language=item_data.language,
) )
list_items.append(list_item) list_items.append(list_item)
# Create comment # Create comment
comment = comment_class( comment = comment_class(
name=comment_data.name, name=comment_data.name, text=comment_data.text, list_item=list_items
text=comment_data.text,
list_item=list_items
) )
comments_list.append(comment) comments_list.append(comment)
@@ -446,17 +467,17 @@ class CommentFactory:
list_items_data = [] list_items_data = []
if comment.list_item: if comment.list_item:
for list_item in comment.list_item: for list_item in comment.list_item:
list_items_data.append(CommentListItemData( list_items_data.append(
CommentListItemData(
value=list_item.value, value=list_item.value,
list_item=list_item.list_item, list_item=list_item.list_item,
language=list_item.language language=list_item.language,
)) )
)
# Extract comment data # Extract comment data
comment_data = CommentData( comment_data = CommentData(
name=comment.name, name=comment.name, text=comment.text, list_items=list_items_data
text=comment.text,
list_items=list_items_data
) )
comments_data_list.append(comment_data) comments_data_list.append(comment_data)
@@ -531,7 +552,10 @@ class AlpineBitsFactory:
"""Unified factory class for creating AlpineBits objects with a simple interface.""" """Unified factory class for creating AlpineBits objects with a simple interface."""
@staticmethod @staticmethod
def create(data: Union[CustomerData, HotelReservationIdData, CommentsData], message_type: OtaMessageType) -> Any: def create(
data: Union[CustomerData, HotelReservationIdData, CommentsData],
message_type: OtaMessageType,
) -> Any:
""" """
Create an AlpineBits object based on the data type and message type. Create an AlpineBits object based on the data type and message type.
@@ -552,7 +576,9 @@ class AlpineBitsFactory:
if message_type == OtaMessageType.NOTIF: if message_type == OtaMessageType.NOTIF:
return HotelReservationIdFactory.create_notif_hotel_reservation_id(data) return HotelReservationIdFactory.create_notif_hotel_reservation_id(data)
else: else:
return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(data) return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
data
)
elif isinstance(data, CommentsData): elif isinstance(data, CommentsData):
if message_type == OtaMessageType.NOTIF: if message_type == OtaMessageType.NOTIF:
@@ -564,7 +590,9 @@ class AlpineBitsFactory:
raise ValueError(f"Unsupported data type: {type(data)}") raise ValueError(f"Unsupported data type: {type(data)}")
@staticmethod @staticmethod
def create_res_guests(customer_data: CustomerData, message_type: OtaMessageType) -> Union[NotifResGuests, RetrieveResGuests]: def create_res_guests(
customer_data: CustomerData, message_type: OtaMessageType
) -> Union[NotifResGuests, RetrieveResGuests]:
""" """
Create a complete ResGuests structure with a primary customer. Create a complete ResGuests structure with a primary customer.
@@ -581,7 +609,9 @@ class AlpineBitsFactory:
return ResGuestFactory.create_retrieve_res_guests(customer_data) return ResGuestFactory.create_retrieve_res_guests(customer_data)
@staticmethod @staticmethod
def extract_data(obj: Any) -> Union[CustomerData, HotelReservationIdData, CommentsData]: def extract_data(
obj: Any,
) -> Union[CustomerData, HotelReservationIdData, CommentsData]:
""" """
Extract data from an AlpineBits object back to a simple data class. Extract data from an AlpineBits object back to a simple data class.
@@ -592,28 +622,28 @@ class AlpineBitsFactory:
The appropriate data object The appropriate data object
""" """
# Check if it's a Customer object # Check if it's a Customer object
if hasattr(obj, 'person_name') and hasattr(obj.person_name, 'given_name'): if hasattr(obj, "person_name") and hasattr(obj.person_name, "given_name"):
if isinstance(obj, NotifCustomer): if isinstance(obj, NotifCustomer):
return CustomerFactory.from_notif_customer(obj) return CustomerFactory.from_notif_customer(obj)
elif isinstance(obj, RetrieveCustomer): elif isinstance(obj, RetrieveCustomer):
return CustomerFactory.from_retrieve_customer(obj) return CustomerFactory.from_retrieve_customer(obj)
# Check if it's a HotelReservationId object # Check if it's a HotelReservationId object
elif hasattr(obj, 'res_id_type'): elif hasattr(obj, "res_id_type"):
if isinstance(obj, NotifHotelReservationId): if isinstance(obj, NotifHotelReservationId):
return HotelReservationIdFactory.from_notif_hotel_reservation_id(obj) return HotelReservationIdFactory.from_notif_hotel_reservation_id(obj)
elif isinstance(obj, RetrieveHotelReservationId): elif isinstance(obj, RetrieveHotelReservationId):
return HotelReservationIdFactory.from_retrieve_hotel_reservation_id(obj) return HotelReservationIdFactory.from_retrieve_hotel_reservation_id(obj)
# Check if it's a Comments object # Check if it's a Comments object
elif hasattr(obj, 'comment'): elif hasattr(obj, "comment"):
if isinstance(obj, NotifComments): if isinstance(obj, NotifComments):
return CommentFactory.from_notif_comments(obj) return CommentFactory.from_notif_comments(obj)
elif isinstance(obj, RetrieveComments): elif isinstance(obj, RetrieveComments):
return CommentFactory.from_retrieve_comments(obj) return CommentFactory.from_retrieve_comments(obj)
# Check if it's a ResGuests object # Check if it's a ResGuests object
elif hasattr(obj, 'res_guest'): elif hasattr(obj, "res_guest"):
return ResGuestFactory.extract_primary_customer(obj) return ResGuestFactory.extract_primary_customer(obj)
else: else:
@@ -744,16 +774,17 @@ if __name__ == "__main__":
print("=== HotelReservationId Creation ===") print("=== HotelReservationId Creation ===")
reservation_id_data = HotelReservationIdData( reservation_id_data = HotelReservationIdData(
res_id_type="123", res_id_type="123", res_id_value="RESERVATION-456", res_id_source="HOTEL_SYSTEM"
res_id_value="RESERVATION-456",
res_id_source="HOTEL_SYSTEM"
) )
notif_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.NOTIF) notif_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.NOTIF)
retrieve_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.RETRIEVE) retrieve_res_id = AlpineBitsFactory.create(
reservation_id_data, OtaMessageType.RETRIEVE
)
print("Created reservation IDs using unified factory") print("Created reservation IDs using unified factory")
print("=== Comments Creation ===") print("=== Comments Creation ===")
comments_data = CommentsData(comments=[ comments_data = CommentsData(
comments=[
CommentData( CommentData(
name=CommentName2.CUSTOMER_COMMENT, name=CommentName2.CUSTOMER_COMMENT,
text="This is a customer comment about the reservation", text="This is a customer comment about the reservation",
@@ -761,27 +792,30 @@ if __name__ == "__main__":
CommentListItemData( CommentListItemData(
value="Special dietary requirements: vegetarian", value="Special dietary requirements: vegetarian",
list_item="1", list_item="1",
language="en" language="en",
), ),
CommentListItemData( CommentListItemData(
value="Late arrival expected", value="Late arrival expected", list_item="2", language="en"
list_item="2", ),
language="en" ],
)
]
), ),
CommentData( CommentData(
name=CommentName2.ADDITIONAL_INFO, name=CommentName2.ADDITIONAL_INFO,
text="Additional information about the stay" text="Additional information about the stay",
),
]
) )
])
notif_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.NOTIF) notif_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.NOTIF)
retrieve_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.RETRIEVE) retrieve_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.RETRIEVE)
print("Created comments using unified factory") print("Created comments using unified factory")
print("=== ResGuests Creation ===") print("=== ResGuests Creation ===")
notif_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF) notif_res_guests = AlpineBitsFactory.create_res_guests(
retrieve_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) customer_data, OtaMessageType.NOTIF
)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
print("Created ResGuests using unified factory") print("Created ResGuests using unified factory")
print("=== Data Extraction ===") print("=== Data Extraction ===")

View File

@@ -1,4 +1,5 @@
"""Entry point for util package.""" """Entry point for util package."""
from .handshake_util import main from .handshake_util import main
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -2,26 +2,22 @@ from ..generated.alpinebits import OtaPingRq, OtaPingRs
from xsdata_pydantic.bindings import XmlParser from xsdata_pydantic.bindings import XmlParser
def main(): def main():
# test parsing a ping request sample # test parsing a ping request sample
path = "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml" path = (
"AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml"
)
with open( with open(path, "r", encoding="utf-8") as f:
path, "r", encoding="utf-8") as f:
xml = f.read() xml = f.read()
# Parse the XML into the request object # Parse the XML into the request object
# Test parsing back # Test parsing back
parser = XmlParser() parser = XmlParser()
parsed_result = parser.from_string(xml, OtaPingRs) parsed_result = parser.from_string(xml, OtaPingRs)
print(parsed_result.echo_data) print(parsed_result.echo_data)
@@ -34,19 +30,14 @@ def main():
print(warning.content[0]) print(warning.content[0])
# save json in echo_data to file with indents # save json in echo_data to file with indents
output_path = "echo_data_response.json" output_path = "echo_data_response.json"
with open(output_path, "w", encoding="utf-8") as out_f: with open(output_path, "w", encoding="utf-8") as out_f:
import json import json
json.dump(json.loads(parsed_result.echo_data), out_f, indent=4) json.dump(json.loads(parsed_result.echo_data), out_f, indent=4)
print(f"Saved echo_data json to {output_path}") print(f"Saved echo_data json to {output_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,11 +2,12 @@
""" """
Convenience launcher for the Wix Form Handler API Convenience launcher for the Wix Form Handler API
""" """
import os import os
import subprocess import subprocess
# Change to src directory # Change to src directory
src_dir = os.path.join(os.path.dirname(__file__), 'src/alpine_bits_python') src_dir = os.path.join(os.path.dirname(__file__), "src/alpine_bits_python")
# Run the API using uv # Run the API using uv
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -10,10 +10,11 @@ from alpine_bits_python.alpinebits_server import (
AlpineBitsActionName, AlpineBitsActionName,
Version, Version,
AlpineBitsResponse, AlpineBitsResponse,
HttpStatusCode HttpStatusCode,
) )
import asyncio import asyncio
class NewImplementedAction(AlpineBitsAction): class NewImplementedAction(AlpineBitsAction):
"""A new action that IS implemented.""" """A new action that IS implemented."""
@@ -21,10 +22,13 @@ class NewImplementedAction(AlpineBitsAction):
self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_INFO_INFO self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_INFO_INFO
self.version = Version.V2024_10 self.version = Version.V2024_10
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""This action is implemented.""" """This action is implemented."""
return AlpineBitsResponse("Implemented!", HttpStatusCode.OK) return AlpineBitsResponse("Implemented!", HttpStatusCode.OK)
class NewUnimplementedAction(AlpineBitsAction): class NewUnimplementedAction(AlpineBitsAction):
"""A new action that is NOT implemented (no handle override).""" """A new action that is NOT implemented (no handle override)."""
@@ -34,6 +38,7 @@ class NewUnimplementedAction(AlpineBitsAction):
# Notice: No handle method override - will use default "not implemented" # Notice: No handle method override - will use default "not implemented"
async def main(): async def main():
print("🔍 Testing Action Discovery Logic") print("🔍 Testing Action Discovery Logic")
print("=" * 50) print("=" * 50)
@@ -57,5 +62,6 @@ async def main():
result = await unimplemented_action.handle("test", "<xml/>", Version.V2024_10) result = await unimplemented_action.handle("test", "<xml/>", Version.V2024_10)
print(f"🔴 NewUnimplementedAction result: {result.xml_content}") print(f"🔴 NewUnimplementedAction result: {result.xml_content}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -4,7 +4,7 @@ import sys
import os import os
# Add the src directory to the path so we can import our modules # Add the src directory to the path so we can import our modules
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from simplified_access import ( from simplified_access import (
CustomerData, CustomerData,
@@ -20,7 +20,7 @@ from simplified_access import (
NotifResGuests, NotifResGuests,
RetrieveResGuests, RetrieveResGuests,
NotifHotelReservationId, NotifHotelReservationId,
RetrieveHotelReservationId RetrieveHotelReservationId,
) )
@@ -35,7 +35,7 @@ def sample_customer_data():
phone_numbers=[ phone_numbers=[
("+1234567890", PhoneTechType.MOBILE), ("+1234567890", PhoneTechType.MOBILE),
("+0987654321", PhoneTechType.VOICE), ("+0987654321", PhoneTechType.VOICE),
("+1111111111", None) ("+1111111111", None),
], ],
email_address="john.doe@example.com", email_address="john.doe@example.com",
email_newsletter=True, email_newsletter=True,
@@ -46,17 +46,14 @@ def sample_customer_data():
address_catalog=False, address_catalog=False,
gender="Male", gender="Male",
birth_date="1980-01-01", birth_date="1980-01-01",
language="en" language="en",
) )
@pytest.fixture @pytest.fixture
def minimal_customer_data(): def minimal_customer_data():
"""Fixture providing minimal customer data (only required fields).""" """Fixture providing minimal customer data (only required fields)."""
return CustomerData( return CustomerData(given_name="Jane", surname="Smith")
given_name="Jane",
surname="Smith"
)
@pytest.fixture @pytest.fixture
@@ -66,16 +63,14 @@ def sample_hotel_reservation_id_data():
res_id_type="123", res_id_type="123",
res_id_value="RESERVATION-456", res_id_value="RESERVATION-456",
res_id_source="HOTEL_SYSTEM", res_id_source="HOTEL_SYSTEM",
res_id_source_context="BOOKING_ENGINE" res_id_source_context="BOOKING_ENGINE",
) )
@pytest.fixture @pytest.fixture
def minimal_hotel_reservation_id_data(): def minimal_hotel_reservation_id_data():
"""Fixture providing minimal hotel reservation ID data (only required fields).""" """Fixture providing minimal hotel reservation ID data (only required fields)."""
return HotelReservationIdData( return HotelReservationIdData(res_id_type="999")
res_id_type="999"
)
class TestCustomerData: class TestCustomerData:
@@ -152,7 +147,9 @@ class TestCustomerFactory:
def test_create_customer_minimal(self, minimal_customer_data): def test_create_customer_minimal(self, minimal_customer_data):
"""Test creating customers with minimal data.""" """Test creating customers with minimal data."""
notif_customer = CustomerFactory.create_notif_customer(minimal_customer_data) notif_customer = CustomerFactory.create_notif_customer(minimal_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(minimal_customer_data) retrieve_customer = CustomerFactory.create_retrieve_customer(
minimal_customer_data
)
for customer in [notif_customer, retrieve_customer]: for customer in [notif_customer, retrieve_customer]:
assert customer.person_name.given_name == "Jane" assert customer.person_name.given_name == "Jane"
@@ -169,40 +166,64 @@ class TestCustomerFactory:
def test_email_newsletter_options(self): def test_email_newsletter_options(self):
"""Test different email newsletter options.""" """Test different email newsletter options."""
# Newsletter yes # Newsletter yes
data_yes = CustomerData(given_name="Test", surname="User", data_yes = CustomerData(
email_address="test@example.com", email_newsletter=True) given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=True,
)
customer = CustomerFactory.create_notif_customer(data_yes) customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.email.remark == "newsletter:yes" assert customer.email.remark == "newsletter:yes"
# Newsletter no # Newsletter no
data_no = CustomerData(given_name="Test", surname="User", data_no = CustomerData(
email_address="test@example.com", email_newsletter=False) given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=False,
)
customer = CustomerFactory.create_notif_customer(data_no) customer = CustomerFactory.create_notif_customer(data_no)
assert customer.email.remark == "newsletter:no" assert customer.email.remark == "newsletter:no"
# Newsletter not specified # Newsletter not specified
data_none = CustomerData(given_name="Test", surname="User", data_none = CustomerData(
email_address="test@example.com", email_newsletter=None) given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=None,
)
customer = CustomerFactory.create_notif_customer(data_none) customer = CustomerFactory.create_notif_customer(data_none)
assert customer.email.remark is None assert customer.email.remark is None
def test_address_catalog_options(self): def test_address_catalog_options(self):
"""Test different address catalog options.""" """Test different address catalog options."""
# Catalog no # Catalog no
data_no = CustomerData(given_name="Test", surname="User", data_no = CustomerData(
address_line="123 Street", address_catalog=False) given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=False,
)
customer = CustomerFactory.create_notif_customer(data_no) customer = CustomerFactory.create_notif_customer(data_no)
assert customer.address.remark == "catalog:no" assert customer.address.remark == "catalog:no"
# Catalog yes # Catalog yes
data_yes = CustomerData(given_name="Test", surname="User", data_yes = CustomerData(
address_line="123 Street", address_catalog=True) given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=True,
)
customer = CustomerFactory.create_notif_customer(data_yes) customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.address.remark == "catalog:yes" assert customer.address.remark == "catalog:yes"
# Catalog not specified # Catalog not specified
data_none = CustomerData(given_name="Test", surname="User", data_none = CustomerData(
address_line="123 Street", address_catalog=None) given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=None,
)
customer = CustomerFactory.create_notif_customer(data_none) customer = CustomerFactory.create_notif_customer(data_none)
assert customer.address.remark is None assert customer.address.remark is None
@@ -228,8 +249,8 @@ class TestCustomerFactory:
phone_numbers=[ phone_numbers=[
("+1111111111", PhoneTechType.VOICE), ("+1111111111", PhoneTechType.VOICE),
("+2222222222", PhoneTechType.FAX), ("+2222222222", PhoneTechType.FAX),
("+3333333333", PhoneTechType.MOBILE) ("+3333333333", PhoneTechType.MOBILE),
] ],
) )
customer = CustomerFactory.create_notif_customer(data) customer = CustomerFactory.create_notif_customer(data)
@@ -241,14 +262,20 @@ class TestCustomerFactory:
class TestHotelReservationIdData: class TestHotelReservationIdData:
"""Test the HotelReservationIdData dataclass.""" """Test the HotelReservationIdData dataclass."""
def test_hotel_reservation_id_data_creation_full(self, sample_hotel_reservation_id_data): def test_hotel_reservation_id_data_creation_full(
self, sample_hotel_reservation_id_data
):
"""Test creating HotelReservationIdData with all fields.""" """Test creating HotelReservationIdData with all fields."""
assert sample_hotel_reservation_id_data.res_id_type == "123" assert sample_hotel_reservation_id_data.res_id_type == "123"
assert sample_hotel_reservation_id_data.res_id_value == "RESERVATION-456" assert sample_hotel_reservation_id_data.res_id_value == "RESERVATION-456"
assert sample_hotel_reservation_id_data.res_id_source == "HOTEL_SYSTEM" assert sample_hotel_reservation_id_data.res_id_source == "HOTEL_SYSTEM"
assert sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE" assert (
sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE"
)
def test_hotel_reservation_id_data_creation_minimal(self, minimal_hotel_reservation_id_data): def test_hotel_reservation_id_data_creation_minimal(
self, minimal_hotel_reservation_id_data
):
"""Test creating HotelReservationIdData with only required fields.""" """Test creating HotelReservationIdData with only required fields."""
assert minimal_hotel_reservation_id_data.res_id_type == "999" assert minimal_hotel_reservation_id_data.res_id_type == "999"
assert minimal_hotel_reservation_id_data.res_id_value is None assert minimal_hotel_reservation_id_data.res_id_value is None
@@ -259,9 +286,13 @@ class TestHotelReservationIdData:
class TestHotelReservationIdFactory: class TestHotelReservationIdFactory:
"""Test the HotelReservationIdFactory class.""" """Test the HotelReservationIdFactory class."""
def test_create_notif_hotel_reservation_id_full(self, sample_hotel_reservation_id_data): def test_create_notif_hotel_reservation_id_full(
self, sample_hotel_reservation_id_data
):
"""Test creating a NotifHotelReservationId with full data.""" """Test creating a NotifHotelReservationId with full data."""
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
sample_hotel_reservation_id_data
)
assert isinstance(reservation_id, NotifHotelReservationId) assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
@@ -269,9 +300,13 @@ class TestHotelReservationIdFactory:
assert reservation_id.res_id_source == "HOTEL_SYSTEM" assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE" assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
def test_create_retrieve_hotel_reservation_id_full(self, sample_hotel_reservation_id_data): def test_create_retrieve_hotel_reservation_id_full(
self, sample_hotel_reservation_id_data
):
"""Test creating a RetrieveHotelReservationId with full data.""" """Test creating a RetrieveHotelReservationId with full data."""
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
sample_hotel_reservation_id_data
)
assert isinstance(reservation_id, RetrieveHotelReservationId) assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
@@ -279,10 +314,20 @@ class TestHotelReservationIdFactory:
assert reservation_id.res_id_source == "HOTEL_SYSTEM" assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE" assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
def test_create_hotel_reservation_id_minimal(self, minimal_hotel_reservation_id_data): def test_create_hotel_reservation_id_minimal(
self, minimal_hotel_reservation_id_data
):
"""Test creating hotel reservation IDs with minimal data.""" """Test creating hotel reservation IDs with minimal data."""
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(minimal_hotel_reservation_id_data) notif_reservation_id = (
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(minimal_hotel_reservation_id_data) HotelReservationIdFactory.create_notif_hotel_reservation_id(
minimal_hotel_reservation_id_data
)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
minimal_hotel_reservation_id_data
)
)
for reservation_id in [notif_reservation_id, retrieve_reservation_id]: for reservation_id in [notif_reservation_id, retrieve_reservation_id]:
assert reservation_id.res_id_type == "999" assert reservation_id.res_id_type == "999"
@@ -290,17 +335,29 @@ class TestHotelReservationIdFactory:
assert reservation_id.res_id_source is None assert reservation_id.res_id_source is None
assert reservation_id.res_id_source_context is None assert reservation_id.res_id_source_context is None
def test_from_notif_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data): def test_from_notif_hotel_reservation_id_roundtrip(
self, sample_hotel_reservation_id_data
):
"""Test converting NotifHotelReservationId back to HotelReservationIdData.""" """Test converting NotifHotelReservationId back to HotelReservationIdData."""
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(reservation_id) sample_hotel_reservation_id_data
)
converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(
reservation_id
)
assert converted_data == sample_hotel_reservation_id_data assert converted_data == sample_hotel_reservation_id_data
def test_from_retrieve_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data): def test_from_retrieve_hotel_reservation_id_roundtrip(
self, sample_hotel_reservation_id_data
):
"""Test converting RetrieveHotelReservationId back to HotelReservationIdData.""" """Test converting RetrieveHotelReservationId back to HotelReservationIdData."""
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(reservation_id) sample_hotel_reservation_id_data
)
converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
reservation_id
)
assert converted_data == sample_hotel_reservation_id_data assert converted_data == sample_hotel_reservation_id_data
@@ -334,8 +391,12 @@ class TestResGuestFactory:
def test_create_res_guests_minimal(self, minimal_customer_data): def test_create_res_guests_minimal(self, minimal_customer_data):
"""Test creating ResGuests with minimal customer data.""" """Test creating ResGuests with minimal customer data."""
notif_res_guests = ResGuestFactory.create_notif_res_guests(minimal_customer_data) notif_res_guests = ResGuestFactory.create_notif_res_guests(
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(minimal_customer_data) minimal_customer_data
)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
minimal_customer_data
)
for res_guests in [notif_res_guests, retrieve_res_guests]: for res_guests in [notif_res_guests, retrieve_res_guests]:
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
@@ -395,35 +456,47 @@ class TestAlpineBitsFactory:
def test_create_customer_retrieve(self, sample_customer_data): def test_create_customer_retrieve(self, sample_customer_data):
"""Test creating customer using unified factory for RETRIEVE.""" """Test creating customer using unified factory for RETRIEVE."""
customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE) customer = AlpineBitsFactory.create(
sample_customer_data, OtaMessageType.RETRIEVE
)
assert isinstance(customer, RetrieveCustomer) assert isinstance(customer, RetrieveCustomer)
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe" assert customer.person_name.surname == "Doe"
def test_create_hotel_reservation_id_notif(self, sample_hotel_reservation_id_data): def test_create_hotel_reservation_id_notif(self, sample_hotel_reservation_id_data):
"""Test creating hotel reservation ID using unified factory for NOTIF.""" """Test creating hotel reservation ID using unified factory for NOTIF."""
reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF) reservation_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.NOTIF
)
assert isinstance(reservation_id, NotifHotelReservationId) assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_value == "RESERVATION-456"
def test_create_hotel_reservation_id_retrieve(self, sample_hotel_reservation_id_data): def test_create_hotel_reservation_id_retrieve(
self, sample_hotel_reservation_id_data
):
"""Test creating hotel reservation ID using unified factory for RETRIEVE.""" """Test creating hotel reservation ID using unified factory for RETRIEVE."""
reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE) reservation_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
)
assert isinstance(reservation_id, RetrieveHotelReservationId) assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123" assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_value == "RESERVATION-456"
def test_create_res_guests_notif(self, sample_customer_data): def test_create_res_guests_notif(self, sample_customer_data):
"""Test creating ResGuests using unified factory for NOTIF.""" """Test creating ResGuests using unified factory for NOTIF."""
res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF) res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.NOTIF
)
assert isinstance(res_guests, NotifResGuests) assert isinstance(res_guests, NotifResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
def test_create_res_guests_retrieve(self, sample_customer_data): def test_create_res_guests_retrieve(self, sample_customer_data):
"""Test creating ResGuests using unified factory for RETRIEVE.""" """Test creating ResGuests using unified factory for RETRIEVE."""
res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE) res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.RETRIEVE
)
assert isinstance(res_guests, RetrieveResGuests) assert isinstance(res_guests, RetrieveResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John" assert customer.person_name.given_name == "John"
@@ -431,8 +504,12 @@ class TestAlpineBitsFactory:
def test_extract_data_from_customer(self, sample_customer_data): def test_extract_data_from_customer(self, sample_customer_data):
"""Test extracting data from customer objects.""" """Test extracting data from customer objects."""
# Create both types and extract data back # Create both types and extract data back
notif_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF) notif_customer = AlpineBitsFactory.create(
retrieve_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE) sample_customer_data, OtaMessageType.NOTIF
)
retrieve_customer = AlpineBitsFactory.create(
sample_customer_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_customer) notif_extracted = AlpineBitsFactory.extract_data(notif_customer)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_customer) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_customer)
@@ -440,11 +517,17 @@ class TestAlpineBitsFactory:
assert notif_extracted == sample_customer_data assert notif_extracted == sample_customer_data
assert retrieve_extracted == sample_customer_data assert retrieve_extracted == sample_customer_data
def test_extract_data_from_hotel_reservation_id(self, sample_hotel_reservation_id_data): def test_extract_data_from_hotel_reservation_id(
self, sample_hotel_reservation_id_data
):
"""Test extracting data from hotel reservation ID objects.""" """Test extracting data from hotel reservation ID objects."""
# Create both types and extract data back # Create both types and extract data back
notif_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF) notif_res_id = AlpineBitsFactory.create(
retrieve_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE) sample_hotel_reservation_id_data, OtaMessageType.NOTIF
)
retrieve_res_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_res_id) notif_extracted = AlpineBitsFactory.extract_data(notif_res_id)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_id) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_id)
@@ -455,8 +538,12 @@ class TestAlpineBitsFactory:
def test_extract_data_from_res_guests(self, sample_customer_data): def test_extract_data_from_res_guests(self, sample_customer_data):
"""Test extracting data from ResGuests objects.""" """Test extracting data from ResGuests objects."""
# Create both types and extract data back # Create both types and extract data back
notif_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF) notif_res_guests = AlpineBitsFactory.create_res_guests(
retrieve_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE) sample_customer_data, OtaMessageType.NOTIF
)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_res_guests) notif_extracted = AlpineBitsFactory.extract_data(notif_res_guests)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_guests) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_guests)
@@ -481,33 +568,46 @@ class TestAlpineBitsFactory:
given_name="Unified", given_name="Unified",
surname="Factory", surname="Factory",
email_address="unified@factory.com", email_address="unified@factory.com",
phone_numbers=[("+1234567890", PhoneTechType.MOBILE)] phone_numbers=[("+1234567890", PhoneTechType.MOBILE)],
) )
reservation_data = HotelReservationIdData( reservation_data = HotelReservationIdData(
res_id_type="999", res_id_type="999", res_id_value="UNIFIED-TEST"
res_id_value="UNIFIED-TEST"
) )
# Create using unified factory # Create using unified factory
customer_notif = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF) customer_notif = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF)
customer_retrieve = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE) customer_retrieve = AlpineBitsFactory.create(
customer_data, OtaMessageType.RETRIEVE
)
res_id_notif = AlpineBitsFactory.create(reservation_data, OtaMessageType.NOTIF) res_id_notif = AlpineBitsFactory.create(reservation_data, OtaMessageType.NOTIF)
res_id_retrieve = AlpineBitsFactory.create(reservation_data, OtaMessageType.RETRIEVE) res_id_retrieve = AlpineBitsFactory.create(
reservation_data, OtaMessageType.RETRIEVE
)
res_guests_notif = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF) res_guests_notif = AlpineBitsFactory.create_res_guests(
res_guests_retrieve = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) customer_data, OtaMessageType.NOTIF
)
res_guests_retrieve = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
# Extract everything back # Extract everything back
extracted_customer_from_notif = AlpineBitsFactory.extract_data(customer_notif) extracted_customer_from_notif = AlpineBitsFactory.extract_data(customer_notif)
extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(customer_retrieve) extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(
customer_retrieve
)
extracted_res_id_from_notif = AlpineBitsFactory.extract_data(res_id_notif) extracted_res_id_from_notif = AlpineBitsFactory.extract_data(res_id_notif)
extracted_res_id_from_retrieve = AlpineBitsFactory.extract_data(res_id_retrieve) extracted_res_id_from_retrieve = AlpineBitsFactory.extract_data(res_id_retrieve)
extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(res_guests_notif) extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(
extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(res_guests_retrieve) res_guests_notif
)
extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(
res_guests_retrieve
)
# Verify everything matches # Verify everything matches
assert extracted_customer_from_notif == customer_data assert extracted_customer_from_notif == customer_data
@@ -525,31 +625,66 @@ class TestIntegration:
"""Test that both factories can work with the same customer data.""" """Test that both factories can work with the same customer data."""
# Create using CustomerFactory # Create using CustomerFactory
notif_customer = CustomerFactory.create_notif_customer(sample_customer_data) notif_customer = CustomerFactory.create_notif_customer(sample_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(sample_customer_data) retrieve_customer = CustomerFactory.create_retrieve_customer(
sample_customer_data
)
# Create using ResGuestFactory and extract customers # Create using ResGuestFactory and extract customers
notif_res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) notif_res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
sample_customer_data
)
notif_from_res_guests = notif_res_guests.res_guest.profiles.profile_info.profile.customer notif_from_res_guests = (
retrieve_from_res_guests = retrieve_res_guests.res_guest.profiles.profile_info.profile.customer notif_res_guests.res_guest.profiles.profile_info.profile.customer
)
retrieve_from_res_guests = (
retrieve_res_guests.res_guest.profiles.profile_info.profile.customer
)
# Compare customer names (structure should be identical) # Compare customer names (structure should be identical)
assert notif_customer.person_name.given_name == notif_from_res_guests.person_name.given_name assert (
assert notif_customer.person_name.surname == notif_from_res_guests.person_name.surname notif_customer.person_name.given_name
assert retrieve_customer.person_name.given_name == retrieve_from_res_guests.person_name.given_name == notif_from_res_guests.person_name.given_name
assert retrieve_customer.person_name.surname == retrieve_from_res_guests.person_name.surname )
assert (
notif_customer.person_name.surname
== notif_from_res_guests.person_name.surname
)
assert (
retrieve_customer.person_name.given_name
== retrieve_from_res_guests.person_name.given_name
)
assert (
retrieve_customer.person_name.surname
== retrieve_from_res_guests.person_name.surname
)
def test_hotel_reservation_id_factories_produce_same_data(self, sample_hotel_reservation_id_data): def test_hotel_reservation_id_factories_produce_same_data(
self, sample_hotel_reservation_id_data
):
"""Test that both HotelReservationId factories produce equivalent results.""" """Test that both HotelReservationId factories produce equivalent results."""
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) notif_reservation_id = (
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) HotelReservationIdFactory.create_notif_hotel_reservation_id(
sample_hotel_reservation_id_data
)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
sample_hotel_reservation_id_data
)
)
# Both should have the same field values # Both should have the same field values
assert notif_reservation_id.res_id_type == retrieve_reservation_id.res_id_type assert notif_reservation_id.res_id_type == retrieve_reservation_id.res_id_type
assert notif_reservation_id.res_id_value == retrieve_reservation_id.res_id_value assert notif_reservation_id.res_id_value == retrieve_reservation_id.res_id_value
assert notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source assert (
assert notif_reservation_id.res_id_source_context == retrieve_reservation_id.res_id_source_context notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source
)
assert (
notif_reservation_id.res_id_source_context
== retrieve_reservation_id.res_id_source_context
)
def test_complex_customer_workflow(self): def test_complex_customer_workflow(self):
"""Test a complex workflow with multiple operations.""" """Test a complex workflow with multiple operations."""
@@ -559,7 +694,7 @@ class TestIntegration:
surname="Johnson", surname="Johnson",
phone_numbers=[ phone_numbers=[
("+1555123456", PhoneTechType.MOBILE), ("+1555123456", PhoneTechType.MOBILE),
("+1555654321", PhoneTechType.VOICE) ("+1555654321", PhoneTechType.VOICE),
], ],
email_address="alice.johnson@company.com", email_address="alice.johnson@company.com",
email_newsletter=False, email_newsletter=False,
@@ -569,7 +704,7 @@ class TestIntegration:
country_code="CA", country_code="CA",
address_catalog=True, address_catalog=True,
gender="Female", gender="Female",
language="fr" language="fr",
) )
# Create ResGuests for both types # Create ResGuests for both types
@@ -578,7 +713,9 @@ class TestIntegration:
# Extract data back from both # Extract data back from both
notif_extracted = ResGuestFactory.extract_primary_customer(notif_res_guests) notif_extracted = ResGuestFactory.extract_primary_customer(notif_res_guests)
retrieve_extracted = ResGuestFactory.extract_primary_customer(retrieve_res_guests) retrieve_extracted = ResGuestFactory.extract_primary_customer(
retrieve_res_guests
)
# All should be equal # All should be equal
assert original_data == notif_extracted assert original_data == notif_extracted
@@ -592,16 +729,28 @@ class TestIntegration:
res_id_type="456", res_id_type="456",
res_id_value="COMPLEX-RESERVATION-789", res_id_value="COMPLEX-RESERVATION-789",
res_id_source="INTEGRATION_SYSTEM", res_id_source="INTEGRATION_SYSTEM",
res_id_source_context="API_CALL" res_id_source_context="API_CALL",
) )
# Create HotelReservationId for both types # Create HotelReservationId for both types
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data) notif_reservation_id = (
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(original_data) HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
original_data
)
)
# Extract data back from both # Extract data back from both
notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(notif_reservation_id) notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(
retrieve_extracted = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(retrieve_reservation_id) notif_reservation_id
)
retrieve_extracted = (
HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
retrieve_reservation_id
)
)
# All should be equal # All should be equal
assert original_data == notif_extracted assert original_data == notif_extracted

View File

@@ -6,6 +6,7 @@ Test the handshake functionality with the real AlpineBits sample file.
import asyncio import asyncio
from alpine_bits_python.alpinebits_server import AlpineBitsServer from alpine_bits_python.alpinebits_server import AlpineBitsServer
async def main(): async def main():
print("🔄 Testing AlpineBits Handshake with Sample File") print("🔄 Testing AlpineBits Handshake with Sample File")
print("=" * 60) print("=" * 60)
@@ -14,16 +15,22 @@ async def main():
server = AlpineBitsServer() server = AlpineBitsServer()
# Read the sample handshake request # Read the sample handshake request
with open("AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml", "r") as f: with open(
"AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml",
"r",
) as f:
ping_request_xml = f.read() ping_request_xml = f.read()
print("📤 Sending handshake request...") print("📤 Sending handshake request...")
# Handle the ping request # Handle the ping request
response = await server.handle_request("OTA_Ping:Handshaking", ping_request_xml, "2024-10") response = await server.handle_request(
"OTA_Ping:Handshaking", ping_request_xml, "2024-10"
)
print(f"\n📥 Response Status: {response.status_code}") print(f"\n📥 Response Status: {response.status_code}")
print(f"📄 Response XML:\n{response.xml_content}") print(f"📄 Response XML:\n{response.xml_content}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())