Starting to implement action_OTA_HotelResNotif_GuestRequests. Necessary to fully comply with spec

This commit is contained in:
Jonas Linter
2025-10-01 09:31:11 +02:00
parent 228aed6d58
commit 13df12afc6
7 changed files with 167 additions and 129 deletions

View File

@@ -9,7 +9,13 @@ from typing import Tuple
from alpine_bits_python.db import Customer, Reservation from alpine_bits_python.db import Customer, Reservation
# Import the generated classes # Import the generated classes
from .generated.alpinebits import HotelReservationResStatus, OtaHotelResNotifRq, OtaResRetrieveRs, CommentName2, UniqueIdType2 from .generated.alpinebits import (
HotelReservationResStatus,
OtaHotelResNotifRq,
OtaResRetrieveRs,
CommentName2,
UniqueIdType2,
)
import logging import logging
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -431,7 +437,9 @@ class CommentFactory:
@staticmethod @staticmethod
def _create_comments( def _create_comments(
comments_class: type[RetrieveComments] | type[NotifComments], comment_class: type[RetrieveComment] | type[NotifComment], data: CommentsData comments_class: type[RetrieveComments] | type[NotifComments],
comment_class: type[RetrieveComment] | type[NotifComment],
data: CommentsData,
) -> Any: ) -> Any:
"""Internal method to create comments of the specified type.""" """Internal method to create comments of the specified type."""
@@ -440,7 +448,9 @@ class CommentFactory:
# Create list items # Create list items
list_items = [] list_items = []
for item_data in comment_data.list_items: for item_data in comment_data.list_items:
_LOGGER.info(f"Creating list item: value={item_data.value}, list_item={item_data.list_item}, language={item_data.language}") _LOGGER.info(
f"Creating list item: value={item_data.value}, list_item={item_data.list_item}, language={item_data.language}"
)
list_item = comment_class.ListItem( list_item = comment_class.ListItem(
value=item_data.value, value=item_data.value,
@@ -662,7 +672,7 @@ class AlpineBitsFactory:
def create_xml_from_db(list: list[Tuple[Reservation, Customer]]): def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
""" Create RetrievedReservation XML from database entries. """Create RetrievedReservation XML from database entries.
list of pairs (Reservation, Customer) list of pairs (Reservation, Customer)
""" """
@@ -670,11 +680,16 @@ def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
reservations_list = [] reservations_list = []
for reservation, customer in list: for reservation, customer in list:
_LOGGER.info(f"Creating XML for reservation {reservation.form_id} and customer {customer.given_name}") _LOGGER.info(
f"Creating XML for reservation {reservation.form_id} and customer {customer.given_name}"
)
try: try:
phone_numbers = (
phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone is not None else [] [(customer.phone, PhoneTechType.MOBILE)]
if customer.phone is not None
else []
)
customer_data = CustomerData( customer_data = CustomerData(
given_name=customer.given_name, given_name=customer.given_name,
surname=customer.surname, surname=customer.surname,
@@ -703,10 +718,8 @@ def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
reservation.num_adults, children_ages reservation.num_adults, children_ages
) )
unique_id_string = reservation.form_id unique_id_string = reservation.form_id
if len(unique_id_string) > 32: if len(unique_id_string) > 32:
unique_id_string = unique_id_string[:32] # Truncate to 32 characters unique_id_string = unique_id_string[:32] # Truncate to 32 characters
@@ -717,7 +730,9 @@ def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
# TimeSpan # TimeSpan
time_span = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan( time_span = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.TimeSpan(
start=reservation.start_date.isoformat() if reservation.start_date else None, start=reservation.start_date.isoformat()
if reservation.start_date
else None,
end=reservation.end_date.isoformat() if reservation.end_date else None, end=reservation.end_date.isoformat() if reservation.end_date else None,
) )
room_stay = ( room_stay = (
@@ -781,12 +796,15 @@ def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
comments_xml = None comments_xml = None
if comments: if comments:
for c in comments: for c in comments:
_LOGGER.info(f"Creating comment: name={c.name}, text={c.text}, list_items={len(c.list_items)}") _LOGGER.info(
f"Creating comment: name={c.name}, text={c.text}, list_items={len(c.list_items)}"
)
comments_data = CommentsData(comments=comments) comments_data = CommentsData(comments=comments)
comments_xml = alpine_bits_factory.create(comments_data, OtaMessageType.RETRIEVE) comments_xml = alpine_bits_factory.create(
comments_data, OtaMessageType.RETRIEVE
)
res_global_info = ( res_global_info = (
OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo( OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo(
@@ -796,8 +814,6 @@ def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
) )
) )
hotel_reservation = OtaResRetrieveRs.ReservationsList.HotelReservation( hotel_reservation = OtaResRetrieveRs.ReservationsList.HotelReservation(
create_date_time=datetime.now(timezone.utc).isoformat(), create_date_time=datetime.now(timezone.utc).isoformat(),
res_status=HotelReservationResStatus.REQUESTED, res_status=HotelReservationResStatus.REQUESTED,
@@ -811,7 +827,9 @@ def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
reservations_list.append(hotel_reservation) reservations_list.append(hotel_reservation)
except Exception as e: except Exception as e:
_LOGGER.error(f"Error creating XML for reservation {reservation.form_id} and customer {customer.given_name}: {e}") _LOGGER.error(
f"Error creating XML for reservation {reservation.form_id} and customer {customer.given_name}: {e}"
)
retrieved_reservations = OtaResRetrieveRs.ReservationsList( retrieved_reservations = OtaResRetrieveRs.ReservationsList(
hotel_reservation=reservations_list hotel_reservation=reservations_list
@@ -830,7 +848,6 @@ def create_xml_from_db(list: list[Tuple[Reservation, Customer]]):
return ota_res_retrieve_rs return ota_res_retrieve_rs
# Usage examples # Usage examples
if __name__ == "__main__": if __name__ == "__main__":
# Create customer data using simple data class # Create customer data using simple data class

View File

@@ -12,7 +12,7 @@ import difflib
import json import json
import inspect import inspect
import re import re
from typing import Dict, List, Optional, Any, Union, Tuple, Type from typing import Dict, List, Optional, Any, Union, Tuple, Type, override
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, IntEnum from enum import Enum, IntEnum
@@ -20,7 +20,6 @@ from enum import Enum, IntEnum
from alpine_bits_python.alpine_bits_helpers import PhoneTechType, create_xml_from_db from alpine_bits_python.alpine_bits_helpers import PhoneTechType, create_xml_from_db
from .generated.alpinebits import OtaPingRq, OtaPingRs, WarningStatus, OtaReadRq from .generated.alpinebits import OtaPingRq, OtaPingRs, WarningStatus, OtaReadRq
from xsdata_pydantic.bindings import XmlSerializer from xsdata_pydantic.bindings import XmlSerializer
from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.formats.dataclass.serializers.config import SerializerConfig
@@ -37,7 +36,6 @@ logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class HttpStatusCode(IntEnum): class HttpStatusCode(IntEnum):
"""Allowed HTTP status codes for AlpineBits responses.""" """Allowed HTTP status codes for AlpineBits responses."""
@@ -114,6 +112,15 @@ class Version(str, Enum):
# Add other versions as needed # Add other versions as needed
class AlpineBitsClientInfo:
"""Wrapper for username, password, client_id"""
def __init__(self, username: str, password: str, client_id: str | None = None):
self.username = username
self.password = password
self.client_id = client_id
@dataclass @dataclass
class AlpineBitsResponse: class AlpineBitsResponse:
"""Response data structure for AlpineBits actions.""" """Response data structure for AlpineBits actions."""
@@ -139,7 +146,13 @@ class AlpineBitsAction(ABC):
) # list of versions in case action supports multiple versions ) # list of versions in case action supports multiple versions
async def handle( async def handle(
self, action: str, request_xml: str, version: Version, dbsession=None, server_capabilities=None, username=None, password=None, config: Dict = None self,
action: str,
request_xml: str,
version: Version,
client_info: AlpineBitsClientInfo,
dbsession=None,
server_capabilities=None,
) -> AlpineBitsResponse: ) -> AlpineBitsResponse:
""" """
Handle the incoming request XML and return response XML. Handle the incoming request XML and return response XML.
@@ -268,7 +281,7 @@ class ServerCapabilities:
class PingAction(AlpineBitsAction): class PingAction(AlpineBitsAction):
"""Implementation for OTA_Ping action (handshaking).""" """Implementation for OTA_Ping action (handshaking)."""
def __init__(self, config: Dict = None): def __init__(self, config: Dict = {}):
self.name = AlpineBitsActionName.OTA_PING self.name = AlpineBitsActionName.OTA_PING
self.version = [ self.version = [
Version.V2024_10, Version.V2024_10,
@@ -276,11 +289,13 @@ class PingAction(AlpineBitsAction):
] # Supports multiple versions ] # Supports multiple versions
self.config = config self.config = config
@override
async def handle( async def handle(
self, self,
action: str, action: str,
request_xml: str, request_xml: str,
version: Version, version: Version,
client_info: AlpineBitsClientInfo,
server_capabilities: None | ServerCapabilities = None, server_capabilities: None | ServerCapabilities = None,
) -> AlpineBitsResponse: ) -> AlpineBitsResponse:
"""Handle ping requests.""" """Handle ping requests."""
@@ -352,7 +367,7 @@ class PingAction(AlpineBitsAction):
capabilities_json = json.dumps(matching_capabilities, indent=2) capabilities_json = json.dumps(matching_capabilities, indent=2)
warning = OtaPingRs.Warnings.Warning( warning = OtaPingRs.Warnings.Warning(
status=WarningStatus.ALPINEBITS_HANDSHAKE.value, status=WarningStatus.ALPINEBITS_HANDSHAKE,
type_value="11", type_value="11",
content=[capabilities_json], content=[capabilities_json],
) )
@@ -379,19 +394,24 @@ class PingAction(AlpineBitsAction):
) )
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
def strip_control_chars(s): def strip_control_chars(s):
# Remove all control characters (ASCII < 32 and DEL) # Remove all control characters (ASCII < 32 and DEL)
return re.sub(r'[\x00-\x1F\x7F]', '', s) return re.sub(r"[\x00-\x1F\x7F]", "", s)
def validate_hotel_authentication(username: str, password: str, hotelid: str, config: Dict) -> bool:
""" Validate hotel authentication based on username, password, and hotel ID.
Example config def validate_hotel_authentication(
alpine_bits_auth: username: str, password: str, hotelid: str, config: Dict
- hotel_id: "123" ) -> bool:
hotel_name: "Frangart Inn" """Validate hotel authentication based on username, password, and hotel ID.
username: "alice"
password: !secret ALICE_PASSWORD Example config
alpine_bits_auth:
- hotel_id: "123"
hotel_name: "Frangart Inn"
username: "alice"
password: !secret ALICE_PASSWORD
""" """
if not config or "alpine_bits_auth" not in config: if not config or "alpine_bits_auth" not in config:
@@ -409,20 +429,22 @@ def validate_hotel_authentication(username: str, password: str, hotelid: str, co
# look for hotelid in config # look for hotelid in config
class ReadAction(AlpineBitsAction): class ReadAction(AlpineBitsAction):
"""Implementation for OTA_Read action.""" """Implementation for OTA_Read action."""
def __init__(self, config: Dict = None): def __init__(self, config: Dict = {}):
self.name = AlpineBitsActionName.OTA_READ self.name = AlpineBitsActionName.OTA_READ
self.version = [Version.V2024_10, Version.V2022_10] self.version = [Version.V2024_10, Version.V2022_10]
self.config = config self.config = config
async def handle( async def handle(
self, action: str, request_xml: str, version: Version, dbsession=None, username=None, password=None self,
action: str,
request_xml: str,
version: Version,
client_info: AlpineBitsClientInfo,
dbsession=None,
server_capabilities=None,
) -> AlpineBitsResponse: ) -> AlpineBitsResponse:
"""Handle read requests.""" """Handle read requests."""
@@ -430,9 +452,9 @@ class ReadAction(AlpineBitsAction):
clean_expected = strip_control_chars(self.name.value[1]).strip() clean_expected = strip_control_chars(self.name.value[1]).strip()
if clean_action != clean_expected: if clean_action != clean_expected:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Invalid action {action}, expected {self.name.value[1]}", HttpStatusCode.BAD_REQUEST f"Error: Invalid action {action}, expected {self.name.value[1]}",
HttpStatusCode.BAD_REQUEST,
) )
if dbsession is None: if dbsession is None:
@@ -450,22 +472,24 @@ class ReadAction(AlpineBitsAction):
if hotelname is None: if hotelname is None:
hotelname = "unknown" hotelname = "unknown"
if username is None or password is None or hotelid is None: if hotelid is None:
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unauthorized Read Request for this specific hotel {hotelname}. Check credentials", HttpStatusCode.UNAUTHORIZED f"Error: Unauthorized Read Request. No target hotel specified. Check credentials",
HttpStatusCode.UNAUTHORIZED,
) )
if not validate_hotel_authentication(username, password, hotelid, self.config): if not validate_hotel_authentication(client_info.username, client_info.password, hotelid, self.config):
return AlpineBitsResponse( return AlpineBitsResponse(
f"Error: Unauthorized Read Request for this specific hotel {hotelname}. Check credentials", HttpStatusCode.UNAUTHORIZED f"Error: Unauthorized Read Request for this specific hotel {hotelname}. Check credentials",
HttpStatusCode.UNAUTHORIZED,
) )
start_date = None start_date = None
if hotel_read_request.selection_criteria is not None: if hotel_read_request.selection_criteria is not None:
start_date = datetime.fromisoformat(hotel_read_request.selection_criteria.start) start_date = datetime.fromisoformat(
hotel_read_request.selection_criteria.start
)
# query all reservations for this hotel from the database, where start_date is greater than or equal to the given start_date # query all reservations for this hotel from the database, where start_date is greater than or equal to the given start_date
@@ -478,11 +502,17 @@ class ReadAction(AlpineBitsAction):
stmt = stmt.filter(Reservation.start_date >= start_date) stmt = stmt.filter(Reservation.start_date >= start_date)
result = await dbsession.execute(stmt) result = await dbsession.execute(stmt)
reservation_customer_pairs: list[tuple[Reservation, Customer]] = result.all() # List of (Reservation, Customer) tuples reservation_customer_pairs: list[tuple[Reservation, Customer]] = (
result.all()
) # List of (Reservation, Customer) tuples
_LOGGER.info(f"Querying reservations and customers for hotel {hotelid} from database") _LOGGER.info(
f"Querying reservations and customers for hotel {hotelid} from database"
)
for reservation, customer in reservation_customer_pairs: for reservation, customer in reservation_customer_pairs:
_LOGGER.info(f"Reservation: {reservation.id}, Customer: {customer.given_name}") _LOGGER.info(
f"Reservation: {reservation.id}, Customer: {customer.given_name}"
)
res_retrive_rs = create_xml_from_db(reservation_customer_pairs) res_retrive_rs = create_xml_from_db(reservation_customer_pairs)
@@ -497,58 +527,28 @@ class ReadAction(AlpineBitsAction):
return AlpineBitsResponse(response_xml, HttpStatusCode.OK) return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class NotifReportReadAction(AlpineBitsAction):
"""Necessary for read action to follow specification. Clients need to report acknowledgements"""
def __init__(self, config: Dict = {}):
self.name = AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS
self.version = [Version.V2024_10, Version.V2022_10]
self.config = config
async def handle(
self,
action: str,
request_xml: str,
version: Version,
dbsession=None,
username=None,
password=None,
) -> AlpineBitsResponse:
"""Handle read requests."""
return AlpineBitsResponse(
f"Error: Action {action} not implemented", HttpStatusCode.BAD_REQUEST
)
# For demonstration, just echo back a simple XML response
response_xml = """<?xml version="1.0" encoding="UTF-8"?>
<OTA_ReadRS xmlns="http://www.opentravel.org/OTA/2003/
05" Version="8.000">
<Success/>
</OTA_ReadRS>"""
return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
# class HotelAvailNotifAction(AlpineBitsAction):
# """Implementation for Hotel Availability Notification action with supports."""
# def __init__(self):
# self.name = AlpineBitsActionName.OTA_HOTEL_AVAIL_NOTIF
# self.version = Version.V2022_10
# self.supports = [
# "OTA_HotelAvailNotif_accept_rooms",
# "OTA_HotelAvailNotif_accept_categories",
# "OTA_HotelAvailNotif_accept_deltas",
# "OTA_HotelAvailNotif_accept_BookingThreshold",
# ]
# async def handle(
# self, action: str, request_xml: str, version: Version
# ) -> AlpineBitsResponse:
# """Handle hotel availability notifications."""
# response_xml = """<?xml version="1.0" encoding="UTF-8"?>
# <OTA_HotelAvailNotifRS xmlns="http://www.opentravel.org/OTA/2003/05" Version="8.000">
# <Success/>
# </OTA_HotelAvailNotifRS>"""
# return AlpineBitsResponse(response_xml, HttpStatusCode.OK)
class GuestRequestsAction(AlpineBitsAction): class GuestRequestsAction(AlpineBitsAction):
@@ -576,7 +576,6 @@ class AlpineBitsServer:
self.config = config self.config = config
self._initialize_action_instances() self._initialize_action_instances()
def _initialize_action_instances(self): def _initialize_action_instances(self):
"""Initialize instances of all discovered action classes.""" """Initialize instances of all discovered action classes."""
for capability_name, action_class in self.capabilities.action_registry.items(): for capability_name, action_class in self.capabilities.action_registry.items():
@@ -591,7 +590,12 @@ class AlpineBitsServer:
return self.capabilities.get_capabilities_json() return self.capabilities.get_capabilities_json()
async def handle_request( async def handle_request(
self, request_action_name: str, request_xml: str, version: str = "2024-10", dbsession=None, username=None, password=None self,
request_action_name: str,
request_xml: str,
client_info: AlpineBitsClientInfo,
version: str = "2024-10",
dbsession=None,
) -> AlpineBitsResponse: ) -> AlpineBitsResponse:
""" """
Handle an incoming AlpineBits request by routing to appropriate action handler. Handle an incoming AlpineBits request by routing to appropriate action handler.
@@ -642,11 +646,15 @@ class AlpineBitsServer:
# Special case for ping action - pass server capabilities # Special case for ping action - pass server capabilities
if capability_name == "action_OTA_Ping": if capability_name == "action_OTA_Ping":
return await action_instance.handle( return await action_instance.handle(
request_action_name, request_xml, version_enum, self.capabilities action=request_action_name, request_xml=request_xml, version=version_enum, server_capabilities=self.capabilities, client_info=client_info
) )
else: else:
return await action_instance.handle( return await action_instance.handle(
request_action_name, request_xml, version_enum, dbsession=dbsession, username=username, password=password action=request_action_name,
request_xml=request_xml,
version=version_enum,
dbsession=dbsession,
client_info=client_info,
) )
except Exception as e: except Exception as e:
print(f"Error handling request {request_action_name}: {str(e)}") print(f"Error handling request {request_action_name}: {str(e)}")
@@ -669,7 +677,7 @@ class AlpineBitsServer:
return sorted(request_names) return sorted(request_names)
def is_action_supported( def is_action_supported(
self, request_action_name: str, version: str = None self, request_action_name: str, version: str | None = None
) -> bool: ) -> bool:
""" """
Check if a request action is supported. Check if a request action is supported.

View File

@@ -33,7 +33,7 @@ import json
import os import os
import gzip import gzip
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from .alpinebits_server import AlpineBitsServer, Version from .alpinebits_server import AlpineBitsClientInfo, AlpineBitsServer, Version
import urllib.parse import urllib.parse
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
@@ -83,6 +83,7 @@ async def lifespan(app: FastAPI):
# Optional: Dispose engine on shutdown # Optional: Dispose engine on shutdown
await engine.dispose() await engine.dispose()
async def get_async_session(request: Request): async def get_async_session(request: Request):
async_sessionmaker = request.app.state.async_sessionmaker async_sessionmaker = request.app.state.async_sessionmaker
async with async_sessionmaker() as session: async with async_sessionmaker() as session:
@@ -93,7 +94,7 @@ app = FastAPI(
title="Wix Form Handler API", title="Wix Form Handler API",
description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting", description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting",
version="1.0.0", version="1.0.0",
lifespan=lifespan lifespan=lifespan,
) )
# Create API router with /api prefix # Create API router with /api prefix
@@ -155,8 +156,6 @@ async def process_form_submission(submission_data: Dict[str, Any]) -> None:
_LOGGER.error(f"Error processing form submission: {str(e)}") _LOGGER.error(f"Error processing form submission: {str(e)}")
@api_router.get("/") @api_router.get("/")
@limiter.limit(DEFAULT_RATE_LIMIT) @limiter.limit(DEFAULT_RATE_LIMIT)
async def root(request: Request): async def root(request: Request):
@@ -512,7 +511,9 @@ def parse_multipart_data(content_type: str, body: bytes) -> Dict[str, Any]:
@api_router.post("/alpinebits/server-2024-10") @api_router.post("/alpinebits/server-2024-10")
@limiter.limit("60/minute") @limiter.limit("60/minute")
async def alpinebits_server_handshake( async def alpinebits_server_handshake(
request: Request, credentials_tupel: tuple = Depends(validate_basic_auth), dbsession=Depends(get_async_session) request: Request,
credentials_tupel: tuple = Depends(validate_basic_auth),
dbsession=Depends(get_async_session),
): ):
""" """
AlpineBits server endpoint implementing the handshake protocol. AlpineBits server endpoint implementing the handshake protocol.
@@ -615,14 +616,22 @@ async def alpinebits_server_handshake(
# Get optional request XML # Get optional request XML
request_xml = form_data.get("request") request_xml = form_data.get("request")
server = app.state.alpine_bits_server server: AlpineBitsServer = app.state.alpine_bits_server
version = Version.V2024_10 version = Version.V2024_10
username, password = credentials_tupel username, password = credentials_tupel
client_info = AlpineBitsClientInfo(username=username, password=password, client_id=client_id)
# Create successful handshake response # Create successful handshake response
response = await server.handle_request(action, request_xml, version, dbsession=dbsession, username=username, password=password) response = await server.handle_request(
action,
request_xml,
client_info=client_info,
version=version,
dbsession=dbsession,
)
response_xml = response.xml_content response_xml = response.xml_content

View File

@@ -1,6 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List
from annotatedyaml.loader import ( from annotatedyaml.loader import (
HAS_C_LOADER, HAS_C_LOADER,
JSON_TYPE, JSON_TYPE,
@@ -12,7 +12,15 @@ from annotatedyaml.loader import (
parse_yaml as parse_annotated_yaml, parse_yaml as parse_annotated_yaml,
secret_yaml as annotated_secret_yaml, secret_yaml as annotated_secret_yaml,
) )
from voluptuous import Schema, Required, All, Length, PREVENT_EXTRA, MultipleInvalid from voluptuous import (
Schema,
Required,
All,
Length,
PREVENT_EXTRA,
MultipleInvalid,
Optional,
)
# --- Voluptuous schemas --- # --- Voluptuous schemas ---
database_schema = Schema({Required("url"): str}, extra=PREVENT_EXTRA) database_schema = Schema({Required("url"): str}, extra=PREVENT_EXTRA)
@@ -24,6 +32,11 @@ hotel_auth_schema = Schema(
Required("hotel_name"): str, Required("hotel_name"): str,
Required("username"): str, Required("username"): str,
Required("password"): str, Required("password"): str,
Optional("push_endpoint"): {
Required("url"): str,
Required("token"): str,
Optional("username"): str,
},
}, },
extra=PREVENT_EXTRA, extra=PREVENT_EXTRA,
) )

View File

View File

@@ -18,9 +18,6 @@ def get_database_url(config=None):
return db_url return db_url
class Customer(Base): class Customer(Base):
__tablename__ = "customers" __tablename__ = "customers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
@@ -71,7 +68,6 @@ class Reservation(Base):
customer = relationship("Customer", back_populates="reservations") customer = relationship("Customer", back_populates="reservations")
class HashedCustomer(Base): class HashedCustomer(Base):
__tablename__ = "hashed_customers" __tablename__ = "hashed_customers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)

View File

@@ -53,7 +53,6 @@ logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def setup_db(config): async def setup_db(config):
DATABASE_URL = get_database_url(config) DATABASE_URL = get_database_url(config)
engine = create_async_engine(DATABASE_URL, echo=True) engine = create_async_engine(DATABASE_URL, echo=True)
@@ -67,7 +66,6 @@ async def setup_db(config):
return engine, AsyncSessionLocal return engine, AsyncSessionLocal
async def main(): async def main():
print("🚀 Starting AlpineBits XML generation script...") print("🚀 Starting AlpineBits XML generation script...")
# Load config (yaml, annotatedyaml) # Load config (yaml, annotatedyaml)
@@ -92,7 +90,6 @@ async def main():
# # Ensure DB schema is created (async) # # Ensure DB schema is created (async)
engine, AsyncSessionLocal = await setup_db(config) engine, AsyncSessionLocal = await setup_db(config)
async with engine.begin() as conn: async with engine.begin() as conn:
@@ -227,8 +224,6 @@ async def main():
def create_xml_from_db(customer: DBCustomer, reservation: DBReservation): def create_xml_from_db(customer: DBCustomer, reservation: DBReservation):
# Prepare data for XML # Prepare data for XML
phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else [] phone_numbers = [(customer.phone, PhoneTechType.MOBILE)] if customer.phone else []
customer_data = CustomerData( customer_data = CustomerData(