diff --git a/src/alpine_bits_python/__main__.py b/src/alpine_bits_python/__main__.py index 9a948c0..a9e3250 100644 --- a/src/alpine_bits_python/__main__.py +++ b/src/alpine_bits_python/__main__.py @@ -1,6 +1,7 @@ """Entry point for alpine_bits_python package.""" + from .main import main if __name__ == "__main__": print("running test main") - main() \ No newline at end of file + main() diff --git a/src/alpine_bits_python/alpinebits_server.py b/src/alpine_bits_python/alpinebits_server.py index 63e27c9..a6bcc41 100644 --- a/src/alpine_bits_python/alpinebits_server.py +++ b/src/alpine_bits_python/alpinebits_server.py @@ -23,49 +23,65 @@ from xsdata_pydantic.bindings import XmlParser class HttpStatusCode(IntEnum): """Allowed HTTP status codes for AlpineBits responses.""" + OK = 200 BAD_REQUEST = 400 UNAUTHORIZED = 401 INTERNAL_SERVER_ERROR = 500 - class AlpineBitsActionName(Enum): """Enum for AlpineBits action names with capability and request name mappings.""" - + # Format: (capability_name, actual_request_name) OTA_PING = ("action_OTA_Ping", "OTA_Ping:Handshaking") OTA_READ = ("action_OTA_Read", "OTA_Read:GuestRequests") OTA_HOTEL_AVAIL_NOTIF = ("action_OTA_HotelAvailNotif", "OTA_HotelAvailNotif") - OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = ("action_OTA_HotelResNotif_GuestRequests", - "OTA_HotelResNotif:GuestRequests") - OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = ("action_OTA_HotelDescriptiveContentNotif_Inventory", - "OTA_HotelDescriptiveContentNotif:Inventory") - OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = ("action_OTA_HotelDescriptiveContentNotif_Info", - "OTA_HotelDescriptiveContentNotif:Info") - OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = ("action_OTA_HotelDescriptiveInfo_Inventory", - "OTA_HotelDescriptiveInfo:Inventory") - OTA_HOTEL_DESCRIPTIVE_INFO_INFO = ("action_OTA_HotelDescriptiveInfo_Info", - "OTA_HotelDescriptiveInfo:Info") - OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = ("action_OTA_HotelRatePlanNotif_RatePlans", - "OTA_HotelRatePlanNotif:RatePlans") - OTA_HOTEL_RATE_PLAN_BASE_RATES = ("action_OTA_HotelRatePlan_BaseRates", - "OTA_HotelRatePlan:BaseRates") - + OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS = ( + "action_OTA_HotelResNotif_GuestRequests", + "OTA_HotelResNotif:GuestRequests", + ) + OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INVENTORY = ( + "action_OTA_HotelDescriptiveContentNotif_Inventory", + "OTA_HotelDescriptiveContentNotif:Inventory", + ) + OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO = ( + "action_OTA_HotelDescriptiveContentNotif_Info", + "OTA_HotelDescriptiveContentNotif:Info", + ) + OTA_HOTEL_DESCRIPTIVE_INFO_INVENTORY = ( + "action_OTA_HotelDescriptiveInfo_Inventory", + "OTA_HotelDescriptiveInfo:Inventory", + ) + OTA_HOTEL_DESCRIPTIVE_INFO_INFO = ( + "action_OTA_HotelDescriptiveInfo_Info", + "OTA_HotelDescriptiveInfo:Info", + ) + OTA_HOTEL_RATE_PLAN_NOTIF_RATE_PLANS = ( + "action_OTA_HotelRatePlanNotif_RatePlans", + "OTA_HotelRatePlanNotif:RatePlans", + ) + OTA_HOTEL_RATE_PLAN_BASE_RATES = ( + "action_OTA_HotelRatePlan_BaseRates", + "OTA_HotelRatePlan:BaseRates", + ) + def __init__(self, capability_name: str, request_name: str): self.capability_name = capability_name self.request_name = request_name - + @classmethod - def get_by_capability_name(cls, capability_name: str) -> Optional['AlpineBitsActionName']: + def get_by_capability_name( + cls, capability_name: str + ) -> Optional["AlpineBitsActionName"]: """Get action enum by capability name.""" for action in cls: if action.capability_name == capability_name: return action return None - + @classmethod - def get_by_request_name(cls, request_name: str) -> Optional['AlpineBitsActionName']: + def get_by_request_name(cls, request_name: str) -> Optional["AlpineBitsActionName"]: """Get action enum by request name.""" for action in cls: if action.request_name == request_name: @@ -75,22 +91,25 @@ class AlpineBitsActionName(Enum): class Version(str, Enum): """Enum for AlpineBits versions.""" + V2024_10 = "2024-10" V2022_10 = "2022-10" # Add other versions as needed - @dataclass class AlpineBitsResponse: """Response data structure for AlpineBits actions.""" + xml_content: str status_code: HttpStatusCode = HttpStatusCode.OK - + def __post_init__(self): """Validate that status code is one of the allowed values.""" if self.status_code not in [200, 400, 401, 500]: - raise ValueError(f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500") + raise ValueError( + f"Invalid status code {self.status_code}. Must be 200, 400, 401, or 500" + ) # Abstract base class for AlpineBits Action @@ -98,20 +117,24 @@ class AlpineBitsAction(ABC): """Abstract base class for handling AlpineBits actions.""" name: AlpineBitsActionName - version: Version | list[Version] # list of versions in case action supports multiple versions - - async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: + version: ( + Version | list[Version] + ) # list of versions in case action supports multiple versions + + async def handle( + self, action: str, request_xml: str, version: Version + ) -> AlpineBitsResponse: """ Handle the incoming request XML and return response XML. - + Default implementation returns "not implemented" error. Override this method in subclasses to provide actual functionality. - + Args: action: The action to perform (e.g., "OTA_PingRQ") request_xml: The XML request body as string version: The AlpineBits version - + Returns: AlpineBitsResponse with error or actual response """ @@ -121,7 +144,7 @@ class AlpineBitsAction(ABC): async def check_version_supported(self, version: Version) -> bool: """ Check if the action supports the given version. - + Args: version: The AlpineBits version to check Returns: @@ -130,103 +153,93 @@ class AlpineBitsAction(ABC): if isinstance(self.version, list): return version in self.version return version == self.version - - - - class ServerCapabilities: """ Automatically discovers AlpineBitsAction implementations and generates capabilities. """ - + def __init__(self): self.action_registry: Dict[str, Type[AlpineBitsAction]] = {} self._discover_actions() self.capability_dict = None - + def _discover_actions(self): """Discover all AlpineBitsAction implementations in the current module.""" current_module = inspect.getmodule(self) - + for name, obj in inspect.getmembers(current_module): - if (inspect.isclass(obj) and - issubclass(obj, AlpineBitsAction) and - obj != AlpineBitsAction): - + if ( + inspect.isclass(obj) + and issubclass(obj, AlpineBitsAction) + and obj != AlpineBitsAction + ): # Check if this action is actually implemented (not just returning default) if self._is_action_implemented(obj): action_instance = obj() - if hasattr(action_instance, 'name'): + if hasattr(action_instance, "name"): # Use capability name for the registry key self.action_registry[action_instance.name.capability_name] = obj - + def _is_action_implemented(self, action_class: Type[AlpineBitsAction]) -> bool: """ Check if an action is actually implemented or just uses the default behavior. This is a simple check - in practice, you might want more sophisticated detection. """ # Check if the class has overridden the handle method - if 'handle' in action_class.__dict__: + if "handle" in action_class.__dict__: return True return False - def create_capabilities_dict(self) -> None: """ Generate the capabilities dictionary based on discovered actions. - + """ versions_dict = {} - + for action_name, action_class in self.action_registry.items(): action_instance = action_class() - + # Get supported versions for this action if isinstance(action_instance.version, list): supported_versions = action_instance.version else: supported_versions = [action_instance.version] - + # Add action to each supported version for version in supported_versions: version_str = version.value - + if version_str not in versions_dict: - versions_dict[version_str] = { - "version": version_str, - "actions": [] - } - + versions_dict[version_str] = {"version": version_str, "actions": []} + action_dict = {"action": action_name} - + # Add supports field if the action has custom supports - if hasattr(action_instance, 'supports') and action_instance.supports: + if hasattr(action_instance, "supports") and action_instance.supports: action_dict["supports"] = action_instance.supports - + versions_dict[version_str]["actions"].append(action_dict) self.capability_dict = {"versions": list(versions_dict.values())} - - return None - def get_capabilities_dict(self) -> Dict: """ Get capabilities as a dictionary. Generates if not already created. """ - + if self.capability_dict is None: self.create_capabilities_dict() return self.capability_dict - + def get_capabilities_json(self) -> str: """Get capabilities as formatted JSON string.""" return json.dumps(self.get_capabilities_dict(), indent=2) - + def get_supported_actions(self) -> List[str]: """Get list of all supported action names.""" return list(self.action_registry.keys()) @@ -234,22 +247,35 @@ class ServerCapabilities: # Sample Action Implementations for demonstration + class PingAction(AlpineBitsAction): """Implementation for OTA_Ping action (handshaking).""" - + def __init__(self): self.name = AlpineBitsActionName.OTA_PING - self.version = [Version.V2024_10, Version.V2022_10] # Supports multiple versions - - async def handle(self, action: str, request_xml: str, version: Version, server_capabilities: None | ServerCapabilities = None) -> AlpineBitsResponse: + self.version = [ + Version.V2024_10, + Version.V2022_10, + ] # Supports multiple versions + + async def handle( + self, + action: str, + request_xml: str, + version: Version, + server_capabilities: None | ServerCapabilities = None, + ) -> AlpineBitsResponse: """Handle ping requests.""" if request_xml is None: - return AlpineBitsResponse(f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST) + return AlpineBitsResponse( + f"Error: Xml Request missing", HttpStatusCode.BAD_REQUEST + ) if server_capabilities is None: - return AlpineBitsResponse("Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR) - + return AlpineBitsResponse( + "Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR + ) # Parse the incoming request XML and extract EchoData parser = XmlParser() @@ -259,54 +285,66 @@ class PingAction(AlpineBitsAction): echo_data = json.loads(parsed_request.echo_data) except Exception as e: - return AlpineBitsResponse(f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST) + return AlpineBitsResponse( + f"Error: Invalid XML request", HttpStatusCode.BAD_REQUEST + ) # compare echo data with capabilities, create a dictionary containing the matching capabilities capabilities_dict = server_capabilities.get_capabilities_dict() matching_capabilities = {"versions": []} - + # Iterate through client's requested versions for client_version in echo_data.get("versions", []): client_version_str = client_version.get("version", "") - + # Find matching server version for server_version in capabilities_dict["versions"]: if server_version["version"] == client_version_str: # Found a matching version, now find common actions - matching_version = { - "version": client_version_str, - "actions": [] - } - + matching_version = {"version": client_version_str, "actions": []} + # Get client's requested actions for this version - client_actions = {action.get("action", ""): action for action in client_version.get("actions", [])} - server_actions = {action.get("action", ""): action for action in server_version.get("actions", [])} - + client_actions = { + action.get("action", ""): action + for action in client_version.get("actions", []) + } + server_actions = { + action.get("action", ""): action + for action in server_version.get("actions", []) + } + # Find common actions for action_name in client_actions: if action_name in server_actions: # Use server's action definition (includes our supports) - matching_version["actions"].append(server_actions[action_name]) - + matching_version["actions"].append( + server_actions[action_name] + ) + # Only add version if there are common actions if matching_version["actions"]: matching_capabilities["versions"].append(matching_version) break - + # Debug print to see what we matched - + # Create successful ping response with matched capabilities capabilities_json = json.dumps(matching_capabilities, indent=2) - warning = OtaPingRs.Warnings.Warning(status=WarningStatus.ALPINEBITS_HANDSHAKE.value, type_value="11", content=[capabilities_json]) + warning = OtaPingRs.Warnings.Warning( + status=WarningStatus.ALPINEBITS_HANDSHAKE.value, + type_value="11", + content=[capabilities_json], + ) warning_response = OtaPingRs.Warnings(warning=[warning]) - response_ota_ping = OtaPingRs(version= "7.000", warnings=warning_response, echo_data=capabilities_json, success="") - - - - + response_ota_ping = OtaPingRs( + version="7.000", + warnings=warning_response, + echo_data=capabilities_json, + success="", + ) config = SerializerConfig( pretty_print=True, xml_declaration=True, encoding="UTF-8" @@ -314,34 +352,35 @@ class PingAction(AlpineBitsAction): serializer = XmlSerializer(config=config) - response_xml = serializer.render(response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"}) - - - + response_xml = serializer.render( + response_ota_ping, ns_map={None: "http://www.opentravel.org/OTA/2003/05"} + ) return AlpineBitsResponse(response_xml, HttpStatusCode.OK) class ReadAction(AlpineBitsAction): """Implementation for OTA_Read action.""" - + def __init__(self): self.name = AlpineBitsActionName.OTA_READ self.version = [Version.V2024_10, Version.V2022_10] - - async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: + + async def handle( + self, action: str, request_xml: str, version: Version + ) -> AlpineBitsResponse: """Handle read requests.""" - response_xml = f''' + response_xml = f""" Read operation successful for {version.value} -''' +""" return AlpineBitsResponse(response_xml, HttpStatusCode.OK) class HotelAvailNotifAction(AlpineBitsAction): """Implementation for Hotel Availability Notification action with supports.""" - + def __init__(self): self.name = AlpineBitsActionName.OTA_HOTEL_AVAIL_NOTIF self.version = Version.V2022_10 @@ -349,68 +388,68 @@ class HotelAvailNotifAction(AlpineBitsAction): "OTA_HotelAvailNotif_accept_rooms", "OTA_HotelAvailNotif_accept_categories", "OTA_HotelAvailNotif_accept_deltas", - "OTA_HotelAvailNotif_accept_BookingThreshold" + "OTA_HotelAvailNotif_accept_BookingThreshold", ] - - async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: + + async def handle( + self, action: str, request_xml: str, version: Version + ) -> AlpineBitsResponse: """Handle hotel availability notifications.""" - response_xml = ''' + response_xml = """ -''' +""" return AlpineBitsResponse(response_xml, HttpStatusCode.OK) class GuestRequestsAction(AlpineBitsAction): """Unimplemented action - will not appear in capabilities.""" - + def __init__(self): self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS self.version = Version.V2024_10 - + # Note: This class doesn't override the handle method, so it won't be discovered - - - - class AlpineBitsServer: """ Asynchronous AlpineBits server for handling hotel data exchange requests. - + This server handles various OTA actions and implements the AlpineBits protocol for hotel data exchange. It maintains a registry of supported actions and their capabilities, and can respond to handshake requests with its capabilities. """ - + def __init__(self): self.capabilities = ServerCapabilities() self._action_instances = {} self._initialize_action_instances() - + def _initialize_action_instances(self): """Initialize instances of all discovered action classes.""" for capability_name, action_class in self.capabilities.action_registry.items(): self._action_instances[capability_name] = action_class() - + def get_capabilities(self) -> Dict: """Get server capabilities.""" return self.capabilities.get_capabilities_dict() - + def get_capabilities_json(self) -> str: """Get server capabilities as JSON.""" return self.capabilities.get_capabilities_json() - - async def handle_request(self, request_action_name: str, request_xml: str, version: str = "2024-10") -> AlpineBitsResponse: + + async def handle_request( + self, request_action_name: str, request_xml: str, version: str = "2024-10" + ) -> AlpineBitsResponse: """ Handle an incoming AlpineBits request by routing to appropriate action handler. - + Args: request_action_name: The action name from the request (e.g., "OTA_Read:GuestRequests") request_xml: The XML request body version: The AlpineBits version (defaults to "2024-10") - + Returns: AlpineBitsResponse with the result """ @@ -419,52 +458,56 @@ class AlpineBitsServer: version_enum = Version(version) except ValueError: return AlpineBitsResponse( - f"Error: Unsupported version {version}", - HttpStatusCode.BAD_REQUEST + f"Error: Unsupported version {version}", HttpStatusCode.BAD_REQUEST ) - + # Find the action by request name action_enum = AlpineBitsActionName.get_by_request_name(request_action_name) if not action_enum: return AlpineBitsResponse( f"Error: Unknown action {request_action_name}", - HttpStatusCode.BAD_REQUEST + HttpStatusCode.BAD_REQUEST, ) - + # Check if we have an implementation for this action capability_name = action_enum.capability_name if capability_name not in self._action_instances: return AlpineBitsResponse( f"Error: Action {request_action_name} is not implemented", - HttpStatusCode.BAD_REQUEST + HttpStatusCode.BAD_REQUEST, ) - + action_instance: AlpineBitsAction = self._action_instances[capability_name] - + # Check if the action supports the requested version if not await action_instance.check_version_supported(version_enum): return AlpineBitsResponse( f"Error: Action {request_action_name} does not support version {version}", - HttpStatusCode.BAD_REQUEST + HttpStatusCode.BAD_REQUEST, ) - + # Handle the request try: # Special case for ping action - pass server capabilities if capability_name == "action_OTA_Ping": - return await action_instance.handle(request_action_name, request_xml, version_enum, self.capabilities) + return await action_instance.handle( + request_action_name, request_xml, version_enum, self.capabilities + ) else: - return await action_instance.handle(request_action_name, request_xml, version_enum) + return await action_instance.handle( + request_action_name, request_xml, version_enum + ) except Exception as e: print(f"Error handling request {request_action_name}: {str(e)}") # print stack trace for debugging import traceback + traceback.print_exc() return AlpineBitsResponse( f"Error: Internal server error while processing {request_action_name}: {str(e)}", - HttpStatusCode.INTERNAL_SERVER_ERROR + HttpStatusCode.INTERNAL_SERVER_ERROR, ) - + def get_supported_request_names(self) -> List[str]: """Get all supported request names (not capability names).""" request_names = [] @@ -473,26 +516,28 @@ class AlpineBitsServer: if action_enum: request_names.append(action_enum.request_name) return sorted(request_names) - - def is_action_supported(self, request_action_name: str, version: str = None) -> bool: + + def is_action_supported( + self, request_action_name: str, version: str = None + ) -> bool: """ Check if a request action is supported. - + Args: request_action_name: The request action name (e.g., "OTA_Read:GuestRequests") version: Optional version to check - + Returns: True if supported, False otherwise """ action_enum = AlpineBitsActionName.get_by_request_name(request_action_name) if not action_enum: return False - + capability_name = action_enum.capability_name if capability_name not in self._action_instances: return False - + if version: try: version_enum = Version(version) @@ -504,7 +549,7 @@ class AlpineBitsServer: return action_instance.version == version_enum except ValueError: return False - + return True @@ -512,10 +557,10 @@ async def main(): """Demonstrate the automatic capabilities discovery and request handling.""" print("๐Ÿš€ AlpineBits Server Capabilities Discovery & Request Handling Demo") print("=" * 70) - + # Create server instance server = AlpineBitsServer() - + print("\n๐Ÿ“‹ Discovered Action Classes:") print("-" * 30) for capability_name, action_class in server.capabilities.action_registry.items(): @@ -523,24 +568,26 @@ async def main(): request_name = action_enum.request_name if action_enum else "unknown" print(f"โœ… {capability_name} -> {action_class.__name__}") print(f" Request name: {request_name}") - - print(f"\n๐Ÿ“Š Total Implemented Actions: {len(server.capabilities.get_supported_actions())}") - + + print( + f"\n๐Ÿ“Š Total Implemented Actions: {len(server.capabilities.get_supported_actions())}" + ) + print("\n๐Ÿ” Generated Capabilities JSON:") print("-" * 30) capabilities_json = server.get_capabilities_json() print(capabilities_json) - + print("\n๐ŸŽฏ Supported Request Names:") print("-" * 30) for request_name in server.get_supported_request_names(): print(f" โ€ข {request_name}") - + print("\n๐Ÿงช Testing Request Handling:") print("-" * 30) - + test_xml = "sample request" - + # Test different request formats test_cases = [ ("OTA_Ping:Handshaking", "2024-10"), @@ -548,16 +595,16 @@ async def main(): ("OTA_Read:GuestRequests", "2022-10"), ("OTA_HotelAvailNotif", "2024-10"), ("UnknownAction", "2024-10"), - ("OTA_Ping:Handshaking", "unsupported-version") + ("OTA_Ping:Handshaking", "unsupported-version"), ] - + for request_name, version in test_cases: print(f"\n๏ฟฝ Testing: {request_name} (v{version})") - + # Check if supported first is_supported = server.is_action_supported(request_name, version) print(f" Supported: {is_supported}") - + # Handle the request response = await server.handle_request(request_name, test_xml, version) print(f" Status: {response.status_code}") @@ -565,9 +612,9 @@ async def main(): print(f" Response: {response.xml_content[:100]}...") else: print(f" Response: {response.xml_content}") - + print("\nโœ… Demo completed successfully!") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/alpine_bits_python/api.py b/src/alpine_bits_python/api.py index 7051914..f11874c 100644 --- a/src/alpine_bits_python/api.py +++ b/src/alpine_bits_python/api.py @@ -1,18 +1,29 @@ -from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, APIRouter, Form, File, UploadFile +from fastapi import ( + FastAPI, + HTTPException, + BackgroundTasks, + Request, + Depends, + APIRouter, + Form, + File, + UploadFile, +) +from fastapi.concurrency import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer, HTTPBasicCredentials, HTTPBasic from .config_loader import load_config from fastapi.responses import HTMLResponse, PlainTextResponse, Response from .models import WixFormSubmission -from datetime import datetime, date, timezone +from datetime import datetime, date, timezone from .auth import validate_api_key, validate_wix_signature, generate_api_key from .rate_limit import ( - limiter, - webhook_limiter, + limiter, + webhook_limiter, custom_rate_limit_handler, DEFAULT_RATE_LIMIT, WEBHOOK_RATE_LIMIT, - BURST_RATE_LIMIT + BURST_RATE_LIMIT, ) from slowapi.errors import RateLimitExceeded import logging @@ -24,8 +35,14 @@ import gzip import xml.etree.ElementTree as ET from .alpinebits_server import AlpineBitsServer, Version import urllib.parse +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -from .db import get_async_session, Customer as DBCustomer, Reservation as DBReservation +from .db import ( + Base, + Customer as DBCustomer, + Reservation as DBReservation, + get_database_url, +) # Configure logging @@ -42,12 +59,36 @@ except Exception as e: _LOGGER.error(f"Failed to load config: {str(e)}") config = {} +@asynccontextmanager +async def lifespan(app: FastAPI): + # Setup DB + DATABASE_URL = get_database_url(config) + engine = create_async_engine(DATABASE_URL, echo=True) + AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) + app.state.engine = engine + app.state.async_sessionmaker = AsyncSessionLocal + + # Create tables + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + _LOGGER.info("Database tables checked/created at startup.") + + yield + + # Optional: Dispose engine on shutdown + await engine.dispose() + +async def get_async_session(request: Request): + async_sessionmaker = request.app.state.async_sessionmaker + async with async_sessionmaker() as session: + yield session app = FastAPI( title="Wix Form Handler API", description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting", - version="1.0.0" + version="1.0.0", + lifespan=lifespan ) # Create API router with /api prefix @@ -62,9 +103,9 @@ app.add_middleware( CORSMiddleware, allow_origins=[ "https://*.wix.com", - "https://*.wixstatic.com", + "https://*.wixstatic.com", "http://localhost:3000", # For development - "http://localhost:8000" # For local testing + "http://localhost:8000", # For local testing ], allow_credentials=True, allow_methods=["GET", "POST"], @@ -78,27 +119,39 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None: Add your business logic here. """ try: - _LOGGER.info(f"Processing form submission: {submission_data.get('submissionId')}") - + _LOGGER.info( + f"Processing form submission: {submission_data.get('submissionId')}" + ) + # Example processing - you can replace this with your actual logic - form_name = submission_data.get('formName') - contact_email = submission_data.get('contact', {}).get('email') if submission_data.get('contact') else None - + form_name = submission_data.get("formName") + contact_email = ( + submission_data.get("contact", {}).get("email") + if submission_data.get("contact") + else None + ) + # Extract form fields - form_fields = {k: v for k, v in submission_data.items() if k.startswith('field:')} - - _LOGGER.info(f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}") - + form_fields = { + k: v for k, v in submission_data.items() if k.startswith("field:") + } + + _LOGGER.info( + f"Form: {form_name}, Contact: {contact_email}, Fields: {len(form_fields)}" + ) + # Here you could: # - Save to database # - Send emails # - Call external APIs # - Process the data further - + except Exception as e: _LOGGER.error(f"Error processing form submission: {str(e)}") + + @api_router.get("/") @limiter.limit(DEFAULT_RATE_LIMIT) async def root(request: Request): @@ -111,8 +164,8 @@ async def root(request: Request): "rate_limits": { "default": DEFAULT_RATE_LIMIT, "webhook": WEBHOOK_RATE_LIMIT, - "burst": BURST_RATE_LIMIT - } + "burst": BURST_RATE_LIMIT, + }, } @@ -126,11 +179,10 @@ async def health_check(request: Request): "service": "wix-form-handler", "version": "1.0.0", "authentication": "enabled", - "rate_limiting": "enabled" + "rate_limiting": "enabled", } - # Extracted business logic for handling Wix form submissions async def process_wix_form_submission(request: Request, data: Dict[str, Any], db): """ @@ -138,10 +190,9 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db """ timestamp = datetime.now().isoformat() - _LOGGER.info(f"Received Wix form data at {timestamp}") - #_LOGGER.info(f"Data keys: {list(data.keys())}") - #_LOGGER.info(f"Full data: {json.dumps(data, indent=2)}") + # _LOGGER.info(f"Data keys: {list(data.keys())}") + # _LOGGER.info(f"Full data: {json.dumps(data, indent=2)}") log_entry = { "timestamp": timestamp, "client_ip": request.client.host if request.client else "unknown", @@ -154,9 +205,13 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db if not os.path.exists(logs_dir): os.makedirs(logs_dir, mode=0o755, exist_ok=True) stat_info = os.stat(logs_dir) - _LOGGER.info(f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}") + _LOGGER.info( + f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}" + ) _LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}") - log_filename = f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + log_filename = ( + f"{logs_dir}/wix_test_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + ) with open(log_filename, "w", encoding="utf-8") as f: json.dump(log_entry, f, indent=2, default=str, ensure_ascii=False) file_stat = os.stat(log_filename) @@ -164,16 +219,10 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db _LOGGER.info(f"File mode: {oct(file_stat.st_mode)[-3:]}") _LOGGER.info(f"Data logged to: {log_filename}") - - data = data.get("data") # Handle nested "data" key if present - - - # save customer and reservation to DB - contact_info = data.get("contact", {}) first_name = contact_info.get("name", {}).get("first") last_name = contact_info.get("name", {}).get("last") @@ -193,10 +242,18 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db language = data.get("contact", {}).get("locale", "en")[:2] # Dates - start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value") - end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value") + start_date = ( + data.get("field:date_picker_a7c8") + or data.get("Anreisedatum") + or data.get("submissions", [{}])[1].get("value") + ) + end_date = ( + data.get("field:date_picker_7e65") + or data.get("Abreisedatum") + or data.get("submissions", [{}])[2].get("value") + ) - # Room/guest info + # Room/guest info num_adults = int(data.get("field:number_7cf5") or 2) num_children = int(data.get("field:anzahl_kinder") or 0) children_ages = [] @@ -258,7 +315,7 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db end_date=date.fromisoformat(end_date) if end_date else None, num_adults=num_adults, num_children=num_children, - children_ages=','.join(str(a) for a in children_ages), + children_ages=",".join(str(a) for a in children_ages), offer=offer, utm_comment=utm_comment, created_at=datetime.now(timezone.utc), @@ -277,23 +334,21 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db await db.commit() await db.refresh(db_reservation) - - - return { "status": "success", "message": "Wix form data received successfully", "received_keys": list(data.keys()), "data_logged_to": log_filename, "timestamp": timestamp, - "process_info": log_entry["process_info"], - "note": "No authentication required for this endpoint" + "note": "No authentication required for this endpoint", } @api_router.post("/webhook/wix-form") @webhook_limiter.limit(WEBHOOK_RATE_LIMIT) -async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Depends(get_async_session)): +async def handle_wix_form( + request: Request, data: Dict[str, Any], db_session=Depends(get_async_session) +): """ Unified endpoint to handle Wix form submissions (test and production). No authentication required for this endpoint. @@ -304,16 +359,19 @@ async def handle_wix_form(request: Request, data: Dict[str, Any], db_session=Dep _LOGGER.error(f"Error in handle_wix_form: {str(e)}") # log stacktrace import traceback + traceback_str = traceback.format_exc() _LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}") raise HTTPException( - status_code=500, - detail=f"Error processing Wix form data: {str(e)}" + status_code=500, detail=f"Error processing Wix form data: {str(e)}" ) + @api_router.post("/webhook/wix-form/test") @limiter.limit(DEFAULT_RATE_LIMIT) -async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session=Depends(get_async_session)): +async def handle_wix_form_test( + request: Request, data: Dict[str, Any], db_session=Depends(get_async_session) +): """ Test endpoint to verify the API is working with raw JSON data. No authentication required for testing purposes. @@ -323,40 +381,37 @@ async def handle_wix_form_test(request: Request, data: Dict[str, Any],db_session except Exception as e: _LOGGER.error(f"Error in handle_wix_form_test: {str(e)}") raise HTTPException( - status_code=500, - detail=f"Error processing test data: {str(e)}" + status_code=500, detail=f"Error processing test data: {str(e)}" ) @api_router.post("/admin/generate-api-key") @limiter.limit("5/hour") # Very restrictive for admin operations async def generate_new_api_key( - request: Request, - admin_key: str = Depends(validate_api_key) + request: Request, admin_key: str = Depends(validate_api_key) ): """ Admin endpoint to generate new API keys. Requires admin API key and is heavily rate limited. """ if admin_key != "admin-key": - raise HTTPException( - status_code=403, - detail="Admin access required" - ) - + raise HTTPException(status_code=403, detail="Admin access required") + new_key = generate_api_key() _LOGGER.info(f"Generated new API key (requested by: {admin_key})") - + return { "status": "success", "message": "New API key generated", "api_key": new_key, "timestamp": datetime.now().isoformat(), - "note": "Store this key securely - it won't be shown again" + "note": "Store this key securely - it won't be shown again", } -async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(security_basic)) -> str: +async def validate_basic_auth( + credentials: HTTPBasicCredentials = Depends(security_basic), +) -> str: """ Validate basic authentication for AlpineBits protocol. Returns username if valid, raises HTTPException if not. @@ -369,8 +424,11 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi headers={"WWW-Authenticate": "Basic"}, ) valid = False - for entry in config['alpine_bits_auth']: - if credentials.username == entry['username'] and credentials.password == entry['password']: + for entry in config["alpine_bits_auth"]: + if ( + credentials.username == entry["username"] + and credentials.password == entry["password"] + ): valid = True break if not valid: @@ -379,7 +437,9 @@ async def validate_basic_auth(credentials: HTTPBasicCredentials = Depends(securi detail="ERROR: Invalid credentials", headers={"WWW-Authenticate": "Basic"}, ) - _LOGGER.info(f"AlpineBits authentication successful for user: {credentials.username} (from config)") + _LOGGER.info( + f"AlpineBits authentication successful for user: {credentials.username} (from config)" + ) return credentials.username @@ -390,10 +450,9 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]: """ if "multipart/form-data" not in content_type: raise HTTPException( - status_code=400, - detail="ERROR: Content-Type must be multipart/form-data" + status_code=400, detail="ERROR: Content-Type must be multipart/form-data" ) - + # Extract boundary boundary = None for part in content_type.split(";"): @@ -401,62 +460,56 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]: if part.startswith("boundary="): boundary = part.split("=", 1)[1].strip('"') break - + if not boundary: raise HTTPException( - status_code=400, - detail="ERROR: Missing boundary in multipart/form-data" + status_code=400, detail="ERROR: Missing boundary in multipart/form-data" ) - + # Simple multipart parsing parts = body.split(f"--{boundary}".encode()) data = {} - + for part in parts: if not part.strip() or part.strip() == b"--": continue - + # Split headers and content if b"\r\n\r\n" in part: headers_section, content = part.split(b"\r\n\r\n", 1) content = content.rstrip(b"\r\n") - + # Parse Content-Disposition header - headers = headers_section.decode('utf-8', errors='ignore') + headers = headers_section.decode("utf-8", errors="ignore") name = None - for line in headers.split('\n'): - if 'Content-Disposition' in line and 'name=' in line: + for line in headers.split("\n"): + if "Content-Disposition" in line and "name=" in line: # Extract name parameter - for param in line.split(';'): + for param in line.split(";"): param = param.strip() - if param.startswith('name='): - name = param.split('=', 1)[1].strip('"') + if param.startswith("name="): + name = param.split("=", 1)[1].strip('"') break - + if name: # Handle file uploads or text content - if content.startswith(b'<'): + if content.startswith(b"<"): # Likely XML content - data[name] = content.decode('utf-8', errors='ignore') + data[name] = content.decode("utf-8", errors="ignore") else: - data[name] = content.decode('utf-8', errors='ignore') - + data[name] = content.decode("utf-8", errors="ignore") + return data - - - - @api_router.post("/alpinebits/server-2024-10") @limiter.limit("60/minute") async def alpinebits_server_handshake( - request: Request, - username: str = Depends(validate_basic_auth) + request: Request, username: str = Depends(validate_basic_auth) ): """ AlpineBits server endpoint implementing the handshake protocol. - + This endpoint handles: - Protocol version negotiation via X-AlpineBits-ClientProtocolVersion header - Client identification via X-AlpineBits-ClientID header (optional) @@ -464,62 +517,67 @@ async def alpinebits_server_handshake( - Gzip compression support - Proper error handling with HTTP status codes - Handshaking action processing - + Authentication: HTTP Basic Auth required Content-Type: multipart/form-data Compression: gzip supported (check X-AlpineBits-Server-Accept-Encoding) """ try: # Check required headers - client_protocol_version = request.headers.get("X-AlpineBits-ClientProtocolVersion") + client_protocol_version = request.headers.get( + "X-AlpineBits-ClientProtocolVersion" + ) if not client_protocol_version: # Server concludes client speaks a protocol version preceding 2013-04 client_protocol_version = "pre-2013-04" - _LOGGER.info("No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04") + _LOGGER.info( + "No X-AlpineBits-ClientProtocolVersion header found, assuming pre-2013-04" + ) else: _LOGGER.info(f"Client protocol version: {client_protocol_version}") - + # Optional client ID client_id = request.headers.get("X-AlpineBits-ClientID") if client_id: _LOGGER.info(f"Client ID: {client_id}") - + # Check content encoding content_encoding = request.headers.get("Content-Encoding") is_compressed = content_encoding == "gzip" - + if is_compressed: _LOGGER.info("Request is gzip compressed") - + # Get content type before processing content_type = request.headers.get("Content-Type", "") _LOGGER.info(f"Content-Type: {content_type}") _LOGGER.info(f"Content-Encoding: {content_encoding}") - + # Get request body body = await request.body() - + # Decompress if needed if is_compressed: try: - body = gzip.decompress(body) - except Exception as e: raise HTTPException( status_code=400, - detail=f"ERROR: Failed to decompress gzip content: {str(e)}" + detail=f"ERROR: Failed to decompress gzip content: {str(e)}", ) - + # Check content type (after decompression) - if "multipart/form-data" not in content_type and "application/x-www-form-urlencoded" not in content_type: + if ( + "multipart/form-data" not in content_type + and "application/x-www-form-urlencoded" not in content_type + ): raise HTTPException( status_code=400, - detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded" + detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded", ) - + # Parse multipart data if "multipart/form-data" in content_type: try: @@ -527,7 +585,7 @@ async def alpinebits_server_handshake( except Exception as e: raise HTTPException( status_code=400, - detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}" + detail=f"ERROR: Failed to parse multipart/form-data: {str(e)}", ) elif "application/x-www-form-urlencoded" in content_type: # Parse as urlencoded @@ -535,75 +593,59 @@ async def alpinebits_server_handshake( else: raise HTTPException( status_code=400, - detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded" + detail="ERROR: Content-Type must be multipart/form-data or application/x-www-form-urlencoded", ) - + # Check for required action parameter action = form_data.get("action") if not action: raise HTTPException( - status_code=400, - detail="ERROR: Missing required 'action' parameter") - + status_code=400, detail="ERROR: Missing required 'action' parameter" + ) + _LOGGER.info(f"AlpineBits action: {action}") - # Get optional request XML - request_xml = form_data.get("request") - + request_xml = form_data.get("request") server = AlpineBitsServer() version = Version.V2024_10 - - # Create successful handshake response - response = await server.handle_request(action, request_xml, version) + response = await server.handle_request(action, request_xml, version) response_xml = response.xml_content - + # Set response headers indicating server capabilities headers = { "Content-Type": "application/xml; charset=utf-8", "X-AlpineBits-Server-Accept-Encoding": "gzip", # Indicate gzip support - "X-AlpineBits-Server-Version": "2024-10" + "X-AlpineBits-Server-Version": "2024-10", } - - return Response( - content=response_xml, - status_code=response.status_code, - headers=headers - ) - - + return Response( + content=response_xml, status_code=response.status_code, headers=headers + ) + except HTTPException: # Re-raise HTTP exceptions (auth errors, etc.) raise except Exception as e: _LOGGER.error(f"Error in AlpineBits handshake: {str(e)}") - raise HTTPException( - status_code=500, - detail=f"Internal server error: {str(e)}" - ) + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + @api_router.get("/admin/stats") @limiter.limit("10/minute") -async def get_api_stats( - request: Request, - admin_key: str = Depends(validate_api_key) -): +async def get_api_stats(request: Request, admin_key: str = Depends(validate_api_key)): """ Admin endpoint to get API usage statistics. Requires admin API key. """ if admin_key != "admin-key": - raise HTTPException( - status_code=403, - detail="Admin access required" - ) - + raise HTTPException(status_code=403, detail="Admin access required") + # In a real application, you'd fetch this from your database/monitoring system return { "status": "success", @@ -611,9 +653,9 @@ async def get_api_stats( "uptime": "Available in production deployment", "total_requests": "Available with monitoring setup", "active_api_keys": len([k for k in ["wix-webhook-key", "admin-key"] if k]), - "rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory" + "rate_limit_backend": "redis" if os.getenv("REDIS_URL") else "memory", }, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } @@ -629,11 +671,12 @@ async def landing_page(): try: # Get the path to the HTML file import os + html_path = os.path.join(os.path.dirname(__file__), "templates", "index.html") - + with open(html_path, "r", encoding="utf-8") as f: html_content = f.read() - + return HTMLResponse(content=html_content, status_code=200) except FileNotFoundError: # Fallback if HTML file is not found @@ -660,4 +703,5 @@ async def landing_page(): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/alpine_bits_python/auth.py b/src/alpine_bits_python/auth.py index 21d29bc..5a7632e 100644 --- a/src/alpine_bits_python/auth.py +++ b/src/alpine_bits_python/auth.py @@ -21,7 +21,7 @@ security = HTTPBearer() API_KEYS = { # Example API keys - replace with your own secure keys "wix-webhook-key": "sk_live_your_secure_api_key_here", - "admin-key": "sk_admin_your_admin_key_here" + "admin-key": "sk_admin_your_admin_key_here", } # Load API keys from environment if available @@ -36,19 +36,21 @@ def generate_api_key() -> str: return f"sk_live_{secrets.token_urlsafe(32)}" -def validate_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: +def validate_api_key( + credentials: HTTPAuthorizationCredentials = Security(security), +) -> str: """ Validate API key from Authorization header. Expected format: Authorization: Bearer your_api_key_here """ token = credentials.credentials - + # Check if the token is in our valid API keys for key_name, valid_key in API_KEYS.items(): if secrets.compare_digest(token, valid_key): logger.info(f"Valid API key used: {key_name}") return key_name - + logger.warning(f"Invalid API key attempted: {token[:10]}...") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -64,19 +66,17 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool: """ if not signature or not secret: return False - + try: # Remove 'sha256=' prefix if present - if signature.startswith('sha256='): + if signature.startswith("sha256="): signature = signature[7:] - + # Calculate expected signature expected_signature = hmac.new( - secret.encode('utf-8'), - payload, - hashlib.sha256 + secret.encode("utf-8"), payload, hashlib.sha256 ).hexdigest() - + # Compare signatures securely return secrets.compare_digest(signature, expected_signature) except Exception as e: @@ -86,21 +86,21 @@ def validate_wix_signature(payload: bytes, signature: str, secret: str) -> bool: class APIKeyAuth: """Simple API key authentication class""" - + def __init__(self, api_keys: dict): self.api_keys = api_keys - + def authenticate(self, api_key: str) -> Optional[str]: """Authenticate an API key and return the key name if valid""" for key_name, valid_key in self.api_keys.items(): if secrets.compare_digest(api_key, valid_key): return key_name return None - + def add_key(self, name: str, key: str): """Add a new API key""" self.api_keys[name] = key - + def remove_key(self, name: str): """Remove an API key""" if name in self.api_keys: @@ -108,4 +108,4 @@ class APIKeyAuth: # Initialize auth system -auth_system = APIKeyAuth(API_KEYS) \ No newline at end of file +auth_system = APIKeyAuth(API_KEYS) diff --git a/src/alpine_bits_python/config_loader.py b/src/alpine_bits_python/config_loader.py index 59fcaeb..b207b4d 100644 --- a/src/alpine_bits_python/config_loader.py +++ b/src/alpine_bits_python/config_loader.py @@ -1,4 +1,3 @@ - import os from pathlib import Path from typing import Any, Dict, List, Optional @@ -16,37 +15,45 @@ from annotatedyaml.loader import ( from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid # --- Voluptuous schemas --- -database_schema = Schema({ - Required('url'): str -}, extra=PREVENT_EXTRA) +database_schema = Schema({Required("url"): str}, extra=PREVENT_EXTRA) - -hotel_auth_schema = Schema({ - Required("hotel_id"): str, - Required("hotel_name"): str, - Required("username"): str, - Required("password"): str -}, extra=PREVENT_EXTRA) - -basic_auth_schema = Schema( - All([hotel_auth_schema], Length(min=1)) +hotel_auth_schema = Schema( + { + Required("hotel_id"): str, + Required("hotel_name"): str, + Required("username"): str, + Required("password"): str, + }, + extra=PREVENT_EXTRA, ) -config_schema = Schema({ - Required('database'): database_schema, - Required('alpine_bits_auth'): basic_auth_schema -}, extra=PREVENT_EXTRA) +basic_auth_schema = Schema(All([hotel_auth_schema], Length(min=1))) -DEFAULT_CONFIG_FILE = 'config.yaml' +config_schema = Schema( + { + Required("database"): database_schema, + Required("alpine_bits_auth"): basic_auth_schema, + }, + extra=PREVENT_EXTRA, +) + +DEFAULT_CONFIG_FILE = "config.yaml" class Config: - def __init__(self, config_folder: str | Path = None, config_name: str = DEFAULT_CONFIG_FILE, testing_mode: bool = False): + def __init__( + self, + config_folder: str | Path = None, + config_name: str = DEFAULT_CONFIG_FILE, + testing_mode: bool = False, + ): if config_folder is None: - config_folder = os.environ.get('ALPINEBITS_CONFIG_DIR') + config_folder = os.environ.get("ALPINEBITS_CONFIG_DIR") if not config_folder: - config_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../config')) + config_folder = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../config") + ) if isinstance(config_folder, str): config_folder = Path(config_folder) self.config_folder = config_folder @@ -61,8 +68,8 @@ class Config: validated = config_schema(stuff) except MultipleInvalid as e: raise ValueError(f"Config validation error: {e}") - self.database = validated['database'] - self.basic_auth = validated['alpine_bits_auth'] + self.database = validated["database"] + self.basic_auth = validated["alpine_bits_auth"] self.config = validated def get(self, key, default=None): @@ -70,19 +77,20 @@ class Config: @property def db_url(self) -> str: - return self.database['url'] + return self.database["url"] @property def hotel_id(self) -> str: - return self.basic_auth['hotel_id'] + return self.basic_auth["hotel_id"] @property def hotel_name(self) -> str: - return self.basic_auth['hotel_name'] + return self.basic_auth["hotel_name"] @property def users(self) -> List[Dict[str, str]]: - return self.basic_auth['users'] + return self.basic_auth["users"] + # For backward compatibility def load_config(): diff --git a/src/alpine_bits_python/db.py b/src/alpine_bits_python/db.py index 09d039b..32a82d0 100644 --- a/src/alpine_bits_python/db.py +++ b/src/alpine_bits_python/db.py @@ -5,27 +5,24 @@ import os Base = declarative_base() + # Async SQLAlchemy setup def get_database_url(config=None): db_url = None - if config and 'database' in config and 'url' in config['database']: - db_url = config['database']['url'] + if config and "database" in config and "url" in config["database"]: + db_url = config["database"]["url"] if not db_url: - db_url = os.environ.get('DATABASE_URL') + db_url = os.environ.get("DATABASE_URL") if not db_url: - db_url = 'sqlite+aiosqlite:///alpinebits.db' + db_url = "sqlite+aiosqlite:///alpinebits.db" return db_url -DATABASE_URL = get_database_url() -engine = create_async_engine(DATABASE_URL, echo=True) -AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) -async def get_async_session(): - async with AsyncSessionLocal() as session: - yield session + + class Customer(Base): - __tablename__ = 'customers' + __tablename__ = "customers" id = Column(Integer, primary_key=True) given_name = Column(String) contact_id = Column(String, unique=True) @@ -42,13 +39,14 @@ class Customer(Base): birth_date = Column(String) language = Column(String) address_catalog = Column(Boolean) # Added for XML - name_title = Column(String) # Added for XML - reservations = relationship('Reservation', back_populates='customer') + name_title = Column(String) # Added for XML + reservations = relationship("Reservation", back_populates="customer") + class Reservation(Base): - __tablename__ = 'reservations' + __tablename__ = "reservations" id = Column(Integer, primary_key=True) - customer_id = Column(Integer, ForeignKey('customers.id')) + customer_id = Column(Integer, ForeignKey("customers.id")) form_id = Column(String, unique=True) start_date = Column(Date) end_date = Column(Date) @@ -70,16 +68,14 @@ class Reservation(Base): # Add hotel_code and hotel_name for XML hotel_code = Column(String) hotel_name = Column(String) - customer = relationship('Customer', back_populates='reservations') + customer = relationship("Customer", back_populates="reservations") + class HashedCustomer(Base): - __tablename__ = 'hashed_customers' + __tablename__ = "hashed_customers" id = Column(Integer, primary_key=True) customer_id = Column(Integer) hashed_email = Column(String) hashed_phone = Column(String) hashed_name = Column(String) redacted_at = Column(DateTime) - - - diff --git a/src/alpine_bits_python/main.py b/src/alpine_bits_python/main.py index b6cc021..c5f203d 100644 --- a/src/alpine_bits_python/main.py +++ b/src/alpine_bits_python/main.py @@ -15,11 +15,16 @@ from .simplified_access import ( HotelReservationIdData, PhoneTechType, AlpineBitsFactory, - OtaMessageType + OtaMessageType, ) # DB and config -from .db import Customer as DBCustomer, Reservation as DBReservation, HashedCustomer, get_async_session +from .db import ( + Customer as DBCustomer, + Reservation as DBReservation, + HashedCustomer, + get_async_session, +) from .config_loader import load_config import hashlib import json @@ -29,8 +34,8 @@ import asyncio from alpine_bits_python import db -async def main(): +async def main(): print("๐Ÿš€ Starting AlpineBits XML generation script...") # Load config (yaml, annotatedyaml) config = load_config() @@ -40,9 +45,9 @@ async def main(): print(json.dumps(config, indent=2)) # Ensure SQLite DB file exists if using SQLite - db_url = config.get('database', {}).get('url', '') - if db_url.startswith('sqlite+aiosqlite:///'): - db_path = db_url.replace('sqlite+aiosqlite:///', '') + db_url = config.get("database", {}).get("url", "") + if db_url.startswith("sqlite+aiosqlite:///"): + db_path = db_url.replace("sqlite+aiosqlite:///", "") db_path = os.path.abspath(db_path) db_dir = os.path.dirname(db_path) if not os.path.exists(db_dir): @@ -54,15 +59,17 @@ async def main(): # # Ensure DB schema is created (async) from .db import engine, Base + async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async for db in get_async_session(): - - # Load data from JSON file - json_path = os.path.join(os.path.dirname(__file__), '../../test_data/wix_test_data_20250928_132611.json') - with open(json_path, 'r', encoding='utf-8') as f: + json_path = os.path.join( + os.path.dirname(__file__), + "../../test_data/wix_test_data_20250928_132611.json", + ) + with open(json_path, "r", encoding="utf-8") as f: wix_data = json.load(f) data = wix_data["data"]["data"] @@ -85,8 +92,16 @@ async def main(): language = data.get("contact", {}).get("locale", "en")[:2] # Dates - start_date = data.get("field:date_picker_a7c8") or data.get("Anreisedatum") or data.get("submissions", [{}])[1].get("value") - end_date = data.get("field:date_picker_7e65") or data.get("Abreisedatum") or data.get("submissions", [{}])[2].get("value") + start_date = ( + data.get("field:date_picker_a7c8") + or data.get("Anreisedatum") + or data.get("submissions", [{}])[1].get("value") + ) + end_date = ( + data.get("field:date_picker_7e65") + or data.get("Abreisedatum") + or data.get("submissions", [{}])[2].get("value") + ) # Room/guest info num_adults = int(data.get("field:number_7cf5") or 2) @@ -100,7 +115,7 @@ async def main(): children_ages.append(age) except ValueError: logging.warning(f"Invalid age value for {k}: {data[k]}") - + # UTM and offer utm_fields = [ ("utm_Source", "utm_source"), @@ -147,7 +162,7 @@ async def main(): end_date=date.fromisoformat(end_date) if end_date else None, num_adults=num_adults, num_children=num_children, - children_ages=','.join(str(a) for a in children_ages), + children_ages=",".join(str(a) for a in children_ages), offer=offer, utm_comment=utm_comment, created_at=datetime.now(timezone.utc), @@ -177,9 +192,19 @@ async def main(): def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): - from .simplified_access import CustomerData, GuestCountsFactory, HotelReservationIdData, AlpineBitsFactory, OtaMessageType, CommentData, CommentsData, CommentListItemData + from .simplified_access import ( + CustomerData, + GuestCountsFactory, + HotelReservationIdData, + AlpineBitsFactory, + OtaMessageType, + CommentData, + CommentsData, + CommentListItemData, + ) from .generated import alpinebits as ab from datetime import datetime, timezone + # Prepare data for XML phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else [] customer_data = CustomerData( @@ -200,11 +225,15 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): language=customer.language, ) alpine_bits_factory = AlpineBitsFactory() - res_guests = alpine_bits_factory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) + res_guests = alpine_bits_factory.create_res_guests( + customer_data, OtaMessageType.RETRIEVE + ) # Guest counts children_ages = [int(a) for a in reservation.children_ages.split(",") if a] - guest_counts = GuestCountsFactory.create_retrieve_guest_counts(reservation.num_adults, children_ages) + guest_counts = GuestCountsFactory.create_retrieve_guest_counts( + reservation.num_adults, children_ages + ) # UniqueID unique_id = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId( @@ -214,11 +243,13 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): # TimeSpan time_span = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan( start=reservation.start_date.isoformat() if reservation.start_date else None, - end=reservation.end_date.isoformat() if reservation.end_date else None + end=reservation.end_date.isoformat() if reservation.end_date else None, ) - room_stay = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay( - time_span=time_span, - guest_counts=guest_counts, + room_stay = ( + ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay( + time_span=time_span, + guest_counts=guest_counts, + ) ) room_stays = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays( room_stay=[room_stay], @@ -231,7 +262,9 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): res_id_source=None, res_id_source_context="99tales", ) - hotel_res_id = alpine_bits_factory.create(hotel_res_id_data, OtaMessageType.RETRIEVE) + hotel_res_id = alpine_bits_factory.create( + hotel_res_id_data, OtaMessageType.RETRIEVE + ) hotel_res_ids = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds( hotel_reservation_id=[hotel_res_id] ) @@ -244,31 +277,37 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): offer_comment = CommentData( name=ab.CommentName2.ADDITIONAL_INFO, text="Angebot/Offerta", - list_items=[CommentListItemData( - value=reservation.offer, - language=customer.language, - list_item="1", - )], + list_items=[ + CommentListItemData( + value=reservation.offer, + language=customer.language, + list_item="1", + ) + ], ) comment = None if reservation.user_comment: comment = CommentData( name=ab.CommentName2.CUSTOMER_COMMENT, text=reservation.user_comment, - list_items=[CommentListItemData( - value="Landing page comment", - language=customer.language, - list_item="1", - )], + list_items=[ + CommentListItemData( + value="Landing page comment", + language=customer.language, + list_item="1", + ) + ], ) comments = [offer_comment, comment] if comment else [offer_comment] comments_data = CommentsData(comments=comments) comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE) - res_global_info = ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo( - hotel_reservation_ids=hotel_res_ids, - basic_property_info=basic_property_info, - comments=comments_xml, + res_global_info = ( + ab.OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo( + hotel_reservation_ids=hotel_res_ids, + basic_property_info=basic_property_info, + comments=comments_xml, + ) ) hotel_reservation = ab.OtaResRetrieveRs.ReservationsList.HotelReservation( @@ -293,6 +332,7 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): print("โœ… Pydantic validation successful!") from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata_pydantic.bindings import XmlSerializer + config = SerializerConfig( pretty_print=True, xml_declaration=True, encoding="UTF-8" ) @@ -306,15 +346,18 @@ def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): print("\n๐Ÿ“„ Generated XML:") print(xml_string) from xsdata_pydantic.bindings import XmlParser + parser = XmlParser() with open("output.xml", "r", encoding="utf-8") as infile: xml_content = infile.read() parsed_result = parser.from_string(xml_content, ab.OtaResRetrieveRs) print("โœ… Round-trip validation successful!") - print(f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}") + print( + f"Parsed reservation status: {parsed_result.reservations_list.hotel_reservation[0].res_status}" + ) except Exception as e: print(f"โŒ Validation/Serialization failed: {e}") + if __name__ == "__main__": asyncio.run(main()) - diff --git a/src/alpine_bits_python/models.py b/src/alpine_bits_python/models.py index ea56a92..27b6f31 100644 --- a/src/alpine_bits_python/models.py +++ b/src/alpine_bits_python/models.py @@ -5,18 +5,23 @@ from datetime import datetime class AlpineBitsHandshakeRequest(BaseModel): """Model for AlpineBits handshake request data""" - action: str = Field(..., description="Action parameter, typically 'OTA_Ping:Handshaking'") + + action: str = Field( + ..., description="Action parameter, typically 'OTA_Ping:Handshaking'" + ) request_xml: Optional[str] = Field(None, description="XML request document") class ContactName(BaseModel): """Contact name structure""" + first: Optional[str] = None last: Optional[str] = None class ContactAddress(BaseModel): """Contact address structure""" + street: Optional[str] = None city: Optional[str] = None state: Optional[str] = None @@ -26,6 +31,7 @@ class ContactAddress(BaseModel): class Contact(BaseModel): """Contact information from Wix form""" + name: Optional[ContactName] = None email: Optional[str] = None locale: Optional[str] = None @@ -43,12 +49,14 @@ class Contact(BaseModel): class SubmissionPdf(BaseModel): """PDF submission structure""" + url: Optional[str] = None filename: Optional[str] = None class WixFormSubmission(BaseModel): """Model for Wix form submission data""" + formName: str submissions: List[Dict[str, Any]] = Field(default_factory=list) submissionTime: str @@ -59,7 +67,7 @@ class WixFormSubmission(BaseModel): submissionPdf: Optional[SubmissionPdf] = None formId: str contact: Optional[Contact] = None - + # Dynamic form fields - these will capture all field:* entries class Config: - extra = "allow" # Allow additional fields not defined in the model \ No newline at end of file + extra = "allow" # Allow additional fields not defined in the model diff --git a/src/alpine_bits_python/rate_limit.py b/src/alpine_bits_python/rate_limit.py index 958e062..638ea59 100644 --- a/src/alpine_bits_python/rate_limit.py +++ b/src/alpine_bits_python/rate_limit.py @@ -11,11 +11,12 @@ logger = logging.getLogger(__name__) # Rate limiting configuration DEFAULT_RATE_LIMIT = "10/minute" # 10 requests per minute per IP WEBHOOK_RATE_LIMIT = "60/minute" # 60 webhook requests per minute per IP -BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP +BURST_RATE_LIMIT = "3/second" # Max 3 requests per second per IP # Redis configuration for distributed rate limiting (optional) REDIS_URL = os.getenv("REDIS_URL", None) + def get_remote_address_with_forwarded(request: Request): """ Get client IP address, considering forwarded headers from proxies/load balancers @@ -25,11 +26,11 @@ def get_remote_address_with_forwarded(request: Request): if forwarded_for: # Take the first IP in the chain return forwarded_for.split(",")[0].strip() - + real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip - + # Fallback to direct connection IP return get_remote_address(request) @@ -39,14 +40,16 @@ if REDIS_URL: # Use Redis for distributed rate limiting (recommended for production) try: import redis + redis_client = redis.from_url(REDIS_URL) limiter = Limiter( - key_func=get_remote_address_with_forwarded, - storage_uri=REDIS_URL + key_func=get_remote_address_with_forwarded, storage_uri=REDIS_URL ) logger.info("Rate limiting initialized with Redis backend") except Exception as e: - logger.warning(f"Failed to connect to Redis: {e}. Using in-memory rate limiting.") + logger.warning( + f"Failed to connect to Redis: {e}. Using in-memory rate limiting." + ) limiter = Limiter(key_func=get_remote_address_with_forwarded) else: # Use in-memory rate limiting (fine for single instance) @@ -65,7 +68,7 @@ def get_api_key_identifier(request: Request) -> str: api_key = auth_header[7:] # Remove "Bearer " prefix # Use first 10 chars of API key as identifier (don't log full key) return f"api_key:{api_key[:10]}" - + # Fallback to IP address return f"ip:{get_remote_address_with_forwarded(request)}" @@ -77,10 +80,10 @@ def api_key_rate_limit_key(request: Request): # Rate limiting decorators for different endpoint types webhook_limiter = Limiter( - key_func=api_key_rate_limit_key, - storage_uri=REDIS_URL if REDIS_URL else None + key_func=api_key_rate_limit_key, storage_uri=REDIS_URL if REDIS_URL else None ) + # Custom rate limit exceeded handler def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded): """Custom handler for rate limit exceeded""" @@ -88,11 +91,11 @@ def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded): f"Rate limit exceeded for {get_remote_address_with_forwarded(request)}: " f"{exc.detail}" ) - + response = _rate_limit_exceeded_handler(request, exc) - + # Add custom headers response.headers["X-RateLimit-Limit"] = str(exc.retry_after) response.headers["X-RateLimit-Retry-After"] = str(exc.retry_after) - - return response \ No newline at end of file + + return response diff --git a/src/alpine_bits_python/reservations.py b/src/alpine_bits_python/reservations.py index 70d483a..5c8f238 100644 --- a/src/alpine_bits_python/reservations.py +++ b/src/alpine_bits_python/reservations.py @@ -1,7 +1,2 @@ - - - def parse_form(form: dict): - pass - \ No newline at end of file diff --git a/src/alpine_bits_python/run_api.py b/src/alpine_bits_python/run_api.py index 9234936..28d921b 100644 --- a/src/alpine_bits_python/run_api.py +++ b/src/alpine_bits_python/run_api.py @@ -2,14 +2,21 @@ """ Startup script for the Wix Form Handler API """ + +import os import uvicorn from .api import app if __name__ == "__main__": + db_path = "alpinebits.db" # Adjust path if needed + if os.path.exists(db_path): + os.remove(db_path) + print(f"Deleted database file: {db_path}") + uvicorn.run( "alpine_bits_python.api:app", host="0.0.0.0", port=8080, reload=True, # Enable auto-reload during development - log_level="info" - ) \ No newline at end of file + log_level="info", + ) diff --git a/src/alpine_bits_python/scripts/setup_security.py b/src/alpine_bits_python/scripts/setup_security.py index 38ebc15..e565a85 100644 --- a/src/alpine_bits_python/scripts/setup_security.py +++ b/src/alpine_bits_python/scripts/setup_security.py @@ -2,6 +2,7 @@ """ Configuration and setup script for the Wix Form Handler API """ + import os import sys import secrets @@ -11,80 +12,83 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from alpine_bits_python.auth import generate_api_key + def generate_secure_keys(): """Generate secure API keys for the application""" - + print("๐Ÿ” Generating Secure API Keys") print("=" * 50) - + # Generate API keys wix_api_key = generate_api_key() admin_api_key = generate_api_key() webhook_secret = secrets.token_urlsafe(32) - + print(f"๐Ÿ”‘ Wix Webhook API Key: {wix_api_key}") print(f"๐Ÿ” Admin API Key: {admin_api_key}") print(f"๐Ÿ”’ Webhook Secret: {webhook_secret}") - + print("\n๐Ÿ“‹ Environment Variables") print("-" * 30) print(f"export WIX_API_KEY='{wix_api_key}'") print(f"export ADMIN_API_KEY='{admin_api_key}'") print(f"export WIX_WEBHOOK_SECRET='{webhook_secret}'") print(f"export REDIS_URL='redis://localhost:6379' # Optional for production") - + print("\n๐Ÿ”ง .env File Content") print("-" * 20) print(f"WIX_API_KEY={wix_api_key}") print(f"ADMIN_API_KEY={admin_api_key}") print(f"WIX_WEBHOOK_SECRET={webhook_secret}") print("REDIS_URL=redis://localhost:6379") - + # Optionally write to .env file create_env = input("\nโ“ Create .env file? (y/n): ").lower().strip() - if create_env == 'y': + if create_env == "y": # Create .env in the project root (two levels up from scripts) - env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') - with open(env_path, 'w') as f: + env_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env" + ) + with open(env_path, "w") as f: f.write(f"WIX_API_KEY={wix_api_key}\n") f.write(f"ADMIN_API_KEY={admin_api_key}\n") f.write(f"WIX_WEBHOOK_SECRET={webhook_secret}\n") f.write("REDIS_URL=redis://localhost:6379\n") print(f"โœ… .env file created at {env_path}!") print("โš ๏ธ Add .env to your .gitignore file!") - + print("\n๐ŸŒ Wix Configuration") print("-" * 20) print("1. In your Wix site, go to Settings > Webhooks") print("2. Add webhook URL: https://yourdomain.com/webhook/wix-form") print("3. Add custom header: Authorization: Bearer " + wix_api_key) print("4. Optionally configure webhook signature with the secret above") - + return { - 'wix_api_key': wix_api_key, - 'admin_api_key': admin_api_key, - 'webhook_secret': webhook_secret + "wix_api_key": wix_api_key, + "admin_api_key": admin_api_key, + "webhook_secret": webhook_secret, } def check_security_setup(): """Check current security configuration""" - + print("๐Ÿ” Security Configuration Check") print("=" * 40) - + # Check environment variables - wix_key = os.getenv('WIX_API_KEY') - admin_key = os.getenv('ADMIN_API_KEY') - webhook_secret = os.getenv('WIX_WEBHOOK_SECRET') - redis_url = os.getenv('REDIS_URL') - + wix_key = os.getenv("WIX_API_KEY") + admin_key = os.getenv("ADMIN_API_KEY") + webhook_secret = os.getenv("WIX_WEBHOOK_SECRET") + redis_url = os.getenv("REDIS_URL") + print("Environment Variables:") print(f" WIX_API_KEY: {'โœ… Set' if wix_key else 'โŒ Not set'}") print(f" ADMIN_API_KEY: {'โœ… Set' if admin_key else 'โŒ Not set'}") print(f" WIX_WEBHOOK_SECRET: {'โœ… Set' if webhook_secret else 'โŒ Not set'}") print(f" REDIS_URL: {'โœ… Set' if redis_url else 'โš ๏ธ Optional (using in-memory)'}") - + # Security recommendations print("\n๐Ÿ›ก๏ธ Security Recommendations:") if not wix_key: @@ -94,19 +98,19 @@ def check_security_setup(): print(" โš ๏ธ WIX_API_KEY should be longer for better security") else: print(" โœ… WIX_API_KEY looks secure") - + if not admin_key: print(" โŒ Set ADMIN_API_KEY environment variable") elif wix_key and admin_key == wix_key: print(" โŒ Admin and Wix keys should be different") else: print(" โœ… ADMIN_API_KEY configured") - + if not webhook_secret: print(" โš ๏ธ Consider setting WIX_WEBHOOK_SECRET for signature validation") else: print(" โœ… Webhook signature validation enabled") - + print("\n๐Ÿš€ Production Checklist:") print(" - Use HTTPS in production") print(" - Set up Redis for distributed rate limiting") @@ -118,12 +122,14 @@ def check_security_setup(): if __name__ == "__main__": print("๐Ÿ” Wix Form Handler API - Security Setup") print("=" * 50) - - choice = input("Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): ").strip() - + + choice = input( + "Choose an option:\n1. Generate new API keys\n2. Check current setup\n\nEnter choice (1 or 2): " + ).strip() + if choice == "1": generate_secure_keys() elif choice == "2": check_security_setup() else: - print("Invalid choice. Please run again and choose 1 or 2.") \ No newline at end of file + print("Invalid choice. Please run again and choose 1 or 2.") diff --git a/src/alpine_bits_python/scripts/test_api.py b/src/alpine_bits_python/scripts/test_api.py index 76ed30d..021f621 100644 --- a/src/alpine_bits_python/scripts/test_api.py +++ b/src/alpine_bits_python/scripts/test_api.py @@ -2,6 +2,7 @@ """ Test script for the Secure Wix Form Handler API """ + import asyncio import aiohttp import json @@ -30,7 +31,7 @@ SAMPLE_WIX_DATA = { "submissionsLink": "https://www.wix.app/forms/test-form/submissions", "submissionPdf": { "url": "https://example.com/submission.pdf", - "filename": "submission.pdf" + "filename": "submission.pdf", }, "formId": "test-form-789", "field:email_5139": "test@example.com", @@ -43,10 +44,7 @@ SAMPLE_WIX_DATA = { "field:alter_kind_4": "12", "field:long_answer_3524": "This is a long answer field with more details about the inquiry.", "contact": { - "name": { - "first": "John", - "last": "Doe" - }, + "name": {"first": "John", "last": "Doe"}, "email": "test@example.com", "locale": "de", "company": "Test Company", @@ -57,29 +55,29 @@ SAMPLE_WIX_DATA = { "street": "Test Street 123", "city": "Test City", "country": "Germany", - "postalCode": "12345" + "postalCode": "12345", }, "jobTitle": "Manager", "phone": "+1234567890", "createdDate": "2024-03-20T10:00:00.000Z", - "updatedDate": "2024-03-20T10:30:00.000Z" - } + "updatedDate": "2024-03-20T10:30:00.000Z", + }, } async def test_api(): """Test the API endpoints with authentication""" - + headers_with_auth = { "Content-Type": "application/json", - "Authorization": f"Bearer {TEST_API_KEY}" + "Authorization": f"Bearer {TEST_API_KEY}", } - + admin_headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {ADMIN_API_KEY}" + "Content-Type": "application/json", + "Authorization": f"Bearer {ADMIN_API_KEY}", } - + async with aiohttp.ClientSession() as session: # Test health endpoint (no auth required) print("1. Testing health endpoint (no auth)...") @@ -89,7 +87,7 @@ async def test_api(): print(f" โœ… Health check: {response.status} - {result.get('status')}") except Exception as e: print(f" โŒ Health check failed: {e}") - + # Test root endpoint (no auth required) print("\n2. Testing root endpoint (no auth)...") try: @@ -98,87 +96,94 @@ async def test_api(): print(f" โœ… Root: {response.status} - {result.get('message')}") except Exception as e: print(f" โŒ Root endpoint failed: {e}") - + # Test webhook endpoint without auth (should fail) print("\n3. Testing webhook endpoint WITHOUT auth (should fail)...") try: async with session.post( f"{BASE_URL}/api/webhook/wix-form", json=SAMPLE_WIX_DATA, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, ) as response: result = await response.json() if response.status == 401: - print(f" โœ… Correctly rejected: {response.status} - {result.get('detail')}") + print( + f" โœ… Correctly rejected: {response.status} - {result.get('detail')}" + ) else: print(f" โŒ Unexpected response: {response.status} - {result}") except Exception as e: print(f" โŒ Test failed: {e}") - + # Test webhook endpoint with valid auth print("\n4. Testing webhook endpoint WITH valid auth...") try: async with session.post( f"{BASE_URL}/api/webhook/wix-form", json=SAMPLE_WIX_DATA, - headers=headers_with_auth + headers=headers_with_auth, ) as response: result = await response.json() if response.status == 200: - print(f" โœ… Webhook success: {response.status} - {result.get('status')}") + print( + f" โœ… Webhook success: {response.status} - {result.get('status')}" + ) else: print(f" โŒ Webhook failed: {response.status} - {result}") except Exception as e: print(f" โŒ Webhook test failed: {e}") - + # Test test endpoint with auth print("\n5. Testing simple test endpoint WITH auth...") try: async with session.post( f"{BASE_URL}/api/webhook/wix-form/test", json={"test": "data", "timestamp": datetime.now().isoformat()}, - headers=headers_with_auth + headers=headers_with_auth, ) as response: result = await response.json() if response.status == 200: - print(f" โœ… Test endpoint: {response.status} - {result.get('status')}") + print( + f" โœ… Test endpoint: {response.status} - {result.get('status')}" + ) else: print(f" โŒ Test endpoint failed: {response.status} - {result}") except Exception as e: print(f" โŒ Test endpoint failed: {e}") - + # Test rate limiting by making multiple rapid requests print("\n6. Testing rate limiting (making 5 rapid requests)...") rate_limit_test_count = 0 for i in range(5): try: - async with session.get( - f"{BASE_URL}/api/health" - ) as response: + async with session.get(f"{BASE_URL}/api/health") as response: if response.status == 200: rate_limit_test_count += 1 elif response.status == 429: - print(f" โœ… Rate limit triggered on request {i+1}") + print(f" โœ… Rate limit triggered on request {i + 1}") break except Exception as e: print(f" โŒ Rate limit test failed: {e}") break - + if rate_limit_test_count == 5: print(" โ„น๏ธ No rate limit reached (normal for low request volume)") - + # Test admin endpoint (if admin key is configured) print("\n7. Testing admin stats endpoint...") try: async with session.get( - f"{BASE_URL}/api/admin/stats", - headers=admin_headers + f"{BASE_URL}/api/admin/stats", headers=admin_headers ) as response: result = await response.json() if response.status == 200: - print(f" โœ… Admin stats: {response.status} - {result.get('status')}") + print( + f" โœ… Admin stats: {response.status} - {result.get('status')}" + ) elif response.status == 401: - print(f" โš ๏ธ Admin access denied (API key not configured): {result.get('detail')}") + print( + f" โš ๏ธ Admin access denied (API key not configured): {result.get('detail')}" + ) else: print(f" โŒ Admin endpoint failed: {response.status} - {result}") except Exception as e: @@ -189,12 +194,18 @@ if __name__ == "__main__": print("๐Ÿ”’ Testing Secure Wix Form Handler API...") print("=" * 60) print("๐Ÿ“ API URL:", BASE_URL) - print("๐Ÿ”‘ Using API Key:", TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY) - print("๐Ÿ” Using Admin Key:", ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY) + print( + "๐Ÿ”‘ Using API Key:", + TEST_API_KEY[:20] + "..." if len(TEST_API_KEY) > 20 else TEST_API_KEY, + ) + print( + "๐Ÿ” Using Admin Key:", + ADMIN_API_KEY[:20] + "..." if len(ADMIN_API_KEY) > 20 else ADMIN_API_KEY, + ) print("=" * 60) print("Make sure the API is running with: python3 run_api.py") print("-" * 60) - + try: asyncio.run(test_api()) print("\n" + "=" * 60) @@ -207,4 +218,4 @@ if __name__ == "__main__": print("3. Add Authorization header: Bearer your_api_key") except Exception as e: print(f"\nโŒ Error testing API: {e}") - print("Make sure the API server is running!") \ No newline at end of file + print("Make sure the API server is running!") diff --git a/src/alpine_bits_python/simplified_access.py b/src/alpine_bits_python/simplified_access.py index 4b26d42..b0ecd28 100644 --- a/src/alpine_bits_python/simplified_access.py +++ b/src/alpine_bits_python/simplified_access.py @@ -15,15 +15,26 @@ NotifHotelReservationId = OtaHotelResNotifRq.HotelReservations.HotelReservation. RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId # Define type aliases for Comments types -NotifComments = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments -RetrieveComments = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments -NotifComment = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment -RetrieveComment = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment +NotifComments = ( + OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments +) +RetrieveComments = ( + OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments +) +NotifComment = ( + OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.Comments.Comment +) +RetrieveComment = ( + OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.Comments.Comment +) # type aliases for GuestCounts -NotifGuestCounts = OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts -RetrieveGuestCounts = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts - +NotifGuestCounts = ( + OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.GuestCounts +) +RetrieveGuestCounts = ( + OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts +) # phonetechtype enum 1,3,5 voice, fax, mobile @@ -36,12 +47,13 @@ class PhoneTechType(Enum): # Enum to specify which OTA message type to use class OtaMessageType(Enum): NOTIF = "notification" # For OtaHotelResNotifRq - RETRIEVE = "retrieve" # For OtaResRetrieveRs + RETRIEVE = "retrieve" # For OtaResRetrieveRs @dataclass class KidsAgeData: """Data class to hold information about children's ages.""" + ages: list[int] @@ -77,9 +89,10 @@ class CustomerData: class GuestCountsFactory: - @staticmethod - def create_notif_guest_counts(adults: int, kids: Optional[list[int]] = None) -> NotifGuestCounts: + def create_notif_guest_counts( + adults: int, kids: Optional[list[int]] = None + ) -> NotifGuestCounts: """ Create a GuestCounts object for OtaHotelResNotifRq. :param adults: Number of adults @@ -89,18 +102,23 @@ class GuestCountsFactory: return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts) @staticmethod - def create_retrieve_guest_counts(adults: int, kids: Optional[list[int]] = None) -> RetrieveGuestCounts: + def create_retrieve_guest_counts( + adults: int, kids: Optional[list[int]] = None + ) -> RetrieveGuestCounts: """ Create a GuestCounts object for OtaResRetrieveRs. :param adults: Number of adults :param kids: List of ages for each kid (optional) :return: GuestCounts instance """ - return GuestCountsFactory._create_guest_counts(adults, kids, RetrieveGuestCounts) + return GuestCountsFactory._create_guest_counts( + adults, kids, RetrieveGuestCounts + ) - @staticmethod - def _create_guest_counts(adults: int, kids: Optional[list[int]], guest_counts_class: type) -> Any: + def _create_guest_counts( + adults: int, kids: Optional[list[int]], guest_counts_class: type + ) -> Any: """ Internal method to create a GuestCounts object of the specified type. :param adults: Number of adults @@ -356,9 +374,10 @@ class HotelReservationIdFactory: ) -@dataclass +@dataclass class CommentListItemData: """Simple data class to hold comment list item information.""" + value: str # The text content of the list item list_item: str # Numeric identifier (pattern: [0-9]+) language: str # Two-letter language code (pattern: [a-z][a-z]) @@ -367,6 +386,7 @@ class CommentListItemData: @dataclass class CommentData: """Simple data class to hold comment information without nested type constraints.""" + name: CommentName2 # Required: "included services", "customer comment", "additional info" text: Optional[str] = None # Optional text content list_items: list[CommentListItemData] = None # Optional list items @@ -379,6 +399,7 @@ class CommentData: @dataclass class CommentsData: """Simple data class to hold multiple comments (1-3 max).""" + comments: list[CommentData] = None # 1-3 comments maximum def __post_init__(self): @@ -388,21 +409,23 @@ class CommentsData: class CommentFactory: """Factory class to create Comment instances for both OtaHotelResNotifRq and OtaResRetrieveRs.""" - + @staticmethod def create_notif_comments(data: CommentsData) -> NotifComments: """Create Comments for OtaHotelResNotifRq.""" return CommentFactory._create_comments(NotifComments, NotifComment, data) - + @staticmethod def create_retrieve_comments(data: CommentsData) -> RetrieveComments: """Create Comments for OtaResRetrieveRs.""" return CommentFactory._create_comments(RetrieveComments, RetrieveComment, data) - + @staticmethod - def _create_comments(comments_class: type, comment_class: type, data: CommentsData) -> Any: + def _create_comments( + comments_class: type, comment_class: type, data: CommentsData + ) -> Any: """Internal method to create comments of the specified type.""" - + comments_list = [] for comment_data in data.comments: # Create list items @@ -411,55 +434,53 @@ class CommentFactory: list_item = comment_class.ListItem( value=item_data.value, list_item=item_data.list_item, - language=item_data.language + language=item_data.language, ) list_items.append(list_item) - + # Create comment comment = comment_class( - name=comment_data.name, - text=comment_data.text, - list_item=list_items + name=comment_data.name, text=comment_data.text, list_item=list_items ) comments_list.append(comment) - + # Create comments container return comments_class(comment=comments_list) - + @staticmethod def from_notif_comments(comments: NotifComments) -> CommentsData: """Convert NotifComments back to CommentsData.""" return CommentFactory._comments_to_data(comments) - + @staticmethod def from_retrieve_comments(comments: RetrieveComments) -> CommentsData: """Convert RetrieveComments back to CommentsData.""" return CommentFactory._comments_to_data(comments) - + @staticmethod def _comments_to_data(comments: Any) -> CommentsData: """Internal method to convert any comments type to CommentsData.""" - + comments_data_list = [] for comment in comments.comment: # Extract list items list_items_data = [] if comment.list_item: for list_item in comment.list_item: - list_items_data.append(CommentListItemData( - value=list_item.value, - list_item=list_item.list_item, - language=list_item.language - )) - + list_items_data.append( + CommentListItemData( + value=list_item.value, + list_item=list_item.list_item, + language=list_item.language, + ) + ) + # Extract comment data comment_data = CommentData( - name=comment.name, - text=comment.text, - list_items=list_items_data + name=comment.name, text=comment.text, list_items=list_items_data ) comments_data_list.append(comment_data) - + return CommentsData(comments=comments_data_list) @@ -529,16 +550,19 @@ class ResGuestFactory: class AlpineBitsFactory: """Unified factory class for creating AlpineBits objects with a simple interface.""" - + @staticmethod - def create(data: Union[CustomerData, HotelReservationIdData, CommentsData], message_type: OtaMessageType) -> Any: + def create( + data: Union[CustomerData, HotelReservationIdData, CommentsData], + message_type: OtaMessageType, + ) -> Any: """ Create an AlpineBits object based on the data type and message type. - + Args: data: The data object (CustomerData, HotelReservationIdData, CommentsData, etc.) message_type: Whether to create for NOTIF or RETRIEVE message types - + Returns: The appropriate AlpineBits object based on the data type and message type """ @@ -547,31 +571,35 @@ class AlpineBitsFactory: return CustomerFactory.create_notif_customer(data) else: return CustomerFactory.create_retrieve_customer(data) - + elif isinstance(data, HotelReservationIdData): if message_type == OtaMessageType.NOTIF: return HotelReservationIdFactory.create_notif_hotel_reservation_id(data) else: - return HotelReservationIdFactory.create_retrieve_hotel_reservation_id(data) - + return HotelReservationIdFactory.create_retrieve_hotel_reservation_id( + data + ) + elif isinstance(data, CommentsData): if message_type == OtaMessageType.NOTIF: return CommentFactory.create_notif_comments(data) else: return CommentFactory.create_retrieve_comments(data) - + else: raise ValueError(f"Unsupported data type: {type(data)}") - + @staticmethod - def create_res_guests(customer_data: CustomerData, message_type: OtaMessageType) -> Union[NotifResGuests, RetrieveResGuests]: + def create_res_guests( + customer_data: CustomerData, message_type: OtaMessageType + ) -> Union[NotifResGuests, RetrieveResGuests]: """ Create a complete ResGuests structure with a primary customer. - + Args: customer_data: The customer data message_type: Whether to create for NOTIF or RETRIEVE message types - + Returns: The appropriate ResGuests object """ @@ -579,43 +607,45 @@ class AlpineBitsFactory: return ResGuestFactory.create_notif_res_guests(customer_data) else: return ResGuestFactory.create_retrieve_res_guests(customer_data) - + @staticmethod - def extract_data(obj: Any) -> Union[CustomerData, HotelReservationIdData, CommentsData]: + def extract_data( + obj: Any, + ) -> Union[CustomerData, HotelReservationIdData, CommentsData]: """ Extract data from an AlpineBits object back to a simple data class. - + Args: obj: The AlpineBits object to extract data from - + Returns: The appropriate data object """ # Check if it's a Customer object - if hasattr(obj, 'person_name') and hasattr(obj.person_name, 'given_name'): + if hasattr(obj, "person_name") and hasattr(obj.person_name, "given_name"): if isinstance(obj, NotifCustomer): return CustomerFactory.from_notif_customer(obj) elif isinstance(obj, RetrieveCustomer): return CustomerFactory.from_retrieve_customer(obj) - + # Check if it's a HotelReservationId object - elif hasattr(obj, 'res_id_type'): + elif hasattr(obj, "res_id_type"): if isinstance(obj, NotifHotelReservationId): return HotelReservationIdFactory.from_notif_hotel_reservation_id(obj) elif isinstance(obj, RetrieveHotelReservationId): return HotelReservationIdFactory.from_retrieve_hotel_reservation_id(obj) - + # Check if it's a Comments object - elif hasattr(obj, 'comment'): + elif hasattr(obj, "comment"): if isinstance(obj, NotifComments): return CommentFactory.from_notif_comments(obj) elif isinstance(obj, RetrieveComments): return CommentFactory.from_retrieve_comments(obj) - + # Check if it's a ResGuests object - elif hasattr(obj, 'res_guest'): + elif hasattr(obj, "res_guest"): return ResGuestFactory.extract_primary_customer(obj) - + else: raise ValueError(f"Unsupported object type: {type(obj)}") @@ -733,70 +763,74 @@ if __name__ == "__main__": # Verify roundtrip conversion print("Roundtrip conversion successful:", customer_data == extracted_data) - + print("\n--- Unified AlpineBitsFactory Examples ---") - + # Much simpler approach - single factory with enum parameter! print("=== Customer Creation ===") notif_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF) retrieve_customer = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE) print("Created customers using unified factory") - + print("=== HotelReservationId Creation ===") reservation_id_data = HotelReservationIdData( - res_id_type="123", - res_id_value="RESERVATION-456", - res_id_source="HOTEL_SYSTEM" + res_id_type="123", res_id_value="RESERVATION-456", res_id_source="HOTEL_SYSTEM" ) notif_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.NOTIF) - retrieve_res_id = AlpineBitsFactory.create(reservation_id_data, OtaMessageType.RETRIEVE) + retrieve_res_id = AlpineBitsFactory.create( + reservation_id_data, OtaMessageType.RETRIEVE + ) print("Created reservation IDs using unified factory") - + print("=== Comments Creation ===") - comments_data = CommentsData(comments=[ - CommentData( - name=CommentName2.CUSTOMER_COMMENT, - text="This is a customer comment about the reservation", - list_items=[ - CommentListItemData( - value="Special dietary requirements: vegetarian", - list_item="1", - language="en" - ), - CommentListItemData( - value="Late arrival expected", - list_item="2", - language="en" - ) - ] - ), - CommentData( - name=CommentName2.ADDITIONAL_INFO, - text="Additional information about the stay" - ) - ]) + comments_data = CommentsData( + comments=[ + CommentData( + name=CommentName2.CUSTOMER_COMMENT, + text="This is a customer comment about the reservation", + list_items=[ + CommentListItemData( + value="Special dietary requirements: vegetarian", + list_item="1", + language="en", + ), + CommentListItemData( + value="Late arrival expected", list_item="2", language="en" + ), + ], + ), + CommentData( + name=CommentName2.ADDITIONAL_INFO, + text="Additional information about the stay", + ), + ] + ) notif_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.NOTIF) retrieve_comments = AlpineBitsFactory.create(comments_data, OtaMessageType.RETRIEVE) print("Created comments using unified factory") - + print("=== ResGuests Creation ===") - notif_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF) - retrieve_res_guests = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) + notif_res_guests = AlpineBitsFactory.create_res_guests( + customer_data, OtaMessageType.NOTIF + ) + retrieve_res_guests = AlpineBitsFactory.create_res_guests( + customer_data, OtaMessageType.RETRIEVE + ) print("Created ResGuests using unified factory") - + print("=== Data Extraction ===") # Extract data back using unified interface extracted_customer_data = AlpineBitsFactory.extract_data(notif_customer) extracted_res_id_data = AlpineBitsFactory.extract_data(notif_res_id) extracted_comments_data = AlpineBitsFactory.extract_data(retrieve_comments) extracted_from_res_guests = AlpineBitsFactory.extract_data(retrieve_res_guests) - + print("Data extraction successful:") print("- Customer roundtrip:", customer_data == extracted_customer_data) print("- ReservationId roundtrip:", reservation_id_data == extracted_res_id_data) print("- Comments roundtrip:", comments_data == extracted_comments_data) print("- ResGuests roundtrip:", customer_data == extracted_from_res_guests) - + print("\n--- Comparison with old approach ---") print("Old way required multiple imports and knowing specific factory methods") print("New way: single import, single factory, enum parameter to specify type!") diff --git a/src/alpine_bits_python/util/__init__.py b/src/alpine_bits_python/util/__init__.py index c86dd84..7eff50a 100644 --- a/src/alpine_bits_python/util/__init__.py +++ b/src/alpine_bits_python/util/__init__.py @@ -1 +1 @@ -"""Utility functions for alpine_bits_python.""" \ No newline at end of file +"""Utility functions for alpine_bits_python.""" diff --git a/src/alpine_bits_python/util/__main__.py b/src/alpine_bits_python/util/__main__.py index beb47d9..16f9496 100644 --- a/src/alpine_bits_python/util/__main__.py +++ b/src/alpine_bits_python/util/__main__.py @@ -1,5 +1,6 @@ """Entry point for util package.""" + from .handshake_util import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/alpine_bits_python/util/handshake_util.py b/src/alpine_bits_python/util/handshake_util.py index f82cbae..74409dd 100644 --- a/src/alpine_bits_python/util/handshake_util.py +++ b/src/alpine_bits_python/util/handshake_util.py @@ -2,26 +2,22 @@ from ..generated.alpinebits import OtaPingRq, OtaPingRs from xsdata_pydantic.bindings import XmlParser - - def main(): # test parsing a ping request sample - path = "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml" + path = ( + "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRS.xml" + ) - with open( - path, "r", encoding="utf-8") as f: + with open(path, "r", encoding="utf-8") as f: xml = f.read() # Parse the XML into the request object - # Test parsing back - + # Test parsing back parser = XmlParser() - - parsed_result = parser.from_string(xml, OtaPingRs) print(parsed_result.echo_data) @@ -34,19 +30,14 @@ def main(): print(warning.content[0]) - - - # save json in echo_data to file with indents output_path = "echo_data_response.json" with open(output_path, "w", encoding="utf-8") as out_f: import json + json.dump(json.loads(parsed_result.echo_data), out_f, indent=4) print(f"Saved echo_data json to {output_path}") - if __name__ == "__main__": - - - main() \ No newline at end of file + main() diff --git a/start_api.py b/start_api.py index 2e84fd3..fb51cd2 100644 --- a/start_api.py +++ b/start_api.py @@ -2,12 +2,13 @@ """ Convenience launcher for the Wix Form Handler API """ + import os import subprocess # Change to src directory -src_dir = os.path.join(os.path.dirname(__file__), 'src/alpine_bits_python') +src_dir = os.path.join(os.path.dirname(__file__), "src/alpine_bits_python") # Run the API using uv if __name__ == "__main__": - subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")]) \ No newline at end of file + subprocess.run(["uv", "run", "python", os.path.join(src_dir, "run_api.py")]) diff --git a/test/test_discovery.py b/test/test_discovery.py index 28b347a..2eda866 100644 --- a/test/test_discovery.py +++ b/test/test_discovery.py @@ -5,57 +5,63 @@ discovers implemented vs unimplemented actions. """ from alpine_bits_python.alpinebits_server import ( - ServerCapabilities, - AlpineBitsAction, - AlpineBitsActionName, - Version, + ServerCapabilities, + AlpineBitsAction, + AlpineBitsActionName, + Version, AlpineBitsResponse, - HttpStatusCode + HttpStatusCode, ) import asyncio + class NewImplementedAction(AlpineBitsAction): """A new action that IS implemented.""" - + def __init__(self): self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_INFO_INFO self.version = Version.V2024_10 - - async def handle(self, action: str, request_xml: str, version: Version) -> AlpineBitsResponse: + + async def handle( + self, action: str, request_xml: str, version: Version + ) -> AlpineBitsResponse: """This action is implemented.""" return AlpineBitsResponse("Implemented!", HttpStatusCode.OK) + class NewUnimplementedAction(AlpineBitsAction): """A new action that is NOT implemented (no handle override).""" - + def __init__(self): self.name = AlpineBitsActionName.OTA_HOTEL_DESCRIPTIVE_CONTENT_NOTIF_INFO self.version = Version.V2024_10 - + # Notice: No handle method override - will use default "not implemented" + async def main(): print("๐Ÿ” Testing Action Discovery Logic") print("=" * 50) - + # Create capabilities and see what gets discovered capabilities = ServerCapabilities() - + print("๐Ÿ“‹ Actions found by discovery:") for action_name in capabilities.get_supported_actions(): print(f" โœ… {action_name}") - + print(f"\n๐Ÿ“Š Total discovered: {len(capabilities.get_supported_actions())}") - + # Test the new implemented action implemented_action = NewImplementedAction() result = await implemented_action.handle("test", "", Version.V2024_10) print(f"\n๐ŸŸข NewImplementedAction result: {result.xml_content}") - + # Test the unimplemented action (should use default behavior) unimplemented_action = NewUnimplementedAction() result = await unimplemented_action.handle("test", "", Version.V2024_10) print(f"๐Ÿ”ด NewUnimplementedAction result: {result.xml_content}") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/test/test_simplified_access.py b/test/test_simplified_access.py index d202337..6b1c96a 100644 --- a/test/test_simplified_access.py +++ b/test/test_simplified_access.py @@ -4,11 +4,11 @@ import sys import os # Add the src directory to the path so we can import our modules -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from simplified_access import ( - CustomerData, - CustomerFactory, + CustomerData, + CustomerFactory, ResGuestFactory, HotelReservationIdData, HotelReservationIdFactory, @@ -20,7 +20,7 @@ from simplified_access import ( NotifResGuests, RetrieveResGuests, NotifHotelReservationId, - RetrieveHotelReservationId + RetrieveHotelReservationId, ) @@ -35,7 +35,7 @@ def sample_customer_data(): phone_numbers=[ ("+1234567890", PhoneTechType.MOBILE), ("+0987654321", PhoneTechType.VOICE), - ("+1111111111", None) + ("+1111111111", None), ], email_address="john.doe@example.com", email_newsletter=True, @@ -46,17 +46,14 @@ def sample_customer_data(): address_catalog=False, gender="Male", birth_date="1980-01-01", - language="en" + language="en", ) @pytest.fixture def minimal_customer_data(): """Fixture providing minimal customer data (only required fields).""" - return CustomerData( - given_name="Jane", - surname="Smith" - ) + return CustomerData(given_name="Jane", surname="Smith") @pytest.fixture @@ -66,21 +63,19 @@ def sample_hotel_reservation_id_data(): res_id_type="123", res_id_value="RESERVATION-456", res_id_source="HOTEL_SYSTEM", - res_id_source_context="BOOKING_ENGINE" + res_id_source_context="BOOKING_ENGINE", ) @pytest.fixture def minimal_hotel_reservation_id_data(): """Fixture providing minimal hotel reservation ID data (only required fields).""" - return HotelReservationIdData( - res_id_type="999" - ) + return HotelReservationIdData(res_id_type="999") class TestCustomerData: """Test the CustomerData dataclass.""" - + def test_customer_data_creation_full(self, sample_customer_data): """Test creating CustomerData with all fields.""" assert sample_customer_data.given_name == "John" @@ -89,7 +84,7 @@ class TestCustomerData: assert sample_customer_data.email_address == "john.doe@example.com" assert sample_customer_data.email_newsletter is True assert len(sample_customer_data.phone_numbers) == 3 - + def test_customer_data_creation_minimal(self, minimal_customer_data): """Test creating CustomerData with only required fields.""" assert minimal_customer_data.given_name == "Jane" @@ -97,7 +92,7 @@ class TestCustomerData: assert minimal_customer_data.phone_numbers == [] assert minimal_customer_data.email_address is None assert minimal_customer_data.address_line is None - + def test_phone_numbers_default_initialization(self): """Test that phone_numbers gets initialized to empty list.""" customer_data = CustomerData(given_name="Test", surname="User") @@ -106,54 +101,56 @@ class TestCustomerData: class TestCustomerFactory: """Test the CustomerFactory class.""" - + def test_create_notif_customer_full(self, sample_customer_data): """Test creating a NotifCustomer with full data.""" customer = CustomerFactory.create_notif_customer(sample_customer_data) - + assert isinstance(customer, NotifCustomer) assert customer.person_name.given_name == "John" assert customer.person_name.surname == "Doe" assert customer.person_name.name_prefix == "Mr." assert customer.person_name.name_title == "Jr." - + # Check telephone assert len(customer.telephone) == 3 assert customer.telephone[0].phone_number == "+1234567890" assert customer.telephone[0].phone_tech_type == "5" # MOBILE assert customer.telephone[1].phone_tech_type == "1" # VOICE assert customer.telephone[2].phone_tech_type is None - + # Check email assert customer.email.value == "john.doe@example.com" assert customer.email.remark == "newsletter:yes" - + # Check address assert customer.address.address_line == "123 Main Street" assert customer.address.city_name == "Anytown" assert customer.address.postal_code == "12345" assert customer.address.country_name.code == "US" assert customer.address.remark == "catalog:no" - + # Check other attributes assert customer.gender == "Male" assert customer.birth_date == "1980-01-01" assert customer.language == "en" - + def test_create_retrieve_customer_full(self, sample_customer_data): """Test creating a RetrieveCustomer with full data.""" customer = CustomerFactory.create_retrieve_customer(sample_customer_data) - + assert isinstance(customer, RetrieveCustomer) assert customer.person_name.given_name == "John" assert customer.person_name.surname == "Doe" # Same structure as NotifCustomer, so we don't need to test all fields again - + def test_create_customer_minimal(self, minimal_customer_data): """Test creating customers with minimal data.""" notif_customer = CustomerFactory.create_notif_customer(minimal_customer_data) - retrieve_customer = CustomerFactory.create_retrieve_customer(minimal_customer_data) - + retrieve_customer = CustomerFactory.create_retrieve_customer( + minimal_customer_data + ) + for customer in [notif_customer, retrieve_customer]: assert customer.person_name.given_name == "Jane" assert customer.person_name.surname == "Smith" @@ -165,73 +162,97 @@ class TestCustomerFactory: assert customer.gender is None assert customer.birth_date is None assert customer.language is None - + def test_email_newsletter_options(self): """Test different email newsletter options.""" # Newsletter yes - data_yes = CustomerData(given_name="Test", surname="User", - email_address="test@example.com", email_newsletter=True) + data_yes = CustomerData( + given_name="Test", + surname="User", + email_address="test@example.com", + email_newsletter=True, + ) customer = CustomerFactory.create_notif_customer(data_yes) assert customer.email.remark == "newsletter:yes" - + # Newsletter no - data_no = CustomerData(given_name="Test", surname="User", - email_address="test@example.com", email_newsletter=False) + data_no = CustomerData( + given_name="Test", + surname="User", + email_address="test@example.com", + email_newsletter=False, + ) customer = CustomerFactory.create_notif_customer(data_no) assert customer.email.remark == "newsletter:no" - + # Newsletter not specified - data_none = CustomerData(given_name="Test", surname="User", - email_address="test@example.com", email_newsletter=None) + data_none = CustomerData( + given_name="Test", + surname="User", + email_address="test@example.com", + email_newsletter=None, + ) customer = CustomerFactory.create_notif_customer(data_none) assert customer.email.remark is None - + def test_address_catalog_options(self): """Test different address catalog options.""" # Catalog no - data_no = CustomerData(given_name="Test", surname="User", - address_line="123 Street", address_catalog=False) + data_no = CustomerData( + given_name="Test", + surname="User", + address_line="123 Street", + address_catalog=False, + ) customer = CustomerFactory.create_notif_customer(data_no) assert customer.address.remark == "catalog:no" - + # Catalog yes - data_yes = CustomerData(given_name="Test", surname="User", - address_line="123 Street", address_catalog=True) + data_yes = CustomerData( + given_name="Test", + surname="User", + address_line="123 Street", + address_catalog=True, + ) customer = CustomerFactory.create_notif_customer(data_yes) assert customer.address.remark == "catalog:yes" - + # Catalog not specified - data_none = CustomerData(given_name="Test", surname="User", - address_line="123 Street", address_catalog=None) + data_none = CustomerData( + given_name="Test", + surname="User", + address_line="123 Street", + address_catalog=None, + ) customer = CustomerFactory.create_notif_customer(data_none) assert customer.address.remark is None - + def test_from_notif_customer_roundtrip(self, sample_customer_data): """Test converting NotifCustomer back to CustomerData.""" customer = CustomerFactory.create_notif_customer(sample_customer_data) converted_data = CustomerFactory.from_notif_customer(customer) - + assert converted_data == sample_customer_data - + def test_from_retrieve_customer_roundtrip(self, sample_customer_data): """Test converting RetrieveCustomer back to CustomerData.""" customer = CustomerFactory.create_retrieve_customer(sample_customer_data) converted_data = CustomerFactory.from_retrieve_customer(customer) - + assert converted_data == sample_customer_data - + def test_phone_tech_type_conversion(self): """Test that PhoneTechType enum values are properly converted.""" data = CustomerData( - given_name="Test", + given_name="Test", surname="User", phone_numbers=[ ("+1111111111", PhoneTechType.VOICE), ("+2222222222", PhoneTechType.FAX), - ("+3333333333", PhoneTechType.MOBILE) - ] + ("+3333333333", PhoneTechType.MOBILE), + ], ) - + customer = CustomerFactory.create_notif_customer(data) assert customer.telephone[0].phone_tech_type == "1" # VOICE assert customer.telephone[1].phone_tech_type == "3" # FAX @@ -240,15 +261,21 @@ class TestCustomerFactory: class TestHotelReservationIdData: """Test the HotelReservationIdData dataclass.""" - - def test_hotel_reservation_id_data_creation_full(self, sample_hotel_reservation_id_data): + + def test_hotel_reservation_id_data_creation_full( + self, sample_hotel_reservation_id_data + ): """Test creating HotelReservationIdData with all fields.""" assert sample_hotel_reservation_id_data.res_id_type == "123" assert sample_hotel_reservation_id_data.res_id_value == "RESERVATION-456" assert sample_hotel_reservation_id_data.res_id_source == "HOTEL_SYSTEM" - assert sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE" - - def test_hotel_reservation_id_data_creation_minimal(self, minimal_hotel_reservation_id_data): + assert ( + sample_hotel_reservation_id_data.res_id_source_context == "BOOKING_ENGINE" + ) + + def test_hotel_reservation_id_data_creation_minimal( + self, minimal_hotel_reservation_id_data + ): """Test creating HotelReservationIdData with only required fields.""" assert minimal_hotel_reservation_id_data.res_id_type == "999" assert minimal_hotel_reservation_id_data.res_id_value is None @@ -258,124 +285,158 @@ class TestHotelReservationIdData: class TestHotelReservationIdFactory: """Test the HotelReservationIdFactory class.""" - - def test_create_notif_hotel_reservation_id_full(self, sample_hotel_reservation_id_data): + + def test_create_notif_hotel_reservation_id_full( + self, sample_hotel_reservation_id_data + ): """Test creating a NotifHotelReservationId with full data.""" - reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) - + reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id( + sample_hotel_reservation_id_data + ) + assert isinstance(reservation_id, NotifHotelReservationId) assert reservation_id.res_id_type == "123" assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_source == "HOTEL_SYSTEM" assert reservation_id.res_id_source_context == "BOOKING_ENGINE" - - def test_create_retrieve_hotel_reservation_id_full(self, sample_hotel_reservation_id_data): + + def test_create_retrieve_hotel_reservation_id_full( + self, sample_hotel_reservation_id_data + ): """Test creating a RetrieveHotelReservationId with full data.""" - reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) - + reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id( + sample_hotel_reservation_id_data + ) + assert isinstance(reservation_id, RetrieveHotelReservationId) assert reservation_id.res_id_type == "123" assert reservation_id.res_id_value == "RESERVATION-456" assert reservation_id.res_id_source == "HOTEL_SYSTEM" assert reservation_id.res_id_source_context == "BOOKING_ENGINE" - - def test_create_hotel_reservation_id_minimal(self, minimal_hotel_reservation_id_data): + + def test_create_hotel_reservation_id_minimal( + self, minimal_hotel_reservation_id_data + ): """Test creating hotel reservation IDs with minimal data.""" - notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(minimal_hotel_reservation_id_data) - retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(minimal_hotel_reservation_id_data) - + notif_reservation_id = ( + HotelReservationIdFactory.create_notif_hotel_reservation_id( + minimal_hotel_reservation_id_data + ) + ) + retrieve_reservation_id = ( + HotelReservationIdFactory.create_retrieve_hotel_reservation_id( + minimal_hotel_reservation_id_data + ) + ) + for reservation_id in [notif_reservation_id, retrieve_reservation_id]: assert reservation_id.res_id_type == "999" assert reservation_id.res_id_value is None assert reservation_id.res_id_source is None assert reservation_id.res_id_source_context is None - - def test_from_notif_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data): + + def test_from_notif_hotel_reservation_id_roundtrip( + self, sample_hotel_reservation_id_data + ): """Test converting NotifHotelReservationId back to HotelReservationIdData.""" - reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) - converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id(reservation_id) - + reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id( + sample_hotel_reservation_id_data + ) + converted_data = HotelReservationIdFactory.from_notif_hotel_reservation_id( + reservation_id + ) + assert converted_data == sample_hotel_reservation_id_data - - def test_from_retrieve_hotel_reservation_id_roundtrip(self, sample_hotel_reservation_id_data): + + def test_from_retrieve_hotel_reservation_id_roundtrip( + self, sample_hotel_reservation_id_data + ): """Test converting RetrieveHotelReservationId back to HotelReservationIdData.""" - reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) - converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(reservation_id) - + reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id( + sample_hotel_reservation_id_data + ) + converted_data = HotelReservationIdFactory.from_retrieve_hotel_reservation_id( + reservation_id + ) + assert converted_data == sample_hotel_reservation_id_data class TestResGuestFactory: """Test the ResGuestFactory class.""" - + def test_create_notif_res_guests(self, sample_customer_data): """Test creating NotifResGuests structure.""" res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) - + assert isinstance(res_guests, NotifResGuests) - + # Navigate down the nested structure customer = res_guests.res_guest.profiles.profile_info.profile.customer assert customer.person_name.given_name == "John" assert customer.person_name.surname == "Doe" assert customer.email.value == "john.doe@example.com" - + def test_create_retrieve_res_guests(self, sample_customer_data): """Test creating RetrieveResGuests structure.""" res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) - + assert isinstance(res_guests, RetrieveResGuests) - + # Navigate down the nested structure customer = res_guests.res_guest.profiles.profile_info.profile.customer assert customer.person_name.given_name == "John" assert customer.person_name.surname == "Doe" assert customer.email.value == "john.doe@example.com" - + def test_create_res_guests_minimal(self, minimal_customer_data): """Test creating ResGuests with minimal customer data.""" - notif_res_guests = ResGuestFactory.create_notif_res_guests(minimal_customer_data) - retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(minimal_customer_data) - + notif_res_guests = ResGuestFactory.create_notif_res_guests( + minimal_customer_data + ) + retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests( + minimal_customer_data + ) + for res_guests in [notif_res_guests, retrieve_res_guests]: customer = res_guests.res_guest.profiles.profile_info.profile.customer assert customer.person_name.given_name == "Jane" assert customer.person_name.surname == "Smith" assert customer.email is None assert customer.address is None - + def test_extract_primary_customer_notif(self, sample_customer_data): """Test extracting primary customer from NotifResGuests.""" res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) extracted_data = ResGuestFactory.extract_primary_customer(res_guests) - + assert extracted_data == sample_customer_data - + def test_extract_primary_customer_retrieve(self, sample_customer_data): """Test extracting primary customer from RetrieveResGuests.""" res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) extracted_data = ResGuestFactory.extract_primary_customer(res_guests) - + assert extracted_data == sample_customer_data - + def test_roundtrip_conversion_notif(self, sample_customer_data): """Test complete roundtrip: CustomerData -> NotifResGuests -> CustomerData.""" res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) extracted_data = ResGuestFactory.extract_primary_customer(res_guests) - + assert extracted_data == sample_customer_data - + def test_roundtrip_conversion_retrieve(self, sample_customer_data): """Test complete roundtrip: CustomerData -> RetrieveResGuests -> CustomerData.""" res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) extracted_data = ResGuestFactory.extract_primary_customer(res_guests) - + assert extracted_data == sample_customer_data class TestPhoneTechType: """Test the PhoneTechType enum.""" - + def test_enum_values(self): """Test that enum values are correct.""" assert PhoneTechType.VOICE.value == "1" @@ -385,95 +446,121 @@ class TestPhoneTechType: class TestAlpineBitsFactory: """Test the unified AlpineBitsFactory class.""" - + def test_create_customer_notif(self, sample_customer_data): """Test creating customer using unified factory for NOTIF.""" customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF) assert isinstance(customer, NotifCustomer) assert customer.person_name.given_name == "John" assert customer.person_name.surname == "Doe" - + def test_create_customer_retrieve(self, sample_customer_data): """Test creating customer using unified factory for RETRIEVE.""" - customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE) + customer = AlpineBitsFactory.create( + sample_customer_data, OtaMessageType.RETRIEVE + ) assert isinstance(customer, RetrieveCustomer) assert customer.person_name.given_name == "John" assert customer.person_name.surname == "Doe" - + def test_create_hotel_reservation_id_notif(self, sample_hotel_reservation_id_data): """Test creating hotel reservation ID using unified factory for NOTIF.""" - reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF) + reservation_id = AlpineBitsFactory.create( + sample_hotel_reservation_id_data, OtaMessageType.NOTIF + ) assert isinstance(reservation_id, NotifHotelReservationId) assert reservation_id.res_id_type == "123" assert reservation_id.res_id_value == "RESERVATION-456" - - def test_create_hotel_reservation_id_retrieve(self, sample_hotel_reservation_id_data): + + def test_create_hotel_reservation_id_retrieve( + self, sample_hotel_reservation_id_data + ): """Test creating hotel reservation ID using unified factory for RETRIEVE.""" - reservation_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE) + reservation_id = AlpineBitsFactory.create( + sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE + ) assert isinstance(reservation_id, RetrieveHotelReservationId) assert reservation_id.res_id_type == "123" assert reservation_id.res_id_value == "RESERVATION-456" - + def test_create_res_guests_notif(self, sample_customer_data): """Test creating ResGuests using unified factory for NOTIF.""" - res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF) + res_guests = AlpineBitsFactory.create_res_guests( + sample_customer_data, OtaMessageType.NOTIF + ) assert isinstance(res_guests, NotifResGuests) customer = res_guests.res_guest.profiles.profile_info.profile.customer assert customer.person_name.given_name == "John" - + def test_create_res_guests_retrieve(self, sample_customer_data): """Test creating ResGuests using unified factory for RETRIEVE.""" - res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE) + res_guests = AlpineBitsFactory.create_res_guests( + sample_customer_data, OtaMessageType.RETRIEVE + ) assert isinstance(res_guests, RetrieveResGuests) customer = res_guests.res_guest.profiles.profile_info.profile.customer assert customer.person_name.given_name == "John" - + def test_extract_data_from_customer(self, sample_customer_data): """Test extracting data from customer objects.""" # Create both types and extract data back - notif_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.NOTIF) - retrieve_customer = AlpineBitsFactory.create(sample_customer_data, OtaMessageType.RETRIEVE) - + notif_customer = AlpineBitsFactory.create( + sample_customer_data, OtaMessageType.NOTIF + ) + retrieve_customer = AlpineBitsFactory.create( + sample_customer_data, OtaMessageType.RETRIEVE + ) + notif_extracted = AlpineBitsFactory.extract_data(notif_customer) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_customer) - + assert notif_extracted == sample_customer_data assert retrieve_extracted == sample_customer_data - - def test_extract_data_from_hotel_reservation_id(self, sample_hotel_reservation_id_data): + + def test_extract_data_from_hotel_reservation_id( + self, sample_hotel_reservation_id_data + ): """Test extracting data from hotel reservation ID objects.""" # Create both types and extract data back - notif_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.NOTIF) - retrieve_res_id = AlpineBitsFactory.create(sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE) - + notif_res_id = AlpineBitsFactory.create( + sample_hotel_reservation_id_data, OtaMessageType.NOTIF + ) + retrieve_res_id = AlpineBitsFactory.create( + sample_hotel_reservation_id_data, OtaMessageType.RETRIEVE + ) + notif_extracted = AlpineBitsFactory.extract_data(notif_res_id) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_id) - + assert notif_extracted == sample_hotel_reservation_id_data assert retrieve_extracted == sample_hotel_reservation_id_data - + def test_extract_data_from_res_guests(self, sample_customer_data): """Test extracting data from ResGuests objects.""" # Create both types and extract data back - notif_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.NOTIF) - retrieve_res_guests = AlpineBitsFactory.create_res_guests(sample_customer_data, OtaMessageType.RETRIEVE) - + notif_res_guests = AlpineBitsFactory.create_res_guests( + sample_customer_data, OtaMessageType.NOTIF + ) + retrieve_res_guests = AlpineBitsFactory.create_res_guests( + sample_customer_data, OtaMessageType.RETRIEVE + ) + notif_extracted = AlpineBitsFactory.extract_data(notif_res_guests) retrieve_extracted = AlpineBitsFactory.extract_data(retrieve_res_guests) - + assert notif_extracted == sample_customer_data assert retrieve_extracted == sample_customer_data - + def test_unsupported_data_type_error(self): """Test that unsupported data types raise ValueError.""" with pytest.raises(ValueError, match="Unsupported data type"): AlpineBitsFactory.create("invalid_data", OtaMessageType.NOTIF) - + def test_unsupported_object_type_error(self): """Test that unsupported object types raise ValueError in extract_data.""" with pytest.raises(ValueError, match="Unsupported object type"): AlpineBitsFactory.extract_data("invalid_object") - + def test_complete_workflow_with_unified_factory(self): """Test a complete workflow using only the unified factory.""" # Original data @@ -481,34 +568,47 @@ class TestAlpineBitsFactory: given_name="Unified", surname="Factory", email_address="unified@factory.com", - phone_numbers=[("+1234567890", PhoneTechType.MOBILE)] + phone_numbers=[("+1234567890", PhoneTechType.MOBILE)], ) - + reservation_data = HotelReservationIdData( - res_id_type="999", - res_id_value="UNIFIED-TEST" + res_id_type="999", res_id_value="UNIFIED-TEST" ) - + # Create using unified factory customer_notif = AlpineBitsFactory.create(customer_data, OtaMessageType.NOTIF) - customer_retrieve = AlpineBitsFactory.create(customer_data, OtaMessageType.RETRIEVE) - + customer_retrieve = AlpineBitsFactory.create( + customer_data, OtaMessageType.RETRIEVE + ) + res_id_notif = AlpineBitsFactory.create(reservation_data, OtaMessageType.NOTIF) - res_id_retrieve = AlpineBitsFactory.create(reservation_data, OtaMessageType.RETRIEVE) - - res_guests_notif = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.NOTIF) - res_guests_retrieve = AlpineBitsFactory.create_res_guests(customer_data, OtaMessageType.RETRIEVE) - + res_id_retrieve = AlpineBitsFactory.create( + reservation_data, OtaMessageType.RETRIEVE + ) + + res_guests_notif = AlpineBitsFactory.create_res_guests( + customer_data, OtaMessageType.NOTIF + ) + res_guests_retrieve = AlpineBitsFactory.create_res_guests( + customer_data, OtaMessageType.RETRIEVE + ) + # Extract everything back extracted_customer_from_notif = AlpineBitsFactory.extract_data(customer_notif) - extracted_customer_from_retrieve = AlpineBitsFactory.extract_data(customer_retrieve) - + extracted_customer_from_retrieve = AlpineBitsFactory.extract_data( + customer_retrieve + ) + extracted_res_id_from_notif = AlpineBitsFactory.extract_data(res_id_notif) extracted_res_id_from_retrieve = AlpineBitsFactory.extract_data(res_id_retrieve) - - extracted_from_res_guests_notif = AlpineBitsFactory.extract_data(res_guests_notif) - extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data(res_guests_retrieve) - + + extracted_from_res_guests_notif = AlpineBitsFactory.extract_data( + res_guests_notif + ) + extracted_from_res_guests_retrieve = AlpineBitsFactory.extract_data( + res_guests_retrieve + ) + # Verify everything matches assert extracted_customer_from_notif == customer_data assert extracted_customer_from_retrieve == customer_data @@ -520,37 +620,72 @@ class TestAlpineBitsFactory: class TestIntegration: """Integration tests combining both factories.""" - + def test_both_factories_produce_same_customer_data(self, sample_customer_data): """Test that both factories can work with the same customer data.""" # Create using CustomerFactory notif_customer = CustomerFactory.create_notif_customer(sample_customer_data) - retrieve_customer = CustomerFactory.create_retrieve_customer(sample_customer_data) - + retrieve_customer = CustomerFactory.create_retrieve_customer( + sample_customer_data + ) + # Create using ResGuestFactory and extract customers notif_res_guests = ResGuestFactory.create_notif_res_guests(sample_customer_data) - retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(sample_customer_data) - - notif_from_res_guests = notif_res_guests.res_guest.profiles.profile_info.profile.customer - retrieve_from_res_guests = retrieve_res_guests.res_guest.profiles.profile_info.profile.customer - + retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests( + sample_customer_data + ) + + notif_from_res_guests = ( + notif_res_guests.res_guest.profiles.profile_info.profile.customer + ) + retrieve_from_res_guests = ( + retrieve_res_guests.res_guest.profiles.profile_info.profile.customer + ) + # Compare customer names (structure should be identical) - assert notif_customer.person_name.given_name == notif_from_res_guests.person_name.given_name - assert notif_customer.person_name.surname == notif_from_res_guests.person_name.surname - assert retrieve_customer.person_name.given_name == retrieve_from_res_guests.person_name.given_name - assert retrieve_customer.person_name.surname == retrieve_from_res_guests.person_name.surname - - def test_hotel_reservation_id_factories_produce_same_data(self, sample_hotel_reservation_id_data): + assert ( + notif_customer.person_name.given_name + == notif_from_res_guests.person_name.given_name + ) + assert ( + notif_customer.person_name.surname + == notif_from_res_guests.person_name.surname + ) + assert ( + retrieve_customer.person_name.given_name + == retrieve_from_res_guests.person_name.given_name + ) + assert ( + retrieve_customer.person_name.surname + == retrieve_from_res_guests.person_name.surname + ) + + def test_hotel_reservation_id_factories_produce_same_data( + self, sample_hotel_reservation_id_data + ): """Test that both HotelReservationId factories produce equivalent results.""" - notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(sample_hotel_reservation_id_data) - retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(sample_hotel_reservation_id_data) - + notif_reservation_id = ( + HotelReservationIdFactory.create_notif_hotel_reservation_id( + sample_hotel_reservation_id_data + ) + ) + retrieve_reservation_id = ( + HotelReservationIdFactory.create_retrieve_hotel_reservation_id( + sample_hotel_reservation_id_data + ) + ) + # Both should have the same field values assert notif_reservation_id.res_id_type == retrieve_reservation_id.res_id_type assert notif_reservation_id.res_id_value == retrieve_reservation_id.res_id_value - assert notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source - assert notif_reservation_id.res_id_source_context == retrieve_reservation_id.res_id_source_context - + assert ( + notif_reservation_id.res_id_source == retrieve_reservation_id.res_id_source + ) + assert ( + notif_reservation_id.res_id_source_context + == retrieve_reservation_id.res_id_source_context + ) + def test_complex_customer_workflow(self): """Test a complex workflow with multiple operations.""" # Create original data @@ -559,7 +694,7 @@ class TestIntegration: surname="Johnson", phone_numbers=[ ("+1555123456", PhoneTechType.MOBILE), - ("+1555654321", PhoneTechType.VOICE) + ("+1555654321", PhoneTechType.VOICE), ], email_address="alice.johnson@company.com", email_newsletter=False, @@ -569,22 +704,24 @@ class TestIntegration: country_code="CA", address_catalog=True, gender="Female", - language="fr" + language="fr", ) - + # Create ResGuests for both types notif_res_guests = ResGuestFactory.create_notif_res_guests(original_data) retrieve_res_guests = ResGuestFactory.create_retrieve_res_guests(original_data) - + # Extract data back from both notif_extracted = ResGuestFactory.extract_primary_customer(notif_res_guests) - retrieve_extracted = ResGuestFactory.extract_primary_customer(retrieve_res_guests) - + retrieve_extracted = ResGuestFactory.extract_primary_customer( + retrieve_res_guests + ) + # All should be equal assert original_data == notif_extracted assert original_data == retrieve_extracted assert notif_extracted == retrieve_extracted - + def test_complex_hotel_reservation_id_workflow(self): """Test a complex workflow with HotelReservationId operations.""" # Create original reservation ID data @@ -592,18 +729,30 @@ class TestIntegration: res_id_type="456", res_id_value="COMPLEX-RESERVATION-789", res_id_source="INTEGRATION_SYSTEM", - res_id_source_context="API_CALL" + res_id_source_context="API_CALL", ) - + # Create HotelReservationId for both types - notif_reservation_id = HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data) - retrieve_reservation_id = HotelReservationIdFactory.create_retrieve_hotel_reservation_id(original_data) - + notif_reservation_id = ( + HotelReservationIdFactory.create_notif_hotel_reservation_id(original_data) + ) + retrieve_reservation_id = ( + HotelReservationIdFactory.create_retrieve_hotel_reservation_id( + original_data + ) + ) + # Extract data back from both - notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id(notif_reservation_id) - retrieve_extracted = HotelReservationIdFactory.from_retrieve_hotel_reservation_id(retrieve_reservation_id) - + notif_extracted = HotelReservationIdFactory.from_notif_hotel_reservation_id( + notif_reservation_id + ) + retrieve_extracted = ( + HotelReservationIdFactory.from_retrieve_hotel_reservation_id( + retrieve_reservation_id + ) + ) + # All should be equal assert original_data == notif_extracted assert original_data == retrieve_extracted - assert notif_extracted == retrieve_extracted \ No newline at end of file + assert notif_extracted == retrieve_extracted diff --git a/test_handshake.py b/test_handshake.py index 47ff199..00c87b0 100644 --- a/test_handshake.py +++ b/test_handshake.py @@ -6,24 +6,31 @@ Test the handshake functionality with the real AlpineBits sample file. import asyncio from alpine_bits_python.alpinebits_server import AlpineBitsServer + async def main(): print("๐Ÿ”„ Testing AlpineBits Handshake with Sample File") print("=" * 60) - + # Create server instance server = AlpineBitsServer() - + # Read the sample handshake request - with open("AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml", "r") as f: + with open( + "AlpineBits-HotelData-2024-10/files/samples/Handshake/Handshake-OTA_PingRQ.xml", + "r", + ) as f: ping_request_xml = f.read() - + print("๐Ÿ“ค Sending handshake request...") - + # Handle the ping request - response = await server.handle_request("OTA_Ping:Handshaking", ping_request_xml, "2024-10") - + response = await server.handle_request( + "OTA_Ping:Handshaking", ping_request_xml, "2024-10" + ) + print(f"\n๐Ÿ“ฅ Response Status: {response.status_code}") print(f"๐Ÿ“„ Response XML:\n{response.xml_content}") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main())