Formatting

This commit is contained in:
Jonas Linter
2025-10-07 09:40:23 +02:00
parent 2d9e90c9a4
commit b4b7a537e1
8 changed files with 213 additions and 198 deletions

View File

@@ -53,17 +53,21 @@ RetrieveGuestCounts = (
OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.GuestCounts
) )
NotifUniqueId = (OtaHotelResNotifRq.HotelReservations.HotelReservation.UniqueId) NotifUniqueId = OtaHotelResNotifRq.HotelReservations.HotelReservation.UniqueId
RetrieveUniqueId = (OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId) RetrieveUniqueId = OtaResRetrieveRs.ReservationsList.HotelReservation.UniqueId
NotifTimeSpan = (OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.TimeSpan) NotifTimeSpan = (
RetrieveTimeSpan = (OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan) OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.TimeSpan
)
RetrieveTimeSpan = (
OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan
)
NotifRoomStays = (OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays) NotifRoomStays = OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays
RetrieveRoomStays = (OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays) RetrieveRoomStays = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays
NotifHotelReservation = (OtaHotelResNotifRq.HotelReservations.HotelReservation) NotifHotelReservation = OtaHotelResNotifRq.HotelReservations.HotelReservation
RetrieveHotelReservation = (OtaResRetrieveRs.ReservationsList.HotelReservation) RetrieveHotelReservation = OtaResRetrieveRs.ReservationsList.HotelReservation
# phonetechtype enum 1,3,5 voice, fax, mobile # phonetechtype enum 1,3,5 voice, fax, mobile
@@ -119,10 +123,13 @@ class CustomerData:
class GuestCountsFactory: class GuestCountsFactory:
"""Factory class to create GuestCounts instances for both OtaHotelResNotifRq and OtaResRetrieveRs.""" """Factory class to create GuestCounts instances for both OtaHotelResNotifRq and OtaResRetrieveRs."""
@staticmethod @staticmethod
def create_guest_counts( def create_guest_counts(
adults: int, kids: Optional[list[int]] = None adults: int,
, message_type: OtaMessageType = OtaMessageType.RETRIEVE) -> NotifGuestCounts: kids: Optional[list[int]] = None,
message_type: OtaMessageType = OtaMessageType.RETRIEVE,
) -> NotifGuestCounts:
""" """
Create a GuestCounts object for OtaHotelResNotifRq or OtaResRetrieveRs. Create a GuestCounts object for OtaHotelResNotifRq or OtaResRetrieveRs.
:param adults: Number of adults :param adults: Number of adults
@@ -130,14 +137,16 @@ class GuestCountsFactory:
:return: GuestCounts instance :return: GuestCounts instance
""" """
if message_type == OtaMessageType.RETRIEVE: if message_type == OtaMessageType.RETRIEVE:
return GuestCountsFactory._create_guest_counts(adults, kids, RetrieveGuestCounts) return GuestCountsFactory._create_guest_counts(
adults, kids, RetrieveGuestCounts
)
elif message_type == OtaMessageType.NOTIF: elif message_type == OtaMessageType.NOTIF:
return GuestCountsFactory._create_guest_counts(adults, kids, NotifGuestCounts) return GuestCountsFactory._create_guest_counts(
adults, kids, NotifGuestCounts
)
else: else:
raise ValueError(f"Unsupported message type: {message_type}") raise ValueError(f"Unsupported message type: {message_type}")
@staticmethod @staticmethod
def _create_guest_counts( def _create_guest_counts(
adults: int, kids: Optional[list[int]], guest_counts_class: type adults: int, kids: Optional[list[int]], guest_counts_class: type
@@ -577,9 +586,6 @@ class ResGuestFactory:
return CustomerFactory.from_retrieve_customer(customer) return CustomerFactory.from_retrieve_customer(customer)
class AlpineBitsFactory: class AlpineBitsFactory:
"""Unified factory class for creating AlpineBits objects with a simple interface.""" """Unified factory class for creating AlpineBits objects with a simple interface."""
@@ -681,24 +687,24 @@ class AlpineBitsFactory:
else: else:
raise ValueError(f"Unsupported object type: {type(obj)}") raise ValueError(f"Unsupported object type: {type(obj)}")
def create_res_retrieve_response(list: list[Tuple[Reservation, Customer]]):
def create_res_retrieve_response(list: list[Tuple[Reservation, Customer]]):
"""Create RetrievedReservation XML from database entries.""" """Create RetrievedReservation XML from database entries."""
return _create_xml_from_db(list, OtaMessageType.RETRIEVE) return _create_xml_from_db(list, OtaMessageType.RETRIEVE)
def create_res_notif_push_message(list: Tuple[Reservation, Customer]): def create_res_notif_push_message(list: Tuple[Reservation, Customer]):
"""Create Reservation Notification XML from database entries.""" """Create Reservation Notification XML from database entries."""
return _create_xml_from_db(list, OtaMessageType.NOTIF) return _create_xml_from_db(list, OtaMessageType.NOTIF)
def _process_single_reservation(reservation: Reservation, customer: Customer, message_type: OtaMessageType): def _process_single_reservation(
reservation: Reservation, customer: Customer, message_type: OtaMessageType
):
phone_numbers = ( phone_numbers = (
[(customer.phone, PhoneTechType.MOBILE)] [(customer.phone, PhoneTechType.MOBILE)] if customer.phone is not None else []
if customer.phone is not None
else []
) )
customer_data = CustomerData( customer_data = CustomerData(
@@ -719,9 +725,7 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
language=customer.language, language=customer.language,
) )
alpine_bits_factory = AlpineBitsFactory() alpine_bits_factory = AlpineBitsFactory()
res_guests = alpine_bits_factory.create_res_guests( res_guests = alpine_bits_factory.create_res_guests(customer_data, message_type)
customer_data, message_type
)
# Guest counts # Guest counts
children_ages = [int(a) for a in reservation.children_ages.split(",") if a] children_ages = [int(a) for a in reservation.children_ages.split(",") if a]
@@ -731,8 +735,6 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
unique_id_string = reservation.unique_id unique_id_string = reservation.unique_id
if message_type == OtaMessageType.NOTIF: if message_type == OtaMessageType.NOTIF:
UniqueId = NotifUniqueId UniqueId = NotifUniqueId
RoomStays = NotifRoomStays RoomStays = NotifRoomStays
@@ -747,25 +749,17 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
raise ValueError(f"Unsupported message type: {message_type}") raise ValueError(f"Unsupported message type: {message_type}")
# UniqueID # UniqueID
unique_id = UniqueId( unique_id = UniqueId(type_value=UniqueIdType2.VALUE_14, id=unique_id_string)
type_value=UniqueIdType2.VALUE_14, id=unique_id_string
)
# TimeSpan # TimeSpan
time_span = RoomStays.RoomStay.TimeSpan( time_span = RoomStays.RoomStay.TimeSpan(
start=reservation.start_date.isoformat() start=reservation.start_date.isoformat() if reservation.start_date else None,
if reservation.start_date
else None,
end=reservation.end_date.isoformat() if reservation.end_date else None, end=reservation.end_date.isoformat() if reservation.end_date else None,
) )
room_stay = ( room_stay = RoomStays.RoomStay(
RoomStays.RoomStay(
time_span=time_span, time_span=time_span,
guest_counts=guest_counts, guest_counts=guest_counts,
) )
)
room_stays = RoomStays( room_stays = RoomStays(
room_stay=[room_stay], room_stay=[room_stay],
) )
@@ -779,7 +773,6 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
klick_id = reservation.gclid klick_id = reservation.gclid
res_id_source = "google" res_id_source = "google"
# explicitly set klick_id to None otherwise an empty string will be sent # explicitly set klick_id to None otherwise an empty string will be sent
if klick_id in (None, "", "None"): if klick_id in (None, "", "None"):
klick_id = None klick_id = None
@@ -799,7 +792,6 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
else: # extract string from Column object else: # extract string from Column object
klick_id = str(klick_id) klick_id = str(klick_id)
utm_medium = ( utm_medium = (
str(reservation.utm_medium) str(reservation.utm_medium)
if reservation.utm_medium is not None and str(reservation.utm_medium) != "" if reservation.utm_medium is not None and str(reservation.utm_medium) != ""
@@ -820,9 +812,7 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
res_id_source_context="99tales", res_id_source_context="99tales",
) )
hotel_res_id = alpine_bits_factory.create( hotel_res_id = alpine_bits_factory.create(hotel_res_id_data, message_type)
hotel_res_id_data, message_type
)
hotel_res_ids = HotelReservation.ResGlobalInfo.HotelReservationIds( hotel_res_ids = HotelReservation.ResGlobalInfo.HotelReservationIds(
hotel_reservation_id=[hotel_res_id] hotel_reservation_id=[hotel_res_id]
) )
@@ -881,16 +871,17 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
) )
comments_data = CommentsData(comments=comments) comments_data = CommentsData(comments=comments)
comments_xml = alpine_bits_factory.create( comments_xml = alpine_bits_factory.create(comments_data, message_type)
comments_data, message_type
company_name = Profile.CompanyInfo.CompanyName(
value="99tales GmbH", code="who knows?", code_context="who knows?"
) )
company_name = Profile.CompanyInfo.CompanyName(value="99tales GmbH", code="who knows?", code_context="who knows?")
company_info = Profile.CompanyInfo(company_name=company_name) company_info = Profile.CompanyInfo(company_name=company_name)
profile = Profile(company_info=company_info, profile_type=ProfileProfileType.VALUE_4) profile = Profile(
company_info=company_info, profile_type=ProfileProfileType.VALUE_4
)
profile_info = HotelReservation.ResGlobalInfo.Profiles.ProfileInfo(profile=profile) profile_info = HotelReservation.ResGlobalInfo.Profiles.ProfileInfo(profile=profile)
@@ -898,14 +889,12 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
profiles = HotelReservation.ResGlobalInfo.Profiles(profile_info=profile_info) profiles = HotelReservation.ResGlobalInfo.Profiles(profile_info=profile_info)
res_global_info = ( res_global_info = HotelReservation.ResGlobalInfo(
HotelReservation.ResGlobalInfo(
hotel_reservation_ids=hotel_res_ids, hotel_reservation_ids=hotel_res_ids,
basic_property_info=basic_property_info, basic_property_info=basic_property_info,
comments=comments_xml, comments=comments_xml,
profiles=profiles, profiles=profiles,
) )
)
hotel_reservation = HotelReservation( hotel_reservation = HotelReservation(
create_date_time=datetime.now(timezone.utc).isoformat(), create_date_time=datetime.now(timezone.utc).isoformat(),
@@ -920,7 +909,10 @@ def _process_single_reservation(reservation: Reservation, customer: Customer, me
return hotel_reservation return hotel_reservation
def _create_xml_from_db(entries: list[Tuple[Reservation, Customer]] | Tuple[Reservation, Customer], type: OtaMessageType): def _create_xml_from_db(
entries: list[Tuple[Reservation, Customer]] | Tuple[Reservation, Customer],
type: OtaMessageType,
):
"""Create RetrievedReservation XML from database entries. """Create RetrievedReservation XML from database entries.
list of pairs (Reservation, Customer) list of pairs (Reservation, Customer)
@@ -933,14 +925,12 @@ def _create_xml_from_db(entries: list[Tuple[Reservation, Customer]] | Tuple[Rese
if not isinstance(entries, list): if not isinstance(entries, list):
entries = [entries] entries = [entries]
for reservation, customer in entries: for reservation, customer in entries:
_LOGGER.info( _LOGGER.info(
f"Creating XML for reservation {reservation.unique_id} and customer {customer.given_name}" f"Creating XML for reservation {reservation.unique_id} and customer {customer.given_name}"
) )
try: try:
hotel_reservation = _process_single_reservation(reservation, customer, type) hotel_reservation = _process_single_reservation(reservation, customer, type)
reservations_list.append(hotel_reservation) reservations_list.append(hotel_reservation)
@@ -968,7 +958,6 @@ def _create_xml_from_db(entries: list[Tuple[Reservation, Customer]] | Tuple[Rese
return ota_hotel_res_notif_rq return ota_hotel_res_notif_rq
elif type == OtaMessageType.RETRIEVE: elif type == OtaMessageType.RETRIEVE:
retrieved_reservations = OtaResRetrieveRs.ReservationsList( retrieved_reservations = OtaResRetrieveRs.ReservationsList(
hotel_reservation=reservations_list hotel_reservation=reservations_list
) )

View File

@@ -18,10 +18,21 @@ from xml.etree import ElementTree as ET
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, IntEnum from enum import Enum, IntEnum
from alpine_bits_python.alpine_bits_helpers import PhoneTechType, create_res_notif_push_message, create_res_retrieve_response from alpine_bits_python.alpine_bits_helpers import (
PhoneTechType,
create_res_notif_push_message,
create_res_retrieve_response,
)
from .generated.alpinebits import OtaNotifReportRq, OtaNotifReportRs, OtaPingRq, OtaPingRs, WarningStatus, OtaReadRq from .generated.alpinebits import (
OtaNotifReportRq,
OtaNotifReportRs,
OtaPingRq,
OtaPingRs,
WarningStatus,
OtaReadRq,
)
from xsdata_pydantic.bindings import XmlSerializer from xsdata_pydantic.bindings import XmlSerializer
from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.formats.dataclass.serializers.config import SerializerConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -268,7 +279,6 @@ class ServerCapabilities:
self.capability_dict = {"versions": list(versions_dict.values())} self.capability_dict = {"versions": list(versions_dict.values())}
# filter duplicates in actions for each version # filter duplicates in actions for each version
for version in self.capability_dict["versions"]: for version in self.capability_dict["versions"]:
seen_actions = set() seen_actions = set()
@@ -283,7 +293,8 @@ class ServerCapabilities:
for version in self.capability_dict["versions"]: for version in self.capability_dict["versions"]:
if version["version"] == "2024-10": if version["version"] == "2024-10":
version["actions"] = [ version["actions"] = [
action for action in version["actions"] action
for action in version["actions"]
if action.get("action") != "action_OTA_Ping" if action.get("action") != "action_OTA_Ping"
] ]
@@ -298,7 +309,6 @@ class ServerCapabilities:
self.create_capabilities_dict() self.create_capabilities_dict()
return self.capability_dict return self.capability_dict
def get_supported_actions(self) -> List[str]: def get_supported_actions(self) -> List[str]:
"""Get list of all supported action names.""" """Get list of all supported action names."""
return list(self.action_registry.keys()) return list(self.action_registry.keys())
@@ -395,7 +405,6 @@ class PingAction(AlpineBitsAction):
# Create successful ping response with matched capabilities # Create successful ping response with matched capabilities
capabilities_json_str = dump_json_for_xml(matching_capabilities) capabilities_json_str = dump_json_for_xml(matching_capabilities)
warning = OtaPingRs.Warnings.Warning( warning = OtaPingRs.Warnings.Warning(
status=WarningStatus.ALPINEBITS_HANDSHAKE, status=WarningStatus.ALPINEBITS_HANDSHAKE,
type_value="11", type_value="11",
@@ -404,8 +413,6 @@ class PingAction(AlpineBitsAction):
warning_response = OtaPingRs.Warnings(warning=[warning]) warning_response = OtaPingRs.Warnings(warning=[warning])
client_response_echo_data = dump_json_for_xml(echo_data_client) client_response_echo_data = dump_json_for_xml(echo_data_client)
response_ota_ping = OtaPingRs( response_ota_ping = OtaPingRs(
@@ -510,7 +517,9 @@ class ReadAction(AlpineBitsAction):
HttpStatusCode.UNAUTHORIZED, HttpStatusCode.UNAUTHORIZED,
) )
if not validate_hotel_authentication(client_info.username, client_info.password, hotelid, self.config): if not validate_hotel_authentication(
client_info.username, client_info.password, hotelid, self.config
):
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unauthorized Read Request for this specific hotel {hotelname}. Check credentials", f"Error: Unauthorized Read Request for this specific hotel {hotelname}. Check credentials",
HttpStatusCode.UNAUTHORIZED, HttpStatusCode.UNAUTHORIZED,
@@ -525,8 +534,6 @@ class ReadAction(AlpineBitsAction):
# query all reservations for this hotel from the database, where start_date is greater than or equal to the given start_date # query all reservations for this hotel from the database, where start_date is greater than or equal to the given start_date
stmt = ( stmt = (
select(Reservation, Customer) select(Reservation, Customer)
.join(Customer, Reservation.customer_id == Customer.id) .join(Customer, Reservation.customer_id == Customer.id)
@@ -547,8 +554,6 @@ class ReadAction(AlpineBitsAction):
) )
stmt = stmt.filter(~Reservation.id.in_(subquery)) stmt = stmt.filter(~Reservation.id.in_(subquery))
result = await dbsession.execute(stmt) result = await dbsession.execute(stmt)
reservation_customer_pairs: list[tuple[Reservation, Customer]] = ( reservation_customer_pairs: list[tuple[Reservation, Customer]] = (
result.all() result.all()
@@ -601,9 +606,7 @@ class NotifReportReadAction(AlpineBitsAction):
warnings = notif_report.warnings warnings = notif_report.warnings
notif_report_details = notif_report.notif_details notif_report_details = notif_report.notif_details
success_message = OtaNotifReportRs( success_message = OtaNotifReportRs(version="7.000", success="")
version="7.000", success=""
)
if client_info.client_id is None: if client_info.client_id is None:
return AlpineBitsResponse( return AlpineBitsResponse(
@@ -622,12 +625,14 @@ class NotifReportReadAction(AlpineBitsAction):
return AlpineBitsResponse( return AlpineBitsResponse(
response_xml, HttpStatusCode.OK response_xml, HttpStatusCode.OK
) # Nothing to process ) # Nothing to process
elif notif_report_details is not None and notif_report_details.hotel_notif_report is None: elif (
notif_report_details is not None
and notif_report_details.hotel_notif_report is None
):
return AlpineBitsResponse( return AlpineBitsResponse(
response_xml, HttpStatusCode.OK response_xml, HttpStatusCode.OK
) # Nothing to process ) # Nothing to process
else: else:
if dbsession is None: if dbsession is None:
return AlpineBitsResponse( return AlpineBitsResponse(
"Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR "Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR
@@ -635,19 +640,17 @@ class NotifReportReadAction(AlpineBitsAction):
timestamp = datetime.now(ZoneInfo("UTC")) timestamp = datetime.now(ZoneInfo("UTC"))
for entry in notif_report_details.hotel_notif_report.hotel_reservations.hotel_reservation: # type: ignore for entry in notif_report_details.hotel_notif_report.hotel_reservations.hotel_reservation: # type: ignore
unique_id = entry.unique_id.id unique_id = entry.unique_id.id
acked_request = AckedRequest( acked_request = AckedRequest(
unique_id=unique_id, client_id=client_info.client_id, timestamp=timestamp unique_id=unique_id,
client_id=client_info.client_id,
timestamp=timestamp,
) )
dbsession.add(acked_request) dbsession.add(acked_request)
await dbsession.commit() await dbsession.commit()
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
return AlpineBitsResponse(
response_xml, HttpStatusCode.OK
)
class PushAction(AlpineBitsAction): class PushAction(AlpineBitsAction):
@@ -671,7 +674,6 @@ class PushAction(AlpineBitsAction):
xml_push_request = create_res_notif_push_message(request_xml) xml_push_request = create_res_notif_push_message(request_xml)
config = SerializerConfig( config = SerializerConfig(
pretty_print=True, xml_declaration=True, encoding="UTF-8" pretty_print=True, xml_declaration=True, encoding="UTF-8"
) )
@@ -683,8 +685,6 @@ class PushAction(AlpineBitsAction):
return AlpineBitsResponse(xml_push_request, HttpStatusCode.OK) return AlpineBitsResponse(xml_push_request, HttpStatusCode.OK)
class AlpineBitsServer: class AlpineBitsServer:
""" """
Asynchronous AlpineBits server for handling hotel data exchange requests. Asynchronous AlpineBits server for handling hotel data exchange requests.
@@ -740,7 +740,9 @@ class AlpineBitsServer:
# Find the action by request name # Find the action by request name
action_enum = AlpineBitsActionName.get_by_request_name(request_action_name) action_enum = AlpineBitsActionName.get_by_request_name(request_action_name)
_LOGGER.info(f"Handling request for action: {request_action_name} with action enum: {action_enum}") _LOGGER.info(
f"Handling request for action: {request_action_name} with action enum: {action_enum}"
)
if not action_enum: if not action_enum:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unknown action {request_action_name}", f"Error: Unknown action {request_action_name}",
@@ -769,7 +771,6 @@ class AlpineBitsServer:
# Special case for ping action - pass server capabilities # Special case for ping action - pass server capabilities
if action_enum == AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS: if action_enum == AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS:
action_instance: PushAction action_instance: PushAction
if request_xml is None or not isinstance(request_xml, tuple): if request_xml is None or not isinstance(request_xml, tuple):
return AlpineBitsResponse( return AlpineBitsResponse(
@@ -777,16 +778,21 @@ class AlpineBitsServer:
HttpStatusCode.BAD_REQUEST, HttpStatusCode.BAD_REQUEST,
) )
return await action_instance.handle( return await action_instance.handle(
action=request_action_name, request_xml=request_xml, version=version_enum, client_info=client_info action=request_action_name,
request_xml=request_xml,
version=version_enum,
client_info=client_info,
) )
if action_enum == AlpineBitsActionName.OTA_PING: if action_enum == AlpineBitsActionName.OTA_PING:
return await action_instance.handle( return await action_instance.handle(
action=request_action_name, request_xml=request_xml, version=version_enum, server_capabilities=self.capabilities, client_info=client_info action=request_action_name,
request_xml=request_xml,
version=version_enum,
server_capabilities=self.capabilities,
client_info=client_info,
) )
else: else:
return await action_instance.handle( return await action_instance.handle(
action=request_action_name, action=request_action_name,
request_xml=request_xml, request_xml=request_xml,
@@ -848,5 +854,3 @@ class AlpineBitsServer:
return False return False
return True return True

View File

@@ -16,7 +16,12 @@ from .config_loader import load_config
from fastapi.responses import HTMLResponse, PlainTextResponse, Response from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from .models import WixFormSubmission from .models import WixFormSubmission
from datetime import datetime, date, timezone from datetime import datetime, date, timezone
from .auth import generate_unique_id, validate_api_key, validate_wix_signature, generate_api_key from .auth import (
generate_unique_id,
validate_api_key,
validate_wix_signature,
generate_api_key,
)
from .rate_limit import ( from .rate_limit import (
limiter, limiter,
webhook_limiter, webhook_limiter,
@@ -34,7 +39,12 @@ import os
import asyncio import asyncio
import gzip import gzip
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from .alpinebits_server import AlpineBitsClientInfo, AlpineBitsServer, Version, AlpineBitsActionName from .alpinebits_server import (
AlpineBitsClientInfo,
AlpineBitsServer,
Version,
AlpineBitsActionName,
)
import urllib.parse import urllib.parse
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from functools import partial from functools import partial
@@ -57,6 +67,7 @@ security_basic = HTTPBasic()
from collections import defaultdict from collections import defaultdict
# --- Enhanced event dispatcher with hotel-specific routing --- # --- Enhanced event dispatcher with hotel-specific routing ---
class EventDispatcher: class EventDispatcher:
def __init__(self): def __init__(self):
@@ -80,6 +91,7 @@ class EventDispatcher:
for func in self.hotel_listeners[key]: for func in self.hotel_listeners[key]:
await func(*args, **kwargs) await func(*args, **kwargs)
event_dispatcher = EventDispatcher() event_dispatcher = EventDispatcher()
# Load config at startup # Load config at startup
@@ -92,30 +104,41 @@ async def push_listener(customer: DBCustomer, reservation: DBReservation, hotel)
""" """
push_endpoint = hotel.get("push_endpoint") push_endpoint = hotel.get("push_endpoint")
if not push_endpoint: if not push_endpoint:
_LOGGER.warning(f"No push endpoint configured for hotel {hotel.get('hotel_id')}") _LOGGER.warning(
f"No push endpoint configured for hotel {hotel.get('hotel_id')}"
)
return return
server: AlpineBitsServer = app.state.alpine_bits_server server: AlpineBitsServer = app.state.alpine_bits_server
hotel_id = hotel['hotel_id'] hotel_id = hotel["hotel_id"]
reservation_hotel_id = reservation.hotel_code reservation_hotel_id = reservation.hotel_code
# Double-check hotel matching (should be guaranteed by dispatcher) # Double-check hotel matching (should be guaranteed by dispatcher)
if hotel_id != reservation_hotel_id: if hotel_id != reservation_hotel_id:
_LOGGER.warning(f"Hotel ID mismatch: listener for {hotel_id}, reservation for {reservation_hotel_id}") _LOGGER.warning(
f"Hotel ID mismatch: listener for {hotel_id}, reservation for {reservation_hotel_id}"
)
return return
_LOGGER.info(f"Processing push notification for hotel {hotel_id}, reservation {reservation.unique_id}") _LOGGER.info(
f"Processing push notification for hotel {hotel_id}, reservation {reservation.unique_id}"
)
# Prepare payload for push notification # Prepare payload for push notification
request = await server.handle_request(
request = await server.handle_request(request_action_name=AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS.request_name, request_xml=(reservation, customer), client_info=None, version=Version.V2024_10) request_action_name=AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS.request_name,
request_xml=(reservation, customer),
client_info=None,
version=Version.V2024_10,
)
if request.status_code != 200: if request.status_code != 200:
_LOGGER.error(f"Failed to generate push request for hotel {hotel_id}, reservation {reservation.unique_id}: {request.xml_content}") _LOGGER.error(
f"Failed to generate push request for hotel {hotel_id}, reservation {reservation.unique_id}: {request.xml_content}"
)
return return
# save push request to file # save push request to file
logs_dir = "logs/push_requests" logs_dir = "logs/push_requests"
@@ -126,28 +149,37 @@ async def push_listener(customer: DBCustomer, reservation: DBReservation, hotel)
f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}" f"Created directory owner: uid:{stat_info.st_uid}, gid:{stat_info.st_gid}"
) )
_LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}") _LOGGER.info(f"Directory mode: {oct(stat_info.st_mode)[-3:]}")
log_filename = ( log_filename = f"{logs_dir}/alpinebits_push_{hotel_id}_{reservation.unique_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xml"
f"{logs_dir}/alpinebits_push_{hotel_id}_{reservation.unique_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xml"
)
with open(log_filename, "w", encoding="utf-8") as f: with open(log_filename, "w", encoding="utf-8") as f:
f.write(request.xml_content) f.write(request.xml_content)
return return
headers = {"Authorization": f"Bearer {push_endpoint.get('token','')}"} if push_endpoint.get('token') else {} headers = (
{"Authorization": f"Bearer {push_endpoint.get('token', '')}"}
if push_endpoint.get("token")
else {}
)
"" ""
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
resp = await client.post(push_endpoint["url"], json=payload, headers=headers, timeout=10) resp = await client.post(
_LOGGER.info(f"Push event fired to {push_endpoint['url']} for hotel {hotel['hotel_id']}, status: {resp.status_code}") push_endpoint["url"], json=payload, headers=headers, timeout=10
)
_LOGGER.info(
f"Push event fired to {push_endpoint['url']} for hotel {hotel['hotel_id']}, status: {resp.status_code}"
)
if resp.status_code not in [200, 201, 202]: if resp.status_code not in [200, 201, 202]:
_LOGGER.warning(f"Push endpoint returned non-success status {resp.status_code}: {resp.text}") _LOGGER.warning(
f"Push endpoint returned non-success status {resp.status_code}: {resp.text}"
)
except Exception as e: except Exception as e:
_LOGGER.error(f"Push event failed for hotel {hotel['hotel_id']}: {e}") _LOGGER.error(f"Push event failed for hotel {hotel['hotel_id']}: {e}")
# Optionally implement retry logic here@asynccontextmanager # Optionally implement retry logic here@asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Setup DB # Setup DB
@@ -167,7 +199,6 @@ async def lifespan(app: FastAPI):
app.state.alpine_bits_server = AlpineBitsServer(config) app.state.alpine_bits_server = AlpineBitsServer(config)
app.state.event_dispatcher = event_dispatcher app.state.event_dispatcher = event_dispatcher
# Register push listeners for hotels with push_endpoint # Register push listeners for hotels with push_endpoint
for hotel in config.get("alpine_bits_auth", []): for hotel in config.get("alpine_bits_auth", []):
push_endpoint = hotel.get("push_endpoint") push_endpoint = hotel.get("push_endpoint")
@@ -176,11 +207,11 @@ async def lifespan(app: FastAPI):
if push_endpoint and hotel_id: if push_endpoint and hotel_id:
# Register hotel-specific listener # Register hotel-specific listener
event_dispatcher.register_hotel_listener( event_dispatcher.register_hotel_listener(
"form_processed", "form_processed", hotel_id, partial(push_listener, hotel=hotel)
hotel_id, )
partial(push_listener, hotel=hotel) _LOGGER.info(
f"Registered push listener for hotel {hotel_id} with endpoint {push_endpoint.get('url')}"
) )
_LOGGER.info(f"Registered push listener for hotel {hotel_id} with endpoint {push_endpoint.get('url')}")
elif push_endpoint and not hotel_id: elif push_endpoint and not hotel_id:
_LOGGER.warning(f"Hotel has push_endpoint but no hotel_id: {hotel}") _LOGGER.warning(f"Hotel has push_endpoint but no hotel_id: {hotel}")
elif hotel_id and not push_endpoint: elif hotel_id and not push_endpoint:
@@ -351,7 +382,7 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
name_prefix = data.get("field:anrede") name_prefix = data.get("field:anrede")
email_newsletter_string = data.get("field:form_field_5a7b", "") email_newsletter_string = data.get("field:form_field_5a7b", "")
yes_values = {"Selezionato", "Angekreuzt", "Checked"} yes_values = {"Selezionato", "Angekreuzt", "Checked"}
email_newsletter = (email_newsletter_string in yes_values) email_newsletter = email_newsletter_string in yes_values
address_line = None address_line = None
city_name = None city_name = None
postal_code = None postal_code = None
@@ -404,8 +435,6 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
# strip to first 35 chars # strip to first 35 chars
unique_id = unique_id[:32] unique_id = unique_id[:32]
# use database session # use database session
# Save all relevant data to DB (including new fields) # Save all relevant data to DB (including new fields)
@@ -431,21 +460,20 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
await db.flush() # This assigns db_customer.id without committing await db.flush() # This assigns db_customer.id without committing
# await db.refresh(db_customer) # await db.refresh(db_customer)
# Determine hotel_code and hotel_name # Determine hotel_code and hotel_name
# Priority: 1) Form field, 2) Configuration default, 3) Hardcoded fallback # Priority: 1) Form field, 2) Configuration default, 3) Hardcoded fallback
hotel_code = ( hotel_code = (
data.get("field:hotelid") or data.get("field:hotelid")
data.get("hotelid") or or data.get("hotelid")
request.app.state.config.get("default_hotel_code") or or request.app.state.config.get("default_hotel_code")
"123" # fallback or "123" # fallback
) )
hotel_name = ( hotel_name = (
data.get("field:hotelname") or data.get("field:hotelname")
data.get("hotelname") or or data.get("hotelname")
request.app.state.config.get("default_hotel_name") or or request.app.state.config.get("default_hotel_name")
"Frangart Inn" # fallback or "Frangart Inn" # fallback
) )
db_reservation = DBReservation( db_reservation = DBReservation(
@@ -473,22 +501,24 @@ async def process_wix_form_submission(request: Request, data: Dict[str, Any], db
await db.commit() await db.commit()
await db.refresh(db_reservation) await db.refresh(db_reservation)
async def push_event(): async def push_event():
# Fire event for listeners (push, etc.) - hotel-specific dispatch # Fire event for listeners (push, etc.) - hotel-specific dispatch
dispatcher = getattr(request.app.state, "event_dispatcher", None) dispatcher = getattr(request.app.state, "event_dispatcher", None)
if dispatcher: if dispatcher:
# Get hotel_code from reservation to target the right listeners # Get hotel_code from reservation to target the right listeners
hotel_code = getattr(db_reservation, 'hotel_code', None) hotel_code = getattr(db_reservation, "hotel_code", None)
if hotel_code and hotel_code.strip(): if hotel_code and hotel_code.strip():
await dispatcher.dispatch_for_hotel("form_processed", hotel_code, db_customer, db_reservation) await dispatcher.dispatch_for_hotel(
"form_processed", hotel_code, db_customer, db_reservation
)
_LOGGER.info(f"Dispatched form_processed event for hotel {hotel_code}") _LOGGER.info(f"Dispatched form_processed event for hotel {hotel_code}")
else: else:
_LOGGER.warning("No hotel_code in reservation, skipping push notifications") _LOGGER.warning(
"No hotel_code in reservation, skipping push notifications"
)
asyncio.create_task(push_event()) asyncio.create_task(push_event())
return { return {
"status": "success", "status": "success",
"message": "Wix form data received successfully", "message": "Wix form data received successfully",
@@ -517,9 +547,7 @@ async def handle_wix_form(
traceback_str = traceback.format_exc() traceback_str = traceback.format_exc()
_LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}") _LOGGER.error(f"Stack trace for handle_wix_form: {traceback_str}")
raise HTTPException( raise HTTPException(status_code=500, detail=f"Error processing Wix form data")
status_code=500, detail=f"Error processing Wix form data"
)
@api_router.post("/webhook/wix-form/test") @api_router.post("/webhook/wix-form/test")
@@ -535,9 +563,7 @@ async def handle_wix_form_test(
return await process_wix_form_submission(request, data, db_session) return await process_wix_form_submission(request, data, db_session)
except Exception as e: except Exception as e:
_LOGGER.error(f"Error in handle_wix_form_test: {str(e)}") _LOGGER.error(f"Error in handle_wix_form_test: {str(e)}")
raise HTTPException( raise HTTPException(status_code=500, detail=f"Error processing test data")
status_code=500, detail=f"Error processing test data"
)
@api_router.post("/admin/generate-api-key") @api_router.post("/admin/generate-api-key")
@@ -773,7 +799,9 @@ async def alpinebits_server_handshake(
username, password = credentials_tupel username, password = credentials_tupel
client_info = AlpineBitsClientInfo(username=username, password=password, client_id=client_id) client_info = AlpineBitsClientInfo(
username=username, password=password, client_id=client_id
)
# Create successful handshake response # Create successful handshake response
response = await server.handle_request( response = await server.handle_request(

View File

@@ -30,6 +30,7 @@ if os.getenv("WIX_API_KEY"):
if os.getenv("ADMIN_API_KEY"): if os.getenv("ADMIN_API_KEY"):
API_KEYS["admin-key"] = os.getenv("ADMIN_API_KEY") API_KEYS["admin-key"] = os.getenv("ADMIN_API_KEY")
def generate_unique_id() -> str: def generate_unique_id() -> str:
"""Generate a unique ID with max length 35 characters""" """Generate a unique ID with max length 35 characters"""
return secrets.token_urlsafe(26)[:35] # 26 bytes -> 35 chars in base64url return secrets.token_urlsafe(26)[:35] # 26 bytes -> 35 chars in base64url

View File

@@ -67,11 +67,12 @@ class Reservation(Base):
customer = relationship("Customer", back_populates="reservations") customer = relationship("Customer", back_populates="reservations")
# Table for tracking acknowledged requests by client # Table for tracking acknowledged requests by client
class AckedRequest(Base): class AckedRequest(Base):
__tablename__ = 'acked_requests' __tablename__ = "acked_requests"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
client_id = Column(String, index=True) client_id = Column(String, index=True)
unique_id = Column(String, index=True) # Should match Reservation.form_id or another unique field unique_id = Column(
String, index=True
) # Should match Reservation.form_id or another unique field
timestamp = Column(DateTime) timestamp = Column(DateTime)

View File

@@ -4,7 +4,6 @@ import sys
import os import os
from alpine_bits_python.alpine_bits_helpers import ( from alpine_bits_python.alpine_bits_helpers import (
CustomerData, CustomerData,
CustomerFactory, CustomerFactory,

View File

@@ -1,7 +1,3 @@
import pytest import pytest
from alpine_bits_python.alpinebits_server import AlpineBitsServer, AlpineBitsClientInfo from alpine_bits_python.alpinebits_server import AlpineBitsServer, AlpineBitsClientInfo
from xsdata_pydantic.bindings import XmlParser from xsdata_pydantic.bindings import XmlParser
@@ -9,14 +5,3 @@ from alpine_bits_python.generated.alpinebits import OtaResRetrieveRs, OtaHotelRe
pass pass

View File

@@ -1,4 +1,3 @@
import json import json
import pytest import pytest
import asyncio import asyncio
@@ -8,8 +7,6 @@ from xsdata_pydantic.bindings import XmlParser
from alpine_bits_python.generated.alpinebits import OtaPingRs from alpine_bits_python.generated.alpinebits import OtaPingRs
def extract_relevant_sections(xml_string): def extract_relevant_sections(xml_string):
# Remove version attribute value, keep only presence # Remove version attribute value, keep only presence
# Use the same XmlParser as AlpineBitsServer # Use the same XmlParser as AlpineBitsServer
@@ -17,21 +14,25 @@ def extract_relevant_sections(xml_string):
obj = parser.from_string(xml_string, OtaPingRs) obj = parser.from_string(xml_string, OtaPingRs)
return obj return obj
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_action_response_matches_expected(): async def test_ping_action_response_matches_expected():
with open("test/test_data/Handshake-OTA_PingRQ.xml", "r", encoding="utf-8") as f: with open("test/test_data/Handshake-OTA_PingRQ.xml", "r", encoding="utf-8") as f:
server = AlpineBitsServer() server = AlpineBitsServer()
with open("test/test_data/Handshake-OTA_PingRQ.xml", "r", encoding="utf-8") as f: with open(
"test/test_data/Handshake-OTA_PingRQ.xml", "r", encoding="utf-8"
) as f:
request_xml = f.read() request_xml = f.read()
with open("test/test_data/Handshake-OTA_PingRS.xml", "r", encoding="utf-8") as f: with open(
"test/test_data/Handshake-OTA_PingRS.xml", "r", encoding="utf-8"
) as f:
expected_xml = f.read() expected_xml = f.read()
client_info = AlpineBitsClientInfo(username="irrelevant", password="irrelevant") client_info = AlpineBitsClientInfo(username="irrelevant", password="irrelevant")
response = await server.handle_request( response = await server.handle_request(
request_action_name="OTA_Ping:Handshaking", request_action_name="OTA_Ping:Handshaking",
request_xml=request_xml, request_xml=request_xml,
client_info=client_info, client_info=client_info,
version="2024-10" version="2024-10",
) )
actual_obj = extract_relevant_sections(response.xml_content) actual_obj = extract_relevant_sections(response.xml_content)
expected_obj = extract_relevant_sections(expected_xml) expected_obj = extract_relevant_sections(expected_xml)
@@ -40,12 +41,17 @@ async def test_ping_action_response_matches_expected():
expected_matches = json.loads(expected_obj.warnings.warning[0].content[0]) expected_matches = json.loads(expected_obj.warnings.warning[0].content[0])
assert actual_matches == expected_matches, f"Expected warnings {expected_matches}, got {actual_matches}" assert actual_matches == expected_matches, (
f"Expected warnings {expected_matches}, got {actual_matches}"
)
actual_capabilities = json.loads(actual_obj.echo_data) actual_capabilities = json.loads(actual_obj.echo_data)
expected_capabilities = json.loads(expected_obj.echo_data) expected_capabilities = json.loads(expected_obj.echo_data)
assert actual_capabilities == expected_capabilities, f"Expected echo data {expected_capabilities}, got {actual_capabilities}" assert actual_capabilities == expected_capabilities, (
f"Expected echo data {expected_capabilities}, got {actual_capabilities}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_action_response_success(): async def test_ping_action_response_success():
@@ -57,13 +63,14 @@ async def test_ping_action_response_success():
request_action_name="OTA_Ping:Handshaking", request_action_name="OTA_Ping:Handshaking",
request_xml=request_xml, request_xml=request_xml,
client_info=client_info, client_info=client_info,
version="2024-10" version="2024-10",
) )
assert response.status_code == 200 assert response.status_code == 200
assert "<OTA_PingRS" in response.xml_content assert "<OTA_PingRS" in response.xml_content
assert "<Success" in response.xml_content assert "<Success" in response.xml_content
assert "Version=" in response.xml_content assert "Version=" in response.xml_content
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_action_response_version_arbitrary(): async def test_ping_action_response_version_arbitrary():
server = AlpineBitsServer() server = AlpineBitsServer()
@@ -74,12 +81,13 @@ async def test_ping_action_response_version_arbitrary():
request_action_name="OTA_Ping:Handshaking", request_action_name="OTA_Ping:Handshaking",
request_xml=request_xml, request_xml=request_xml,
client_info=client_info, client_info=client_info,
version="2022-10" version="2022-10",
) )
assert response.status_code == 200 assert response.status_code == 200
assert "<OTA_PingRS" in response.xml_content assert "<OTA_PingRS" in response.xml_content
assert "Version=" in response.xml_content assert "Version=" in response.xml_content
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_action_response_invalid_action(): async def test_ping_action_response_invalid_action():
server = AlpineBitsServer() server = AlpineBitsServer()
@@ -90,7 +98,7 @@ async def test_ping_action_response_invalid_action():
request_action_name="InvalidAction", request_action_name="InvalidAction",
request_xml=request_xml, request_xml=request_xml,
client_info=client_info, client_info=client_info,
version="2024-10" version="2024-10",
) )
assert response.status_code == 400 assert response.status_code == 400
assert "Error" in response.xml_content assert "Error" in response.xml_content