diff --git a/src/alpine_bits_python/__main__.py b/src/alpine_bits_python/__main__.py
index 9a948c0..a9e3250 100644
--- a/src/alpine_bits_python/__main__.py
+++ b/src/alpine_bits_python/__main__.py
@@ -1,6 +1,7 @@
"""Entry point for alpine_bits_python package."""
+
from .main import main
if __name__ == "__main__":
print("running test main")
- main()
\ No newline at end of file
+ main()
diff --git a/src/alpine_bits_python/alpinebits_server.py b/src/alpine_bits_python/alpinebits_server.py
index 63e27c9..a6bcc41 100644
--- a/src/alpine_bits_python/alpinebits_server.py
+++ b/src/alpine_bits_python/alpinebits_server.py
@@ -23,49 +23,65 @@ from xsdata_pydantic.bindings import XmlParser
class HttpStatusCode(IntEnum):
"""Allowed HTTP status codes for AlpineBits responses."""
+
OK = 200
BAD_REQUEST = 400
UNAUTHORIZED = 401
INTERNAL_SERVER_ERROR = 500
-
class AlpineBitsActionName(Enum):
"""Enum for AlpineBits action names with capability and request name mappings."""
-
+
# Format: (capability_name, actual_request_name)
OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking")
OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests")
OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif")
- OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = ("action_OTA_HotelResNotif_GuestRequests",
- "OTA_HotelResNotif:GuestRequests")
- OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = ("action_OTA_HotelDescriptiveContentNotif_Inventory",
- "OTA_HotelDescriptiveContentNotif:Inventory")
- OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = ("action_OTA_HotelDescriptiveContentNotif_Info",
- "OTA_HotelDescriptiveContentNotif:Info")
- OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = ("action_OTA_HotelDescriptiveInfo_Inventory",
- "OTA_HotelDescriptiveInfo:Inventory")
- OTA_HOTEL_DESCRIPTIVE_INFO_INFO = ("action_OTA_HotelDescriptiveInfo_Info",
- "OTA_HotelDescriptiveInfo:Info")
- OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = ("action_OTA_HotelRatePlanNotif_RatePlans",
- "OTA_HotelRatePlanNotif:RatePlans")
- OTA_HOTEL_RATE_PLAN_BASE_RATES = ("action_OTA_HotelRatePlan_BaseRates",
- "OTA_HotelRatePlan:BaseRates")
-
+ OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = (
+ "action_OTA_HotelResNotif_GuestRequests",
+ "OTA_HotelResNotif:GuestRequests",
+ )
+ OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = (
+ "action_OTA_HotelDescriptiveContentNotif_Inventory",
+ "OTA_HotelDescriptiveContentNotif:Inventory",
+ )
+ OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = (
+ "action_OTA_HotelDescriptiveContentNotif_Info",
+ "OTA_HotelDescriptiveContentNotif:Info",
+ )
+ OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = (
+ "action_OTA_HotelDescriptiveInfo_Inventory",
+ "OTA_HotelDescriptiveInfo:Inventory",
+ )
+ OTA_HOTEL_DESCRIPTIVE_INFO_INFO = (
+ "action_OTA_HotelDescriptiveInfo_Info",
+ "OTA_HotelDescriptiveInfo:Info",
+ )
+ OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = (
+ "action_OTA_HotelRatePlanNotif_RatePlans",
+ "OTA_HotelRatePlanNotif:RatePlans",
+ )
+ OTA_HOTEL_RATE_PLAN_BASE_RATES = (
+ "action_OTA_HotelRatePlan_BaseRates",
+ "OTA_HotelRatePlan:BaseRates",
+ )
+
def __init__(self, capability_name: str, request_name: str):
self.capability_name = capability_name
self.request_name = request_name
-
+
@classmethod
- def get_by_capability_name(cls, capability_name: str) -> Optional['AlpineBitsActionName']:
+ def get_by_capability_name(
+ cls, capability_name: str
+ ) -> Optional["AlpineBitsActionName"]:
"""Get action enum by capability name."""
for action in cls:
if action.capability_name == capability_name:
return action
return None
-
+
@classmethod
- def get_by_request_name(cls, request_name: str) -> Optional['AlpineBitsActionName']:
+ def get_by_request_name(cls, request_name: str) -> Optional["AlpineBitsActionName"]:
"""Get action enum by request name."""
for action in cls:
if action.request_name == request_name:
@@ -75,22 +91,25 @@ class AlpineBitsActionName(Enum):
class Version(str, Enum):
"""Enum for AlpineBits versions."""
+
V2024_10 = "2024-10"
V2022_10 = "2022-10"
# Add other versions as needed
-
@dataclass
class AlpineBitsResponse:
"""Response data structure for AlpineBits actions."""
+
xml_content: str
status_code: HttpStatusCode = HttpStatusCode.OK
-
+
def __post_init__(self):
"""Validate that status code is one of the allowed values."""
if self.status_code not in [200, 400, 401, 500]:
- raise ValueError(f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500")
+ raise ValueError(
+ f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500"
+ )
# Abstract base class for AlpineBits Action
@@ -98,20 +117,24 @@ class AlpineBitsAction(ABC):
"""Abstract base class for handling AlpineBits actions."""
name: AlpineBitsActionName
- version: Version | list[Version] # list of versions in case action supports multiple versions
-
- async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
+ version: (
+ Version | list[Version]
+ ) # list of versions in case action supports multiple versions
+
+ async def handle(
+ self, action: str, request_xml: str, version: Version
+ ) -> AlpineBitsResponse:
"""
Handle the incoming request XML and return response XML.
-
+
Default implementation returns "not implemented" error.
Override this method in subclasses to provide actual functionality.
-
+
Args:
action: The action to perform (e.g., "OTA_PingRQ")
request_xml: The XML request body as string
version: The AlpineBits version
-
+
Returns:
AlpineBitsResponse with error or actual response
"""
@@ -121,7 +144,7 @@ class AlpineBitsAction(ABC):
async def check_version_supported(self, version: Version) -> bool:
"""
Check if the action supports the given version.
-
+
Args:
version: The AlpineBits version to check
Returns:
@@ -130,103 +153,93 @@ class AlpineBitsAction(ABC):
if isinstance(self.version, list):
return version in self.version
return version == self.version
-
-
-
-
class ServerCapabilities:
"""
Automatically discovers AlpineBitsAction implementations and generates capabilities.
"""
-
+
def __init__(self):
self.action_registry: Dict[str, Type[AlpineBitsAction]] = {}
self._discover_actions()
self.capability_dict = None
-
+
def _discover_actions(self):
"""Discover all AlpineBitsAction implementations in the current module."""
current_module = inspect.getmodule(self)
-
+
for name, obj in inspect.getmembers(current_module):
- if (inspect.isclass(obj) and
- issubclass(obj, AlpineBitsAction) and
- obj != AlpineBitsAction):
-
+ if (
+ inspect.isclass(obj)
+ and issubclass(obj, AlpineBitsAction)
+ and obj != AlpineBitsAction
+ ):
# Check if this action is actually implemented (not just returning default)
if self._is_action_implemented(obj):
action_instance = obj()
- if hasattr(action_instance, 'name'):
+ if hasattr(action_instance, "name"):
# Use capability name for the registry key
self.action_registry[action_instance.name.capability_name] = obj
-
+
def _is_action_implemented(self, action_class: Type[AlpineBitsAction]) -> bool:
"""
Check if an action is actually implemented or just uses the default behavior.
This is a simple check - in practice, you might want more sophisticated detection.
"""
# Check if the class has overridden the handle method
- if 'handle' in action_class.__dict__:
+ if "handle" in action_class.__dict__:
return True
return False
-
def create_capabilities_dict(self) -> None:
"""
Generate the capabilities dictionary based on discovered actions.
-
+
"""
versions_dict = {}
-
+
for action_name, action_class in self.action_registry.items():
action_instance = action_class()
-
+
# Get supported versions for this action
if isinstance(action_instance.version, list):
supported_versions = action_instance.version
else:
supported_versions = [action_instance.version]
-
+
# Add action to each supported version
for version in supported_versions:
version_str = version.value
-
+
if version_str not in versions_dict:
- versions_dict[version_str] = {
- "version": version_str,
- "actions": []
- }
-
+ versions_dict[version_str] = {"version": version_str, "actions": []}
+
action_dict = {"action": action_name}
-
+
# Add supports field if the action has custom supports
- if hasattr(action_instance, 'supports') and action_instance.supports:
+ if hasattr(action_instance, "supports") and action_instance.supports:
action_dict["supports"] = action_instance.supports
-
+
versions_dict[version_str]["actions"].append(action_dict)
self.capability_dict = {"versions": list(versions_dict.values())}
-
-
return None
-
def get_capabilities_dict(self) -> Dict:
"""
Get capabilities as a dictionary. Generates if not already created.
"""
-
+
if self.capability_dict is None:
self.create_capabilities_dict()
return self.capability_dict
-
+
def get_capabilities_json(self) -> str:
"""Get capabilities as formatted JSON string."""
return json.dumps(self.get_capabilities_dict(), indent=2)
-
+
def get_supported_actions(self) -> List[str]:
"""Get list of all supported action names."""
return list(self.action_registry.keys())
@@ -234,22 +247,35 @@ class ServerCapabilities:
# Sample Action Implementations for demonstration
+
class PingAction(AlpineBitsAction):
"""Implementation for OTA_Ping action (handshaking)."""
-
+
def __init__(self):
self.name = AlpineBitsActionName.OTA_PING
- self.version = [Version.V2024_10, Version.V2022_10] # Supports multiple versions
-
- async def handle(self, action: str, request_xml: str, version: Version, server_capabilities: None | ServerCapabilities = None) -> AlpineBitsResponse:
+ self.version = [
+ Version.V2024_10,
+ Version.V2022_10,
+ ] # Supports multiple versions
+
+ async def handle(
+ self,
+ action: str,
+ request_xml: str,
+ version: Version,
+ server_capabilities: None | ServerCapabilities = None,
+ ) -> AlpineBitsResponse:
"""Handle ping requests."""
if request_xml is None:
- return AlpineBitsResponse(f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST)
+ return AlpineBitsResponse(
+ f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST
+ )
if server_capabilities is None:
- return AlpineBitsResponse("Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR)
-
+ return AlpineBitsResponse(
+ "Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR
+ )
# Parse the incoming request XML and extract EchoData
parser = XmlParser()
@@ -259,54 +285,66 @@ class PingAction(AlpineBitsAction):
echo_data = json.loads(parsed_request.echo_data)
except Exception as e:
- return AlpineBitsResponse(f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST)
+ return AlpineBitsResponse(
+ f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST
+ )
# compare echo data with capabilities, create a dictionary containing the matching capabilities
capabilities_dict = server_capabilities.get_capabilities_dict()
matching_capabilities = {"versions": []}
-
+
# Iterate through client's requested versions
for client_version in echo_data.get("versions", []):
client_version_str = client_version.get("version", "")
-
+
# Find matching server version
for server_version in capabilities_dict["versions"]:
if server_version["version"] == client_version_str:
# Found a matching version, now find common actions
- matching_version = {
- "version": client_version_str,
- "actions": []
- }
-
+ matching_version = {"version": client_version_str, "actions": []}
+
# Get client's requested actions for this version
- client_actions = {action.get("action", ""): action for action in client_version.get("actions", [])}
- server_actions = {action.get("action", ""): action for action in server_version.get("actions", [])}
-
+ client_actions = {
+ action.get("action", ""): action
+ for action in client_version.get("actions", [])
+ }
+ server_actions = {
+ action.get("action", ""): action
+ for action in server_version.get("actions", [])
+ }
+
# Find common actions
for action_name in client_actions:
if action_name in server_actions:
# Use server's action definition (includes our supports)
- matching_version["actions"].append(server_actions[action_name])
-
+ matching_version["actions"].append(
+ server_actions[action_name]
+ )
+
# Only add version if there are common actions
if matching_version["actions"]:
matching_capabilities["versions"].append(matching_version)
break
-
+
# Debug print to see what we matched
-
+
# Create successful ping response with matched capabilities
capabilities_json = json.dumps(matching_capabilities, indent=2)
- warning = OtaPingRs.Warnings.Warning(status=WarningStatus.ALPINEBITS_HANDSHAKE.value, type_value="11", content=[capabilities_json])
+ warning = OtaPingRs.Warnings.Warning(
+ status=WarningStatus.ALPINEBITS_HANDSHAKE.value,
+ type_value="11",
+ content=[capabilities_json],
+ )
warning_response = OtaPingRs.Warnings(warning=[warning])
- response_ota_ping = OtaPingRs(version= "7.000", warnings=warning_response, echo_data=capabilities_json, success="")
-
-
-
-
+ response_ota_ping = OtaPingRs(
+ version="7.000",
+ warnings=warning_response,
+ echo_data=capabilities_json,
+ success="",
+ )
config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8"
@@ -314,34 +352,35 @@ class PingAction(AlpineBitsAction):
serializer = XmlSerializer(config=config)
- response_xml = serializer.render(response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"})
-
-
-
+ response_xml = serializer.render(
+ response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}
+ )
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class ReadAction(AlpineBitsAction):
"""Implementation for OTA_Read action."""
-
+
def __init__(self):
self.name = AlpineBitsActionName.OTA_READ
self.version = [Version.V2024_10, Version.V2022_10]
-
- async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
+
+ async def handle(
+ self, action: str, request_xml: str, version: Version
+ ) -> AlpineBitsResponse:
"""Handle read requests."""
- response_xml = f'''
+ response_xml = f"""
Read operation successful for {version.value}
-'''
+"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class HotelAvailNotifAction(AlpineBitsAction):
"""Implementation for Hotel Availability Notification action with supports."""
-
+
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_AVAIL_NOTIF
self.version = Version.V2022_10
@@ -349,68 +388,68 @@ class HotelAvailNotifAction(AlpineBitsAction):
"OTA_HotelAvailNotif_accept_rooms",
"OTA_HotelAvailNotif_accept_categories",
"OTA_HotelAvailNotif_accept_deltas",
- "OTA_HotelAvailNotif_accept_BookingThreshold"
+ "OTA_HotelAvailNotif_accept_BookingThreshold",
]
-
- async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
+
+ async def handle(
+ self, action: str, request_xml: str, version: Version
+ ) -> AlpineBitsResponse:
"""Handle hotel availability notifications."""
- response_xml = '''
+ response_xml = """
-'''
+"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class GuestRequestsAction(AlpineBitsAction):
"""Unimplemented action - will not appear in capabilities."""
-
+
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS
self.version = Version.V2024_10
-
+
# Note: This class doesn't override the handle method, so it won't be discovered
-
-
-
-
class AlpineBitsServer:
"""
Asynchronous AlpineBits server for handling hotel data exchange requests.
-
+
This server handles various OTA actions and implements the AlpineBits protocol
for hotel data exchange. It maintains a registry of supported actions and
their capabilities, and can respond to handshake requests with its capabilities.
"""
-
+
def __init__(self):
self.capabilities = ServerCapabilities()
self._action_instances = {}
self._initialize_action_instances()
-
+
def _initialize_action_instances(self):
"""Initialize instances of all discovered action classes."""
for capability_name, action_class in self.capabilities.action_registry.items():
self._action_instances[capability_name] = action_class()
-
+
def get_capabilities(self) -> Dict:
"""Get server capabilities."""
return self.capabilities.get_capabilities_dict()
-
+
def get_capabilities_json(self) -> str:
"""Get server capabilities as JSON."""
return self.capabilities.get_capabilities_json()
-
- async def handle_request(self, request_action_name: str, request_xml: str, version: str = "2024-10") -> AlpineBitsResponse:
+
+ async def handle_request(
+ self, request_action_name: str, request_xml: str, version: str = "2024-10"
+ ) -> AlpineBitsResponse:
"""
Handle an incoming AlpineBits request by routing to appropriate action handler.
-
+
Args:
request_action_name: The action name from the request (e.g., "OTA_Read:GuestRequests")
request_xml: The XML request body
version: The AlpineBits version (defaults to "2024-10")
-
+
Returns:
AlpineBitsResponse with the result
"""
@@ -419,52 +458,56 @@ class AlpineBitsServer:
version_enum = Version(version)
except ValueError:
return AlpineBitsResponse(
- f"Error: Unsupported version {version}",
- HttpStatusCode.BAD_REQUEST
+ f"Error: Unsupported version {version}", HttpStatusCode.BAD_REQUEST
)
-
+
# Find the action by request name
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum:
return AlpineBitsResponse(
f"Error: Unknown action {request_action_name}",
- HttpStatusCode.BAD_REQUEST
+ HttpStatusCode.BAD_REQUEST,
)
-
+
# Check if we have an implementation for this action
capability_name = action_enum.capability_name
if capability_name not in self._action_instances:
return AlpineBitsResponse(
f"Error: Action {request_action_name} is not implemented",
- HttpStatusCode.BAD_REQUEST
+ HttpStatusCode.BAD_REQUEST,
)
-
+
action_instance: AlpineBitsAction = self._action_instances[capability_name]
-
+
# Check if the action supports the requested version
if not await action_instance.check_version_supported(version_enum):
return AlpineBitsResponse(
f"Error: Action {request_action_name} does not support version {version}",
- HttpStatusCode.BAD_REQUEST
+ HttpStatusCode.BAD_REQUEST,
)
-
+
# Handle the request
try:
# Special case for ping action - pass server capabilities
if capability_name == "action_OTA_Ping":
- return await action_instance.handle(request_action_name, request_xml, version_enum, self.capabilities)
+ return await action_instance.handle(
+ request_action_name, request_xml, version_enum, self.capabilities
+ )
else:
- return await action_instance.handle(request_action_name, request_xml, version_enum)
+ return await action_instance.handle(
+ request_action_name, request_xml, version_enum
+ )
except Exception as e:
print(f"Error handling request {request_action_name}: {str(e)}")
# print stack trace for debugging
import traceback
+
traceback.print_exc()
return AlpineBitsResponse(
f"Error: Internal server error while processing {request_action_name}: {str(e)}",
- HttpStatusCode.INTERNAL_SERVER_ERROR
+ HttpStatusCode.INTERNAL_SERVER_ERROR,
)
-
+
def get_supported_request_names(self) -> List[str]:
"""Get all supported request names (not capability names)."""
request_names = []
@@ -473,26 +516,28 @@ class AlpineBitsServer:
if action_enum:
request_names.append(action_enum.request_name)
return sorted(request_names)
-
- def is_action_supported(self, request_action_name: str, version: str = None) -> bool:
+
+ def is_action_supported(
+ self, request_action_name: str, version: str = None
+ ) -> bool:
"""
Check if a request action is supported.
-
+
Args:
request_action_name: The request action name (e.g., "OTA_Read:GuestRequests")
version: Optional version to check
-
+
Returns:
True if supported, False otherwise
"""
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
if not action_enum:
return False
-
+
capability_name = action_enum.capability_name
if capability_name not in self._action_instances:
return False
-
+
if version:
try:
version_enum = Version(version)
@@ -504,7 +549,7 @@ class AlpineBitsServer:
return action_instance.version == version_enum
except ValueError:
return False
-
+
return True
@@ -512,10 +557,10 @@ async def main():
"""Demonstrate the automatic capabilities discovery and request handling."""
print("๐ AlpineBits Server Capabilities Discovery & Request Handling Demo")
print("=" * 70)
-
+
# Create server instance
server = AlpineBitsServer()
-
+
print("\n๐ Discovered Action Classes:")
print("-" * 30)
for capability_name, action_class in server.capabilities.action_registry.items():
@@ -523,24 +568,26 @@ async def main():
request_name = action_enum.request_name if action_enum else "unknown"
print(f"โ
{capability_name} -> {action_class.__name__}")
print(f" Request name: {request_name}")
-
- print(f"\n๐ Total Implemented Actions: {len(server.capabilities.get_supported_actions())}")
-
+
+ print(
+ f"\n๐ Total Implemented Actions: {len(server.capabilities.get_supported_actions())}"
+ )
+
print("\n๐ Generated Capabilities JSON:")
print("-" * 30)
capabilities_json = server.get_capabilities_json()
print(capabilities_json)
-
+
print("\n๐ฏ Supported Request Names:")
print("-" * 30)
for request_name in server.get_supported_request_names():
print(f" โข {request_name}")
-
+
print("\n๐งช Testing Request Handling:")
print("-" * 30)
-
+
test_xml = "sample request"
-
+
# Test different request formats
test_cases = [
("OTA_Ping:Handshaking", "2024-10"),
@@ -548,16 +595,16 @@ async def main():
("OTA_Read:GuestRequests", "2022-10"),
("OTA_HotelAvailNotif", "2024-10"),
("UnknownAction", "2024-10"),
- ("OTA_Ping:Handshaking", "unsupported-version")
+ ("OTA_Ping:Handshaking", "unsupported-version"),
]
-
+
for request_name, version in test_cases:
print(f"\n๏ฟฝ Testing: {request_name} (v{version})")
-
+
# Check if supported first
is_supported = server.is_action_supported(request_name, version)
print(f" Supported: {is_supported}")
-
+
# Handle the request
response = await server.handle_request(request_name, test_xml, version)
print(f" Status: {response.status_code}")
@@ -565,9 +612,9 @@ async def main():
print(f" Response: {response.xml_content[:100]}...")
else:
print(f" Response: {response.xml_content}")
-
+
print("\nโ
Demo completed successfully!")
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/src/alpine_bits_python/api.py b/src/alpine_bits_python/api.py
index 7051914..f11874c 100644
--- a/src/alpine_bits_python/api.py
+++ b/src/alpine_bits_python/api.py
@@ -1,18 +1,29 @@
-from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, APIRouter, Form, File, UploadFile
+from fastapi import (
+ FastAPI,
+ HTTPException,
+ BackgroundTasks,
+ Request,
+ Depends,
+ APIRouter,
+ Form,
+ File,
+ UploadFile,
+)
+from fastapi.concurrency import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic
from .config_loader import load_config
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from .models import WixFormSubmission
-from datetime import datetime, date, timezone
+from datetime import datetime, date, timezone
from .auth import validate_api_key, validate_wix_signature, generate_api_key
from .rate_limit import (
- limiter,
- webhook_limiter,
+ limiter,
+ webhook_limiter,
custom_rate_limit_handler,
DEFAULT_RATE_LIMIT,
WEBHOOK_RATE_LIMIT,
- BURST_RATE_LIMIT
+ BURST_RATE_LIMIT,
)
from slowapi.errors import RateLimitExceeded
import logging
@@ -24,8 +35,14 @@ import gzip
import xml.etree.ElementTree as ET
from .alpinebits_server import AlpineBitsServer, Version
import urllib.parse
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
-from .db import get_async_session, Customer as DBCustomer, Reservation as DBReservation
+from .db import (
+ Base,
+ Customer as DBCustomer,
+ Reservation as DBReservation,
+ get_database_url,
+)
# Configure logging
@@ -42,12 +59,36 @@ except Exception as e:
_LOGGER.error(f"Failed to load config: {str(e)}")
config = {}
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ # Setup DB
+ DATABASE_URL = get_database_url(config)
+ engine = create_async_engine(DATABASE_URL, echo=True)
+ AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
+ app.state.engine = engine
+ app.state.async_sessionmaker = AsyncSessionLocal
+
+ # Create tables
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+ _LOGGER.info("Database tables checked/created at startup.")
+
+ yield
+
+ # Optional: Dispose engine on shutdown
+ await engine.dispose()
+
+async def get_async_session(request: Request):
+ async_sessionmaker = request.app.state.async_sessionmaker
+ async with async_sessionmaker() as session:
+ yield session
app = FastAPI(
title="Wix Form Handler API",
description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting",
- version="1.0.0"
+ version="1.0.0",
+ lifespan=lifespan
)
# Create API router with /api prefix
@@ -62,9 +103,9 @@ app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.wix.com",
- "https://*.wixstatic.com",
+ "https://*.wixstatic.com",
"http://localhost:3000", # For development
- "http://localhost:8000" # For local testing
+ "http://localhost:8000", # For local testing
],
allow_credentials=True,
allow_methods=["GET", "POST"],
@@ -78,27 +119,39 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None:
Add your business logic here.
"""
try:
- _LOGGER.info(f"Processing form submission: {submission_data.get('submissionId')}")
-
+ _LOGGER.info(
+ f"Processing form submission: {submission_data.get('submissionId')}"
+ )
+
# Example processing - you can replace this with your actual logic
- form_name = submission_data.get('formName')
- contact_email = submission_data.get('contact', {}).get('email') if submission_data.get('contact') else None
-
+ form_name = submission_data.get("formName")
+ contact_email = (
+ submission_data.get("contact", {}).get("email")
+ if submission_data.get("contact")
+ else None
+ )
+
# Extract form fields
- form_fields = {k: v for k, v in submission_data.items() if k.startswith('field:')}
-
- _LOGGER.info(f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}")
-
+ form_fields = {
+ k: v for k, v in submission_data.items() if k.startswith("field:")
+ }
+
+ _LOGGER.info(
+ f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}"
+ )
+
# Here you could:
# - Save to database
# - Send emails
# - Call external APIs
# - Process the data further
-
+
except Exception as e:
_LOGGER.error(f"Error processing form submission: {str(e)}")
+
+
@api_router.get("/")
@limiter.limit(DEFAULT_RATE_LIMIT)
async def root(request: Request):
@@ -111,8 +164,8 @@ async def root(request: Request):
"rate_limits": {
"default": DEFAULT_RATE_LIMIT,
"webhook": WEBHOOK_RATE_LIMIT,
- "burst": BURST_RATE_LIMIT
- }
+ "burst": BURST_RATE_LIMIT,
+ },
}
@@ -126,11 +179,10 @@ async def health_check(request: Request):
"service": "wix-form-handler",
"version": "1.0.0",
"authentication": "enabled",
- "rate_limiting": "enabled"
+ "rate_limiting": "enabled",
}
-
# Extracted business logic for handling Wix form submissions
async def process_wix_form_submission(request: Request, data: Dict[str, Any], db):
"""
@@ -138,10 +190,9 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
"""
timestamp = datetime.now().isoformat()
-
_LOGGER.info(f"Received Wix form data at {timestamp}")
- #_LOGGER.info(f"Data keys: {list(data.keys())}")
- #_LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
+ # _LOGGER.info(f"Data keys: {list(data.keys())}")
+ # _LOGGER.info(f"Full data: {json.dumps(data, indent=2)}")
log_entry = {
"timestamp": timestamp,
"client_ip": request.client.host if request.client else "unknown",
@@ -154,9 +205,13 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
if not os.path.exists(logs_dir):
os.makedirs(logs_dir, mode=0o755, exist_ok=True)
stat_info = os.stat(logs_dir)
- _LOGGER.info(f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}")
+ _LOGGER.info(
+ f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}"
+ )
_LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}")
- log_filename = f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
+ log_filename = (
+ f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
+ )
with open(log_filename, "w", encoding="utf-8") as f:
json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False)
file_stat = os.stat(log_filename)
@@ -164,16 +219,10 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
_LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}")
_LOGGER.info(f"Data logged to: {log_filename}")
-
-
data = data.get("data") # Handle nested "data" key if present
-
-
-
# save customer and reservation to DB
-
contact_info = data.get("contact", {})
first_name = contact_info.get("name", {}).get("first")
last_name = contact_info.get("name", {}).get("last")
@@ -193,10 +242,18 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
language = data.get("contact", {}).get("locale", "en")[:2]
# Dates
- start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value")
- end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value")
+ start_date = (
+ data.get("field:date_picker_a7c8")
+ or data.get("Anreisedatum")
+ or data.get("submissions", [{}])[1].get("value")
+ )
+ end_date = (
+ data.get("field:date_picker_7e65")
+ or data.get("Abreisedatum")
+ or data.get("submissions", [{}])[2].get("value")
+ )
- # Room/guest info
+ # Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2)
num_children = int(data.get("field:anzahl_kinder") or 0)
children_ages = []
@@ -258,7 +315,7 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults,
num_children=num_children,
- children_ages=','.join(str(a) for a in children_ages),
+ children_ages=",".join(str(a) for a in children_ages),
offer=offer,
utm_comment=utm_comment,
created_at=datetime.now(timezone.utc),
@@ -277,23 +334,21 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
await db.commit()
await db.refresh(db_reservation)
-
-
-
return {
"status": "success",
"message": "Wix form data received successfully",
"received_keys": list(data.keys()),
"data_logged_to": log_filename,
"timestamp": timestamp,
- "process_info": log_entry["process_info"],
- "note": "No authentication required for this endpoint"
+ "note": "No authentication required for this endpoint",
}
@api_router.post("/webhook/wix-form")
@webhook_limiter.limit(WEBHOOK_RATE_LIMIT)
-async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)):
+async def handle_wix_form(
+ request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
+):
"""
Unified endpoint to handle Wix form submissions (test and production).
No authentication required for this endpoint.
@@ -304,16 +359,19 @@ async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Dep
_LOGGER.error(f"Error in handle_wix_form: {str(e)}")
# log stacktrace
import traceback
+
traceback_str = traceback.format_exc()
_LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}")
raise HTTPException(
- status_code=500,
- detail=f"Error processing Wix form data: {str(e)}"
+ status_code=500, detail=f"Error processing Wix form data: {str(e)}"
)
+
@api_router.post("/webhook/wix-form/test")
@limiter.limit(DEFAULT_RATE_LIMIT)
-async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session=Depends(get_async_session)):
+async def handle_wix_form_test(
+ request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)
+):
"""
Test endpoint to verify the API is working with raw JSON data.
No authentication required for testing purposes.
@@ -323,40 +381,37 @@ async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session
except Exception as e:
_LOGGER.error(f"Error in handle_wix_form_test: {str(e)}")
raise HTTPException(
- status_code=500,
- detail=f"Error processing test data: {str(e)}"
+ status_code=500, detail=f"Error processing test data: {str(e)}"
)
@api_router.post("/admin/generate-api-key")
@limiter.limit("5/hour") # Very restrictive for admin operations
async def generate_new_api_key(
- request: Request,
- admin_key: str = Depends(validate_api_key)
+ request: Request, admin_key: str = Depends(validate_api_key)
):
"""
Admin endpoint to generate new API keys.
Requires admin API key and is heavily rate limited.
"""
if admin_key != "admin-key":
- raise HTTPException(
- status_code=403,
- detail="Admin access required"
- )
-
+ raise HTTPException(status_code=403, detail="Admin access required")
+
new_key = generate_api_key()
_LOGGER.info(f"Generated new API key (requested by: {admin_key})")
-
+
return {
"status": "success",
"message": "New API key generated",
"api_key": new_key,
"timestamp": datetime.now().isoformat(),
- "note": "Store this key securely - it won't be shown again"
+ "note": "Store this key securely - it won't be shown again",
}
-async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(security_basic)) -> str:
+async def validate_basic_auth(
+ credentials: HTTPBasicCredentials = Depends(security_basic),
+) -> str:
"""
Validate basic authentication for AlpineBits protocol.
Returns username if valid, raises HTTPException if not.
@@ -369,8 +424,11 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
headers={"WWW-Authenticate": "Basic"},
)
valid = False
- for entry in config['alpine_bits_auth']:
- if credentials.username == entry['username'] and credentials.password == entry['password']:
+ for entry in config["alpine_bits_auth"]:
+ if (
+ credentials.username == entry["username"]
+ and credentials.password == entry["password"]
+ ):
valid = True
break
if not valid:
@@ -379,7 +437,9 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi
detail="ERROR: Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)
- _LOGGER.info(f"AlpineBits authentication successful for user: {credentials.username} (from config)")
+ _LOGGER.info(
+ f"AlpineBits authentication successful for user: {credentials.username} (from config)"
+ )
return credentials.username
@@ -390,10 +450,9 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
"""
if "multipart/form-data" not in content_type:
raise HTTPException(
- status_code=400,
- detail="ERROR: Content-Type must be multipart/form-data"
+ status_code=400, detail="ERROR: Content-Type must be multipart/form-data"
)
-
+
# Extract boundary
boundary = None
for part in content_type.split(";"):
@@ -401,62 +460,56 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
if part.startswith("boundary="):
boundary = part.split("=", 1)[1].strip('"')
break
-
+
if not boundary:
raise HTTPException(
- status_code=400,
- detail="ERROR: Missing boundary in multipart/form-data"
+ status_code=400, detail="ERROR: Missing boundary in multipart/form-data"
)
-
+
# Simple multipart parsing
parts = body.split(f"--{boundary}".encode())
data = {}
-
+
for part in parts:
if not part.strip() or part.strip() == b"--":
continue
-
+
# Split headers and content
if b"\r\n\r\n" in part:
headers_section, content = part.split(b"\r\n\r\n", 1)
content = content.rstrip(b"\r\n")
-
+
# Parse Content-Disposition header
- headers = headers_section.decode('utf-8', errors='ignore')
+ headers = headers_section.decode("utf-8", errors="ignore")
name = None
- for line in headers.split('\n'):
- if 'Content-Disposition' in line and 'name=' in line:
+ for line in headers.split("\n"):
+ if "Content-Disposition" in line and "name=" in line:
# Extract name parameter
- for param in line.split(';'):
+ for param in line.split(";"):
param = param.strip()
- if param.startswith('name='):
- name = param.split('=', 1)[1].strip('"')
+ if param.startswith("name="):
+ name = param.split("=", 1)[1].strip('"')
break
-
+
if name:
# Handle file uploads or text content
- if content.startswith(b'<'):
+ if content.startswith(b"<"):
# Likely XML content
- data[name] = content.decode('utf-8', errors='ignore')
+ data[name] = content.decode("utf-8", errors="ignore")
else:
- data[name] = content.decode('utf-8', errors='ignore')
-
+ data[name] = content.decode("utf-8", errors="ignore")
+
return data
-
-
-
-
@api_router.post("/alpinebits/server-2024-10")
@limiter.limit("60/minute")
async def alpinebits_server_handshake(
- request: Request,
- username: str = Depends(validate_basic_auth)
+ request: Request, username: str = Depends(validate_basic_auth)
):
"""
AlpineBits server endpoint implementing the handshake protocol.
-
+
This endpoint handles:
- Protocol version negotiation via X-AlpineBits-ClientProtocolVersion header
- Client identification via X-AlpineBits-ClientID header (optional)
@@ -464,62 +517,67 @@ async def alpinebits_server_handshake(
- Gzip compression support
- Proper error handling with HTTP status codes
- Handshaking action processing
-
+
Authentication: HTTP Basic Auth required
Content-Type: multipart/form-data
Compression: gzip supported (check X-AlpineBits-Server-Accept-Encoding)
"""
try:
# Check required headers
- client_protocol_version = request.headers.get("X-AlpineBits-ClientProtocolVersion")
+ client_protocol_version = request.headers.get(
+ "X-AlpineBits-ClientProtocolVersion"
+ )
if not client_protocol_version:
# Server concludes client speaks a protocol version preceding 2013-04
client_protocol_version = "pre-2013-04"
- _LOGGER.info("No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04")
+ _LOGGER.info(
+ "No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04"
+ )
else:
_LOGGER.info(f"Client protocol version: {client_protocol_version}")
-
+
# Optional client ID
client_id = request.headers.get("X-AlpineBits-ClientID")
if client_id:
_LOGGER.info(f"Client ID: {client_id}")
-
+
# Check content encoding
content_encoding = request.headers.get("Content-Encoding")
is_compressed = content_encoding == "gzip"
-
+
if is_compressed:
_LOGGER.info("Request is gzip compressed")
-
+
# Get content type before processing
content_type = request.headers.get("Content-Type", "")
_LOGGER.info(f"Content-Type: {content_type}")
_LOGGER.info(f"Content-Encoding: {content_encoding}")
-
+
# Get request body
body = await request.body()
-
+
# Decompress if needed
if is_compressed:
try:
-
body = gzip.decompress(body)
-
except Exception as e:
raise HTTPException(
status_code=400,
- detail=f"ERROR: Failed to decompress gzip content: {str(e)}"
+ detail=f"ERROR: Failed to decompress gzip content: {str(e)}",
)
-
+
# Check content type (after decompression)
- if "multipart/form-data" not in content_type and "application/x-www-form-urlencoded" not in content_type:
+ if (
+ "multipart/form-data" not in content_type
+ and "application/x-www-form-urlencoded" not in content_type
+ ):
raise HTTPException(
status_code=400,
- detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded"
+ detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
)
-
+
# Parse multipart data
if "multipart/form-data" in content_type:
try:
@@ -527,7 +585,7 @@ async def alpinebits_server_handshake(
except Exception as e:
raise HTTPException(
status_code=400,
- detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}"
+ detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}",
)
elif "application/x-www-form-urlencoded" in content_type:
# Parse as urlencoded
@@ -535,75 +593,59 @@ async def alpinebits_server_handshake(
else:
raise HTTPException(
status_code=400,
- detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded"
+ detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded",
)
-
+
# Check for required action parameter
action = form_data.get("action")
if not action:
raise HTTPException(
- status_code=400,
- detail="ERROR: Missing required 'action' parameter")
-
+ status_code=400, detail="ERROR: Missing required 'action' parameter"
+ )
+
_LOGGER.info(f"AlpineBits action: {action}")
-
# Get optional request XML
- request_xml = form_data.get("request")
-
+ request_xml = form_data.get("request")
server = AlpineBitsServer()
version = Version.V2024_10
-
-
# Create successful handshake response
- response = await server.handle_request(action, request_xml, version)
+ response = await server.handle_request(action, request_xml, version)
response_xml = response.xml_content
-
+
# Set response headers indicating server capabilities
headers = {
"Content-Type": "application/xml; charset=utf-8",
"X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support
- "X-AlpineBits-Server-Version": "2024-10"
+ "X-AlpineBits-Server-Version": "2024-10",
}
-
- return Response(
- content=response_xml,
- status_code=response.status_code,
- headers=headers
- )
-
-
+ return Response(
+ content=response_xml, status_code=response.status_code, headers=headers
+ )
+
except HTTPException:
# Re-raise HTTP exceptions (auth errors, etc.)
raise
except Exception as e:
_LOGGER.error(f"Error in AlpineBits handshake: {str(e)}")
- raise HTTPException(
- status_code=500,
- detail=f"Internal server error: {str(e)}"
- )
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
@api_router.get("/admin/stats")
@limiter.limit("10/minute")
-async def get_api_stats(
- request: Request,
- admin_key: str = Depends(validate_api_key)
-):
+async def get_api_stats(request: Request, admin_key: str = Depends(validate_api_key)):
"""
Admin endpoint to get API usage statistics.
Requires admin API key.
"""
if admin_key != "admin-key":
- raise HTTPException(
- status_code=403,
- detail="Admin access required"
- )
-
+ raise HTTPException(status_code=403, detail="Admin access required")
+
# In a real application, you'd fetch this from your database/monitoring system
return {
"status": "success",
@@ -611,9 +653,9 @@ async def get_api_stats(
"uptime": "Available in production deployment",
"total_requests": "Available with monitoring setup",
"active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]),
- "rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory"
+ "rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory",
},
- "timestamp": datetime.now().isoformat()
+ "timestamp": datetime.now().isoformat(),
}
@@ -629,11 +671,12 @@ async def landing_page():
try:
# Get the path to the HTML file
import os
+
html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html")
-
+
with open(html_path, "r", encoding="utf-8") as f:
html_content = f.read()
-
+
return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError:
# Fallback if HTML file is not found
@@ -660,4 +703,5 @@ async def landing_page():
if __name__ == "__main__":
import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
+
+ uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/src/alpine_bits_python/auth.py b/src/alpine_bits_python/auth.py
index 21d29bc..5a7632e 100644
--- a/src/alpine_bits_python/auth.py
+++ b/src/alpine_bits_python/auth.py
@@ -21,7 +21,7 @@ security = HTTPBearer()
API_KEYS = {
# Example API keys - replace with your own secure keys
"wix-webhook-key": "sk_live_your_secure_api_key_here",
- "admin-key": "sk_admin_your_admin_key_here"
+ "admin-key": "sk_admin_your_admin_key_here",
}
# Load API keys from environment if available
@@ -36,19 +36,21 @@ def generate_api_key() -> str:
return f"sk_live_{secrets.token_urlsafe(32)}"
-def validate_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
+def validate_api_key(
+ credentials: HTTPAuthorizationCredentials = Security(security),
+) -> str:
"""
Validate API key from Authorization header.
Expected format: Authorization: Bearer your_api_key_here
"""
token = credentials.credentials
-
+
# Check if the token is in our valid API keys
for key_name, valid_key in API_KEYS.items():
if secrets.compare_digest(token, valid_key):
logger.info(f"Valid API key used: {key_name}")
return key_name
-
+
logger.warning(f"Invalid API key attempted: {token[:10]}...")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -64,19 +66,17 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool:
"""
if not signature or not secret:
return False
-
+
try:
# Remove 'sha256=' prefix if present
- if signature.startswith('sha256='):
+ if signature.startswith("sha256="):
signature = signature[7:]
-
+
# Calculate expected signature
expected_signature = hmac.new(
- secret.encode('utf-8'),
- payload,
- hashlib.sha256
+ secret.encode("utf-8"), payload, hashlib.sha256
).hexdigest()
-
+
# Compare signatures securely
return secrets.compare_digest(signature, expected_signature)
except Exception as e:
@@ -86,21 +86,21 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool:
class APIKeyAuth:
"""Simple API key authentication class"""
-
+
def __init__(self, api_keys: dict):
self.api_keys = api_keys
-
+
def authenticate(self, api_key: str) -> Optional[str]:
"""Authenticate an API key and return the key name if valid"""
for key_name, valid_key in self.api_keys.items():
if secrets.compare_digest(api_key, valid_key):
return key_name
return None
-
+
def add_key(self, name: str, key: str):
"""Add a new API key"""
self.api_keys[name] = key
-
+
def remove_key(self, name: str):
"""Remove an API key"""
if name in self.api_keys:
@@ -108,4 +108,4 @@ class APIKeyAuth:
# Initialize auth system
-auth_system = APIKeyAuth(API_KEYS)
\ No newline at end of file
+auth_system = APIKeyAuth(API_KEYS)
diff --git a/src/alpine_bits_python/config_loader.py b/src/alpine_bits_python/config_loader.py
index 59fcaeb..b207b4d 100644
--- a/src/alpine_bits_python/config_loader.py
+++ b/src/alpine_bits_python/config_loader.py
@@ -1,4 +1,3 @@
-
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
@@ -16,37 +15,45 @@ from annotatedyaml.loader import (
from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid
# --- Voluptuous schemas ---
-database_schema = Schema({
- Required('url'): str
-}, extra=PREVENT_EXTRA)
+database_schema = Schema({Required("url"): str}, extra=PREVENT_EXTRA)
-
-hotel_auth_schema = Schema({
- Required("hotel_id"): str,
- Required("hotel_name"): str,
- Required("username"): str,
- Required("password"): str
-}, extra=PREVENT_EXTRA)
-
-basic_auth_schema = Schema(
- All([hotel_auth_schema], Length(min=1))
+hotel_auth_schema = Schema(
+ {
+ Required("hotel_id"): str,
+ Required("hotel_name"): str,
+ Required("username"): str,
+ Required("password"): str,
+ },
+ extra=PREVENT_EXTRA,
)
-config_schema = Schema({
- Required('database'): database_schema,
- Required('alpine_bits_auth'): basic_auth_schema
-}, extra=PREVENT_EXTRA)
+basic_auth_schema = Schema(All([hotel_auth_schema], Length(min=1)))
-DEFAULT_CONFIG_FILE = 'config.yaml'
+config_schema = Schema(
+ {
+ Required("database"): database_schema,
+ Required("alpine_bits_auth"): basic_auth_schema,
+ },
+ extra=PREVENT_EXTRA,
+)
+
+DEFAULT_CONFIG_FILE = "config.yaml"
class Config:
- def __init__(self, config_folder: str | Path = None, config_name: str = DEFAULT_CONFIG_FILE, testing_mode: bool = False):
+ def __init__(
+ self,
+ config_folder: str | Path = None,
+ config_name: str = DEFAULT_CONFIG_FILE,
+ testing_mode: bool = False,
+ ):
if config_folder is None:
- config_folder = os.environ.get('ALPINEBITS_CONFIG_DIR')
+ config_folder = os.environ.get("ALPINEBITS_CONFIG_DIR")
if not config_folder:
- config_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../config'))
+ config_folder = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "../../config")
+ )
if isinstance(config_folder, str):
config_folder = Path(config_folder)
self.config_folder = config_folder
@@ -61,8 +68,8 @@ class Config:
validated = config_schema(stuff)
except MultipleInvalid as e:
raise ValueError(f"Config validation error: {e}")
- self.database = validated['database']
- self.basic_auth = validated['alpine_bits_auth']
+ self.database = validated["database"]
+ self.basic_auth = validated["alpine_bits_auth"]
self.config = validated
def get(self, key, default=None):
@@ -70,19 +77,20 @@ class Config:
@property
def db_url(self) -> str:
- return self.database['url']
+ return self.database["url"]
@property
def hotel_id(self) -> str:
- return self.basic_auth['hotel_id']
+ return self.basic_auth["hotel_id"]
@property
def hotel_name(self) -> str:
- return self.basic_auth['hotel_name']
+ return self.basic_auth["hotel_name"]
@property
def users(self) -> List[Dict[str, str]]:
- return self.basic_auth['users']
+ return self.basic_auth["users"]
+
# For backward compatibility
def load_config():
diff --git a/src/alpine_bits_python/db.py b/src/alpine_bits_python/db.py
index 09d039b..32a82d0 100644
--- a/src/alpine_bits_python/db.py
+++ b/src/alpine_bits_python/db.py
@@ -5,27 +5,24 @@ import os
Base = declarative_base()
+
# Async SQLAlchemy setup
def get_database_url(config=None):
db_url = None
- if config and 'database' in config and 'url' in config['database']:
- db_url = config['database']['url']
+ if config and "database" in config and "url" in config["database"]:
+ db_url = config["database"]["url"]
if not db_url:
- db_url = os.environ.get('DATABASE_URL')
+ db_url = os.environ.get("DATABASE_URL")
if not db_url:
- db_url = 'sqlite+aiosqlite:///alpinebits.db'
+ db_url = "sqlite+aiosqlite:///alpinebits.db"
return db_url
-DATABASE_URL = get_database_url()
-engine = create_async_engine(DATABASE_URL, echo=True)
-AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
-async def get_async_session():
- async with AsyncSessionLocal() as session:
- yield session
+
+
class Customer(Base):
- __tablename__ = 'customers'
+ __tablename__ = "customers"
id = Column(Integer, primary_key=True)
given_name = Column(String)
contact_id = Column(String, unique=True)
@@ -42,13 +39,14 @@ class Customer(Base):
birth_date = Column(String)
language = Column(String)
address_catalog = Column(Boolean) # Added for XML
- name_title = Column(String) # Added for XML
- reservations = relationship('Reservation', back_populates='customer')
+ name_title = Column(String) # Added for XML
+ reservations = relationship("Reservation", back_populates="customer")
+
class Reservation(Base):
- __tablename__ = 'reservations'
+ __tablename__ = "reservations"
id = Column(Integer, primary_key=True)
- customer_id = Column(Integer, ForeignKey('customers.id'))
+ customer_id = Column(Integer, ForeignKey("customers.id"))
form_id = Column(String, unique=True)
start_date = Column(Date)
end_date = Column(Date)
@@ -70,16 +68,14 @@ class Reservation(Base):
# Add hotel_code and hotel_name for XML
hotel_code = Column(String)
hotel_name = Column(String)
- customer = relationship('Customer', back_populates='reservations')
+ customer = relationship("Customer", back_populates="reservations")
+
class HashedCustomer(Base):
- __tablename__ = 'hashed_customers'
+ __tablename__ = "hashed_customers"
id = Column(Integer, primary_key=True)
customer_id = Column(Integer)
hashed_email = Column(String)
hashed_phone = Column(String)
hashed_name = Column(String)
redacted_at = Column(DateTime)
-
-
-
diff --git a/src/alpine_bits_python/main.py b/src/alpine_bits_python/main.py
index b6cc021..c5f203d 100644
--- a/src/alpine_bits_python/main.py
+++ b/src/alpine_bits_python/main.py
@@ -15,11 +15,16 @@ from .simplified_access import (
HotelReservationIdData,
PhoneTechType,
AlpineBitsFactory,
- OtaMessageType
+ OtaMessageType,
)
# DB and config
-from .db import Customer as DBCustomer, Reservation as DBReservation, HashedCustomer, get_async_session
+from .db import (
+ Customer as DBCustomer,
+ Reservation as DBReservation,
+ HashedCustomer,
+ get_async_session,
+)
from .config_loader import load_config
import hashlib
import json
@@ -29,8 +34,8 @@ import asyncio
from alpine_bits_python import db
-async def main():
+async def main():
print("๐ Starting AlpineBits XML generation script...")
# Load config (yaml, annotatedyaml)
config = load_config()
@@ -40,9 +45,9 @@ async def main():
print(json.dumps(config, indent=2))
# Ensure SQLite DB file exists if using SQLite
- db_url = config.get('database', {}).get('url', '')
- if db_url.startswith('sqlite+aiosqlite:///'):
- db_path = db_url.replace('sqlite+aiosqlite:///', '')
+ db_url = config.get("database", {}).get("url", "")
+ if db_url.startswith("sqlite+aiosqlite:///"):
+ db_path = db_url.replace("sqlite+aiosqlite:///", "")
db_path = os.path.abspath(db_path)
db_dir = os.path.dirname(db_path)
if not os.path.exists(db_dir):
@@ -54,15 +59,17 @@ async def main():
# # Ensure DB schema is created (async)
from .db import engine, Base
+
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async for db in get_async_session():
-
-
# Load data from JSON file
- json_path = os.path.join(os.path.dirname(__file__), '../../test_data/wix_test_data_20250928_132611.json')
- with open(json_path, 'r', encoding='utf-8') as f:
+ json_path = os.path.join(
+ os.path.dirname(__file__),
+ "../../test_data/wix_test_data_20250928_132611.json",
+ )
+ with open(json_path, "r", encoding="utf-8") as f:
wix_data = json.load(f)
data = wix_data["data"]["data"]
@@ -85,8 +92,16 @@ async def main():
language = data.get("contact", {}).get("locale", "en")[:2]
# Dates
- start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value")
- end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value")
+ start_date = (
+ data.get("field:date_picker_a7c8")
+ or data.get("Anreisedatum")
+ or data.get("submissions", [{}])[1].get("value")
+ )
+ end_date = (
+ data.get("field:date_picker_7e65")
+ or data.get("Abreisedatum")
+ or data.get("submissions", [{}])[2].get("value")
+ )
# Room/guest info
num_adults = int(data.get("field:number_7cf5") or 2)
@@ -100,7 +115,7 @@ async def main():
children_ages.append(age)
except ValueError:
logging.warning(f"Invalid age value for {k}: {data[k]}")
-
+
# UTM and offer
utm_fields = [
("utm_Source", "utm_source"),
@@ -147,7 +162,7 @@ async def main():
end_date=date.fromisoformat(end_date) if end_date else None,
num_adults=num_adults,
num_children=num_children,
- children_ages=','.join(str(a) for a in children_ages),
+ children_ages=",".join(str(a) for a in children_ages),
offer=offer,
utm_comment=utm_comment,
created_at=datetime.now(timezone.utc),
@@ -177,9 +192,19 @@ async def main():
def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
- from .simplified_access import CustomerData, GuestCountsFactory, HotelReservationIdData, AlpineBitsFactory, OtaMessageType, CommentData, CommentsData, CommentListItemData
+ from .simplified_access import (
+ CustomerData,
+ GuestCountsFactory,
+ HotelReservationIdData,
+ AlpineBitsFactory,
+ OtaMessageType,
+ CommentData,
+ CommentsData,
+ CommentListItemData,
+ )
from .generated import alpinebits as ab
from datetime import datetime, timezone
+
# Prepare data for XML
phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else []
customer_data = CustomerData(
@@ -200,11 +225,15 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
language=customer.language,
)
alpine_bits_factory = AlpineBitsFactory()
- res_guests = alpine_bits_factory.create_res_guests(customer_data, OtaMessageType.RETRIEVE)
+ res_guests = alpine_bits_factory.create_res_guests(
+ customer_data, OtaMessageType.RETRIEVE
+ )
# Guest counts
children_ages = [int(a) for a in reservation.children_ages.split(",") if a]
- guest_counts = GuestCountsFactory.create_retrieve_guest_counts(reservation.num_adults, children_ages)
+ guest_counts = GuestCountsFactory.create_retrieve_guest_counts(
+ reservation.num_adults, children_ages
+ )
# UniqueID
unique_id = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId(
@@ -214,11 +243,13 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
# TimeSpan
time_span = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan(
start=reservation.start_date.isoformat() if reservation.start_date else None,
- end=reservation.end_date.isoformat() if reservation.end_date else None
+ end=reservation.end_date.isoformat() if reservation.end_date else None,
)
- room_stay = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay(
- time_span=time_span,
- guest_counts=guest_counts,
+ room_stay = (
+ ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay(
+ time_span=time_span,
+ guest_counts=guest_counts,
+ )
)
room_stays = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays(
room_stay=[room_stay],
@@ -231,7 +262,9 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
res_id_source=None,
res_id_source_context="99tales",
)
- hotel_res_id = alpine_bits_factory.create(hotel_res_id_data, OtaMessageType.RETRIEVE)
+ hotel_res_id = alpine_bits_factory.create(
+ hotel_res_id_data, OtaMessageType.RETRIEVE
+ )
hotel_res_ids = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds(
hotel_reservation_id=[hotel_res_id]
)
@@ -244,31 +277,37 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
offer_comment = CommentData(
name=ab.CommentName2.ADDITIONAL_INFO,
text="Angebot/Offerta",
- list_items=[CommentListItemData(
- value=reservation.offer,
- language=customer.language,
- list_item="1",
- )],
+ list_items=[
+ CommentListItemData(
+ value=reservation.offer,
+ language=customer.language,
+ list_item="1",
+ )
+ ],
)
comment = None
if reservation.user_comment:
comment = CommentData(
name=ab.CommentName2.CUSTOMER_COMMENT,
text=reservation.user_comment,
- list_items=[CommentListItemData(
- value="Landing page comment",
- language=customer.language,
- list_item="1",
- )],
+ list_items=[
+ CommentListItemData(
+ value="Landing page comment",
+ language=customer.language,
+ list_item="1",
+ )
+ ],
)
comments = [offer_comment, comment] if comment else [offer_comment]
comments_data = CommentsData(comments=comments)
comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE)
- res_global_info = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo(
- hotel_reservation_ids=hotel_res_ids,
- basic_property_info=basic_property_info,
- comments=comments_xml,
+ res_global_info = (
+ ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo(
+ hotel_reservation_ids=hotel_res_ids,
+ basic_property_info=basic_property_info,
+ comments=comments_xml,
+ )
)
hotel_reservation = ab.OtaResRetrieveRs.ReservationsList.HotelReservation(
@@ -293,6 +332,7 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("โ
Pydantic validation successful!")
from xsdata.formats.dataclass.serializers.config import SerializerConfig
from xsdata_pydantic.bindings import XmlSerializer
+
config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8"
)
@@ -306,15 +346,18 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
print("\n๐ Generated XML:")
print(xml_string)
from xsdata_pydantic.bindings import XmlParser
+
parser = XmlParser()
with open("output.xml", "r", encoding="utf-8") as infile:
xml_content = infile.read()
parsed_result = parser.from_string(xml_content, ab.OtaResRetrieveRs)
print("โ
Round-trip validation successful!")
- print(f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}")
+ print(
+ f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}"
+ )
except Exception as e:
print(f"โ Validation/Serialization failed: {e}")
+
if __name__ == "__main__":
asyncio.run(main())
-
diff --git a/src/alpine_bits_python/models.py b/src/alpine_bits_python/models.py
index ea56a92..27b6f31 100644
--- a/src/alpine_bits_python/models.py
+++ b/src/alpine_bits_python/models.py
@@ -5,18 +5,23 @@ from datetime import datetime
class AlpineBitsHandshakeRequest(BaseModel):
"""Model for AlpineBits handshake request data"""
- action: str = Field(..., description="Action parameter, typically 'OTA_Ping:Handshaking'")
+
+ action: str = Field(
+ ..., description="Action parameter, typically 'OTA_Ping:Handshaking'"
+ )
request_xml: Optional[str] = Field(None, description="XML request document")
class ContactName(BaseModel):
"""Contact name structure"""
+
first: Optional[str] = None
last: Optional[str] = None
class ContactAddress(BaseModel):
"""Contact address structure"""
+
street: Optional[str] = None
city: Optional[str] = None
state: Optional[str] = None
@@ -26,6 +31,7 @@ class ContactAddress(BaseModel):
class Contact(BaseModel):
"""Contact information from Wix form"""
+
name: Optional[ContactName] = None
email: Optional[str] = None
locale: Optional[str] = None
@@ -43,12 +49,14 @@ class Contact(BaseModel):
class SubmissionPdf(BaseModel):
"""PDF submission structure"""
+
url: Optional[str] = None
filename: Optional[str] = None
class WixFormSubmission(BaseModel):
"""Model for Wix form submission data"""
+
formName: str
submissions: List[Dict[str, Any]] = Field(default_factory=list)
submissionTime: str
@@ -59,7 +67,7 @@ class WixFormSubmission(BaseModel):
submissionPdf: Optional[SubmissionPdf] = None
formId: str
contact: Optional[Contact] = None
-
+
# Dynamic form fields - these will capture all field:* entries
class Config:
- extra = "allow" # Allow additional fields not defined in the model
\ No newline at end of file
+ extra = "allow" # Allow additional fields not defined in the model
diff --git a/src/alpine_bits_python/rate_limit.py b/src/alpine_bits_python/rate_limit.py
index 958e062..638ea59 100644
--- a/src/alpine_bits_python/rate_limit.py
+++ b/src/alpine_bits_python/rate_limit.py
@@ -11,11 +11,12 @@ logger = logging.getLogger(__name__)
# Rate limiting configuration
DEFAULT_RATE_LIMIT = "10/minute" # 10 requests per minute per IP
WEBHOOK_RATE_LIMIT = "60/minute" # 60 webhook requests per minute per IP
-BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
+BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP
# Redis configuration for distributed rate limiting (optional)
REDIS_URL = os.getenv("REDIS_URL", None)
+
def get_remote_address_with_forwarded(request: Request):
"""
Get client IP address, considering forwarded headers from proxies/load balancers
@@ -25,11 +26,11 @@ def get_remote_address_with_forwarded(request: Request):
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
-
+
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
-
+
# Fallback to direct connection IP
return get_remote_address(request)
@@ -39,14 +40,16 @@ if REDIS_URL:
# Use Redis for distributed rate limiting (recommended for production)
try:
import redis
+
redis_client = redis.from_url(REDIS_URL)
limiter = Limiter(
- key_func=get_remote_address_with_forwarded,
- storage_uri=REDIS_URL
+ key_func=get_remote_address_with_forwarded, storage_uri=REDIS_URL
)
logger.info("Rate limiting initialized with Redis backend")
except Exception as e:
- logger.warning(f"Failed to connect to Redis: {e}. Using in-memory rate limiting.")
+ logger.warning(
+ f"Failed to connect to Redis: {e}. Using in-memory rate limiting."
+ )
limiter = Limiter(key_func=get_remote_address_with_forwarded)
else:
# Use in-memory rate limiting (fine for single instance)
@@ -65,7 +68,7 @@ def get_api_key_identifier(request: Request) -> str:
api_key = auth_header[7:] # Remove "Bearer " prefix
# Use first 10 chars of API key as identifier (don't log full key)
return f"api_key:{api_key[:10]}"
-
+
# Fallback to IP address
return f"ip:{get_remote_address_with_forwarded(request)}"
@@ -77,10 +80,10 @@ def api_key_rate_limit_key(request: Request):
# Rate limiting decorators for different endpoint types
webhook_limiter = Limiter(
- key_func=api_key_rate_limit_key,
- storage_uri=REDIS_URL if REDIS_URL else None
+ key_func=api_key_rate_limit_key, storage_uri=REDIS_URL if REDIS_URL else None
)
+
# Custom rate limit exceeded handler
def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Custom handler for rate limit exceeded"""
@@ -88,11 +91,11 @@ def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
f"Rate limit exceeded for {get_remote_address_with_forwarded(request)}: "
f"{exc.detail}"
)
-
+
response = _rate_limit_exceeded_handler(request, exc)
-
+
# Add custom headers
response.headers["X-RateLimit-Limit"] = str(exc.retry_after)
response.headers["X-RateLimit-Retry-After"] = str(exc.retry_after)
-
- return response
\ No newline at end of file
+
+ return response
diff --git a/src/alpine_bits_python/reservations.py b/src/alpine_bits_python/reservations.py
index 70d483a..5c8f238 100644
--- a/src/alpine_bits_python/reservations.py
+++ b/src/alpine_bits_python/reservations.py
@@ -1,7 +1,2 @@
-
-
-
def parse_form(form: dict):
-
pass
-
\ No newline at end of file
diff --git a/src/alpine_bits_python/run_api.py b/src/alpine_bits_python/run_api.py
index 9234936..28d921b 100644
--- a/src/alpine_bits_python/run_api.py
+++ b/src/alpine_bits_python/run_api.py
@@ -2,14 +2,21 @@
"""
Startup script for the Wix Form Handler API
"""
+
+import os
import uvicorn
from .api import app
if __name__ == "__main__":
+ db_path = "alpinebits.db" # Adjust path if needed
+ if os.path.exists(db_path):
+ os.remove(db_path)
+ print(f"Deleted database file: {db_path}")
+
uvicorn.run(
"alpine_bits_python.api:app",
host="0.0.0.0",
port=8080,
reload=True, # Enable auto-reload during development
- log_level="info"
- )
\ No newline at end of file
+ log_level="info",
+ )
diff --git a/src/alpine_bits_python/scripts/setup_security.py b/src/alpine_bits_python/scripts/setup_security.py
index 38ebc15..e565a85 100644
--- a/src/alpine_bits_python/scripts/setup_security.py
+++ b/src/alpine_bits_python/scripts/setup_security.py
@@ -2,6 +2,7 @@
"""
Configuration and setup script for the Wix Form Handler API
"""
+
import os
import sys
import secrets
@@ -11,80 +12,83 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from alpine_bits_python.auth import generate_api_key
+
def generate_secure_keys():
"""Generate secure API keys for the application"""
-
+
print("๐ Generating Secure API Keys")
print("=" * 50)
-
+
# Generate API keys
wix_api_key = generate_api_key()
admin_api_key = generate_api_key()
webhook_secret = secrets.token_urlsafe(32)
-
+
print(f"๐ Wix Webhook API Key: {wix_api_key}")
print(f"๐ Admin API Key: {admin_api_key}")
print(f"๐ Webhook Secret: {webhook_secret}")
-
+
print("\n๐ Environment Variables")
print("-" * 30)
print(f"export WIX_API_KEY='{wix_api_key}'")
print(f"export ADMIN_API_KEY='{admin_api_key}'")
print(f"export WIX_WEBHOOK_SECRET='{webhook_secret}'")
print(f"export REDIS_URL='redis://localhost:6379' # Optional for production")
-
+
print("\n๐ง .env File Content")
print("-" * 20)
print(f"WIX_API_KEY={wix_api_key}")
print(f"ADMIN_API_KEY={admin_api_key}")
print(f"WIX_WEBHOOK_SECRET={webhook_secret}")
print("REDIS_URL=redis://localhost:6379")
-
+
# Optionally write to .env file
create_env = input("\nโ Create .env file? (y/n): ").lower().strip()
- if create_env == 'y':
+ if create_env == "y":
# Create .env in the project root (two levels up from scripts)
- env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env')
- with open(env_path, 'w') as f:
+ env_path = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env"
+ )
+ with open(env_path, "w") as f:
f.write(f"WIX_API_KEY={wix_api_key}\n")
f.write(f"ADMIN_API_KEY={admin_api_key}\n")
f.write(f"WIX_WEBHOOK_SECRET={webhook_secret}\n")
f.write("REDIS_URL=redis://localhost:6379\n")
print(f"โ
.env file created at {env_path}!")
print("โ ๏ธ Add .env to your .gitignore file!")
-
+
print("\n๐ Wix Configuration")
print("-" * 20)
print("1. In your Wix site, go to Settings > Webhooks")
print("2. Add webhook URL: https://yourdomain.com/webhook/wix-form")
print("3. Add custom header: Authorization: Bearer " + wix_api_key)
print("4. Optionally configure webhook signature with the secret above")
-
+
return {
- 'wix_api_key': wix_api_key,
- 'admin_api_key': admin_api_key,
- 'webhook_secret': webhook_secret
+ "wix_api_key": wix_api_key,
+ "admin_api_key": admin_api_key,
+ "webhook_secret": webhook_secret,
}
def check_security_setup():
"""Check current security configuration"""
-
+
print("๐ Security Configuration Check")
print("=" * 40)
-
+
# Check environment variables
- wix_key = os.getenv('WIX_API_KEY')
- admin_key = os.getenv('ADMIN_API_KEY')
- webhook_secret = os.getenv('WIX_WEBHOOK_SECRET')
- redis_url = os.getenv('REDIS_URL')
-
+ wix_key = os.getenv("WIX_API_KEY")
+ admin_key = os.getenv("ADMIN_API_KEY")
+ webhook_secret = os.getenv("WIX_WEBHOOK_SECRET")
+ redis_url = os.getenv("REDIS_URL")
+
print("Environment Variables:")
print(f" WIX_API_KEY: {'โ
Set' if wix_key else 'โ Not set'}")
print(f" ADMIN_API_KEY: {'โ
Set' if admin_key else 'โ Not set'}")
print(f" WIX_WEBHOOK_SECRET: {'โ
Set' if webhook_secret else 'โ Not set'}")
print(f" REDIS_URL: {'โ
Set' if redis_url else 'โ ๏ธ Optional (using in-memory)'}")
-
+
# Security recommendations
print("\n๐ก๏ธ Security Recommendations:")
if not wix_key:
@@ -94,19 +98,19 @@ def check_security_setup():
print(" โ ๏ธ WIX_API_KEY should be longer for better security")
else:
print(" โ
WIX_API_KEY looks secure")
-
+
if not admin_key:
print(" โ Set ADMIN_API_KEY environment variable")
elif wix_key and admin_key == wix_key:
print(" โ Admin and Wix keys should be different")
else:
print(" โ
ADMIN_API_KEY configured")
-
+
if not webhook_secret:
print(" โ ๏ธ Consider setting WIX_WEBHOOK_SECRET for signature validation")
else:
print(" โ
Webhook signature validation enabled")
-
+
print("\n๐ Production Checklist:")
print(" - Use HTTPS in production")
print(" - Set up Redis for distributed rate limiting")
@@ -118,12 +122,14 @@ def check_security_setup():
if __name__ == "__main__":
print("๐ Wix Form Handler API - Security Setup")
print("=" * 50)
-
- choice = input("Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): ").strip()
-
+
+ choice = input(
+ "Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): "
+ ).strip()
+
if choice == "1":
generate_secure_keys()
elif choice == "2":
check_security_setup()
else:
- print("Invalid choice. Please run again and choose 1 or 2.")
\ No newline at end of file
+ print("Invalid choice. Please run again and choose 1 or 2.")
diff --git a/src/alpine_bits_python/scripts/test_api.py b/src/alpine_bits_python/scripts/test_api.py
index 76ed30d..021f621 100644
--- a/src/alpine_bits_python/scripts/test_api.py
+++ b/src/alpine_bits_python/scripts/test_api.py
@@ -2,6 +2,7 @@
"""
Test script for the Secure Wix Form Handler API
"""
+
import asyncio
import aiohttp
import json
@@ -30,7 +31,7 @@ SAMPLE_WIX_DATA = {
"submissionsLink": "https://www.wix.app/forms/test-form/submissions",
"submissionPdf": {
"url": "https://example.com/submission.pdf",
- "filename": "submission.pdf"
+ "filename": "submission.pdf",
},
"formId": "test-form-789",
"field:email_5139": "test@example.com",
@@ -43,10 +44,7 @@ SAMPLE_WIX_DATA = {
"field:alter_kind_4": "12",
"field:long_answer_3524": "This is a long answer field with more details about the inquiry.",
"contact": {
- "name": {
- "first": "John",
- "last": "Doe"
- },
+ "name": {"first": "John", "last": "Doe"},
"email": "test@example.com",
"locale": "de",
"company": "Test Company",
@@ -57,29 +55,29 @@ SAMPLE_WIX_DATA = {
"street": "Test Street 123",
"city": "Test City",
"country": "Germany",
- "postalCode": "12345"
+ "postalCode": "12345",
},
"jobTitle": "Manager",
"phone": "+1234567890",
"createdDate": "2024-03-20T10:00:00.000Z",
- "updatedDate": "2024-03-20T10:30:00.000Z"
- }
+ "updatedDate": "2024-03-20T10:30:00.000Z",
+ },
}
async def test_api():
"""Test the API endpoints with authentication"""
-
+
headers_with_auth = {
"Content-Type": "application/json",
- "Authorization": f"Bearer {TEST_API_KEY}"
+ "Authorization": f"Bearer {TEST_API_KEY}",
}
-
+
admin_headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {ADMIN_API_KEY}"
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {ADMIN_API_KEY}",
}
-
+
async with aiohttp.ClientSession() as session:
# Test health endpoint (no auth required)
print("1. Testing health endpoint (no auth)...")
@@ -89,7 +87,7 @@ async def test_api():
print(f" โ
Health check: {response.status} - {result.get('status')}")
except Exception as e:
print(f" โ Health check failed: {e}")
-
+
# Test root endpoint (no auth required)
print("\n2. Testing root endpoint (no auth)...")
try:
@@ -98,87 +96,94 @@ async def test_api():
print(f" โ
Root: {response.status} - {result.get('message')}")
except Exception as e:
print(f" โ Root endpoint failed: {e}")
-
+
# Test webhook endpoint without auth (should fail)
print("\n3. Testing webhook endpoint WITHOUT auth (should fail)...")
try:
async with session.post(
f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA,
- headers={"Content-Type": "application/json"}
+ headers={"Content-Type": "application/json"},
) as response:
result = await response.json()
if response.status == 401:
- print(f" โ
Correctly rejected: {response.status} - {result.get('detail')}")
+ print(
+ f" โ
Correctly rejected: {response.status} - {result.get('detail')}"
+ )
else:
print(f" โ Unexpected response: {response.status} - {result}")
except Exception as e:
print(f" โ Test failed: {e}")
-
+
# Test webhook endpoint with valid auth
print("\n4. Testing webhook endpoint WITH valid auth...")
try:
async with session.post(
f"{BASE_URL}/api/webhook/wix-form",
json=SAMPLE_WIX_DATA,
- headers=headers_with_auth
+ headers=headers_with_auth,
) as response:
result = await response.json()
if response.status == 200:
- print(f" โ
Webhook success: {response.status} - {result.get('status')}")
+ print(
+ f" โ
Webhook success: {response.status} - {result.get('status')}"
+ )
else:
print(f" โ Webhook failed: {response.status} - {result}")
except Exception as e:
print(f" โ Webhook test failed: {e}")
-
+
# Test test endpoint with auth
print("\n5. Testing simple test endpoint WITH auth...")
try:
async with session.post(
f"{BASE_URL}/api/webhook/wix-form/test",
json={"test": "data", "timestamp": datetime.now().isoformat()},
- headers=headers_with_auth
+ headers=headers_with_auth,
) as response:
result = await response.json()
if response.status == 200:
- print(f" โ
Test endpoint: {response.status} - {result.get('status')}")
+ print(
+ f" โ
Test endpoint: {response.status} - {result.get('status')}"
+ )
else:
print(f" โ Test endpoint failed: {response.status} - {result}")
except Exception as e:
print(f" โ Test endpoint failed: {e}")
-
+
# Test rate limiting by making multiple rapid requests
print("\n6. Testing rate limiting (making 5 rapid requests)...")
rate_limit_test_count = 0
for i in range(5):
try:
- async with session.get(
- f"{BASE_URL}/api/health"
- ) as response:
+ async with session.get(f"{BASE_URL}/api/health") as response:
if response.status == 200:
rate_limit_test_count += 1
elif response.status == 429:
- print(f" โ
Rate limit triggered on request {i+1}")
+ print(f" โ
Rate limit triggered on request {i + 1}")
break
except Exception as e:
print(f" โ Rate limit test failed: {e}")
break
-
+
if rate_limit_test_count == 5:
print(" โน๏ธ No rate limit reached (normal for low request volume)")
-
+
# Test admin endpoint (if admin key is configured)
print("\n7. Testing admin stats endpoint...")
try:
async with session.get(
- f"{BASE_URL}/api/admin/stats",
- headers=admin_headers
+ f"{BASE_URL}/api/admin/stats", headers=admin_headers
) as response:
result = await response.json()
if response.status == 200:
- print(f" โ
Admin stats: {response.status} - {result.get('status')}")
+ print(
+ f" โ
Admin stats: {response.status} - {result.get('status')}"
+ )
elif response.status == 401:
- print(f" โ ๏ธ Admin access denied (API key not configured): {result.get('detail')}")
+ print(
+ f" โ ๏ธ Admin access denied (API key not configured): {result.get('detail')}"
+ )
else:
print(f" โ Admin endpoint failed: {response.status} - {result}")
except Exception as e:
@@ -189,12 +194,18 @@ if __name__ == "__main__":
print("๐ Testing Secure Wix Form Handler API...")
print("=" * 60)
print("๐ API URL:", BASE_URL)
- print("๐ Using API Key:", TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY)
- print("๐ Using Admin Key:", ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY)
+ print(
+ "๐ Using API Key:",
+ TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY,
+ )
+ print(
+ "๐ Using Admin Key:",
+ ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY,
+ )
print("=" * 60)
print("Make sure the API is running with: python3 run_api.py")
print("-" * 60)
-
+
try:
asyncio.run(test_api())
print("\n" + "=" * 60)
@@ -207,4 +218,4 @@ if __name__ == "__main__":
print("3. Add Authorization header: Bearer your_api_key")
except Exception as e:
print(f"\nโ Error testing API: {e}")
- print("Make sure the API server is running!")
\ No newline at end of file
+ print("Make sure the API server is running!")
diff --git a/src/alpine_bits_python/simplified_access.py b/src/alpine_bits_python/simplified_access.py
index 4b26d42..b0ecd28 100644
--- a/src/alpine_bits_python/simplified_access.py
+++ b/src/alpine_bits_python/simplified_access.py
@@ -15,15 +15,26 @@ NotifHotelReservationId = OtaHotelResNotifRq.HotelReservations.HotelReservation.
RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId
# Define type aliases for Comments types
-NotifComments = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments
-RetrieveComments = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments
-NotifComment = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment
-RetrieveComment = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment
+NotifComments = (
+ OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments
+)
+RetrieveComments = (
+ OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments
+)
+NotifComment = (
+ OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment
+)
+RetrieveComment = (
+ OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment
+)
# type aliases for GuestCounts
-NotifGuestCounts = OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts
-RetrieveGuestCounts = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts
-
+NotifGuestCounts = (
+ OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts
+)
+RetrieveGuestCounts = (
+ OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts
+)
# phonetechtype enum 1,3,5 voice, fax, mobile
@@ -36,12 +47,13 @@ class PhoneTechType(Enum):
# Enum to specify which OTA message type to use
class OtaMessageType(Enum):
NOTIF = "notification" # For OtaHotelResNotifRq
- RETRIEVE = "retrieve" # For OtaResRetrieveRs
+ RETRIEVE = "retrieve" # For OtaResRetrieveRs
@dataclass
class KidsAgeData:
"""Data class to hold information about children's ages."""
+
ages: list[int]
@@ -77,9 +89,10 @@ class CustomerData:
class GuestCountsFactory:
-
@staticmethod
- def create_notif_guest_counts(adults: int, kids: Optional[list[int]] = None) -> NotifGuestCounts:
+ def create_notif_guest_counts(
+ adults: int, kids: Optional[list[int]] = None
+ ) -> NotifGuestCounts:
"""
Create a GuestCounts object for OtaHotelResNotifRq.
:param adults: Number of adults
@@ -89,18 +102,23 @@ class GuestCountsFactory:
return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts)
@staticmethod
- def create_retrieve_guest_counts(adults: int, kids: Optional[list[int]] = None) -> RetrieveGuestCounts:
+ def create_retrieve_guest_counts(
+ adults: int, kids: Optional[list[int]] = None
+ ) -> RetrieveGuestCounts:
"""
Create a GuestCounts object for OtaResRetrieveRs.
:param adults: Number of adults
:param kids: List of ages for each kid (optional)
:return: GuestCounts instance
"""
- return GuestCountsFactory._create_guest_counts(adults, kids, RetrieveGuestCounts)
+ return GuestCountsFactory._create_guest_counts(
+ adults, kids, RetrieveGuestCounts
+ )
-
@staticmethod
- def _create_guest_counts(adults: int, kids: Optional[list[int]], guest_counts_class: type) -> Any:
+ def _create_guest_counts(
+ adults: int, kids: Optional[list[int]], guest_counts_class: type
+ ) -> Any:
"""
Internal method to create a GuestCounts object of the specified type.
:param adults: Number of adults
@@ -356,9 +374,10 @@ class HotelReservationIdFactory:
)
-@dataclass
+@dataclass
class CommentListItemData:
"""Simple data class to hold comment list item information."""
+
value: str # The text content of the list item
list_item: str # Numeric identifier (pattern: [0-9]+)
language: str # Two-letter language code (pattern: [a-z][a-z])
@@ -367,6 +386,7 @@ class CommentListItemData:
@dataclass
class CommentData:
"""Simple data class to hold comment information without nested type constraints."""
+
name: CommentName2 # Required: "included services", "customer comment", "additional info"
text: Optional[str] = None # Optional text content
list_items: list[CommentListItemData] = None # Optional list items
@@ -379,6 +399,7 @@ class CommentData:
@dataclass
class CommentsData:
"""Simple data class to hold multiple comments (1-3 max)."""
+
comments: list[CommentData] = None # 1-3 comments maximum
def __post_init__(self):
@@ -388,21 +409,23 @@ class CommentsData:
class CommentFactory:
"""Factory class to create Comment instances for both OtaHotelResNotifRq and OtaResRetrieveRs."""
-
+
@staticmethod
def create_notif_comments(data: CommentsData) -> NotifComments:
"""Create Comments for OtaHotelResNotifRq."""
return CommentFactory._create_comments(NotifComments, NotifComment, data)
-
+
@staticmethod
def create_retrieve_comments(data: CommentsData) -> RetrieveComments:
"""Create Comments for OtaResRetrieveRs."""
return CommentFactory._create_comments(RetrieveComments, RetrieveComment, data)
-
+
@staticmethod
- def _create_comments(comments_class: type, comment_class: type, data: CommentsData) -> Any:
+ def _create_comments(
+ comments_class: type, comment_class: type, data: CommentsData
+ ) -> Any:
"""Internal method to create comments of the specified type."""
-
+
comments_list = []
for comment_data in data.comments:
# Create list items
@@ -411,55 +434,53 @@ class CommentFactory:
list_item = comment_class.ListItem(
value=item_data.value,
list_item=item_data.list_item,
- language=item_data.language
+ language=item_data.language,
)
list_items.append(list_item)
-
+
# Create comment
comment = comment_class(
- name=comment_data.name,
- text=comment_data.text,
- list_item=list_items
+ name=comment_data.name, text=comment_data.text, list_item=list_items
)
comments_list.append(comment)
-
+
# Create comments container
return comments_class(comment=comments_list)
-
+
@staticmethod
def from_notif_comments(comments: NotifComments) -> CommentsData:
"""Convert NotifComments back to CommentsData."""
return CommentFactory._comments_to_data(comments)
-
+
@staticmethod
def from_retrieve_comments(comments: RetrieveComments) -> CommentsData:
"""Convert RetrieveComments back to CommentsData."""
return CommentFactory._comments_to_data(comments)
-
+
@staticmethod
def _comments_to_data(comments: Any) -> CommentsData:
"""Internal method to convert any comments type to CommentsData."""
-
+
comments_data_list = []
for comment in comments.comment:
# Extract list items
list_items_data = []
if comment.list_item:
for list_item in comment.list_item:
- list_items_data.append(CommentListItemData(
- value=list_item.value,
- list_item=list_item.list_item,
- language=list_item.language
- ))
-
+ list_items_data.append(
+ CommentListItemData(
+ value=list_item.value,
+ list_item=list_item.list_item,
+ language=list_item.language,
+ )
+ )
+
# Extract comment data
comment_data = CommentData(
- name=comment.name,
- text=comment.text,
- list_items=list_items_data
+ name=comment.name, text=comment.text, list_items=list_items_data
)
comments_data_list.append(comment_data)
-
+
return CommentsData(comments=comments_data_list)
@@ -529,16 +550,19 @@ class ResGuestFactory:
class AlpineBitsFactory:
"""Unified factory class for creating AlpineBits objects with a simple interface."""
-
+
@staticmethod
- def create(data: Union[CustomerData, HotelReservationIdData, CommentsData], message_type: OtaMessageType) -> Any:
+ def create(
+ data: Union[CustomerData, HotelReservationIdData, CommentsData],
+ message_type: OtaMessageType,
+ ) -> Any:
"""
Create an AlpineBits object based on the data type and message type.
-
+
Args:
data: The data object (CustomerData, HotelReservationIdData, CommentsData, etc.)
message_type: Whether to create for NOTIF or RETRIEVE message types
-
+
Returns:
The appropriate AlpineBits object based on the data type and message type
"""
@@ -547,31 +571,35 @@ class AlpineBitsFactory:
return CustomerFactory.create_notif_customer(data)
else:
return CustomerFactory.create_retrieve_customer(data)
-
+
elif isinstance(data, HotelReservationIdData):
if message_type == OtaMessageType.NOTIF:
return HotelReservationIdFactory.create_notif_hotel_reservation_id(data)
else:
- return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(data)
-
+ return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
+ data
+ )
+
elif isinstance(data, CommentsData):
if message_type == OtaMessageType.NOTIF:
return CommentFactory.create_notif_comments(data)
else:
return CommentFactory.create_retrieve_comments(data)
-
+
else:
raise ValueError(f"Unsupported data type: {type(data)}")
-
+
@staticmethod
- def create_res_guests(customer_data: CustomerData, message_type: OtaMessageType) -> Union[NotifResGuests, RetrieveResGuests]:
+ def create_res_guests(
+ customer_data: CustomerData, message_type: OtaMessageType
+ ) -> Union[NotifResGuests, RetrieveResGuests]:
"""
Create a complete ResGuests structure with a primary customer.
-
+
Args:
customer_data: The customer data
message_type: Whether to create for NOTIF or RETRIEVE message types
-
+
Returns:
The appropriate ResGuests object
"""
@@ -579,43 +607,45 @@ class AlpineBitsFactory:
return ResGuestFactory.create_notif_res_guests(customer_data)
else:
return ResGuestFactory.create_retrieve_res_guests(customer_data)
-
+
@staticmethod
- def extract_data(obj: Any) -> Union[CustomerData, HotelReservationIdData, CommentsData]:
+ def extract_data(
+ obj: Any,
+ ) -> Union[CustomerData, HotelReservationIdData, CommentsData]:
"""
Extract data from an AlpineBits object back to a simple data class.
-
+
Args:
obj: The AlpineBits object to extract data from
-
+
Returns:
The appropriate data object
"""
# Check if it's a Customer object
- if hasattr(obj, 'person_name') and hasattr(obj.person_name, 'given_name'):
+ if hasattr(obj, "person_name") and hasattr(obj.person_name, "given_name"):
if isinstance(obj, NotifCustomer):
return CustomerFactory.from_notif_customer(obj)
elif isinstance(obj, RetrieveCustomer):
return CustomerFactory.from_retrieve_customer(obj)
-
+
# Check if it's a HotelReservationId object
- elif hasattr(obj, 'res_id_type'):
+ elif hasattr(obj, "res_id_type"):
if isinstance(obj, NotifHotelReservationId):
return HotelReservationIdFactory.from_notif_hotel_reservation_id(obj)
elif isinstance(obj, RetrieveHotelReservationId):
return HotelReservationIdFactory.from_retrieve_hotel_reservation_id(obj)
-
+
# Check if it's a Comments object
- elif hasattr(obj, 'comment'):
+ elif hasattr(obj, "comment"):
if isinstance(obj, NotifComments):
return CommentFactory.from_notif_comments(obj)
elif isinstance(obj, RetrieveComments):
return CommentFactory.from_retrieve_comments(obj)
-
+
# Check if it's a ResGuests object
- elif hasattr(obj, 'res_guest'):
+ elif hasattr(obj, "res_guest"):
return ResGuestFactory.extract_primary_customer(obj)
-
+
else:
raise ValueError(f"Unsupported object type: {type(obj)}")
@@ -733,70 +763,74 @@ if __name__ == "__main__":
# Verify roundtrip conversion
print("Roundtrip conversion successful:", customer_data == extracted_data)
-
+
print("\n--- Unified AlpineBitsFactory Examples ---")
-
+
# Much simpler approach - single factory with enum parameter!
print("=== Customer Creation ===")
notif_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF)
retrieve_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE)
print("Created customers using unified factory")
-
+
print("=== HotelReservationId Creation ===")
reservation_id_data = HotelReservationIdData(
- res_id_type="123",
- res_id_value="RESERVATION-456",
- res_id_source="HOTEL_SYSTEM"
+ res_id_type="123", res_id_value="RESERVATION-456", res_id_source="HOTEL_SYSTEM"
)
notif_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.NOTIF)
- retrieve_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.RETRIEVE)
+ retrieve_res_id = AlpineBitsFactory.create(
+ reservation_id_data, OtaMessageType.RETRIEVE
+ )
print("Created reservation IDs using unified factory")
-
+
print("=== Comments Creation ===")
- comments_data = CommentsData(comments=[
- CommentData(
- name=CommentName2.CUSTOMER_COMMENT,
- text="This is a customer comment about the reservation",
- list_items=[
- CommentListItemData(
- value="Special dietary requirements: vegetarian",
- list_item="1",
- language="en"
- ),
- CommentListItemData(
- value="Late arrival expected",
- list_item="2",
- language="en"
- )
- ]
- ),
- CommentData(
- name=CommentName2.ADDITIONAL_INFO,
- text="Additional information about the stay"
- )
- ])
+ comments_data = CommentsData(
+ comments=[
+ CommentData(
+ name=CommentName2.CUSTOMER_COMMENT,
+ text="This is a customer comment about the reservation",
+ list_items=[
+ CommentListItemData(
+ value="Special dietary requirements: vegetarian",
+ list_item="1",
+ language="en",
+ ),
+ CommentListItemData(
+ value="Late arrival expected", list_item="2", language="en"
+ ),
+ ],
+ ),
+ CommentData(
+ name=CommentName2.ADDITIONAL_INFO,
+ text="Additional information about the stay",
+ ),
+ ]
+ )
notif_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.NOTIF)
retrieve_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.RETRIEVE)
print("Created comments using unified factory")
-
+
print("=== ResGuests Creation ===")
- notif_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF)
- retrieve_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE)
+ notif_res_guests = AlpineBitsFactory.create_res_guests(
+ customer_data, OtaMessageType.NOTIF
+ )
+ retrieve_res_guests = AlpineBitsFactory.create_res_guests(
+ customer_data, OtaMessageType.RETRIEVE
+ )
print("Created ResGuests using unified factory")
-
+
print("=== Data Extraction ===")
# Extract data back using unified interface
extracted_customer_data = AlpineBitsFactory.extract_data(notif_customer)
extracted_res_id_data = AlpineBitsFactory.extract_data(notif_res_id)
extracted_comments_data = AlpineBitsFactory.extract_data(retrieve_comments)
extracted_from_res_guests = AlpineBitsFactory.extract_data(retrieve_res_guests)
-
+
print("Data extraction successful:")
print("- Customer roundtrip:", customer_data == extracted_customer_data)
print("- ReservationId roundtrip:", reservation_id_data == extracted_res_id_data)
print("- Comments roundtrip:", comments_data == extracted_comments_data)
print("- ResGuests roundtrip:", customer_data == extracted_from_res_guests)
-
+
print("\n--- Comparison with old approach ---")
print("Old way required multiple imports and knowing specific factory methods")
print("New way: single import, single factory, enum parameter to specify type!")
diff --git a/src/alpine_bits_python/util/__init__.py b/src/alpine_bits_python/util/__init__.py
index c86dd84..7eff50a 100644
--- a/src/alpine_bits_python/util/__init__.py
+++ b/src/alpine_bits_python/util/__init__.py
@@ -1 +1 @@
-"""Utility functions for alpine_bits_python."""
\ No newline at end of file
+"""Utility functions for alpine_bits_python."""
diff --git a/src/alpine_bits_python/util/__main__.py b/src/alpine_bits_python/util/__main__.py
index beb47d9..16f9496 100644
--- a/src/alpine_bits_python/util/__main__.py
+++ b/src/alpine_bits_python/util/__main__.py
@@ -1,5 +1,6 @@
"""Entry point for util package."""
+
from .handshake_util import main
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/src/alpine_bits_python/util/handshake_util.py b/src/alpine_bits_python/util/handshake_util.py
index f82cbae..74409dd 100644
--- a/src/alpine_bits_python/util/handshake_util.py
+++ b/src/alpine_bits_python/util/handshake_util.py
@@ -2,26 +2,22 @@ from ..generated.alpinebits import OtaPingRq, OtaPingRs
from xsdata_pydantic.bindings import XmlParser
-
-
def main():
# test parsing a ping request sample
- path = "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml"
+ path = (
+ "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml"
+ )
- with open(
- path, "r", encoding="utf-8") as f:
+ with open(path, "r", encoding="utf-8") as f:
xml = f.read()
# Parse the XML into the request object
- # Test parsing back
-
+ # Test parsing back
parser = XmlParser()
-
-
parsed_result = parser.from_string(xml, OtaPingRs)
print(parsed_result.echo_data)
@@ -34,19 +30,14 @@ def main():
print(warning.content[0])
-
-
-
# save json in echo_data to file with indents
output_path = "echo_data_response.json"
with open(output_path, "w", encoding="utf-8") as out_f:
import json
+
json.dump(json.loads(parsed_result.echo_data), out_f, indent=4)
print(f"Saved echo_data json to {output_path}")
-
if __name__ == "__main__":
-
-
- main()
\ No newline at end of file
+ main()
diff --git a/start_api.py b/start_api.py
index 2e84fd3..fb51cd2 100644
--- a/start_api.py
+++ b/start_api.py
@@ -2,12 +2,13 @@
"""
Convenience launcher for the Wix Form Handler API
"""
+
import os
import subprocess
# Change to src directory
-src_dir = os.path.join(os.path.dirname(__file__), 'src/alpine_bits_python')
+src_dir = os.path.join(os.path.dirname(__file__), "src/alpine_bits_python")
# Run the API using uv
if __name__ == "__main__":
- subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")])
\ No newline at end of file
+ subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")])
diff --git a/test/test_discovery.py b/test/test_discovery.py
index 28b347a..2eda866 100644
--- a/test/test_discovery.py
+++ b/test/test_discovery.py
@@ -5,57 +5,63 @@ discovers implemented vs unimplemented actions.
"""
from alpine_bits_python.alpinebits_server import (
- ServerCapabilities,
- AlpineBitsAction,
- AlpineBitsActionName,
- Version,
+ ServerCapabilities,
+ AlpineBitsAction,
+ AlpineBitsActionName,
+ Version,
AlpineBitsResponse,
- HttpStatusCode
+ HttpStatusCode,
)
import asyncio
+
class NewImplementedAction(AlpineBitsAction):
"""A new action that IS implemented."""
-
+
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_INFO_INFO
self.version = Version.V2024_10
-
- async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse:
+
+ async def handle(
+ self, action: str, request_xml: str, version: Version
+ ) -> AlpineBitsResponse:
"""This action is implemented."""
return AlpineBitsResponse("Implemented!", HttpStatusCode.OK)
+
class NewUnimplementedAction(AlpineBitsAction):
"""A new action that is NOT implemented (no handle override)."""
-
+
def __init__(self):
self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO
self.version = Version.V2024_10
-
+
# Notice: No handle method override - will use default "not implemented"
+
async def main():
print("๐ Testing Action Discovery Logic")
print("=" * 50)
-
+
# Create capabilities and see what gets discovered
capabilities = ServerCapabilities()
-
+
print("๐ Actions found by discovery:")
for action_name in capabilities.get_supported_actions():
print(f" โ
{action_name}")
-
+
print(f"\n๐ Total discovered: {len(capabilities.get_supported_actions())}")
-
+
# Test the new implemented action
implemented_action = NewImplementedAction()
result = await implemented_action.handle("test", "", Version.V2024_10)
print(f"\n๐ข NewImplementedAction result: {result.xml_content}")
-
+
# Test the unimplemented action (should use default behavior)
unimplemented_action = NewUnimplementedAction()
result = await unimplemented_action.handle("test", "", Version.V2024_10)
print(f"๐ด NewUnimplementedAction result: {result.xml_content}")
+
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/test/test_simplified_access.py b/test/test_simplified_access.py
index d202337..6b1c96a 100644
--- a/test/test_simplified_access.py
+++ b/test/test_simplified_access.py
@@ -4,11 +4,11 @@ import sys
import os
# Add the src directory to the path so we can import our modules
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from simplified_access import (
- CustomerData,
- CustomerFactory,
+ CustomerData,
+ CustomerFactory,
ResGuestFactory,
HotelReservationIdData,
HotelReservationIdFactory,
@@ -20,7 +20,7 @@ from simplified_access import (
NotifResGuests,
RetrieveResGuests,
NotifHotelReservationId,
- RetrieveHotelReservationId
+ RetrieveHotelReservationId,
)
@@ -35,7 +35,7 @@ def sample_customer_data():
phone_numbers=[
("+1234567890", PhoneTechType.MOBILE),
("+0987654321", PhoneTechType.VOICE),
- ("+1111111111", None)
+ ("+1111111111", None),
],
email_address="john.doe@example.com",
email_newsletter=True,
@@ -46,17 +46,14 @@ def sample_customer_data():
address_catalog=False,
gender="Male",
birth_date="1980-01-01",
- language="en"
+ language="en",
)
@pytest.fixture
def minimal_customer_data():
"""Fixture providing minimal customer data (only required fields)."""
- return CustomerData(
- given_name="Jane",
- surname="Smith"
- )
+ return CustomerData(given_name="Jane", surname="Smith")
@pytest.fixture
@@ -66,21 +63,19 @@ def sample_hotel_reservation_id_data():
res_id_type="123",
res_id_value="RESERVATION-456",
res_id_source="HOTEL_SYSTEM",
- res_id_source_context="BOOKING_ENGINE"
+ res_id_source_context="BOOKING_ENGINE",
)
@pytest.fixture
def minimal_hotel_reservation_id_data():
"""Fixture providing minimal hotel reservation ID data (only required fields)."""
- return HotelReservationIdData(
- res_id_type="999"
- )
+ return HotelReservationIdData(res_id_type="999")
class TestCustomerData:
"""Test the CustomerData dataclass."""
-
+
def test_customer_data_creation_full(self, sample_customer_data):
"""Test creating CustomerData with all fields."""
assert sample_customer_data.given_name == "John"
@@ -89,7 +84,7 @@ class TestCustomerData:
assert sample_customer_data.email_address == "john.doe@example.com"
assert sample_customer_data.email_newsletter is True
assert len(sample_customer_data.phone_numbers) == 3
-
+
def test_customer_data_creation_minimal(self, minimal_customer_data):
"""Test creating CustomerData with only required fields."""
assert minimal_customer_data.given_name == "Jane"
@@ -97,7 +92,7 @@ class TestCustomerData:
assert minimal_customer_data.phone_numbers == []
assert minimal_customer_data.email_address is None
assert minimal_customer_data.address_line is None
-
+
def test_phone_numbers_default_initialization(self):
"""Test that phone_numbers gets initialized to empty list."""
customer_data = CustomerData(given_name="Test", surname="User")
@@ -106,54 +101,56 @@ class TestCustomerData:
class TestCustomerFactory:
"""Test the CustomerFactory class."""
-
+
def test_create_notif_customer_full(self, sample_customer_data):
"""Test creating a NotifCustomer with full data."""
customer = CustomerFactory.create_notif_customer(sample_customer_data)
-
+
assert isinstance(customer, NotifCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
assert customer.person_name.name_prefix == "Mr."
assert customer.person_name.name_title == "Jr."
-
+
# Check telephone
assert len(customer.telephone) == 3
assert customer.telephone[0].phone_number == "+1234567890"
assert customer.telephone[0].phone_tech_type == "5" # MOBILE
assert customer.telephone[1].phone_tech_type == "1" # VOICE
assert customer.telephone[2].phone_tech_type is None
-
+
# Check email
assert customer.email.value == "john.doe@example.com"
assert customer.email.remark == "newsletter:yes"
-
+
# Check address
assert customer.address.address_line == "123 Main Street"
assert customer.address.city_name == "Anytown"
assert customer.address.postal_code == "12345"
assert customer.address.country_name.code == "US"
assert customer.address.remark == "catalog:no"
-
+
# Check other attributes
assert customer.gender == "Male"
assert customer.birth_date == "1980-01-01"
assert customer.language == "en"
-
+
def test_create_retrieve_customer_full(self, sample_customer_data):
"""Test creating a RetrieveCustomer with full data."""
customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
-
+
assert isinstance(customer, RetrieveCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
# Same structure as NotifCustomer, so we don't need to test all fields again
-
+
def test_create_customer_minimal(self, minimal_customer_data):
"""Test creating customers with minimal data."""
notif_customer = CustomerFactory.create_notif_customer(minimal_customer_data)
- retrieve_customer = CustomerFactory.create_retrieve_customer(minimal_customer_data)
-
+ retrieve_customer = CustomerFactory.create_retrieve_customer(
+ minimal_customer_data
+ )
+
for customer in [notif_customer, retrieve_customer]:
assert customer.person_name.given_name == "Jane"
assert customer.person_name.surname == "Smith"
@@ -165,73 +162,97 @@ class TestCustomerFactory:
assert customer.gender is None
assert customer.birth_date is None
assert customer.language is None
-
+
def test_email_newsletter_options(self):
"""Test different email newsletter options."""
# Newsletter yes
- data_yes = CustomerData(given_name="Test", surname="User",
- email_address="test@example.com", email_newsletter=True)
+ data_yes = CustomerData(
+ given_name="Test",
+ surname="User",
+ email_address="test@example.com",
+ email_newsletter=True,
+ )
customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.email.remark == "newsletter:yes"
-
+
# Newsletter no
- data_no = CustomerData(given_name="Test", surname="User",
- email_address="test@example.com", email_newsletter=False)
+ data_no = CustomerData(
+ given_name="Test",
+ surname="User",
+ email_address="test@example.com",
+ email_newsletter=False,
+ )
customer = CustomerFactory.create_notif_customer(data_no)
assert customer.email.remark == "newsletter:no"
-
+
# Newsletter not specified
- data_none = CustomerData(given_name="Test", surname="User",
- email_address="test@example.com", email_newsletter=None)
+ data_none = CustomerData(
+ given_name="Test",
+ surname="User",
+ email_address="test@example.com",
+ email_newsletter=None,
+ )
customer = CustomerFactory.create_notif_customer(data_none)
assert customer.email.remark is None
-
+
def test_address_catalog_options(self):
"""Test different address catalog options."""
# Catalog no
- data_no = CustomerData(given_name="Test", surname="User",
- address_line="123 Street", address_catalog=False)
+ data_no = CustomerData(
+ given_name="Test",
+ surname="User",
+ address_line="123 Street",
+ address_catalog=False,
+ )
customer = CustomerFactory.create_notif_customer(data_no)
assert customer.address.remark == "catalog:no"
-
+
# Catalog yes
- data_yes = CustomerData(given_name="Test", surname="User",
- address_line="123 Street", address_catalog=True)
+ data_yes = CustomerData(
+ given_name="Test",
+ surname="User",
+ address_line="123 Street",
+ address_catalog=True,
+ )
customer = CustomerFactory.create_notif_customer(data_yes)
assert customer.address.remark == "catalog:yes"
-
+
# Catalog not specified
- data_none = CustomerData(given_name="Test", surname="User",
- address_line="123 Street", address_catalog=None)
+ data_none = CustomerData(
+ given_name="Test",
+ surname="User",
+ address_line="123 Street",
+ address_catalog=None,
+ )
customer = CustomerFactory.create_notif_customer(data_none)
assert customer.address.remark is None
-
+
def test_from_notif_customer_roundtrip(self, sample_customer_data):
"""Test converting NotifCustomer back to CustomerData."""
customer = CustomerFactory.create_notif_customer(sample_customer_data)
converted_data = CustomerFactory.from_notif_customer(customer)
-
+
assert converted_data == sample_customer_data
-
+
def test_from_retrieve_customer_roundtrip(self, sample_customer_data):
"""Test converting RetrieveCustomer back to CustomerData."""
customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
converted_data = CustomerFactory.from_retrieve_customer(customer)
-
+
assert converted_data == sample_customer_data
-
+
def test_phone_tech_type_conversion(self):
"""Test that PhoneTechType enum values are properly converted."""
data = CustomerData(
- given_name="Test",
+ given_name="Test",
surname="User",
phone_numbers=[
("+1111111111", PhoneTechType.VOICE),
("+2222222222", PhoneTechType.FAX),
- ("+3333333333", PhoneTechType.MOBILE)
- ]
+ ("+3333333333", PhoneTechType.MOBILE),
+ ],
)
-
+
customer = CustomerFactory.create_notif_customer(data)
assert customer.telephone[0].phone_tech_type == "1" # VOICE
assert customer.telephone[1].phone_tech_type == "3" # FAX
@@ -240,15 +261,21 @@ class TestCustomerFactory:
class TestHotelReservationIdData:
"""Test the HotelReservationIdData dataclass."""
-
- def test_hotel_reservation_id_data_creation_full(self, sample_hotel_reservation_id_data):
+
+ def test_hotel_reservation_id_data_creation_full(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test creating HotelReservationIdData with all fields."""
assert sample_hotel_reservation_id_data.res_id_type == "123"
assert sample_hotel_reservation_id_data.res_id_value == "RESERVATION-456"
assert sample_hotel_reservation_id_data.res_id_source == "HOTEL_SYSTEM"
- assert sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE"
-
- def test_hotel_reservation_id_data_creation_minimal(self, minimal_hotel_reservation_id_data):
+ assert (
+ sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE"
+ )
+
+ def test_hotel_reservation_id_data_creation_minimal(
+ self, minimal_hotel_reservation_id_data
+ ):
"""Test creating HotelReservationIdData with only required fields."""
assert minimal_hotel_reservation_id_data.res_id_type == "999"
assert minimal_hotel_reservation_id_data.res_id_value is None
@@ -258,124 +285,158 @@ class TestHotelReservationIdData:
class TestHotelReservationIdFactory:
"""Test the HotelReservationIdFactory class."""
-
- def test_create_notif_hotel_reservation_id_full(self, sample_hotel_reservation_id_data):
+
+ def test_create_notif_hotel_reservation_id_full(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test creating a NotifHotelReservationId with full data."""
- reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data)
-
+ reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
+ sample_hotel_reservation_id_data
+ )
+
assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
-
- def test_create_retrieve_hotel_reservation_id_full(self, sample_hotel_reservation_id_data):
+
+ def test_create_retrieve_hotel_reservation_id_full(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test creating a RetrieveHotelReservationId with full data."""
- reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data)
-
+ reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
+ sample_hotel_reservation_id_data
+ )
+
assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
assert reservation_id.res_id_source == "HOTEL_SYSTEM"
assert reservation_id.res_id_source_context == "BOOKING_ENGINE"
-
- def test_create_hotel_reservation_id_minimal(self, minimal_hotel_reservation_id_data):
+
+ def test_create_hotel_reservation_id_minimal(
+ self, minimal_hotel_reservation_id_data
+ ):
"""Test creating hotel reservation IDs with minimal data."""
- notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(minimal_hotel_reservation_id_data)
- retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(minimal_hotel_reservation_id_data)
-
+ notif_reservation_id = (
+ HotelReservationIdFactory.create_notif_hotel_reservation_id(
+ minimal_hotel_reservation_id_data
+ )
+ )
+ retrieve_reservation_id = (
+ HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
+ minimal_hotel_reservation_id_data
+ )
+ )
+
for reservation_id in [notif_reservation_id, retrieve_reservation_id]:
assert reservation_id.res_id_type == "999"
assert reservation_id.res_id_value is None
assert reservation_id.res_id_source is None
assert reservation_id.res_id_source_context is None
-
- def test_from_notif_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data):
+
+ def test_from_notif_hotel_reservation_id_roundtrip(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test converting NotifHotelReservationId back to HotelReservationIdData."""
- reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data)
- converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(reservation_id)
-
+ reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(
+ sample_hotel_reservation_id_data
+ )
+ converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(
+ reservation_id
+ )
+
assert converted_data == sample_hotel_reservation_id_data
-
- def test_from_retrieve_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data):
+
+ def test_from_retrieve_hotel_reservation_id_roundtrip(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test converting RetrieveHotelReservationId back to HotelReservationIdData."""
- reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data)
- converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(reservation_id)
-
+ reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
+ sample_hotel_reservation_id_data
+ )
+ converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
+ reservation_id
+ )
+
assert converted_data == sample_hotel_reservation_id_data
class TestResGuestFactory:
"""Test the ResGuestFactory class."""
-
+
def test_create_notif_res_guests(self, sample_customer_data):
"""Test creating NotifResGuests structure."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
-
+
assert isinstance(res_guests, NotifResGuests)
-
+
# Navigate down the nested structure
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
assert customer.email.value == "john.doe@example.com"
-
+
def test_create_retrieve_res_guests(self, sample_customer_data):
"""Test creating RetrieveResGuests structure."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
-
+
assert isinstance(res_guests, RetrieveResGuests)
-
+
# Navigate down the nested structure
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
assert customer.email.value == "john.doe@example.com"
-
+
def test_create_res_guests_minimal(self, minimal_customer_data):
"""Test creating ResGuests with minimal customer data."""
- notif_res_guests = ResGuestFactory.create_notif_res_guests(minimal_customer_data)
- retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(minimal_customer_data)
-
+ notif_res_guests = ResGuestFactory.create_notif_res_guests(
+ minimal_customer_data
+ )
+ retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
+ minimal_customer_data
+ )
+
for res_guests in [notif_res_guests, retrieve_res_guests]:
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "Jane"
assert customer.person_name.surname == "Smith"
assert customer.email is None
assert customer.address is None
-
+
def test_extract_primary_customer_notif(self, sample_customer_data):
"""Test extracting primary customer from NotifResGuests."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
-
+
assert extracted_data == sample_customer_data
-
+
def test_extract_primary_customer_retrieve(self, sample_customer_data):
"""Test extracting primary customer from RetrieveResGuests."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
-
+
assert extracted_data == sample_customer_data
-
+
def test_roundtrip_conversion_notif(self, sample_customer_data):
"""Test complete roundtrip: CustomerData -> NotifResGuests -> CustomerData."""
res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
-
+
assert extracted_data == sample_customer_data
-
+
def test_roundtrip_conversion_retrieve(self, sample_customer_data):
"""Test complete roundtrip: CustomerData -> RetrieveResGuests -> CustomerData."""
res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
extracted_data = ResGuestFactory.extract_primary_customer(res_guests)
-
+
assert extracted_data == sample_customer_data
class TestPhoneTechType:
"""Test the PhoneTechType enum."""
-
+
def test_enum_values(self):
"""Test that enum values are correct."""
assert PhoneTechType.VOICE.value == "1"
@@ -385,95 +446,121 @@ class TestPhoneTechType:
class TestAlpineBitsFactory:
"""Test the unified AlpineBitsFactory class."""
-
+
def test_create_customer_notif(self, sample_customer_data):
"""Test creating customer using unified factory for NOTIF."""
customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF)
assert isinstance(customer, NotifCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
-
+
def test_create_customer_retrieve(self, sample_customer_data):
"""Test creating customer using unified factory for RETRIEVE."""
- customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE)
+ customer = AlpineBitsFactory.create(
+ sample_customer_data, OtaMessageType.RETRIEVE
+ )
assert isinstance(customer, RetrieveCustomer)
assert customer.person_name.given_name == "John"
assert customer.person_name.surname == "Doe"
-
+
def test_create_hotel_reservation_id_notif(self, sample_hotel_reservation_id_data):
"""Test creating hotel reservation ID using unified factory for NOTIF."""
- reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF)
+ reservation_id = AlpineBitsFactory.create(
+ sample_hotel_reservation_id_data, OtaMessageType.NOTIF
+ )
assert isinstance(reservation_id, NotifHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
-
- def test_create_hotel_reservation_id_retrieve(self, sample_hotel_reservation_id_data):
+
+ def test_create_hotel_reservation_id_retrieve(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test creating hotel reservation ID using unified factory for RETRIEVE."""
- reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE)
+ reservation_id = AlpineBitsFactory.create(
+ sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
+ )
assert isinstance(reservation_id, RetrieveHotelReservationId)
assert reservation_id.res_id_type == "123"
assert reservation_id.res_id_value == "RESERVATION-456"
-
+
def test_create_res_guests_notif(self, sample_customer_data):
"""Test creating ResGuests using unified factory for NOTIF."""
- res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF)
+ res_guests = AlpineBitsFactory.create_res_guests(
+ sample_customer_data, OtaMessageType.NOTIF
+ )
assert isinstance(res_guests, NotifResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
-
+
def test_create_res_guests_retrieve(self, sample_customer_data):
"""Test creating ResGuests using unified factory for RETRIEVE."""
- res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE)
+ res_guests = AlpineBitsFactory.create_res_guests(
+ sample_customer_data, OtaMessageType.RETRIEVE
+ )
assert isinstance(res_guests, RetrieveResGuests)
customer = res_guests.res_guest.profiles.profile_info.profile.customer
assert customer.person_name.given_name == "John"
-
+
def test_extract_data_from_customer(self, sample_customer_data):
"""Test extracting data from customer objects."""
# Create both types and extract data back
- notif_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF)
- retrieve_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE)
-
+ notif_customer = AlpineBitsFactory.create(
+ sample_customer_data, OtaMessageType.NOTIF
+ )
+ retrieve_customer = AlpineBitsFactory.create(
+ sample_customer_data, OtaMessageType.RETRIEVE
+ )
+
notif_extracted = AlpineBitsFactory.extract_data(notif_customer)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_customer)
-
+
assert notif_extracted == sample_customer_data
assert retrieve_extracted == sample_customer_data
-
- def test_extract_data_from_hotel_reservation_id(self, sample_hotel_reservation_id_data):
+
+ def test_extract_data_from_hotel_reservation_id(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test extracting data from hotel reservation ID objects."""
# Create both types and extract data back
- notif_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF)
- retrieve_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE)
-
+ notif_res_id = AlpineBitsFactory.create(
+ sample_hotel_reservation_id_data, OtaMessageType.NOTIF
+ )
+ retrieve_res_id = AlpineBitsFactory.create(
+ sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE
+ )
+
notif_extracted = AlpineBitsFactory.extract_data(notif_res_id)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_id)
-
+
assert notif_extracted == sample_hotel_reservation_id_data
assert retrieve_extracted == sample_hotel_reservation_id_data
-
+
def test_extract_data_from_res_guests(self, sample_customer_data):
"""Test extracting data from ResGuests objects."""
# Create both types and extract data back
- notif_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF)
- retrieve_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE)
-
+ notif_res_guests = AlpineBitsFactory.create_res_guests(
+ sample_customer_data, OtaMessageType.NOTIF
+ )
+ retrieve_res_guests = AlpineBitsFactory.create_res_guests(
+ sample_customer_data, OtaMessageType.RETRIEVE
+ )
+
notif_extracted = AlpineBitsFactory.extract_data(notif_res_guests)
retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_guests)
-
+
assert notif_extracted == sample_customer_data
assert retrieve_extracted == sample_customer_data
-
+
def test_unsupported_data_type_error(self):
"""Test that unsupported data types raise ValueError."""
with pytest.raises(ValueError, match="Unsupported data type"):
AlpineBitsFactory.create("invalid_data", OtaMessageType.NOTIF)
-
+
def test_unsupported_object_type_error(self):
"""Test that unsupported object types raise ValueError in extract_data."""
with pytest.raises(ValueError, match="Unsupported object type"):
AlpineBitsFactory.extract_data("invalid_object")
-
+
def test_complete_workflow_with_unified_factory(self):
"""Test a complete workflow using only the unified factory."""
# Original data
@@ -481,34 +568,47 @@ class TestAlpineBitsFactory:
given_name="Unified",
surname="Factory",
email_address="unified@factory.com",
- phone_numbers=[("+1234567890", PhoneTechType.MOBILE)]
+ phone_numbers=[("+1234567890", PhoneTechType.MOBILE)],
)
-
+
reservation_data = HotelReservationIdData(
- res_id_type="999",
- res_id_value="UNIFIED-TEST"
+ res_id_type="999", res_id_value="UNIFIED-TEST"
)
-
+
# Create using unified factory
customer_notif = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF)
- customer_retrieve = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE)
-
+ customer_retrieve = AlpineBitsFactory.create(
+ customer_data, OtaMessageType.RETRIEVE
+ )
+
res_id_notif = AlpineBitsFactory.create(reservation_data, OtaMessageType.NOTIF)
- res_id_retrieve = AlpineBitsFactory.create(reservation_data, OtaMessageType.RETRIEVE)
-
- res_guests_notif = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF)
- res_guests_retrieve = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE)
-
+ res_id_retrieve = AlpineBitsFactory.create(
+ reservation_data, OtaMessageType.RETRIEVE
+ )
+
+ res_guests_notif = AlpineBitsFactory.create_res_guests(
+ customer_data, OtaMessageType.NOTIF
+ )
+ res_guests_retrieve = AlpineBitsFactory.create_res_guests(
+ customer_data, OtaMessageType.RETRIEVE
+ )
+
# Extract everything back
extracted_customer_from_notif = AlpineBitsFactory.extract_data(customer_notif)
- extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(customer_retrieve)
-
+ extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(
+ customer_retrieve
+ )
+
extracted_res_id_from_notif = AlpineBitsFactory.extract_data(res_id_notif)
extracted_res_id_from_retrieve = AlpineBitsFactory.extract_data(res_id_retrieve)
-
- extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(res_guests_notif)
- extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(res_guests_retrieve)
-
+
+ extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(
+ res_guests_notif
+ )
+ extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(
+ res_guests_retrieve
+ )
+
# Verify everything matches
assert extracted_customer_from_notif == customer_data
assert extracted_customer_from_retrieve == customer_data
@@ -520,37 +620,72 @@ class TestAlpineBitsFactory:
class TestIntegration:
"""Integration tests combining both factories."""
-
+
def test_both_factories_produce_same_customer_data(self, sample_customer_data):
"""Test that both factories can work with the same customer data."""
# Create using CustomerFactory
notif_customer = CustomerFactory.create_notif_customer(sample_customer_data)
- retrieve_customer = CustomerFactory.create_retrieve_customer(sample_customer_data)
-
+ retrieve_customer = CustomerFactory.create_retrieve_customer(
+ sample_customer_data
+ )
+
# Create using ResGuestFactory and extract customers
notif_res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data)
- retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data)
-
- notif_from_res_guests = notif_res_guests.res_guest.profiles.profile_info.profile.customer
- retrieve_from_res_guests = retrieve_res_guests.res_guest.profiles.profile_info.profile.customer
-
+ retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(
+ sample_customer_data
+ )
+
+ notif_from_res_guests = (
+ notif_res_guests.res_guest.profiles.profile_info.profile.customer
+ )
+ retrieve_from_res_guests = (
+ retrieve_res_guests.res_guest.profiles.profile_info.profile.customer
+ )
+
# Compare customer names (structure should be identical)
- assert notif_customer.person_name.given_name == notif_from_res_guests.person_name.given_name
- assert notif_customer.person_name.surname == notif_from_res_guests.person_name.surname
- assert retrieve_customer.person_name.given_name == retrieve_from_res_guests.person_name.given_name
- assert retrieve_customer.person_name.surname == retrieve_from_res_guests.person_name.surname
-
- def test_hotel_reservation_id_factories_produce_same_data(self, sample_hotel_reservation_id_data):
+ assert (
+ notif_customer.person_name.given_name
+ == notif_from_res_guests.person_name.given_name
+ )
+ assert (
+ notif_customer.person_name.surname
+ == notif_from_res_guests.person_name.surname
+ )
+ assert (
+ retrieve_customer.person_name.given_name
+ == retrieve_from_res_guests.person_name.given_name
+ )
+ assert (
+ retrieve_customer.person_name.surname
+ == retrieve_from_res_guests.person_name.surname
+ )
+
+ def test_hotel_reservation_id_factories_produce_same_data(
+ self, sample_hotel_reservation_id_data
+ ):
"""Test that both HotelReservationId factories produce equivalent results."""
- notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data)
- retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data)
-
+ notif_reservation_id = (
+ HotelReservationIdFactory.create_notif_hotel_reservation_id(
+ sample_hotel_reservation_id_data
+ )
+ )
+ retrieve_reservation_id = (
+ HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
+ sample_hotel_reservation_id_data
+ )
+ )
+
# Both should have the same field values
assert notif_reservation_id.res_id_type == retrieve_reservation_id.res_id_type
assert notif_reservation_id.res_id_value == retrieve_reservation_id.res_id_value
- assert notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source
- assert notif_reservation_id.res_id_source_context == retrieve_reservation_id.res_id_source_context
-
+ assert (
+ notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source
+ )
+ assert (
+ notif_reservation_id.res_id_source_context
+ == retrieve_reservation_id.res_id_source_context
+ )
+
def test_complex_customer_workflow(self):
"""Test a complex workflow with multiple operations."""
# Create original data
@@ -559,7 +694,7 @@ class TestIntegration:
surname="Johnson",
phone_numbers=[
("+1555123456", PhoneTechType.MOBILE),
- ("+1555654321", PhoneTechType.VOICE)
+ ("+1555654321", PhoneTechType.VOICE),
],
email_address="alice.johnson@company.com",
email_newsletter=False,
@@ -569,22 +704,24 @@ class TestIntegration:
country_code="CA",
address_catalog=True,
gender="Female",
- language="fr"
+ language="fr",
)
-
+
# Create ResGuests for both types
notif_res_guests = ResGuestFactory.create_notif_res_guests(original_data)
retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(original_data)
-
+
# Extract data back from both
notif_extracted = ResGuestFactory.extract_primary_customer(notif_res_guests)
- retrieve_extracted = ResGuestFactory.extract_primary_customer(retrieve_res_guests)
-
+ retrieve_extracted = ResGuestFactory.extract_primary_customer(
+ retrieve_res_guests
+ )
+
# All should be equal
assert original_data == notif_extracted
assert original_data == retrieve_extracted
assert notif_extracted == retrieve_extracted
-
+
def test_complex_hotel_reservation_id_workflow(self):
"""Test a complex workflow with HotelReservationId operations."""
# Create original reservation ID data
@@ -592,18 +729,30 @@ class TestIntegration:
res_id_type="456",
res_id_value="COMPLEX-RESERVATION-789",
res_id_source="INTEGRATION_SYSTEM",
- res_id_source_context="API_CALL"
+ res_id_source_context="API_CALL",
)
-
+
# Create HotelReservationId for both types
- notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data)
- retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(original_data)
-
+ notif_reservation_id = (
+ HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data)
+ )
+ retrieve_reservation_id = (
+ HotelReservationIdFactory.create_retrieve_hotel_reservation_id(
+ original_data
+ )
+ )
+
# Extract data back from both
- notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(notif_reservation_id)
- retrieve_extracted = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(retrieve_reservation_id)
-
+ notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(
+ notif_reservation_id
+ )
+ retrieve_extracted = (
+ HotelReservationIdFactory.from_retrieve_hotel_reservation_id(
+ retrieve_reservation_id
+ )
+ )
+
# All should be equal
assert original_data == notif_extracted
assert original_data == retrieve_extracted
- assert notif_extracted == retrieve_extracted
\ No newline at end of file
+ assert notif_extracted == retrieve_extracted
diff --git a/test_handshake.py b/test_handshake.py
index 47ff199..00c87b0 100644
--- a/test_handshake.py
+++ b/test_handshake.py
@@ -6,24 +6,31 @@ Test the handshake functionality with the real AlpineBits sample file.
import asyncio
from alpine_bits_python.alpinebits_server import AlpineBitsServer
+
async def main():
print("๐ Testing AlpineBits Handshake with Sample File")
print("=" * 60)
-
+
# Create server instance
server = AlpineBitsServer()
-
+
# Read the sample handshake request
- with open("AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml", "r") as f:
+ with open(
+ "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml",
+ "r",
+ ) as f:
ping_request_xml = f.read()
-
+
print("๐ค Sending handshake request...")
-
+
# Handle the ping request
- response = await server.handle_request("OTA_Ping:Handshaking", ping_request_xml, "2024-10")
-
+ response = await server.handle_request(
+ "OTA_Ping:Handshaking", ping_request_xml, "2024-10"
+ )
+
print(f"\n๐ฅ Response Status: {response.status_code}")
print(f"๐ Response XML:\n{response.xml_content}")
+
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())