diff --git a/src/alpine_bits_python/alpinebits_server.py b/src/alpine_bits_python/alpinebits_server.py index 20404dc..ab0e41a 100644 --- a/src/alpine_bits_python/alpinebits_server.py +++ b/src/alpine_bits_python/alpinebits_server.py @@ -15,7 +15,6 @@ from enum import Enum, IntEnum from typing import Any, Optional, override from zoneinfo import ZoneInfo -from sqlalchemy import select from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata_pydantic.bindings import XmlParser, XmlSerializer @@ -26,6 +25,7 @@ from alpine_bits_python.alpine_bits_helpers import ( from alpine_bits_python.logging_config import get_logger from .db import AckedRequest, Customer, Reservation +from .reservation_service import ReservationService from .generated.alpinebits import ( OtaNotifReportRq, OtaNotifReportRs, @@ -510,32 +510,29 @@ class ReadAction(AlpineBitsAction): hotel_read_request.selection_criteria.start ) - # query all reservations for this hotel from the database, where start_date is greater than or equal to the given start_date + # Use ReservationService to query reservations + reservation_service = ReservationService(dbsession) - stmt = ( - select(Reservation, Customer) - .join(Customer, Reservation.customer_id == Customer.id) - .filter(Reservation.hotel_code == hotelid) - ) if start_date: _LOGGER.info("Filtering reservations from start date %s", start_date) - stmt = stmt.filter(Reservation.created_at >= start_date) - # remove reservations that have been acknowledged via client_id - elif client_info.client_id: - subquery = ( - select(Reservation.id) - .join( - AckedRequest, - Reservation.md5_unique_id == AckedRequest.unique_id, + reservation_customer_pairs = ( + await reservation_service.get_reservations_with_filters( + start_date=start_date, hotel_code=hotelid + ) + ) + elif client_info.client_id: + # Remove reservations that have been acknowledged via client_id + reservation_customer_pairs = ( + await reservation_service.get_unacknowledged_reservations( + client_id=client_info.client_id, hotel_code=hotelid + ) + ) + else: + reservation_customer_pairs = ( + await reservation_service.get_reservations_with_filters( + hotel_code=hotelid ) - .filter(AckedRequest.client_id == client_info.client_id) ) - stmt = stmt.filter(~Reservation.id.in_(subquery)) - - result = await dbsession.execute(stmt) - reservation_customer_pairs: list[tuple[Reservation, Customer]] = ( - result.all() - ) # List of (Reservation, Customer) tuples _LOGGER.info( "Querying reservations and customers for hotel %s from database", @@ -616,19 +613,16 @@ class NotifReportReadAction(AlpineBitsAction): "Error: Something went wrong", HttpStatusCode.INTERNAL_SERVER_ERROR ) - timestamp = datetime.now(ZoneInfo("UTC")) + # Use ReservationService to record acknowledgements + reservation_service = ReservationService(dbsession) + for entry in ( notif_report_details.hotel_notif_report.hotel_reservations.hotel_reservation ): # type: ignore unique_id = entry.unique_id.id - acked_request = AckedRequest( - unique_id=unique_id, - client_id=client_info.client_id, - timestamp=timestamp, + await reservation_service.record_acknowledgement( + client_id=client_info.client_id, unique_id=unique_id ) - dbsession.add(acked_request) - - await dbsession.commit() return AlpineBitsResponse(response_xml, HttpStatusCode.OK) diff --git a/src/alpine_bits_python/api.py b/src/alpine_bits_python/api.py index b3f0615..6269c54 100644 --- a/src/alpine_bits_python/api.py +++ b/src/alpine_bits_python/api.py @@ -32,6 +32,7 @@ from .db import Base, get_database_url from .db import Customer as DBCustomer from .db import Reservation as DBReservation from .logging_config import get_logger, setup_logging +from .reservation_service import ReservationService from .rate_limit import ( BURST_RATE_LIMIT, DEFAULT_RATE_LIMIT, @@ -305,22 +306,6 @@ async def health_check(request: Request): } -def create_db_reservation_from_data( - reservation_model: ReservationData, db_customer_id: int -) -> DBReservation: - """Convert ReservationData to DBReservation, handling children_ages conversion.""" - data = reservation_model.model_dump(exclude_none=True) - - children_list = data.pop("children_ages", []) - children_csv = ",".join(str(int(a)) for a in children_list) if children_list else "" - data["children_ages"] = children_csv - - # Inject FK - data["customer_id"] = db_customer_id - - return DBReservation(**data) - - # Extracted business logic for handling Wix form submissions async def process_wix_form_submission(request: Request, data: dict[str, Any], db): """Shared business logic for handling Wix form submissions (test and production).""" @@ -458,10 +443,11 @@ async def process_wix_form_submission(request: Request, data: dict[str, Any], db if reservation.md5_unique_id is None: raise HTTPException(status_code=400, detail="Failed to generate md5_unique_id") - db_reservation = create_db_reservation_from_data(reservation, db_customer.id) - db.add(db_reservation) - await db.commit() - await db.refresh(db_reservation) + # Use ReservationService to create reservation + reservation_service = ReservationService(db) + db_reservation = await reservation_service.create_reservation( + reservation, db_customer.id + ) async def push_event(): # Fire event for listeners (push, etc.) - hotel-specific dispatch diff --git a/src/alpine_bits_python/reservation_service.py b/src/alpine_bits_python/reservation_service.py new file mode 100644 index 0000000..4c4a020 --- /dev/null +++ b/src/alpine_bits_python/reservation_service.py @@ -0,0 +1,263 @@ +"""Reservation service layer for handling reservation database operations.""" + +import hashlib +from datetime import UTC, datetime +from typing import Optional + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from .db import AckedRequest, Customer, Reservation +from .schemas import ReservationData + + +class ReservationService: + """Service for managing reservations and related operations. + + Handles all database operations for reservations including creation, + retrieval, and acknowledgement tracking. + """ + + def __init__(self, session: AsyncSession): + self.session = session + + def _convert_reservation_data_to_db( + self, reservation_model: ReservationData, customer_id: int + ) -> Reservation: + """Convert ReservationData to Reservation model. + + Args: + reservation_model: ReservationData instance + customer_id: Customer ID to link to + + Returns: + Reservation instance ready for database insertion + """ + data = reservation_model.model_dump(exclude_none=True) + + # Convert children_ages list to CSV string + children_list = data.pop("children_ages", []) + children_csv = ( + ",".join(str(int(a)) for a in children_list) if children_list else "" + ) + data["children_ages"] = children_csv + + # Inject foreign key + data["customer_id"] = customer_id + + return Reservation(**data) + + async def create_reservation( + self, reservation_data: ReservationData, customer_id: int + ) -> Reservation: + """Create a new reservation. + + Args: + reservation_data: ReservationData containing reservation details + customer_id: ID of the customer making the reservation + + Returns: + Created Reservation instance + """ + reservation = self._convert_reservation_data_to_db( + reservation_data, customer_id + ) + self.session.add(reservation) + await self.session.commit() + await self.session.refresh(reservation) + return reservation + + async def get_reservation_by_unique_id( + self, unique_id: str + ) -> Optional[Reservation]: + """Get a reservation by unique_id. + + Args: + unique_id: The unique_id to search for + + Returns: + Reservation instance if found, None otherwise + """ + result = await self.session.execute( + select(Reservation).where(Reservation.unique_id == unique_id) + ) + return result.scalar_one_or_none() + + async def get_reservation_by_md5_unique_id( + self, md5_unique_id: str + ) -> Optional[Reservation]: + """Get a reservation by md5_unique_id. + + Args: + md5_unique_id: The MD5 hash of unique_id + + Returns: + Reservation instance if found, None otherwise + """ + result = await self.session.execute( + select(Reservation).where( + Reservation.md5_unique_id == md5_unique_id + ) + ) + return result.scalar_one_or_none() + + async def check_duplicate_reservation( + self, unique_id: str, md5_unique_id: str + ) -> bool: + """Check if a reservation already exists. + + Args: + unique_id: The unique_id to check + md5_unique_id: The MD5 hash to check + + Returns: + True if reservation exists, False otherwise + """ + existing = await self.get_reservation_by_unique_id(unique_id) + if existing: + return True + + existing_md5 = await self.get_reservation_by_md5_unique_id(md5_unique_id) + return existing_md5 is not None + + async def get_reservations_for_customer( + self, customer_id: int + ) -> list[Reservation]: + """Get all reservations for a customer. + + Args: + customer_id: The customer ID + + Returns: + List of Reservation instances + """ + result = await self.session.execute( + select(Reservation).where(Reservation.customer_id == customer_id) + ) + return list(result.scalars().all()) + + async def get_reservations_with_filters( + self, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + hotel_code: Optional[str] = None, + ) -> list[tuple[Reservation, Customer]]: + """Get reservations with optional filters, joined with customers. + + Args: + start_date: Filter by created_at >= this value + end_date: Filter by created_at <= this value + hotel_code: Filter by hotel code + + Returns: + List of (Reservation, Customer) tuples + """ + query = select(Reservation, Customer).join( + Customer, Reservation.customer_id == Customer.id + ) + + filters = [] + if start_date: + filters.append(Reservation.created_at >= start_date) + if end_date: + filters.append(Reservation.created_at <= end_date) + if hotel_code: + filters.append(Reservation.hotel_code == hotel_code) + + if filters: + query = query.where(and_(*filters)) + + result = await self.session.execute(query) + return list(result.all()) + + async def get_unacknowledged_reservations( + self, + client_id: str, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + hotel_code: Optional[str] = None, + ) -> list[tuple[Reservation, Customer]]: + """Get reservations that haven't been acknowledged by a client. + + Args: + client_id: The client ID to check acknowledgements for + start_date: Filter by start date >= this value + end_date: Filter by end date <= this value + hotel_code: Filter by hotel code + + Returns: + List of (Reservation, Customer) tuples that are unacknowledged + """ + # Get all acknowledged md5_unique_ids for this client + acked_result = await self.session.execute( + select(AckedRequest.unique_id).where( + AckedRequest.client_id == client_id + ) + ) + acked_md5_ids = {row[0] for row in acked_result.all()} + + # Get all reservations with filters + all_reservations = await self.get_reservations_with_filters( + start_date, end_date, hotel_code + ) + + # Filter out acknowledged ones (comparing md5_unique_id) + return [ + (res, cust) + for res, cust in all_reservations + if res.md5_unique_id not in acked_md5_ids + ] + + async def record_acknowledgement( + self, client_id: str, unique_id: str + ) -> AckedRequest: + """Record that a client has acknowledged a reservation. + + Args: + client_id: The client ID + unique_id: The unique_id of the reservation + + Returns: + Created AckedRequest instance + """ + acked = AckedRequest( + client_id=client_id, + unique_id=unique_id, + timestamp=datetime.now(UTC), + ) + self.session.add(acked) + await self.session.commit() + await self.session.refresh(acked) + return acked + + async def is_acknowledged(self, client_id: str, unique_id: str) -> bool: + """Check if a reservation has been acknowledged by a client. + + Args: + client_id: The client ID + unique_id: The reservation unique_id + + Returns: + True if acknowledged, False otherwise + """ + result = await self.session.execute( + select(AckedRequest).where( + and_( + AckedRequest.client_id == client_id, + AckedRequest.unique_id == unique_id, + ) + ) + ) + return result.scalar_one_or_none() is not None + + @staticmethod + def generate_md5_unique_id(unique_id: str) -> str: + """Generate MD5 hash of unique_id. + + Args: + unique_id: The unique_id to hash + + Returns: + MD5 hash as hex string + """ + return hashlib.md5(unique_id.encode("utf-8")).hexdigest()