Got db saving working

This commit is contained in:
Jonas Linter
2025-09-29 13:56:34 +02:00
parent 384fb2b558
commit 06739ebea9
21 changed files with 1188 additions and 830 deletions

View File

@@ -1,6 +1,7 @@
"""Entry point for alpine_bits_python package."""
from .main import main
if __name__ == "__main__":
print("running test main")
main()
main()

View File

@@ -23,49 +23,65 @@ from xsdata_pydantic.bindings import XmlParser
class HttpStatusCode(IntEnum):
"""Allowed HTTP status codes for AlpineBits responses."""
OK = 200
BAD_REQUEST = 400
UNAUTHORIZED = 401
INTERNAL_SERVER_ERROR = 500
class AlpineBitsActionName(Enum):
"""Enum for AlpineBits action names with capability and request name mappings."""
# Format: (capability_name, actual_request_name)
OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking")
OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests")
OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif")
OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = ("action_OTA_HotelResNotif_GuestRequests",
"OTA_HotelResNotif:GuestRequests")
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = ("action_OTA_HotelDescriptiveContentNotif_Inventory",
"OTA_HotelDescriptiveContentNotif:Inventory")
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = ("action_OTA_HotelDescriptiveContentNotif_Info",
"OTA_HotelDescriptiveContentNotif:Info")
OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = ("action_OTA_HotelDescriptiveInfo_Inventory",
"OTA_HotelDescriptiveInfo:Inventory")
OTA_HOTEL_DESCRIPTIVE_INFO_INFO = ("action_OTA_HotelDescriptiveInfo_Info",
"OTA_HotelDescriptiveInfo:Info")
OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = ("action_OTA_HotelRatePlanNotif_RatePlans",
"OTA_HotelRatePlanNotif:RatePlans")
OTA_HOTEL_RATE_PLAN_BASE_RATES = ("action_OTA_HotelRatePlan_BaseRates",
"OTA_HotelRatePlan:BaseRates")
OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = (
"action_OTA_HotelResNotif_GuestRequests",
"OTA_HotelResNotif:GuestRequests",
)
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = (
"action_OTA_HotelDescriptiveContentNotif_Inventory",
"OTA_HotelDescriptiveContentNotif:Inventory",
)
OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = (
"action_OTA_HotelDescriptiveContentNotif_Info",
"OTA_HotelDescriptiveContentNotif:Info",
)
OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = (
"action_OTA_HotelDescriptiveInfo_Inventory",
"OTA_HotelDescriptiveInfo:Inventory",
)
OTA_HOTEL_DESCRIPTIVE_INFO_INFO = (
"action_OTA_HotelDescriptiveInfo_Info",
"OTA_HotelDescriptiveInfo:Info",
)
OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = (
"action_OTA_HotelRatePlanNotif_RatePlans",
"OTA_HotelRatePlanNotif:RatePlans",
)
OTA_HOTEL_RATE_PLAN_BASE_RATES = (
"action_OTA_HotelRatePlan_BaseRates",
"OTA_HotelRatePlan:BaseRates",
)
def __init__(self, capability_name: str, request_name: str):
self.capability_name = capability_name
self.request_name = request_name
@classmethod
def get_by_capability_name(cls, capability_name: str) -> Optional['AlpineBitsActionName']:
def get_by_capability_name(
cls, capability_name: str
) -> Optional["AlpineBitsActionName"]:
"""Get action enum by capability name."""
for action in cls:
if action.capability_name == capability_name:
return action
return None
@classmethod
def get_by_request_name(cls, request_name: str) -> Optional['AlpineBitsActionName']:
def get_by_request_name(cls, request_name: str) -> Optional["AlpineBitsActionName"]:
"""Get action enum by request name."""
for action in cls:
if action.request_name == request_name:
@@ -75,22 +91,25 @@ class AlpineBitsActionName(Enum):
class Version(str, Enum):
"""Enum for AlpineBits versions."""
V2024_10 = "2024-10"
V2022_10 = "2022-10"
# Add other versions as needed
@dataclass
class AlpineBitsResponse:
"""Response data structure for AlpineBits actions."""
xml_content: str
status_code: HttpStatusCode = HttpStatusCode.OK
def __post_init__(self):
"""Validate that status code is one of the allowed values."""
if self.status_code not in [200, 400, 401, 500]:
raise ValueError(f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500")
raise ValueError(
f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500"
)
# Abstract base class for AlpineBits Action
@@ -98,20 +117,24 @@ class AlpineBitsAction(ABC):
"""Abstract base class for handling AlpineBits actions."""
name: AlpineBitsActionName
version: Version | list[Version] # list of versions in case action supports multiple versions
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
version: (
Version | list[Version]
) # list of versions in case action supports multiple versions
async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""
Handle the incoming request XML and return response XML.
Default implementation returns "not implemented" error.
Override this method in subclasses to provide actual functionality.
Args:
action: The action to perform (e.g., "OTA_PingRQ")
request_xml: The XML request body as string
version: The AlpineBits version
Returns:
AlpineBitsResponse with error or actual response
"""
@@ -121,7 +144,7 @@ class AlpineBitsAction(ABC):
async def check_version_supported(self, version: Version) -> bool:
"""
Check if the action supports the given version.
Args:
version: The AlpineBits version to check
Returns:
@@ -130,103 +153,93 @@ class AlpineBitsAction(ABC):
if isinstance(self.version, list):
return version in self.version
return version == self.version
class ServerCapabilities:
"""
Automatically discovers AlpineBitsAction implementations and generates capabilities.
"""
def __init__(self):
self.action_registry: Dict[str, Type[AlpineBitsAction]] = {}
self._discover_actions()
self.capability_dict = None
def _discover_actions(self):
"""Discover all AlpineBitsAction implementations in the current module."""
current_module = inspect.getmodule(self)
for name, obj in inspect.getmembers(current_module):
if (inspect.isclass(obj) and
issubclass(obj, AlpineBitsAction) and
obj != AlpineBitsAction):
if (
inspect.isclass(obj)
and issubclass(obj, AlpineBitsAction)
and obj != AlpineBitsAction
):
# Check if this action is actually implemented (not just returning default)
if self._is_action_implemented(obj):
action_instance = obj()
if hasattr(action_instance, 'name'):
if hasattr(action_instance, "name"):
# Use capability name for the registry key
self.action_registry[action_instance.name.capability_name] = obj
def _is_action_implemented(self, action_class: Type[AlpineBitsAction]) -> bool:
"""
Check if an action is actually implemented or just uses the default behavior.
This is a simple check - in practice, you might want more sophisticated detection.
"""
# Check if the class has overridden the handle method
if 'handle' in action_class.__dict__:
if "handle" in action_class.__dict__:
return True
return False
def create_capabilities_dict(self) -> None:
"""
Generate the capabilities dictionary based on discovered actions.
"""
versions_dict = {}
for action_name, action_class in self.action_registry.items():
action_instance = action_class()
# Get supported versions for this action
if isinstance(action_instance.version, list):
supported_versions = action_instance.version
else:
supported_versions = [action_instance.version]
# Add action to each supported version
for version in supported_versions:
version_str = version.value
if version_str not in versions_dict:
versions_dict[version_str] = {
"version": version_str,
"actions": []
}
versions_dict[version_str] = {"version": version_str, "actions": []}
action_dict = {"action": action_name}
# Add supports field if the action has custom supports
if hasattr(action_instance, 'supports') and action_instance.supports:
if hasattr(action_instance, "supports") and action_instance.supports:
action_dict["supports"] = action_instance.supports
versions_dict[version_str]["actions"].append(action_dict)
self.capability_dict = {"versions": list(versions_dict.values())}
return None
def get_capabilities_dict(self) -> Dict:
"""
Get capabilities as a dictionary. Generates if not already created.
"""
if self.capability_dict is None:
self.create_capabilities_dict()
return self.capability_dict
def get_capabilities_json(self) -> str:
"""Get capabilities as formatted JSON string."""
return json.dumps(self.get_capabilities_dict(), indent=2)
def get_supported_actions(self) -> List[str]:
"""Get list of all supported action names."""
return list(self.action_registry.keys())
@@ -234,22 +247,35 @@ class ServerCapabilities:
# Sample Action Implementations for demonstration
class PingAction(AlpineBitsAction):
"""Implementation for OTA_Ping action (handshaking)."""
def __init__(self):
self.name = AlpineBitsActionName.OTA_PING
self.version = [Version.V2024_10, Version.V2022_10] # Supports multiple versions
async def handle(self, action: str, request_xml: str, version: Version, server_capabilities: None | ServerCapabilities = None) -> AlpineBitsResponse:
self.version = [
Version.V2024_10,
Version.V2022_10,
] # Supports multiple versions
async def handle(
self,
action: str,
request_xml: str,
version: Version,
server_capabilities: None | ServerCapabilities = None,
) -> AlpineBitsResponse:
"""Handle ping requests."""
if request_xml is None:
return AlpineBitsResponse(f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST)
return AlpineBitsResponse(
f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST
)
if server_capabilities is None:
return AlpineBitsResponse("Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR)
return AlpineBitsResponse(
"Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR
)
# Parse the incoming request XML and extract EchoData
parser = XmlParser()
@@ -259,54 +285,66 @@ class PingAction(AlpineBitsAction):
echo_data = json.loads(parsed_request.echo_data)
except Exception as e:
return AlpineBitsResponse(f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST)
return AlpineBitsResponse(
f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST
)
# compare echo data with capabilities, create a dictionary containing the matching capabilities
capabilities_dict = server_capabilities.get_capabilities_dict()
matching_capabilities = {"versions": []}
# Iterate through client's requested versions
for client_version in echo_data.get("versions", []):
client_version_str = client_version.get("version", "")
# Find matching server version
for server_version in capabilities_dict["versions"]:
if server_version["version"] == client_version_str:
# Found a matching version, now find common actions
matching_version = {
"version": client_version_str,
"actions": []
}
matching_version = {"version": client_version_str, "actions": []}
# Get client's requested actions for this version
client_actions = {action.get("action", ""): action for action in client_version.get("actions", [])}
server_actions = {action.get("action", ""): action for action in server_version.get("actions", [])}
client_actions = {
action.get("action", ""): action
for action in client_version.get("actions", [])
}
server_actions = {
action.get("action", ""): action
for action in server_version.get("actions", [])
}
# Find common actions
for action_name in client_actions:
if action_name in server_actions:
# Use server's action definition (includes our supports)
matching_version["actions"].append(server_actions[action_name])
matching_version["actions"].append(
server_actions[action_name]
)
# Only add version if there are common actions
if matching_version["actions"]:
matching_capabilities["versions"].append(matching_version)
break
# Debug print to see what we matched
# Create successful ping response with matched capabilities
capabilities_json = json.dumps(matching_capabilities, indent=2)
warning = OtaPingRs.Warnings.Warning(status=WarningStatus.ALPINEBITS_HANDSHAKE.value, type_value="11", content=[capabilities_json])
warning = OtaPingRs.Warnings.Warning(
status=WarningStatus.ALPINEBITS_HANDSHAKE.value,
type_value="11",
content=[capabilities_json],
)
warning_response = OtaPingRs.Warnings(warning=[warning])
response_ota_ping = OtaPingRs(version= "7.000", warnings=warning_response, echo_data=capabilities_json, success="")
response_ota_ping = OtaPingRs(
version="7.000",
warnings=warning_response,
echo_data=capabilities_json,
success="",
)
config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8"
@@ -314,34 +352,35 @@ class PingAction(AlpineBitsAction):
serializer = XmlSerializer(config=config)
response_xml = serializer.render(response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"})
response_xml = serializer.render(
response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}
)
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class ReadAction(AlpineBitsAction):
"""Implementation for OTA_Read action."""
def __init__(self):
self.name = AlpineBitsActionName.OTA_READ
self.version = [Version.V2024_10, Version.V2022_10]
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""Handle read requests."""
response_xml = f'''<?xml version="1.0" encoding="UTF-8"?>
response_xml = f"""<?xml version="1.0" encoding="UTF-8"?>
<OTA_ReadRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000">
<Success/>
<Data>Read operation successful for {version.value}</Data>
</OTA_ReadRS>'''
</OTA_ReadRS>"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class HotelAvailNotifAction(AlpineBitsAction):
"""Implementation for Hotel Availability Notification action with supports."""
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_AVAIL_NOTIF
self.version = Version.V2022_10
@@ -349,68 +388,68 @@ class HotelAvailNotifAction(AlpineBitsAction):
"OTA_HotelAvailNotif_accept_rooms",
"OTA_HotelAvailNotif_accept_categories",
"OTA_HotelAvailNotif_accept_deltas",
"OTA_HotelAvailNotif_accept_BookingThreshold"
"OTA_HotelAvailNotif_accept_BookingThreshold",
]
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""Handle hotel availability notifications."""
response_xml = '''<?xml version="1.0" encoding="UTF-8"?>
response_xml = """<?xml version="1.0" encoding="UTF-8"?>
<OTA_HotelAvailNotifRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000">
<Success/>
</OTA_HotelAvailNotifRS>'''
</OTA_HotelAvailNotifRS>"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class GuestRequestsAction(AlpineBitsAction):
"""Unimplemented action - will not appear in capabilities."""
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS
self.version = Version.V2024_10
# Note: This class doesn't override the handle method, so it won't be discovered
class AlpineBitsServer:
"""
Asynchronous AlpineBits server for handling hotel data exchange requests.
This server handles various OTA actions and implements the AlpineBits protocol
for hotel data exchange. It maintains a registry of supported actions and
their capabilities, and can respond to handshake requests with its capabilities.
"""
def __init__(self):
self.capabilities = ServerCapabilities()
self._action_instances = {}
self._initialize_action_instances()
def _initialize_action_instances(self):
"""Initialize instances of all discovered action classes."""
for capability_name, action_class in self.capabilities.action_registry.items():
self._action_instances[capability_name] = action_class()
def get_capabilities(self) -> Dict:
"""Get server capabilities."""
return self.capabilities.get_capabilities_dict()
def get_capabilities_json(self) -> str:
"""Get server capabilities as JSON."""
return self.capabilities.get_capabilities_json()
async def handle_request(self, request_action_name: str, request_xml: str, version: str = "2024-10") -> AlpineBitsResponse:
async def handle_request(
self, request_action_name: str, request_xml: str, version: str = "2024-10"
) -> AlpineBitsResponse:
"""
Handle an incoming AlpineBits request by routing to appropriate action handler.
Args:
request_action_name: The action name from the request (e.g., "OTA_Read:GuestRequests")
request_xml: The XML request body
version: The AlpineBits version (defaults to "2024-10")
Returns:
AlpineBitsResponse with the result
"""
@@ -419,52 +458,56 @@ class AlpineBitsServer:
version_enum = Version(version)
except ValueError:
return AlpineBitsResponse(
f"Error: Unsupported version {version}",
HttpStatusCode.BAD_REQUEST
f"Error: Unsupported version {version}", HttpStatusCode.BAD_REQUEST
)
# Find the action by request name
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum:
return AlpineBitsResponse(
f"Error: Unknown action {request_action_name}",
HttpStatusCode.BAD_REQUEST
HttpStatusCode.BAD_REQUEST,
)
# Check if we have an implementation for this action
capability_name = action_enum.capability_name
if capability_name not in self._action_instances:
return AlpineBitsResponse(
f"Error: Action {request_action_name} is not implemented",
HttpStatusCode.BAD_REQUEST
HttpStatusCode.BAD_REQUEST,
)
action_instance: AlpineBitsAction = self._action_instances[capability_name]
# Check if the action supports the requested version
if not await action_instance.check_version_supported(version_enum):
return AlpineBitsResponse(
f"Error: Action {request_action_name} does not support version {version}",
HttpStatusCode.BAD_REQUEST
HttpStatusCode.BAD_REQUEST,
)
# Handle the request
try:
# Special case for ping action - pass server capabilities
if capability_name == "action_OTA_Ping":
return await action_instance.handle(request_action_name, request_xml, version_enum, self.capabilities)
return await action_instance.handle(
request_action_name, request_xml, version_enum, self.capabilities
)
else:
return await action_instance.handle(request_action_name, request_xml, version_enum)
return await action_instance.handle(
request_action_name, request_xml, version_enum
)
except Exception as e:
print(f"Error handling request {request_action_name}: {str(e)}")
# print stack trace for debugging
import traceback
traceback.print_exc()
return AlpineBitsResponse(
f"Error: Internal server error while processing {request_action_name}: {str(e)}",
HttpStatusCode.INTERNAL_SERVER_ERROR
HttpStatusCode.INTERNAL_SERVER_ERROR,
)
def get_supported_request_names(self) -> List[str]:
"""Get all supported request names (not capability names)."""
request_names = []
@@ -473,26 +516,28 @@ class AlpineBitsServer:
if action_enum:
request_names.append(action_enum.request_name)
return sorted(request_names)
def is_action_supported(self, request_action_name: str, version: str = None) -> bool:
def is_action_supported(
self, request_action_name: str, version: str = None
) -> bool:
"""
Check if a request action is supported.
Args:
request_action_name: The request action name (e.g., "OTA_Read:GuestRequests")
version: Optional version to check
Returns:
True if supported, False otherwise
"""
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum:
return False
capability_name = action_enum.capability_name
if capability_name not in self._action_instances:
return False
if version:
try:
version_enum = Version(version)
@@ -504,7 +549,7 @@ class AlpineBitsServer:
return action_instance.version == version_enum
except ValueError:
return False
return True
@@ -512,10 +557,10 @@ async def main():
"""Demonstrate the automatic capabilities discovery and request handling."""
print("🚀 AlpineBits Server Capabilities Discovery & Request Handling Demo")
print("=" * 70)
# Create server instance
server = AlpineBitsServer()
print("\n📋 Discovered Action Classes:")
print("-" * 30)
for capability_name, action_class in server.capabilities.action_registry.items():
@@ -523,24 +568,26 @@ async def main():
request_name = action_enum.request_name if action_enum else "unknown"
print(f"{capability_name} -> {action_class.__name__}")
print(f" Request name: {request_name}")
print(f"\n📊 Total Implemented Actions: {len(server.capabilities.get_supported_actions())}")
print(
f"\n📊 Total Implemented Actions: {len(server.capabilities.get_supported_actions())}"
)
print("\n🔍 Generated Capabilities JSON:")
print("-" * 30)
capabilities_json = server.get_capabilities_json()
print(capabilities_json)
print("\n🎯 Supported Request Names:")
print("-" * 30)
for request_name in server.get_supported_request_names():
print(f"{request_name}")
print("\n🧪 Testing Request Handling:")
print("-" * 30)
test_xml = "<test>sample request</test>"
# Test different request formats
test_cases = [
("OTA_Ping:Handshaking", "2024-10"),
@@ -548,16 +595,16 @@ async def main():
("OTA_Read:GuestRequests", "2022-10"),
("OTA_HotelAvailNotif", "2024-10"),
("UnknownAction", "2024-10"),
("OTA_Ping:Handshaking", "unsupported-version")
("OTA_Ping:Handshaking", "unsupported-version"),
]
for request_name, version in test_cases:
print(f"\n<EFBFBD> Testing: {request_name} (v{version})")
# Check if supported first
is_supported = server.is_action_supported(request_name, version)
print(f" Supported: {is_supported}")
# Handle the request
response = await server.handle_request(request_name, test_xml, version)
print(f" Status: {response.status_code}")
@@ -565,9 +612,9 @@ async def main():
print(f" Response: {response.xml_content[:100]}...")
else:
print(f" Response: {response.xml_content}")
print("\n✅ Demo completed successfully!")
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -1,18 +1,29 @@
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, APIRouter, Form, File, UploadFile
from fastapi import (
FastAPI,
HTTPException,
BackgroundTasks,
Request,
Depends,
APIRouter,
Form,
File,
UploadFile,
)
from fastapi.concurrency import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic
from .config_loader import load_config
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from .models import WixFormSubmission
from datetime import datetime, date, timezone
from datetime import datetime, date, timezone
from .auth import validate_api_key, validate_wix_signature, generate_api_key
from .rate_limit import (
limiter,
webhook_limiter,
limiter,
webhook_limiter,
custom_rate_limit_handler,
DEFAULT_RATE_LIMIT,
WEBHOOK_RATE_LIMIT,
BURST_RATE_LIMIT
BURST_RATE_LIMIT,
)
from slowapi.errors import RateLimitExceeded
import logging
@@ -24,8 +35,14 @@ import gzip
import xml.etree.ElementTree as ET
from .alpinebits_server import AlpineBitsServer, Version
import urllib.parse
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from .db import get_async_session, Customer as DBCustomer, Reservation as DBReservation
from .db import (
Base,
Customer as DBCustomer,
Reservation as DBReservation,
get_database_url,
)
# Configure logging
@@ -42,12 +59,36 @@ except Exception as e:
_LOGGER.error(f"Failed to load config: {str(e)}")
config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup DB
DATABASE_URL = get_database_url(config)
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
app.state.engine = engine
app.state.async_sessionmaker = AsyncSessionLocal
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
_LOGGER.info("Database tables checked/created at startup.")
yield
# Optional: Dispose engine on shutdown
await engine.dispose()
async def get_async_session(request: Request):
async_sessionmaker = request.app.state.async_sessionmaker
async with async_sessionmaker() as session:
yield session
app = FastAPI(
title="Wix Form Handler API",
description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting",
version="1.0.0"
version="1.0.0",
lifespan=lifespan
)
# Create API router with /api prefix
@@ -62,9 +103,9 @@ app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.wix.com",
"https://*.wixstatic.com",
"https://*.wixstatic.com",
"http://localhost:3000", # For development
"http://localhost:8000" # For local testing
"http://localhost:8000", # For local testing
],
allow_credentials=True,
allow_methods=["GET", "POST"],
@@ -78,27 +119,39 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None:
Add your business logic here.
"""
try:
_LOGGER.info(f"Processing form submission: {submission_data.get('submissionId')}")
_LOGGER.info(
f"Processing form submission: {submission_data.get('submissionId')}"
)
# Example processing - you can replace this with your actual logic
form_name = submission_data.get('formName')
contact_email = submission_data.get('contact', {}).get('email') if submission_data.get('contact') else None
form_name = submission_data.get("formName")
contact_email = (
submission_data.get("contact", {}).get("email")
if submission_data.get("contact")
else None
)
# Extract form fields
form_fields = {k: v for k, v in submission_data.items() if k.startswith('field:')}
_LOGGER.info(f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}")
form_fields = {
k: v for k, v in submission_data.items() if k.startswith("field:")
}
_LOGGER.info(
f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}"
)
# Here you could:
# - Save to database
# - Send emails
# - Call external APIs
# - Process the data further
except Exception as e:
_LOGGER.error(f"Error processing form submission: {str(e)}")
@api_router.get("/")
@limiter.limit(DEFAULT_RATE_LIMIT)
async def root(request: Request):
@@ -111,8 +164,8 @@ async def root(request: Request):
"rate_limits": {
"default": DEFAULT_RATE_LIMIT,
"webhook": WEBHOOK_RATE_LIMIT,
"burst": BURST_RATE_LIMIT
}
"burst": BURST_RATE_LIMIT,
},
}
@@ -126,11 +179,10 @@ async def health_check(request: Request):
"service": "wix-form-handler",
"version": "1.0.0",
"authentication": "enabled",
"rate_limiting": "enabled"
"rate_limiting": "enabled",
}
# Extracted business logic for handling Wix form submissions
async def process_wix_form_submission(request: Request, data: Dict[str, Any], db):
"""
@@ -138,10 +190,9 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
"""
timestamp = datetime.now().isoformat()
_LOGGER.info(f"Received Wix form data at {timestamp}")
#_LOGGER.info(f"Data keys: {list(data.keys())}")
#_LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
# _LOGGER.info(f"Data keys: {list(data.keys())}")
# _LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
log_entry = {
"timestamp": timestamp,
"client_ip": request.client.host if request.client else "unknown",
@@ -154,9 +205,13 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
if not os.path.exists(logs_dir):
os.makedirs(logs_dir, mode=0o755, exist_ok=True)
stat_info = os.stat(logs_dir)
_LOGGER.info(f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}")
_LOGGER.info(
f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}"
)
_LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}")
log_filename = f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
log_filename = (
f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
with open(log_filename, "w", encoding="utf-8") as f:
json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False)
file_stat = os.stat(log_filename)
@@ -164,16 +219,10 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
_LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}")
_LOGGER.info(f"Data logged to: {log_filename}")
data = data.get("data") # Handle nested "data" key if present
# save customer and reservation to DB
contact_info = data.get("contact", {})
first_name = contact_info.get("name", {}).get("first")
last_name = contact_info.get("name", {}).get("last")
@@ -193,10 +242,18 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
language = data.get("contact", {}).get("locale", "en")[:2]
# Dates
start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value")
end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value")
start_date = (
data.get("field:date_picker_a7c8")
or data.get("Anreisedatum")
or data.get("submissions", [{}])[1].get("value")
)
end_date = (
data.get("field:date_picker_7e65")
or data.get("Abreisedatum")
or data.get("submissions", [{}])[2].get("value")
)
# Room/guest info
# Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2)
num_children = int(data.get("field:anzahl_kinder") or 0)
children_ages = []
@@ -258,7 +315,7 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults,
num_children=num_children,
children_ages=','.join(str(a) for a in children_ages),
children_ages=",".join(str(a) for a in children_ages),
offer=offer,
utm_comment=utm_comment,
created_at=datetime.now(timezone.utc),
@@ -277,23 +334,21 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
await db.commit()
await db.refresh(db_reservation)
return {
"status": "success",
"message": "Wix form data received successfully",
"received_keys": list(data.keys()),
"data_logged_to": log_filename,
"timestamp": timestamp,
"process_info": log_entry["process_info"],
"note": "No authentication required for this endpoint"
"note": "No authentication required for this endpoint",
}
@api_router.post("/webhook/wix-form")
@webhook_limiter.limit(WEBHOOK_RATE_LIMIT)
async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)):
async def handle_wix_form(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
"""
Unified endpoint to handle Wix form submissions (test and production).
No authentication required for this endpoint.
@@ -304,16 +359,19 @@ async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Dep
_LOGGER.error(f"Error in handle_wix_form: {str(e)}")
# log stacktrace
import traceback
traceback_str = traceback.format_exc()
_LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}")
raise HTTPException(
status_code=500,
detail=f"Error processing Wix form data: {str(e)}"
status_code=500, detail=f"Error processing Wix form data: {str(e)}"
)
@api_router.post("/webhook/wix-form/test")
@limiter.limit(DEFAULT_RATE_LIMIT)
async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session=Depends(get_async_session)):
async def handle_wix_form_test(
request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
):
"""
Test endpoint to verify the API is working with raw JSON data.
No authentication required for testing purposes.
@@ -323,40 +381,37 @@ async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session
except Exception as e:
_LOGGER.error(f"Error in handle_wix_form_test: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error processing test data: {str(e)}"
status_code=500, detail=f"Error processing test data: {str(e)}"
)
@api_router.post("/admin/generate-api-key")
@limiter.limit("5/hour") # Very restrictive for admin operations
async def generate_new_api_key(
request: Request,
admin_key: str = Depends(validate_api_key)
request: Request, admin_key: str = Depends(validate_api_key)
):
"""
Admin endpoint to generate new API keys.
Requires admin API key and is heavily rate limited.
"""
if admin_key != "admin-key":
raise HTTPException(
status_code=403,
detail="Admin access required"
)
raise HTTPException(status_code=403, detail="Admin access required")
new_key = generate_api_key()
_LOGGER.info(f"Generated new API key (requested by: {admin_key})")
return {
"status": "success",
"message": "New API key generated",
"api_key": new_key,
"timestamp": datetime.now().isoformat(),
"note": "Store this key securely - it won't be shown again"
"note": "Store this key securely - it won't be shown again",
}
async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(security_basic)) -> str:
async def validate_basic_auth(
credentials: HTTPBasicCredentials = Depends(security_basic),
) -> str:
"""
Validate basic authentication for AlpineBits protocol.
Returns username if valid, raises HTTPException if not.
@@ -369,8 +424,11 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
headers={"WWW-Authenticate": "Basic"},
)
valid = False
for entry in config['alpine_bits_auth']:
if credentials.username == entry['username'] and credentials.password == entry['password']:
for entry in config["alpine_bits_auth"]:
if (
credentials.username == entry["username"]
and credentials.password == entry["password"]
):
valid = True
break
if not valid:
@@ -379,7 +437,9 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
detail="ERROR: Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)
_LOGGER.info(f"AlpineBits authentication successful for user: {credentials.username} (from config)")
_LOGGER.info(
f"AlpineBits authentication successful for user: {credentials.username} (from config)"
)
return credentials.username
@@ -390,10 +450,9 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
"""
if "multipart/form-data" not in content_type:
raise HTTPException(
status_code=400,
detail="ERROR: Content-Type must be multipart/form-data"
status_code=400, detail="ERROR: Content-Type must be multipart/form-data"
)
# Extract boundary
boundary = None
for part in content_type.split(";"):
@@ -401,62 +460,56 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
if part.startswith("boundary="):
boundary = part.split("=", 1)[1].strip('"')
break
if not boundary:
raise HTTPException(
status_code=400,
detail="ERROR: Missing boundary in multipart/form-data"
status_code=400, detail="ERROR: Missing boundary in multipart/form-data"
)
# Simple multipart parsing
parts = body.split(f"--{boundary}".encode())
data = {}
for part in parts:
if not part.strip() or part.strip() == b"--":
continue
# Split headers and content
if b"\r\n\r\n" in part:
headers_section, content = part.split(b"\r\n\r\n", 1)
content = content.rstrip(b"\r\n")
# Parse Content-Disposition header
headers = headers_section.decode('utf-8', errors='ignore')
headers = headers_section.decode("utf-8", errors="ignore")
name = None
for line in headers.split('\n'):
if 'Content-Disposition' in line and 'name=' in line:
for line in headers.split("\n"):
if "Content-Disposition" in line and "name=" in line:
# Extract name parameter
for param in line.split(';'):
for param in line.split(";"):
param = param.strip()
if param.startswith('name='):
name = param.split('=', 1)[1].strip('"')
if param.startswith("name="):
name = param.split("=", 1)[1].strip('"')
break
if name:
# Handle file uploads or text content
if content.startswith(b'<'):
if content.startswith(b"<"):
# Likely XML content
data[name] = content.decode('utf-8', errors='ignore')
data[name] = content.decode("utf-8", errors="ignore")
else:
data[name] = content.decode('utf-8', errors='ignore')
data[name] = content.decode("utf-8", errors="ignore")
return data
@api_router.post("/alpinebits/server-2024-10")
@limiter.limit("60/minute")
async def alpinebits_server_handshake(
request: Request,
username: str = Depends(validate_basic_auth)
request: Request, username: str = Depends(validate_basic_auth)
):
"""
AlpineBits server endpoint implementing the handshake protocol.
This endpoint handles:
- Protocol version negotiation via X-AlpineBits-ClientProtocolVersion header
- Client identification via X-AlpineBits-ClientID header (optional)
@@ -464,62 +517,67 @@ async def alpinebits_server_handshake(
- Gzip compression support
- Proper error handling with HTTP status codes
- Handshaking action processing
Authentication: HTTP Basic Auth required
Content-Type: multipart/form-data
Compression: gzip supported (check X-AlpineBits-Server-Accept-Encoding)
"""
try:
# Check required headers
client_protocol_version = request.headers.get("X-AlpineBits-ClientProtocolVersion")
client_protocol_version = request.headers.get(
"X-AlpineBits-ClientProtocolVersion"
)
if not client_protocol_version:
# Server concludes client speaks a protocol version preceding 2013-04
client_protocol_version = "pre-2013-04"
_LOGGER.info("No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04")
_LOGGER.info(
"No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04"
)
else:
_LOGGER.info(f"Client protocol version: {client_protocol_version}")
# Optional client ID
client_id = request.headers.get("X-AlpineBits-ClientID")
if client_id:
_LOGGER.info(f"Client ID: {client_id}")
# Check content encoding
content_encoding = request.headers.get("Content-Encoding")
is_compressed = content_encoding == "gzip"
if is_compressed:
_LOGGER.info("Request is gzip compressed")
# Get content type before processing
content_type = request.headers.get("Content-Type", "")
_LOGGER.info(f"Content-Type: {content_type}")
_LOGGER.info(f"Content-Encoding: {content_encoding}")
# Get request body
body = await request.body()
# Decompress if needed
if is_compressed:
try:
body = gzip.decompress(body)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"ERROR: Failed to decompress gzip content: {str(e)}"
detail=f"ERROR: Failed to decompress gzip content: {str(e)}",
)
# Check content type (after decompression)
if "multipart/form-data" not in content_type and "application/x-www-form-urlencoded" not in content_type:
if (
"multipart/form-data" not in content_type
and "application/x-www-form-urlencoded" not in content_type
):
raise HTTPException(
status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded"
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
)
# Parse multipart data
if "multipart/form-data" in content_type:
try:
@@ -527,7 +585,7 @@ async def alpinebits_server_handshake(
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}"
detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}",
)
elif "application/x-www-form-urlencoded" in content_type:
# Parse as urlencoded
@@ -535,75 +593,59 @@ async def alpinebits_server_handshake(
else:
raise HTTPException(
status_code=400,
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded"
detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
)
# Check for required action parameter
action = form_data.get("action")
if not action:
raise HTTPException(
status_code=400,
detail="ERROR: Missing required 'action' parameter")
status_code=400, detail="ERROR: Missing required 'action' parameter"
)
_LOGGER.info(f"AlpineBits action: {action}")
# Get optional request XML
request_xml = form_data.get("request")
request_xml = form_data.get("request")
server = AlpineBitsServer()
version = Version.V2024_10
# Create successful handshake response
response = await server.handle_request(action, request_xml, version)
response = await server.handle_request(action, request_xml, version)
response_xml = response.xml_content
# Set response headers indicating server capabilities
headers = {
"Content-Type": "application/xml; charset=utf-8",
"X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support
"X-AlpineBits-Server-Version": "2024-10"
"X-AlpineBits-Server-Version": "2024-10",
}
return Response(
content=response_xml,
status_code=response.status_code,
headers=headers
)
return Response(
content=response_xml, status_code=response.status_code, headers=headers
)
except HTTPException:
# Re-raise HTTP exceptions (auth errors, etc.)
raise
except Exception as e:
_LOGGER.error(f"Error in AlpineBits handshake: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@api_router.get("/admin/stats")
@limiter.limit("10/minute")
async def get_api_stats(
request: Request,
admin_key: str = Depends(validate_api_key)
):
async def get_api_stats(request: Request, admin_key: str = Depends(validate_api_key)):
"""
Admin endpoint to get API usage statistics.
Requires admin API key.
"""
if admin_key != "admin-key":
raise HTTPException(
status_code=403,
detail="Admin access required"
)
raise HTTPException(status_code=403, detail="Admin access required")
# In a real application, you'd fetch this from your database/monitoring system
return {
"status": "success",
@@ -611,9 +653,9 @@ async def get_api_stats(
"uptime": "Available in production deployment",
"total_requests": "Available with monitoring setup",
"active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]),
"rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory"
"rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory",
},
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now().isoformat(),
}
@@ -629,11 +671,12 @@ async def landing_page():
try:
# Get the path to the HTML file
import os
html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html")
with open(html_path, "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError:
# Fallback if HTML file is not found
@@ -660,4 +703,5 @@ async def landing_page():
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -21,7 +21,7 @@ security = HTTPBearer()
API_KEYS = {
# Example API keys - replace with your own secure keys
"wix-webhook-key": "sk_live_your_secure_api_key_here",
"admin-key": "sk_admin_your_admin_key_here"
"admin-key": "sk_admin_your_admin_key_here",
}
# Load API keys from environment if available
@@ -36,19 +36,21 @@ def generate_api_key() -> str:
return f"sk_live_{secrets.token_urlsafe(32)}"
def validate_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
def validate_api_key(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> str:
"""
Validate API key from Authorization header.
Expected format: Authorization: Bearer your_api_key_here
"""
token = credentials.credentials
# Check if the token is in our valid API keys
for key_name, valid_key in API_KEYS.items():
if secrets.compare_digest(token, valid_key):
logger.info(f"Valid API key used: {key_name}")
return key_name
logger.warning(f"Invalid API key attempted: {token[:10]}...")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -64,19 +66,17 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool:
"""
if not signature or not secret:
return False
try:
# Remove 'sha256=' prefix if present
if signature.startswith('sha256='):
if signature.startswith("sha256="):
signature = signature[7:]
# Calculate expected signature
expected_signature = hmac.new(
secret.encode('utf-8'),
payload,
hashlib.sha256
secret.encode("utf-8"), payload, hashlib.sha256
).hexdigest()
# Compare signatures securely
return secrets.compare_digest(signature, expected_signature)
except Exception as e:
@@ -86,21 +86,21 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool:
class APIKeyAuth:
"""Simple API key authentication class"""
def __init__(self, api_keys: dict):
self.api_keys = api_keys
def authenticate(self, api_key: str) -> Optional[str]:
"""Authenticate an API key and return the key name if valid"""
for key_name, valid_key in self.api_keys.items():
if secrets.compare_digest(api_key, valid_key):
return key_name
return None
def add_key(self, name: str, key: str):
"""Add a new API key"""
self.api_keys[name] = key
def remove_key(self, name: str):
"""Remove an API key"""
if name in self.api_keys:
@@ -108,4 +108,4 @@ class APIKeyAuth:
# Initialize auth system
auth_system = APIKeyAuth(API_KEYS)
auth_system = APIKeyAuth(API_KEYS)

View File

@@ -1,4 +1,3 @@
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
@@ -16,37 +15,45 @@ from annotatedyaml.loader import (
from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid
# --- Voluptuous schemas ---
database_schema = Schema({
Required('url'): str
}, extra=PREVENT_EXTRA)
database_schema = Schema({Required("url"): str}, extra=PREVENT_EXTRA)
hotel_auth_schema = Schema({
Required("hotel_id"): str,
Required("hotel_name"): str,
Required("username"): str,
Required("password"): str
}, extra=PREVENT_EXTRA)
basic_auth_schema = Schema(
All([hotel_auth_schema], Length(min=1))
hotel_auth_schema = Schema(
{
Required("hotel_id"): str,
Required("hotel_name"): str,
Required("username"): str,
Required("password"): str,
},
extra=PREVENT_EXTRA,
)
config_schema = Schema({
Required('database'): database_schema,
Required('alpine_bits_auth'): basic_auth_schema
}, extra=PREVENT_EXTRA)
basic_auth_schema = Schema(All([hotel_auth_schema], Length(min=1)))
DEFAULT_CONFIG_FILE = 'config.yaml'
config_schema = Schema(
{
Required("database"): database_schema,
Required("alpine_bits_auth"): basic_auth_schema,
},
extra=PREVENT_EXTRA,
)
DEFAULT_CONFIG_FILE = "config.yaml"
class Config:
def __init__(self, config_folder: str | Path = None, config_name: str = DEFAULT_CONFIG_FILE, testing_mode: bool = False):
def __init__(
self,
config_folder: str | Path = None,
config_name: str = DEFAULT_CONFIG_FILE,
testing_mode: bool = False,
):
if config_folder is None:
config_folder = os.environ.get('ALPINEBITS_CONFIG_DIR')
config_folder = os.environ.get("ALPINEBITS_CONFIG_DIR")
if not config_folder:
config_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../config'))
config_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../config")
)
if isinstance(config_folder, str):
config_folder = Path(config_folder)
self.config_folder = config_folder
@@ -61,8 +68,8 @@ class Config:
validated = config_schema(stuff)
except MultipleInvalid as e:
raise ValueError(f"Config validation error: {e}")
self.database = validated['database']
self.basic_auth = validated['alpine_bits_auth']
self.database = validated["database"]
self.basic_auth = validated["alpine_bits_auth"]
self.config = validated
def get(self, key, default=None):
@@ -70,19 +77,20 @@ class Config:
@property
def db_url(self) -> str:
return self.database['url']
return self.database["url"]
@property
def hotel_id(self) -> str:
return self.basic_auth['hotel_id']
return self.basic_auth["hotel_id"]
@property
def hotel_name(self) -> str:
return self.basic_auth['hotel_name']
return self.basic_auth["hotel_name"]
@property
def users(self) -> List[Dict[str, str]]:
return self.basic_auth['users']
return self.basic_auth["users"]
# For backward compatibility
def load_config():

View File

@@ -5,27 +5,24 @@ import os
Base = declarative_base()
# Async SQLAlchemy setup
def get_database_url(config=None):
db_url = None
if config and 'database' in config and 'url' in config['database']:
db_url = config['database']['url']
if config and "database" in config and "url" in config["database"]:
db_url = config["database"]["url"]
if not db_url:
db_url = os.environ.get('DATABASE_URL')
db_url = os.environ.get("DATABASE_URL")
if not db_url:
db_url = 'sqlite+aiosqlite:///alpinebits.db'
db_url = "sqlite+aiosqlite:///alpinebits.db"
return db_url
DATABASE_URL = get_database_url()
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_async_session():
async with AsyncSessionLocal() as session:
yield session
class Customer(Base):
__tablename__ = 'customers'
__tablename__ = "customers"
id = Column(Integer, primary_key=True)
given_name = Column(String)
contact_id = Column(String, unique=True)
@@ -42,13 +39,14 @@ class Customer(Base):
birth_date = Column(String)
language = Column(String)
address_catalog = Column(Boolean) # Added for XML
name_title = Column(String) # Added for XML
reservations = relationship('Reservation', back_populates='customer')
name_title = Column(String) # Added for XML
reservations = relationship("Reservation", back_populates="customer")
class Reservation(Base):
__tablename__ = 'reservations'
__tablename__ = "reservations"
id = Column(Integer, primary_key=True)
customer_id = Column(Integer, ForeignKey('customers.id'))
customer_id = Column(Integer, ForeignKey("customers.id"))
form_id = Column(String, unique=True)
start_date = Column(Date)
end_date = Column(Date)
@@ -70,16 +68,14 @@ class Reservation(Base):
# Add hotel_code and hotel_name for XML
hotel_code = Column(String)
hotel_name = Column(String)
customer = relationship('Customer', back_populates='reservations')
customer = relationship("Customer", back_populates="reservations")
class HashedCustomer(Base):
__tablename__ = 'hashed_customers'
__tablename__ = "hashed_customers"
id = Column(Integer, primary_key=True)
customer_id = Column(Integer)
hashed_email = Column(String)
hashed_phone = Column(String)
hashed_name = Column(String)
redacted_at = Column(DateTime)

View File

@@ -15,11 +15,16 @@ from .simplified_access import (
HotelReservationIdData,
PhoneTechType,
AlpineBitsFactory,
OtaMessageType
OtaMessageType,
)
# DB and config
from .db import Customer as DBCustomer, Reservation as DBReservation, HashedCustomer, get_async_session
from .db import (
Customer as DBCustomer,
Reservation as DBReservation,
HashedCustomer,
get_async_session,
)
from .config_loader import load_config
import hashlib
import json
@@ -29,8 +34,8 @@ import asyncio
from alpine_bits_python import db
async def main():
async def main():
print("🚀 Starting AlpineBits XML generation script...")
# Load config (yaml, annotatedyaml)
config = load_config()
@@ -40,9 +45,9 @@ async def main():
print(json.dumps(config, indent=2))
# Ensure SQLite DB file exists if using SQLite
db_url = config.get('database', {}).get('url', '')
if db_url.startswith('sqlite+aiosqlite:///'):
db_path = db_url.replace('sqlite+aiosqlite:///', '')
db_url = config.get("database", {}).get("url", "")
if db_url.startswith("sqlite+aiosqlite:///"):
db_path = db_url.replace("sqlite+aiosqlite:///", "")
db_path = os.path.abspath(db_path)
db_dir = os.path.dirname(db_path)
if not os.path.exists(db_dir):
@@ -54,15 +59,17 @@ async def main():
# # Ensure DB schema is created (async)
from .db import engine, Base
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async for db in get_async_session():
# Load data from JSON file
json_path = os.path.join(os.path.dirname(__file__), '../../test_data/wix_test_data_20250928_132611.json')
with open(json_path, 'r', encoding='utf-8') as f:
json_path = os.path.join(
os.path.dirname(__file__),
"../../test_data/wix_test_data_20250928_132611.json",
)
with open(json_path, "r", encoding="utf-8") as f:
wix_data = json.load(f)
data = wix_data["data"]["data"]
@@ -85,8 +92,16 @@ async def main():
language = data.get("contact", {}).get("locale", "en")[:2]
# Dates
start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value")
end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value")
start_date = (
data.get("field:date_picker_a7c8")
or data.get("Anreisedatum")
or data.get("submissions", [{}])[1].get("value")
)
end_date = (
data.get("field:date_picker_7e65")
or data.get("Abreisedatum")
or data.get("submissions", [{}])[2].get("value")
)
# Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2)
@@ -100,7 +115,7 @@ async def main():
children_ages.append(age)
except ValueError:
logging.warning(f"Invalid age value for {k}: {data[k]}")
# UTM and offer
utm_fields = [
("utm_Source", "utm_source"),
@@ -147,7 +162,7 @@ async def main():
end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults,
num_children=num_children,
children_ages=','.join(str(a) for a in children_ages),
children_ages=",".join(str(a) for a in children_ages),
offer=offer,
utm_comment=utm_comment,
created_at=datetime.now(timezone.utc),
@@ -177,9 +192,19 @@ async def main():
def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
from .simplified_access import CustomerData, GuestCountsFactory, HotelReservationIdData, AlpineBitsFactory, OtaMessageType, CommentData, CommentsData, CommentListItemData
from .simplified_access import (
CustomerData,
GuestCountsFactory,
HotelReservationIdData,
AlpineBitsFactory,
OtaMessageType,
CommentData,
CommentsData,
CommentListItemData,
)
from .generated import alpinebits as ab
from datetime import datetime, timezone
# Prepare data for XML
phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else []
customer_data = CustomerData(
@@ -200,11 +225,15 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
language=customer.language,
)
alpine_bits_factory = AlpineBitsFactory()
res_guests = alpine_bits_factory.create_res_guests(customer_data, OtaMessageType.RETRIEVE)
res_guests = alpine_bits_factory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
# Guest counts
children_ages = [int(a) for a in reservation.children_ages.split(",") if a]
guest_counts = GuestCountsFactory.create_retrieve_guest_counts(reservation.num_adults, children_ages)
guest_counts = GuestCountsFactory.create_retrieve_guest_counts(
reservation.num_adults, children_ages
)
# UniqueID
unique_id = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId(
@@ -214,11 +243,13 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
# TimeSpan
time_span = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan(
start=reservation.start_date.isoformat() if reservation.start_date else None,
end=reservation.end_date.isoformat() if reservation.end_date else None
end=reservation.end_date.isoformat() if reservation.end_date else None,
)
room_stay = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay(
time_span=time_span,
guest_counts=guest_counts,
room_stay = (
ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay(
time_span=time_span,
guest_counts=guest_counts,
)
)
room_stays = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays(
room_stay=[room_stay],
@@ -231,7 +262,9 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
res_id_source=None,
res_id_source_context="99tales",
)
hotel_res_id = alpine_bits_factory.create(hotel_res_id_data, OtaMessageType.RETRIEVE)
hotel_res_id = alpine_bits_factory.create(
hotel_res_id_data, OtaMessageType.RETRIEVE
)
hotel_res_ids = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds(
hotel_reservation_id=[hotel_res_id]
)
@@ -244,31 +277,37 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
offer_comment = CommentData(
name=ab.CommentName2.ADDITIONAL_INFO,
text="Angebot/Offerta",
list_items=[CommentListItemData(
value=reservation.offer,
language=customer.language,
list_item="1",
)],
list_items=[
CommentListItemData(
value=reservation.offer,
language=customer.language,
list_item="1",
)
],
)
comment = None
if reservation.user_comment:
comment = CommentData(
name=ab.CommentName2.CUSTOMER_COMMENT,
text=reservation.user_comment,
list_items=[CommentListItemData(
value="Landing page comment",
language=customer.language,
list_item="1",
)],
list_items=[
CommentListItemData(
value="Landing page comment",
language=customer.language,
list_item="1",
)
],
)
comments = [offer_comment, comment] if comment else [offer_comment]
comments_data = CommentsData(comments=comments)
comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE)
res_global_info = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo(
hotel_reservation_ids=hotel_res_ids,
basic_property_info=basic_property_info,
comments=comments_xml,
res_global_info = (
ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo(
hotel_reservation_ids=hotel_res_ids,
basic_property_info=basic_property_info,
comments=comments_xml,
)
)
hotel_reservation = ab.OtaResRetrieveRs.ReservationsList.HotelReservation(
@@ -293,6 +332,7 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("✅ Pydantic validation successful!")
from xsdata.formats.dataclass.serializers.config import SerializerConfig
from xsdata_pydantic.bindings import XmlSerializer
config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8"
)
@@ -306,15 +346,18 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("\n📄 Generated XML:")
print(xml_string)
from xsdata_pydantic.bindings import XmlParser
parser = XmlParser()
with open("output.xml", "r", encoding="utf-8") as infile:
xml_content = infile.read()
parsed_result = parser.from_string(xml_content, ab.OtaResRetrieveRs)
print("✅ Round-trip validation successful!")
print(f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}")
print(
f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}"
)
except Exception as e:
print(f"❌ Validation/Serialization failed: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -5,18 +5,23 @@ from datetime import datetime
class AlpineBitsHandshakeRequest(BaseModel):
"""Model for AlpineBits handshake request data"""
action: str = Field(..., description="Action parameter, typically 'OTA_Ping:Handshaking'")
action: str = Field(
..., description="Action parameter, typically 'OTA_Ping:Handshaking'"
)
request_xml: Optional[str] = Field(None, description="XML request document")
class ContactName(BaseModel):
"""Contact name structure"""
first: Optional[str] = None
last: Optional[str] = None
class ContactAddress(BaseModel):
"""Contact address structure"""
street: Optional[str] = None
city: Optional[str] = None
state: Optional[str] = None
@@ -26,6 +31,7 @@ class ContactAddress(BaseModel):
class Contact(BaseModel):
"""Contact information from Wix form"""
name: Optional[ContactName] = None
email: Optional[str] = None
locale: Optional[str] = None
@@ -43,12 +49,14 @@ class Contact(BaseModel):
class SubmissionPdf(BaseModel):
"""PDF submission structure"""
url: Optional[str] = None
filename: Optional[str] = None
class WixFormSubmission(BaseModel):
"""Model for Wix form submission data"""
formName: str
submissions: List[Dict[str, Any]] = Field(default_factory=list)
submissionTime: str
@@ -59,7 +67,7 @@ class WixFormSubmission(BaseModel):
submissionPdf: Optional[SubmissionPdf] = None
formId: str
contact: Optional[Contact] = None
# Dynamic form fields - these will capture all field:* entries
class Config:
extra = "allow" # Allow additional fields not defined in the model
extra = "allow" # Allow additional fields not defined in the model

View File

@@ -11,11 +11,12 @@ logger = logging.getLogger(__name__)
# Rate limiting configuration
DEFAULT_RATE_LIMIT = "10/minute" # 10 requests per minute per IP
WEBHOOK_RATE_LIMIT = "60/minute" # 60 webhook requests per minute per IP
BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
# Redis configuration for distributed rate limiting (optional)
REDIS_URL = os.getenv("REDIS_URL", None)
def get_remote_address_with_forwarded(request: Request):
"""
Get client IP address, considering forwarded headers from proxies/load balancers
@@ -25,11 +26,11 @@ def get_remote_address_with_forwarded(request: Request):
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to direct connection IP
return get_remote_address(request)
@@ -39,14 +40,16 @@ if REDIS_URL:
# Use Redis for distributed rate limiting (recommended for production)
try:
import redis
redis_client = redis.from_url(REDIS_URL)
limiter = Limiter(
key_func=get_remote_address_with_forwarded,
storage_uri=REDIS_URL
key_func=get_remote_address_with_forwarded, storage_uri=REDIS_URL
)
logger.info("Rate limiting initialized with Redis backend")
except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}. Using in-memory rate limiting.")
logger.warning(
f"Failed to connect to Redis: {e}. Using in-memory rate limiting."
)
limiter = Limiter(key_func=get_remote_address_with_forwarded)
else:
# Use in-memory rate limiting (fine for single instance)
@@ -65,7 +68,7 @@ def get_api_key_identifier(request: Request) -> str:
api_key = auth_header[7:] # Remove "Bearer " prefix
# Use first 10 chars of API key as identifier (don't log full key)
return f"api_key:{api_key[:10]}"
# Fallback to IP address
return f"ip:{get_remote_address_with_forwarded(request)}"
@@ -77,10 +80,10 @@ def api_key_rate_limit_key(request: Request):
# Rate limiting decorators for different endpoint types
webhook_limiter = Limiter(
key_func=api_key_rate_limit_key,
storage_uri=REDIS_URL if REDIS_URL else None
key_func=api_key_rate_limit_key, storage_uri=REDIS_URL if REDIS_URL else None
)
# Custom rate limit exceeded handler
def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Custom handler for rate limit exceeded"""
@@ -88,11 +91,11 @@ def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
f"Rate limit exceeded for {get_remote_address_with_forwarded(request)}: "
f"{exc.detail}"
)
response = _rate_limit_exceeded_handler(request, exc)
# Add custom headers
response.headers["X-RateLimit-Limit"] = str(exc.retry_after)
response.headers["X-RateLimit-Retry-After"] = str(exc.retry_after)
return response
return response

View File

@@ -1,7 +1,2 @@
def parse_form(form: dict):
pass

View File

@@ -2,14 +2,21 @@
"""
Startup script for the Wix Form Handler API
"""
import os
import uvicorn
from .api import app
if __name__ == "__main__":
db_path = "alpinebits.db" # Adjust path if needed
if os.path.exists(db_path):
os.remove(db_path)
print(f"Deleted database file: {db_path}")
uvicorn.run(
"alpine_bits_python.api:app",
host="0.0.0.0",
port=8080,
reload=True, # Enable auto-reload during development
log_level="info"
)
log_level="info",
)

View File

@@ -2,6 +2,7 @@
"""
Configuration and setup script for the Wix Form Handler API
"""
import os
import sys
import secrets
@@ -11,80 +12,83 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from alpine_bits_python.auth import generate_api_key
def generate_secure_keys():
"""Generate secure API keys for the application"""
print("🔐 Generating Secure API Keys")
print("=" * 50)
# Generate API keys
wix_api_key = generate_api_key()
admin_api_key = generate_api_key()
webhook_secret = secrets.token_urlsafe(32)
print(f"🔑 Wix Webhook API Key: {wix_api_key}")
print(f"🔐 Admin API Key: {admin_api_key}")
print(f"🔒 Webhook Secret: {webhook_secret}")
print("\n📋 Environment Variables")
print("-" * 30)
print(f"export WIX_API_KEY='{wix_api_key}'")
print(f"export ADMIN_API_KEY='{admin_api_key}'")
print(f"export WIX_WEBHOOK_SECRET='{webhook_secret}'")
print(f"export REDIS_URL='redis://localhost:6379' # Optional for production")
print("\n🔧 .env File Content")
print("-" * 20)
print(f"WIX_API_KEY={wix_api_key}")
print(f"ADMIN_API_KEY={admin_api_key}")
print(f"WIX_WEBHOOK_SECRET={webhook_secret}")
print("REDIS_URL=redis://localhost:6379")
# Optionally write to .env file
create_env = input("\n❓ Create .env file? (y/n): ").lower().strip()
if create_env == 'y':
if create_env == "y":
# Create .env in the project root (two levels up from scripts)
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env')
with open(env_path, 'w') as f:
env_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env"
)
with open(env_path, "w") as f:
f.write(f"WIX_API_KEY={wix_api_key}\n")
f.write(f"ADMIN_API_KEY={admin_api_key}\n")
f.write(f"WIX_WEBHOOK_SECRET={webhook_secret}\n")
f.write("REDIS_URL=redis://localhost:6379\n")
print(f"✅ .env file created at {env_path}!")
print("⚠️ Add .env to your .gitignore file!")
print("\n🌐 Wix Configuration")
print("-" * 20)
print("1. In your Wix site, go to Settings > Webhooks")
print("2. Add webhook URL: https://yourdomain.com/webhook/wix-form")
print("3. Add custom header: Authorization: Bearer " + wix_api_key)
print("4. Optionally configure webhook signature with the secret above")
return {
'wix_api_key': wix_api_key,
'admin_api_key': admin_api_key,
'webhook_secret': webhook_secret
"wix_api_key": wix_api_key,
"admin_api_key": admin_api_key,
"webhook_secret": webhook_secret,
}
def check_security_setup():
"""Check current security configuration"""
print("🔍 Security Configuration Check")
print("=" * 40)
# Check environment variables
wix_key = os.getenv('WIX_API_KEY')
admin_key = os.getenv('ADMIN_API_KEY')
webhook_secret = os.getenv('WIX_WEBHOOK_SECRET')
redis_url = os.getenv('REDIS_URL')
wix_key = os.getenv("WIX_API_KEY")
admin_key = os.getenv("ADMIN_API_KEY")
webhook_secret = os.getenv("WIX_WEBHOOK_SECRET")
redis_url = os.getenv("REDIS_URL")
print("Environment Variables:")
print(f" WIX_API_KEY: {'✅ Set' if wix_key else '❌ Not set'}")
print(f" ADMIN_API_KEY: {'✅ Set' if admin_key else '❌ Not set'}")
print(f" WIX_WEBHOOK_SECRET: {'✅ Set' if webhook_secret else '❌ Not set'}")
print(f" REDIS_URL: {'✅ Set' if redis_url else '⚠️ Optional (using in-memory)'}")
# Security recommendations
print("\n🛡️ Security Recommendations:")
if not wix_key:
@@ -94,19 +98,19 @@ def check_security_setup():
print(" ⚠️ WIX_API_KEY should be longer for better security")
else:
print(" ✅ WIX_API_KEY looks secure")
if not admin_key:
print(" ❌ Set ADMIN_API_KEY environment variable")
elif wix_key and admin_key == wix_key:
print(" ❌ Admin and Wix keys should be different")
else:
print(" ✅ ADMIN_API_KEY configured")
if not webhook_secret:
print(" ⚠️ Consider setting WIX_WEBHOOK_SECRET for signature validation")
else:
print(" ✅ Webhook signature validation enabled")
print("\n🚀 Production Checklist:")
print(" - Use HTTPS in production")
print(" - Set up Redis for distributed rate limiting")
@@ -118,12 +122,14 @@ def check_security_setup():
if __name__ == "__main__":
print("🔐 Wix Form Handler API - Security Setup")
print("=" * 50)
choice = input("Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): ").strip()
choice = input(
"Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): "
).strip()
if choice == "1":
generate_secure_keys()
elif choice == "2":
check_security_setup()
else:
print("Invalid choice. Please run again and choose 1 or 2.")
print("Invalid choice. Please run again and choose 1 or 2.")

View File

@@ -2,6 +2,7 @@
"""
Test script for the Secure Wix Form Handler API
"""
import asyncio
import aiohttp
import json
@@ -30,7 +31,7 @@ SAMPLE_WIX_DATA = {
"submissionsLink": "https://www.wix.app/forms/test-form/submissions",
"submissionPdf": {
"url": "https://example.com/submission.pdf",
"filename": "submission.pdf"
"filename": "submission.pdf",
},
"formId": "test-form-789",
"field:email_5139": "test@example.com",
@@ -43,10 +44,7 @@ SAMPLE_WIX_DATA = {
"field:alter_kind_4": "12",
"field:long_answer_3524": "This is a long answer field with more details about the inquiry.",
"contact": {
"name": {
"first": "John",
"last": "Doe"
},
"name": {"first": "John", "last": "Doe"},
"email": "test@example.com",
"locale": "de",
"company": "Test Company",
@@ -57,29 +55,29 @@ SAMPLE_WIX_DATA = {
"street": "Test Street 123",
"city": "Test City",
"country": "Germany",
"postalCode": "12345"
"postalCode": "12345",
},
"jobTitle": "Manager",
"phone": "+1234567890",
"createdDate": "2024-03-20T10:00:00.000Z",
"updatedDate": "2024-03-20T10:30:00.000Z"
}
"updatedDate": "2024-03-20T10:30:00.000Z",
},
}
async def test_api():
"""Test the API endpoints with authentication"""
headers_with_auth = {
"Content-Type": "application/json",
"Authorization": f"Bearer {TEST_API_KEY}"
"Authorization": f"Bearer {TEST_API_KEY}",
}
admin_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {ADMIN_API_KEY}"
"Content-Type": "application/json",
"Authorization": f"Bearer {ADMIN_API_KEY}",
}
async with aiohttp.ClientSession() as session:
# Test health endpoint (no auth required)
print("1. Testing health endpoint (no auth)...")
@@ -89,7 +87,7 @@ async def test_api():
print(f" ✅ Health check: {response.status} - {result.get('status')}")
except Exception as e:
print(f" ❌ Health check failed: {e}")
# Test root endpoint (no auth required)
print("\n2. Testing root endpoint (no auth)...")
try:
@@ -98,87 +96,94 @@ async def test_api():
print(f" ✅ Root: {response.status} - {result.get('message')}")
except Exception as e:
print(f" ❌ Root endpoint failed: {e}")
# Test webhook endpoint without auth (should fail)
print("\n3. Testing webhook endpoint WITHOUT auth (should fail)...")
try:
async with session.post(
f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA,
headers={"Content-Type": "application/json"}
headers={"Content-Type": "application/json"},
) as response:
result = await response.json()
if response.status == 401:
print(f" ✅ Correctly rejected: {response.status} - {result.get('detail')}")
print(
f" ✅ Correctly rejected: {response.status} - {result.get('detail')}"
)
else:
print(f" ❌ Unexpected response: {response.status} - {result}")
except Exception as e:
print(f" ❌ Test failed: {e}")
# Test webhook endpoint with valid auth
print("\n4. Testing webhook endpoint WITH valid auth...")
try:
async with session.post(
f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA,
headers=headers_with_auth
headers=headers_with_auth,
) as response:
result = await response.json()
if response.status == 200:
print(f" ✅ Webhook success: {response.status} - {result.get('status')}")
print(
f" ✅ Webhook success: {response.status} - {result.get('status')}"
)
else:
print(f" ❌ Webhook failed: {response.status} - {result}")
except Exception as e:
print(f" ❌ Webhook test failed: {e}")
# Test test endpoint with auth
print("\n5. Testing simple test endpoint WITH auth...")
try:
async with session.post(
f"{BASE_URL}/api/webhook/wix-form/test",
json={"test": "data", "timestamp": datetime.now().isoformat()},
headers=headers_with_auth
headers=headers_with_auth,
) as response:
result = await response.json()
if response.status == 200:
print(f" ✅ Test endpoint: {response.status} - {result.get('status')}")
print(
f" ✅ Test endpoint: {response.status} - {result.get('status')}"
)
else:
print(f" ❌ Test endpoint failed: {response.status} - {result}")
except Exception as e:
print(f" ❌ Test endpoint failed: {e}")
# Test rate limiting by making multiple rapid requests
print("\n6. Testing rate limiting (making 5 rapid requests)...")
rate_limit_test_count = 0
for i in range(5):
try:
async with session.get(
f"{BASE_URL}/api/health"
) as response:
async with session.get(f"{BASE_URL}/api/health") as response:
if response.status == 200:
rate_limit_test_count += 1
elif response.status == 429:
print(f" ✅ Rate limit triggered on request {i+1}")
print(f" ✅ Rate limit triggered on request {i + 1}")
break
except Exception as e:
print(f" ❌ Rate limit test failed: {e}")
break
if rate_limit_test_count == 5:
print(" No rate limit reached (normal for low request volume)")
# Test admin endpoint (if admin key is configured)
print("\n7. Testing admin stats endpoint...")
try:
async with session.get(
f"{BASE_URL}/api/admin/stats",
headers=admin_headers
f"{BASE_URL}/api/admin/stats", headers=admin_headers
) as response:
result = await response.json()
if response.status == 200:
print(f" ✅ Admin stats: {response.status} - {result.get('status')}")
print(
f" ✅ Admin stats: {response.status} - {result.get('status')}"
)
elif response.status == 401:
print(f" ⚠️ Admin access denied (API key not configured): {result.get('detail')}")
print(
f" ⚠️ Admin access denied (API key not configured): {result.get('detail')}"
)
else:
print(f" ❌ Admin endpoint failed: {response.status} - {result}")
except Exception as e:
@@ -189,12 +194,18 @@ if __name__ == "__main__":
print("🔒 Testing Secure Wix Form Handler API...")
print("=" * 60)
print("📍 API URL:", BASE_URL)
print("🔑 Using API Key:", TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY)
print("🔐 Using Admin Key:", ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY)
print(
"🔑 Using API Key:",
TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY,
)
print(
"🔐 Using Admin Key:",
ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY,
)
print("=" * 60)
print("Make sure the API is running with: python3 run_api.py")
print("-" * 60)
try:
asyncio.run(test_api())
print("\n" + "=" * 60)
@@ -207,4 +218,4 @@ if __name__ == "__main__":
print("3. Add Authorization header: Bearer your_api_key")
except Exception as e:
print(f"\n❌ Error testing API: {e}")
print("Make sure the API server is running!")
print("Make sure the API server is running!")

View File

@@ -15,15 +15,26 @@ NotifHotelReservationId = OtaHotelResNotifRq.HotelReservations.HotelReservation.
RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId
# Define type aliases for Comments types
NotifComments = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments
RetrieveComments = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments
NotifComment = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment
RetrieveComment = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment
NotifComments = (
OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments
)
RetrieveComments = (
OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments
)
NotifComment = (
OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment
)
RetrieveComment = (
OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment
)
# type aliases for GuestCounts
NotifGuestCounts = OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts
RetrieveGuestCounts = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts
NotifGuestCounts = (
OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts
)
RetrieveGuestCounts = (
OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts
)
# phonetechtype enum 1,3,5 voice, fax, mobile
@@ -36,12 +47,13 @@ class PhoneTechType(Enum):
# Enum to specify which OTA message type to use
class OtaMessageType(Enum):
NOTIF = "notification" # For OtaHotelResNotifRq
RETRIEVE = "retrieve" # For OtaResRetrieveRs
RETRIEVE = "retrieve" # For OtaResRetrieveRs
@dataclass
class KidsAgeData:
"""Data class to hold information about children's ages."""
ages: list[int]
@@ -77,9 +89,10 @@ class CustomerData:
class GuestCountsFactory:
@staticmethod
def create_notif_guest_counts(adults: int, kids: Optional[list[int]] = None) -> NotifGuestCounts:
def create_notif_guest_counts(
adults: int, kids: Optional[list[int]] = None
) -> NotifGuestCounts:
"""
Create a GuestCounts object for OtaHotelResNotifRq.
:param adults: Number of adults
@@ -89,18 +102,23 @@ class GuestCountsFactory:
return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts)
@staticmethod
def create_retrieve_guest_counts(adults: int, kids: Optional[list[int]] = None) -> RetrieveGuestCounts:
def create_retrieve_guest_counts(
adults: int, kids: Optional[list[int]] = None
) -> RetrieveGuestCounts:
"""
Create a GuestCounts object for OtaResRetrieveRs.
:param adults: Number of adults
:param kids: List of ages for each kid (optional)
:return: GuestCounts instance
"""
return GuestCountsFactory._create_guest_counts(adults, kids, RetrieveGuestCounts)
return GuestCountsFactory._create_guest_counts(
adults, kids, RetrieveGuestCounts
)
@staticmethod
def _create_guest_counts(adults: int, kids: Optional[list[int]], guest_counts_class: type) -> Any:
def _create_guest_counts(
adults: int, kids: Optional[list[int]], guest_counts_class: type
) -> Any:
"""
Internal method to create a GuestCounts object of the specified type.
:param adults: Number of adults
@@ -356,9 +374,10 @@ class HotelReservationIdFactory:
)
@dataclass
@dataclass
class CommentListItemData:
"""Simple data class to hold comment list item information."""
value: str # The text content of the list item
list_item: str # Numeric identifier (pattern: [0-9]+)
language: str # Two-letter language code (pattern: [a-z][a-z])
@@ -367,6 +386,7 @@ class CommentListItemData:
@dataclass
class CommentData:
"""Simple data class to hold comment information without nested type constraints."""
name: CommentName2 # Required: "included services", "customer comment", "additional info"
text: Optional[str] = None # Optional text content
list_items: list[CommentListItemData] = None # Optional list items
@@ -379,6 +399,7 @@ class CommentData:
@dataclass
class CommentsData:
"""Simple data class to hold multiple comments (1-3 max)."""
comments: list[CommentData] = None # 1-3 comments maximum
def __post_init__(self):
@@ -388,21 +409,23 @@ class CommentsData:
class CommentFactory:
"""Factory class to create Comment instances for both OtaHotelResNotifRq and OtaResRetrieveRs."""
@staticmethod
def create_notif_comments(data: CommentsData) -> NotifComments:
"""Create Comments for OtaHotelResNotifRq."""
return CommentFactory._create_comments(NotifComments, NotifComment, data)
@staticmethod
def create_retrieve_comments(data: CommentsData) -> RetrieveComments:
"""Create Comments for OtaResRetrieveRs."""
return CommentFactory._create_comments(RetrieveComments, RetrieveComment, data)
@staticmethod
def _create_comments(comments_class: type, comment_class: type, data: CommentsData) -> Any:
def _create_comments(
comments_class: type, comment_class: type, data: CommentsData
) -> Any:
"""Internal method to create comments of the specified type."""
comments_list = []
for comment_data in data.comments:
# Create list items
@@ -411,55 +434,53 @@ class CommentFactory:
list_item = comment_class.ListItem(
value=item_data.value,
list_item=item_data.list_item,
language=item_data.language
language=item_data.language,
)
list_items.append(list_item)
# Create comment
comment = comment_class(
name=comment_data.name,
text=comment_data.text,
list_item=list_items
name=comment_data.name, text=comment_data.text, list_item=list_items
)
comments_list.append(comment)
# Create comments container
return comments_class(comment=comments_list)
@staticmethod
def from_notif_comments(comments: NotifComments) -> CommentsData:
"""Convert NotifComments back to CommentsData."""
return CommentFactory._comments_to_data(comments)
@staticmethod
def from_retrieve_comments(comments: RetrieveComments) -> CommentsData:
"""Convert RetrieveComments back to CommentsData."""
return CommentFactory._comments_to_data(comments)
@staticmethod
def _comments_to_data(comments: Any) -> CommentsData:
"""Internal method to convert any comments type to CommentsData."""
comments_data_list = []
for comment in comments.comment:
# Extract list items
list_items_data = []
if comment.list_item:
for list_item in comment.list_item:
list_items_data.append(CommentListItemData(
value=list_item.value,
list_item=list_item.list_item,
language=list_item.language
))
list_items_data.append(
CommentListItemData(
value=list_item.value,
list_item=list_item.list_item,
language=list_item.language,
)
)
# Extract comment data
comment_data = CommentData(
name=comment.name,
text=comment.text,
list_items=list_items_data
name=comment.name, text=comment.text, list_items=list_items_data
)
comments_data_list.append(comment_data)
return CommentsData(comments=comments_data_list)
@@ -529,16 +550,19 @@ class ResGuestFactory:
class AlpineBitsFactory:
"""Unified factory class for creating AlpineBits objects with a simple interface."""
@staticmethod
def create(data: Union[CustomerData, HotelReservationIdData, CommentsData], message_type: OtaMessageType) -> Any:
def create(
data: Union[CustomerData, HotelReservationIdData, CommentsData],
message_type: OtaMessageType,
) -> Any:
"""
Create an AlpineBits object based on the data type and message type.
Args:
data: The data object (CustomerData, HotelReservationIdData, CommentsData, etc.)
message_type: Whether to create for NOTIF or RETRIEVE message types
Returns:
The appropriate AlpineBits object based on the data type and message type
"""
@@ -547,31 +571,35 @@ class AlpineBitsFactory:
return CustomerFactory.create_notif_customer(data)
else:
return CustomerFactory.create_retrieve_customer(data)
elif isinstance(data, HotelReservationIdData):
if message_type == OtaMessageType.NOTIF:
return HotelReservationIdFactory.create_notif_hotel_reservation_id(data)
else:
return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(data)
return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
data
)
elif isinstance(data, CommentsData):
if message_type == OtaMessageType.NOTIF:
return CommentFactory.create_notif_comments(data)
else:
return CommentFactory.create_retrieve_comments(data)
else:
raise ValueError(f"Unsupported data type: {type(data)}")
@staticmethod
def create_res_guests(customer_data: CustomerData, message_type: OtaMessageType) -> Union[NotifResGuests, RetrieveResGuests]:
def create_res_guests(
customer_data: CustomerData, message_type: OtaMessageType
) -> Union[NotifResGuests, RetrieveResGuests]:
"""
Create a complete ResGuests structure with a primary customer.
Args:
customer_data: The customer data
message_type: Whether to create for NOTIF or RETRIEVE message types
Returns:
The appropriate ResGuests object
"""
@@ -579,43 +607,45 @@ class AlpineBitsFactory:
return ResGuestFactory.create_notif_res_guests(customer_data)
else:
return ResGuestFactory.create_retrieve_res_guests(customer_data)
@staticmethod
def extract_data(obj: Any) -> Union[CustomerData, HotelReservationIdData, CommentsData]:
def extract_data(
obj: Any,
) -> Union[CustomerData, HotelReservationIdData, CommentsData]:
"""
Extract data from an AlpineBits object back to a simple data class.
Args:
obj: The AlpineBits object to extract data from
Returns:
The appropriate data object
"""
# Check if it's a Customer object
if hasattr(obj, 'person_name') and hasattr(obj.person_name, 'given_name'):
if hasattr(obj, "person_name") and hasattr(obj.person_name, "given_name"):
if isinstance(obj, NotifCustomer):
return CustomerFactory.from_notif_customer(obj)
elif isinstance(obj, RetrieveCustomer):
return CustomerFactory.from_retrieve_customer(obj)
# Check if it's a HotelReservationId object
elif hasattr(obj, 'res_id_type'):
elif hasattr(obj, "res_id_type"):
if isinstance(obj, NotifHotelReservationId):
return HotelReservationIdFactory.from_notif_hotel_reservation_id(obj)
elif isinstance(obj, RetrieveHotelReservationId):
return HotelReservationIdFactory.from_retrieve_hotel_reservation_id(obj)
# Check if it's a Comments object
elif hasattr(obj, 'comment'):
elif hasattr(obj, "comment"):
if isinstance(obj, NotifComments):
return CommentFactory.from_notif_comments(obj)
elif isinstance(obj, RetrieveComments):
return CommentFactory.from_retrieve_comments(obj)
# Check if it's a ResGuests object
elif hasattr(obj, 'res_guest'):
elif hasattr(obj, "res_guest"):
return ResGuestFactory.extract_primary_customer(obj)
else:
raise ValueError(f"Unsupported object type: {type(obj)}")
@@ -733,70 +763,74 @@ if __name__ == "__main__":
# Verify roundtrip conversion
print("Roundtrip conversion successful:", customer_data == extracted_data)
print("\n--- Unified AlpineBitsFactory Examples ---")
# Much simpler approach - single factory with enum parameter!
print("=== Customer Creation ===")
notif_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF)
retrieve_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE)
print("Created customers using unified factory")
print("=== HotelReservationId Creation ===")
reservation_id_data = HotelReservationIdData(
res_id_type="123",
res_id_value="RESERVATION-456",
res_id_source="HOTEL_SYSTEM"
res_id_type="123", res_id_value="RESERVATION-456", res_id_source="HOTEL_SYSTEM"
)
notif_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.NOTIF)
retrieve_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.RETRIEVE)
retrieve_res_id = AlpineBitsFactory.create(
reservation_id_data, OtaMessageType.RETRIEVE
)
print("Created reservation IDs using unified factory")
print("=== Comments Creation ===")
comments_data = CommentsData(comments=[
CommentData(
name=CommentName2.CUSTOMER_COMMENT,
text="This is a customer comment about the reservation",
list_items=[
CommentListItemData(
value="Special dietary requirements: vegetarian",
list_item="1",
language="en"
),
CommentListItemData(
value="Late arrival expected",
list_item="2",
language="en"
)
]
),
CommentData(
name=CommentName2.ADDITIONAL_INFO,
text="Additional information about the stay"
)
])
comments_data = CommentsData(
comments=[
CommentData(
name=CommentName2.CUSTOMER_COMMENT,
text="This is a customer comment about the reservation",
list_items=[
CommentListItemData(
value="Special dietary requirements: vegetarian",
list_item="1",
language="en",
),
CommentListItemData(
value="Late arrival expected", list_item="2", language="en"
),
],
),
CommentData(
name=CommentName2.ADDITIONAL_INFO,
text="Additional information about the stay",
),
]
)
notif_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.NOTIF)
retrieve_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.RETRIEVE)
print("Created comments using unified factory")
print("=== ResGuests Creation ===")
notif_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE)
notif_res_guests = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.NOTIF
)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
print("Created ResGuests using unified factory")
print("=== Data Extraction ===")
# Extract data back using unified interface
extracted_customer_data = AlpineBitsFactory.extract_data(notif_customer)
extracted_res_id_data = AlpineBitsFactory.extract_data(notif_res_id)
extracted_comments_data = AlpineBitsFactory.extract_data(retrieve_comments)
extracted_from_res_guests = AlpineBitsFactory.extract_data(retrieve_res_guests)
print("Data extraction successful:")
print("- Customer roundtrip:", customer_data == extracted_customer_data)
print("- ReservationId roundtrip:", reservation_id_data == extracted_res_id_data)
print("- Comments roundtrip:", comments_data == extracted_comments_data)
print("- ResGuests roundtrip:", customer_data == extracted_from_res_guests)
print("\n--- Comparison with old approach ---")
print("Old way required multiple imports and knowing specific factory methods")
print("New way: single import, single factory, enum parameter to specify type!")

View File

@@ -1 +1 @@
"""Utility functions for alpine_bits_python."""
"""Utility functions for alpine_bits_python."""

View File

@@ -1,5 +1,6 @@
"""Entry point for util package."""
from .handshake_util import main
if __name__ == "__main__":
main()
main()

View File

@@ -2,26 +2,22 @@ from ..generated.alpinebits import OtaPingRq, OtaPingRs
from xsdata_pydantic.bindings import XmlParser
def main():
# test parsing a ping request sample
path = "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml"
path = (
"AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml"
)
with open(
path, "r", encoding="utf-8") as f:
with open(path, "r", encoding="utf-8") as f:
xml = f.read()
# Parse the XML into the request object
# Test parsing back
# Test parsing back
parser = XmlParser()
parsed_result = parser.from_string(xml, OtaPingRs)
print(parsed_result.echo_data)
@@ -34,19 +30,14 @@ def main():
print(warning.content[0])
# save json in echo_data to file with indents
output_path = "echo_data_response.json"
with open(output_path, "w", encoding="utf-8") as out_f:
import json
json.dump(json.loads(parsed_result.echo_data), out_f, indent=4)
print(f"Saved echo_data json to {output_path}")
if __name__ == "__main__":
main()
main()

View File

@@ -2,12 +2,13 @@
"""
Convenience launcher for the Wix Form Handler API
"""
import os
import subprocess
# Change to src directory
src_dir = os.path.join(os.path.dirname(__file__), 'src/alpine_bits_python')
src_dir = os.path.join(os.path.dirname(__file__), "src/alpine_bits_python")
# Run the API using uv
if __name__ == "__main__":
subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")])
subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")])

View File

@@ -5,57 +5,63 @@ discovers implemented vs unimplemented actions.
"""
from alpine_bits_python.alpinebits_server import (
ServerCapabilities,
AlpineBitsAction,
AlpineBitsActionName,
Version,
ServerCapabilities,
AlpineBitsAction,
AlpineBitsActionName,
Version,
AlpineBitsResponse,
HttpStatusCode
HttpStatusCode,
)
import asyncio
class NewImplementedAction(AlpineBitsAction):
"""A new action that IS implemented."""
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_INFO_INFO
self.version = Version.V2024_10
async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
async def handle(
self, action: str, request_xml: str, version: Version
) -> AlpineBitsResponse:
"""This action is implemented."""
return AlpineBitsResponse("Implemented!", HttpStatusCode.OK)
class NewUnimplementedAction(AlpineBitsAction):
"""A new action that is NOT implemented (no handle override)."""
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO
self.version = Version.V2024_10
# Notice: No handle method override - will use default "not implemented"
async def main():
print("🔍 Testing Action Discovery Logic")
print("=" * 50)
# Create capabilities and see what gets discovered
capabilities = ServerCapabilities()
print("📋 Actions found by discovery:")
for action_name in capabilities.get_supported_actions():
print(f"{action_name}")
print(f"\n📊 Total discovered: {len(capabilities.get_supported_actions())}")
# Test the new implemented action
implemented_action = NewImplementedAction()
result = await implemented_action.handle("test", "<xml/>", Version.V2024_10)
print(f"\n🟢 NewImplementedAction result: {result.xml_content}")
# Test the unimplemented action (should use default behavior)
unimplemented_action = NewUnimplementedAction()
result = await unimplemented_action.handle("test", "<xml/>", Version.V2024_10)
print(f"🔴 NewUnimplementedAction result: {result.xml_content}")
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -4,11 +4,11 @@ import sys
import os
# Add the src directory to the path so we can import our modules
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from simplified_access import (
CustomerData,
CustomerFactory,
CustomerData,
CustomerFactory,
ResGuestFactory,
HotelReservationIdData,
HotelReservationIdFactory,
@@ -20,7 +20,7 @@ from simplified_access import (
NotifResGuests,
RetrieveResGuests,
NotifHotelReservationId,
RetrieveHotelReservationId
RetrieveHotelReservationId,
)
@@ -35,7 +35,7 @@ def sample_customer_data():
phone_numbers=[
("+1234567890", PhoneTechType.MOBILE),
("+0987654321", PhoneTechType.VOICE),
("+1111111111", None)
("+1111111111", None),
],
email_address="john.doe@example.com",
email_newsletter=True,
@@ -46,17 +46,14 @@ def sample_customer_data():
address_catalog=False,
gender="Male",
birth_date="1980-01-01",
language="en"
language="en",
)
@pytest.fixture
def minimal_customer_data():
"""Fixture providing minimal customer data (only required fields)."""
return CustomerData(
given_name="Jane",
surname="Smith"
)
return CustomerData(given_name="Jane", surname="Smith")
@pytest.fixture
@@ -66,21 +63,19 @@ def sample_hotel_reservation_id_data():
res_id_type="123",
res_id_value="RESERVATION-456",
res_id_source="HOTEL_SYSTEM",
res_id_source_context="BOOKING_ENGINE"
res_id_source_context="BOOKING_ENGINE",
)
@pytest.fixture
def minimal_hotel_reservation_id_data():
"""Fixture providing minimal hotel reservation ID data (only required fields)."""
return HotelReservationIdData(
res_id_type="999"
)
return HotelReservationIdData(res_id_type="999")
class TestCustomerData:
"""Test the CustomerData dataclass."""
def test_customer_data_creation_full(self, sample_customer_data):
"""Test creating CustomerData with all fields."""
assert sample_customer_data.given_name == "John"
@@ -89,7 +84,7 @@ class TestCustomerData:
assert sample_customer_data.email_address == "john.doe@example.com"
assert sample_customer_data.email_newsletter is True
assert len(sample_customer_data.phone_numbers) == 3
def test_customer_data_creation_minimal(self, minimal_customer_data):
"""Test creating CustomerData with only required fields."""
assert minimal_customer_data.given_name == "Jane"
@@ -97,7 +92,7 @@ class TestCustomerData:
assert minimal_customer_data.phone_numbers == []
assert minimal_customer_data.email_address is None
assert minimal_customer_data.address_line is None
def test_phone_numbers_default_initialization(self):
"""Test that phone_numbers gets initialized to empty list."""
customer_data = CustomerData(given_name="Test", surname="User")
@@ -106,54 +101,56 @@ class TestCustomerData:
class TestCustomerFactory:
"""Test the CustomerFactory class."""
def test_create_notif_customer_full(self, sample_customer_data):
"""Test creating a NotifCustomer with full data."""
customer = CustomerFactory.create_notif_customer(sample_customer_data)
assert isinstance(customer, NotifCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
assert customer.person_name.name_prefix == "Mr."
assert customer.person_name.name_title == "Jr."
# Check telephone
assert len(customer.telephone) == 3
assert customer.telephone[0].phone_number == "+1234567890"
assert customer.telephone[0].phone_tech_type == "5" # MOBILE
assert customer.telephone[1].phone_tech_type == "1" # VOICE
assert customer.telephone[2].phone_tech_type is None
# Check email
assert customer.email.value == "john.doe@example.com"
assert customer.email.remark == "newsletter:yes"
# Check address
assert customer.address.address_line == "123 Main Street"
assert customer.address.city_name == "Anytown"
assert customer.address.postal_code == "12345"
assert customer.address.country_name.code == "US"
assert customer.address.remark == "catalog:no"
# Check other attributes
assert customer.gender == "Male"
assert customer.birth_date == "1980-01-01"
assert customer.language == "en"
def test_create_retrieve_customer_full(self, sample_customer_data):
"""Test creating a RetrieveCustomer with full data."""
customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
assert isinstance(customer, RetrieveCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
# Same structure as NotifCustomer, so we don't need to test all fields again
def test_create_customer_minimal(self, minimal_customer_data):
"""Test creating customers with minimal data."""
notif_customer = CustomerFactory.create_notif_customer(minimal_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(minimal_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(
minimal_customer_data
)
for customer in [notif_customer, retrieve_customer]:
assert customer.person_name.given_name == "Jane"
assert customer.person_name.surname == "Smith"
@@ -165,73 +162,97 @@ class TestCustomerFactory:
assert customer.gender is None
assert customer.birth_date is None
assert customer.language is None
def test_email_newsletter_options(self):
"""Test different email newsletter options."""
# Newsletter yes
data_yes = CustomerData(given_name="Test", surname="User",
email_address="test@example.com", email_newsletter=True)
data_yes = CustomerData(
given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=True,
)
customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.email.remark == "newsletter:yes"
# Newsletter no
data_no = CustomerData(given_name="Test", surname="User",
email_address="test@example.com", email_newsletter=False)
data_no = CustomerData(
given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=False,
)
customer = CustomerFactory.create_notif_customer(data_no)
assert customer.email.remark == "newsletter:no"
# Newsletter not specified
data_none = CustomerData(given_name="Test", surname="User",
email_address="test@example.com", email_newsletter=None)
data_none = CustomerData(
given_name="Test",
surname="User",
email_address="test@example.com",
email_newsletter=None,
)
customer = CustomerFactory.create_notif_customer(data_none)
assert customer.email.remark is None
def test_address_catalog_options(self):
"""Test different address catalog options."""
# Catalog no
data_no = CustomerData(given_name="Test", surname="User",
address_line="123 Street", address_catalog=False)
data_no = CustomerData(
given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=False,
)
customer = CustomerFactory.create_notif_customer(data_no)
assert customer.address.remark == "catalog:no"
# Catalog yes
data_yes = CustomerData(given_name="Test", surname="User",
address_line="123 Street", address_catalog=True)
data_yes = CustomerData(
given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=True,
)
customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.address.remark == "catalog:yes"
# Catalog not specified
data_none = CustomerData(given_name="Test", surname="User",
address_line="123 Street", address_catalog=None)
data_none = CustomerData(
given_name="Test",
surname="User",
address_line="123 Street",
address_catalog=None,
)
customer = CustomerFactory.create_notif_customer(data_none)
assert customer.address.remark is None
def test_from_notif_customer_roundtrip(self, sample_customer_data):
"""Test converting NotifCustomer back to CustomerData."""
customer = CustomerFactory.create_notif_customer(sample_customer_data)
converted_data = CustomerFactory.from_notif_customer(customer)
assert converted_data == sample_customer_data
def test_from_retrieve_customer_roundtrip(self, sample_customer_data):
"""Test converting RetrieveCustomer back to CustomerData."""
customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
converted_data = CustomerFactory.from_retrieve_customer(customer)
assert converted_data == sample_customer_data
def test_phone_tech_type_conversion(self):
"""Test that PhoneTechType enum values are properly converted."""
data = CustomerData(
given_name="Test",
given_name="Test",
surname="User",
phone_numbers=[
("+1111111111", PhoneTechType.VOICE),
("+2222222222", PhoneTechType.FAX),
("+3333333333", PhoneTechType.MOBILE)
]
("+3333333333", PhoneTechType.MOBILE),
],
)
customer = CustomerFactory.create_notif_customer(data)
assert customer.telephone[0].phone_tech_type == "1" # VOICE
assert customer.telephone[1].phone_tech_type == "3" # FAX
@@ -240,15 +261,21 @@ class TestCustomerFactory:
class TestHotelReservationIdData:
"""Test the HotelReservationIdData dataclass."""
def test_hotel_reservation_id_data_creation_full(self, sample_hotel_reservation_id_data):
def test_hotel_reservation_id_data_creation_full(
self, sample_hotel_reservation_id_data
):
"""Test creating HotelReservationIdData with all fields."""
assert sample_hotel_reservation_id_data.res_id_type == "123"
assert sample_hotel_reservation_id_data.res_id_value == "RESERVATION-456"
assert sample_hotel_reservation_id_data.res_id_source == "HOTEL_SYSTEM"
assert sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE"
def test_hotel_reservation_id_data_creation_minimal(self, minimal_hotel_reservation_id_data):
assert (
sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE"
)
def test_hotel_reservation_id_data_creation_minimal(
self, minimal_hotel_reservation_id_data
):
"""Test creating HotelReservationIdData with only required fields."""
assert minimal_hotel_reservation_id_data.res_id_type == "999"
assert minimal_hotel_reservation_id_data.res_id_value is None
@@ -258,124 +285,158 @@ class TestHotelReservationIdData:
class TestHotelReservationIdFactory:
"""Test the HotelReservationIdFactory class."""
def test_create_notif_hotel_reservation_id_full(self, sample_hotel_reservation_id_data):
def test_create_notif_hotel_reservation_id_full(
self, sample_hotel_reservation_id_data
):
"""Test creating a NotifHotelReservationId with full data."""
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data)
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
sample_hotel_reservation_id_data
)
assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
def test_create_retrieve_hotel_reservation_id_full(self, sample_hotel_reservation_id_data):
def test_create_retrieve_hotel_reservation_id_full(
self, sample_hotel_reservation_id_data
):
"""Test creating a RetrieveHotelReservationId with full data."""
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data)
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
sample_hotel_reservation_id_data
)
assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
def test_create_hotel_reservation_id_minimal(self, minimal_hotel_reservation_id_data):
def test_create_hotel_reservation_id_minimal(
self, minimal_hotel_reservation_id_data
):
"""Test creating hotel reservation IDs with minimal data."""
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(minimal_hotel_reservation_id_data)
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(minimal_hotel_reservation_id_data)
notif_reservation_id = (
HotelReservationIdFactory.create_notif_hotel_reservation_id(
minimal_hotel_reservation_id_data
)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
minimal_hotel_reservation_id_data
)
)
for reservation_id in [notif_reservation_id, retrieve_reservation_id]:
assert reservation_id.res_id_type == "999"
assert reservation_id.res_id_value is None
assert reservation_id.res_id_source is None
assert reservation_id.res_id_source_context is None
def test_from_notif_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data):
def test_from_notif_hotel_reservation_id_roundtrip(
self, sample_hotel_reservation_id_data
):
"""Test converting NotifHotelReservationId back to HotelReservationIdData."""
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data)
converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(reservation_id)
reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
sample_hotel_reservation_id_data
)
converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(
reservation_id
)
assert converted_data == sample_hotel_reservation_id_data
def test_from_retrieve_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data):
def test_from_retrieve_hotel_reservation_id_roundtrip(
self, sample_hotel_reservation_id_data
):
"""Test converting RetrieveHotelReservationId back to HotelReservationIdData."""
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data)
converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(reservation_id)
reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
sample_hotel_reservation_id_data
)
converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
reservation_id
)
assert converted_data == sample_hotel_reservation_id_data
class TestResGuestFactory:
"""Test the ResGuestFactory class."""
def test_create_notif_res_guests(self, sample_customer_data):
"""Test creating NotifResGuests structure."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
assert isinstance(res_guests, NotifResGuests)
# Navigate down the nested structure
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
assert customer.email.value == "john.doe@example.com"
def test_create_retrieve_res_guests(self, sample_customer_data):
"""Test creating RetrieveResGuests structure."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
assert isinstance(res_guests, RetrieveResGuests)
# Navigate down the nested structure
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
assert customer.email.value == "john.doe@example.com"
def test_create_res_guests_minimal(self, minimal_customer_data):
"""Test creating ResGuests with minimal customer data."""
notif_res_guests = ResGuestFactory.create_notif_res_guests(minimal_customer_data)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(minimal_customer_data)
notif_res_guests = ResGuestFactory.create_notif_res_guests(
minimal_customer_data
)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
minimal_customer_data
)
for res_guests in [notif_res_guests, retrieve_res_guests]:
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "Jane"
assert customer.person_name.surname == "Smith"
assert customer.email is None
assert customer.address is None
def test_extract_primary_customer_notif(self, sample_customer_data):
"""Test extracting primary customer from NotifResGuests."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data
def test_extract_primary_customer_retrieve(self, sample_customer_data):
"""Test extracting primary customer from RetrieveResGuests."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data
def test_roundtrip_conversion_notif(self, sample_customer_data):
"""Test complete roundtrip: CustomerData -> NotifResGuests -> CustomerData."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data
def test_roundtrip_conversion_retrieve(self, sample_customer_data):
"""Test complete roundtrip: CustomerData -> RetrieveResGuests -> CustomerData."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
assert extracted_data == sample_customer_data
class TestPhoneTechType:
"""Test the PhoneTechType enum."""
def test_enum_values(self):
"""Test that enum values are correct."""
assert PhoneTechType.VOICE.value == "1"
@@ -385,95 +446,121 @@ class TestPhoneTechType:
class TestAlpineBitsFactory:
"""Test the unified AlpineBitsFactory class."""
def test_create_customer_notif(self, sample_customer_data):
"""Test creating customer using unified factory for NOTIF."""
customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF)
assert isinstance(customer, NotifCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
def test_create_customer_retrieve(self, sample_customer_data):
"""Test creating customer using unified factory for RETRIEVE."""
customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE)
customer = AlpineBitsFactory.create(
sample_customer_data, OtaMessageType.RETRIEVE
)
assert isinstance(customer, RetrieveCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
def test_create_hotel_reservation_id_notif(self, sample_hotel_reservation_id_data):
"""Test creating hotel reservation ID using unified factory for NOTIF."""
reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF)
reservation_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.NOTIF
)
assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
def test_create_hotel_reservation_id_retrieve(self, sample_hotel_reservation_id_data):
def test_create_hotel_reservation_id_retrieve(
self, sample_hotel_reservation_id_data
):
"""Test creating hotel reservation ID using unified factory for RETRIEVE."""
reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE)
reservation_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
)
assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
def test_create_res_guests_notif(self, sample_customer_data):
"""Test creating ResGuests using unified factory for NOTIF."""
res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF)
res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.NOTIF
)
assert isinstance(res_guests, NotifResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
def test_create_res_guests_retrieve(self, sample_customer_data):
"""Test creating ResGuests using unified factory for RETRIEVE."""
res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE)
res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.RETRIEVE
)
assert isinstance(res_guests, RetrieveResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
def test_extract_data_from_customer(self, sample_customer_data):
"""Test extracting data from customer objects."""
# Create both types and extract data back
notif_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF)
retrieve_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE)
notif_customer = AlpineBitsFactory.create(
sample_customer_data, OtaMessageType.NOTIF
)
retrieve_customer = AlpineBitsFactory.create(
sample_customer_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_customer)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_customer)
assert notif_extracted == sample_customer_data
assert retrieve_extracted == sample_customer_data
def test_extract_data_from_hotel_reservation_id(self, sample_hotel_reservation_id_data):
def test_extract_data_from_hotel_reservation_id(
self, sample_hotel_reservation_id_data
):
"""Test extracting data from hotel reservation ID objects."""
# Create both types and extract data back
notif_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF)
retrieve_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE)
notif_res_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.NOTIF
)
retrieve_res_id = AlpineBitsFactory.create(
sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_res_id)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_id)
assert notif_extracted == sample_hotel_reservation_id_data
assert retrieve_extracted == sample_hotel_reservation_id_data
def test_extract_data_from_res_guests(self, sample_customer_data):
"""Test extracting data from ResGuests objects."""
# Create both types and extract data back
notif_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE)
notif_res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.NOTIF
)
retrieve_res_guests = AlpineBitsFactory.create_res_guests(
sample_customer_data, OtaMessageType.RETRIEVE
)
notif_extracted = AlpineBitsFactory.extract_data(notif_res_guests)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_guests)
assert notif_extracted == sample_customer_data
assert retrieve_extracted == sample_customer_data
def test_unsupported_data_type_error(self):
"""Test that unsupported data types raise ValueError."""
with pytest.raises(ValueError, match="Unsupported data type"):
AlpineBitsFactory.create("invalid_data", OtaMessageType.NOTIF)
def test_unsupported_object_type_error(self):
"""Test that unsupported object types raise ValueError in extract_data."""
with pytest.raises(ValueError, match="Unsupported object type"):
AlpineBitsFactory.extract_data("invalid_object")
def test_complete_workflow_with_unified_factory(self):
"""Test a complete workflow using only the unified factory."""
# Original data
@@ -481,34 +568,47 @@ class TestAlpineBitsFactory:
given_name="Unified",
surname="Factory",
email_address="unified@factory.com",
phone_numbers=[("+1234567890", PhoneTechType.MOBILE)]
phone_numbers=[("+1234567890", PhoneTechType.MOBILE)],
)
reservation_data = HotelReservationIdData(
res_id_type="999",
res_id_value="UNIFIED-TEST"
res_id_type="999", res_id_value="UNIFIED-TEST"
)
# Create using unified factory
customer_notif = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF)
customer_retrieve = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE)
customer_retrieve = AlpineBitsFactory.create(
customer_data, OtaMessageType.RETRIEVE
)
res_id_notif = AlpineBitsFactory.create(reservation_data, OtaMessageType.NOTIF)
res_id_retrieve = AlpineBitsFactory.create(reservation_data, OtaMessageType.RETRIEVE)
res_guests_notif = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF)
res_guests_retrieve = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE)
res_id_retrieve = AlpineBitsFactory.create(
reservation_data, OtaMessageType.RETRIEVE
)
res_guests_notif = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.NOTIF
)
res_guests_retrieve = AlpineBitsFactory.create_res_guests(
customer_data, OtaMessageType.RETRIEVE
)
# Extract everything back
extracted_customer_from_notif = AlpineBitsFactory.extract_data(customer_notif)
extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(customer_retrieve)
extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(
customer_retrieve
)
extracted_res_id_from_notif = AlpineBitsFactory.extract_data(res_id_notif)
extracted_res_id_from_retrieve = AlpineBitsFactory.extract_data(res_id_retrieve)
extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(res_guests_notif)
extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(res_guests_retrieve)
extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(
res_guests_notif
)
extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(
res_guests_retrieve
)
# Verify everything matches
assert extracted_customer_from_notif == customer_data
assert extracted_customer_from_retrieve == customer_data
@@ -520,37 +620,72 @@ class TestAlpineBitsFactory:
class TestIntegration:
"""Integration tests combining both factories."""
def test_both_factories_produce_same_customer_data(self, sample_customer_data):
"""Test that both factories can work with the same customer data."""
# Create using CustomerFactory
notif_customer = CustomerFactory.create_notif_customer(sample_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
retrieve_customer = CustomerFactory.create_retrieve_customer(
sample_customer_data
)
# Create using ResGuestFactory and extract customers
notif_res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
notif_from_res_guests = notif_res_guests.res_guest.profiles.profile_info.profile.customer
retrieve_from_res_guests = retrieve_res_guests.res_guest.profiles.profile_info.profile.customer
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
sample_customer_data
)
notif_from_res_guests = (
notif_res_guests.res_guest.profiles.profile_info.profile.customer
)
retrieve_from_res_guests = (
retrieve_res_guests.res_guest.profiles.profile_info.profile.customer
)
# Compare customer names (structure should be identical)
assert notif_customer.person_name.given_name == notif_from_res_guests.person_name.given_name
assert notif_customer.person_name.surname == notif_from_res_guests.person_name.surname
assert retrieve_customer.person_name.given_name == retrieve_from_res_guests.person_name.given_name
assert retrieve_customer.person_name.surname == retrieve_from_res_guests.person_name.surname
def test_hotel_reservation_id_factories_produce_same_data(self, sample_hotel_reservation_id_data):
assert (
notif_customer.person_name.given_name
== notif_from_res_guests.person_name.given_name
)
assert (
notif_customer.person_name.surname
== notif_from_res_guests.person_name.surname
)
assert (
retrieve_customer.person_name.given_name
== retrieve_from_res_guests.person_name.given_name
)
assert (
retrieve_customer.person_name.surname
== retrieve_from_res_guests.person_name.surname
)
def test_hotel_reservation_id_factories_produce_same_data(
self, sample_hotel_reservation_id_data
):
"""Test that both HotelReservationId factories produce equivalent results."""
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data)
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data)
notif_reservation_id = (
HotelReservationIdFactory.create_notif_hotel_reservation_id(
sample_hotel_reservation_id_data
)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
sample_hotel_reservation_id_data
)
)
# Both should have the same field values
assert notif_reservation_id.res_id_type == retrieve_reservation_id.res_id_type
assert notif_reservation_id.res_id_value == retrieve_reservation_id.res_id_value
assert notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source
assert notif_reservation_id.res_id_source_context == retrieve_reservation_id.res_id_source_context
assert (
notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source
)
assert (
notif_reservation_id.res_id_source_context
== retrieve_reservation_id.res_id_source_context
)
def test_complex_customer_workflow(self):
"""Test a complex workflow with multiple operations."""
# Create original data
@@ -559,7 +694,7 @@ class TestIntegration:
surname="Johnson",
phone_numbers=[
("+1555123456", PhoneTechType.MOBILE),
("+1555654321", PhoneTechType.VOICE)
("+1555654321", PhoneTechType.VOICE),
],
email_address="alice.johnson@company.com",
email_newsletter=False,
@@ -569,22 +704,24 @@ class TestIntegration:
country_code="CA",
address_catalog=True,
gender="Female",
language="fr"
language="fr",
)
# Create ResGuests for both types
notif_res_guests = ResGuestFactory.create_notif_res_guests(original_data)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(original_data)
# Extract data back from both
notif_extracted = ResGuestFactory.extract_primary_customer(notif_res_guests)
retrieve_extracted = ResGuestFactory.extract_primary_customer(retrieve_res_guests)
retrieve_extracted = ResGuestFactory.extract_primary_customer(
retrieve_res_guests
)
# All should be equal
assert original_data == notif_extracted
assert original_data == retrieve_extracted
assert notif_extracted == retrieve_extracted
def test_complex_hotel_reservation_id_workflow(self):
"""Test a complex workflow with HotelReservationId operations."""
# Create original reservation ID data
@@ -592,18 +729,30 @@ class TestIntegration:
res_id_type="456",
res_id_value="COMPLEX-RESERVATION-789",
res_id_source="INTEGRATION_SYSTEM",
res_id_source_context="API_CALL"
res_id_source_context="API_CALL",
)
# Create HotelReservationId for both types
notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data)
retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(original_data)
notif_reservation_id = (
HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data)
)
retrieve_reservation_id = (
HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
original_data
)
)
# Extract data back from both
notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(notif_reservation_id)
retrieve_extracted = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(retrieve_reservation_id)
notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(
notif_reservation_id
)
retrieve_extracted = (
HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
retrieve_reservation_id
)
)
# All should be equal
assert original_data == notif_extracted
assert original_data == retrieve_extracted
assert notif_extracted == retrieve_extracted
assert notif_extracted == retrieve_extracted

View File

@@ -6,24 +6,31 @@ Test the handshake functionality with the real AlpineBits sample file.
import asyncio
from alpine_bits_python.alpinebits_server import AlpineBitsServer
async def main():
print("🔄 Testing AlpineBits Handshake with Sample File")
print("=" * 60)
# Create server instance
server = AlpineBitsServer()
# Read the sample handshake request
with open("AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml", "r") as f:
with open(
"AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml",
"r",
) as f:
ping_request_xml = f.read()
print("📤 Sending handshake request...")
# Handle the ping request
response = await server.handle_request("OTA_Ping:Handshaking", ping_request_xml, "2024-10")
response = await server.handle_request(
"OTA_Ping:Handshaking", ping_request_xml, "2024-10"
)
print(f"\n📥 Response Status: {response.status_code}")
print(f"📄 Response XML:\n{response.xml_content}")
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())