Experimenting with pydantic
This commit is contained in:
@@ -6,6 +6,9 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from alpine_bits_python.db import Customer, Reservation
|
||||
from alpine_bits_python.schemas import (
|
||||
HotelReservationIdData as HotelReservationIdDataValidated,
|
||||
)
|
||||
|
||||
# Import the generated classes
|
||||
from .generated.alpinebits import (
|
||||
@@ -21,12 +24,12 @@ _LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
|
||||
# Define type aliases for the two Customer types
|
||||
NotifCustomer = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGuests.ResGuest.Profiles.ProfileInfo.Profile.Customer
|
||||
RetrieveCustomer = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGuests.ResGuest.Profiles.ProfileInfo.Profile.Customer
|
||||
NotifCustomer = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGuests.ResGuest.Profiles.ProfileInfo.Profile.Customer # noqa: E501
|
||||
RetrieveCustomer = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGuests.ResGuest.Profiles.ProfileInfo.Profile.Customer # noqa: E501
|
||||
|
||||
# Define type aliases for HotelReservationId types
|
||||
NotifHotelReservationId = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId
|
||||
RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId
|
||||
NotifHotelReservationId = OtaHotelResNotifRq.HotelReservations.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId # noqa: E501
|
||||
RetrieveHotelReservationId = OtaResRetrieveRs.ReservationsList.HotelReservation.ResGlobalInfo.HotelReservationIds.HotelReservationId # noqa: E501
|
||||
|
||||
# Define type aliases for Comments types
|
||||
NotifComments = (
|
||||
@@ -326,7 +329,7 @@ class CustomerFactory:
|
||||
|
||||
@dataclass
|
||||
class HotelReservationIdData:
|
||||
"""Simple data class to hold hotel reservation ID information without nested type constraints."""
|
||||
"""Hold hotel reservation ID information without nested type constraints."""
|
||||
|
||||
res_id_type: str # Required field - pattern: [0-9]+
|
||||
res_id_value: None | str = None # Max 64 characters
|
||||
@@ -359,7 +362,7 @@ class HotelReservationIdFactory:
|
||||
def _create_hotel_reservation_id(
|
||||
hotel_reservation_id_class: type, data: HotelReservationIdData
|
||||
) -> Any:
|
||||
"""Internal method to create a hotel reservation id of the specified type."""
|
||||
"""Create a hotel reservation id of the specified type."""
|
||||
return hotel_reservation_id_class(
|
||||
res_id_type=data.res_id_type,
|
||||
res_id_value=data.res_id_value,
|
||||
@@ -538,7 +541,7 @@ class ResGuestFactory:
|
||||
def _create_res_guests(
|
||||
res_guests_class: type, customer_class: type, customer_data: CustomerData
|
||||
) -> Any:
|
||||
"""Internal method to create complete ResGuests structure."""
|
||||
"""Create the complete ResGuests structure."""
|
||||
# Create the customer using the existing CustomerFactory
|
||||
customer = CustomerFactory._create_customer(customer_class, customer_data)
|
||||
|
||||
@@ -712,8 +715,6 @@ def _process_single_reservation(
|
||||
reservation.num_adults, children_ages, message_type
|
||||
)
|
||||
|
||||
unique_id_string = reservation.unique_id
|
||||
|
||||
if message_type == OtaMessageType.NOTIF:
|
||||
UniqueId = NotifUniqueId
|
||||
RoomStays = NotifRoomStays
|
||||
@@ -727,8 +728,15 @@ def _process_single_reservation(
|
||||
else:
|
||||
raise ValueError("Unsupported message type: %s", message_type.value)
|
||||
|
||||
unique_id_str = reservation.unique_id
|
||||
|
||||
# TODO MAGIC shortening
|
||||
if len(unique_id_str) > 32:
|
||||
# strip to first 35 chars
|
||||
unique_id_str = unique_id_str[:32]
|
||||
|
||||
# UniqueID
|
||||
unique_id = UniqueId(type_value=UniqueIdType2.VALUE_14, id=unique_id_string)
|
||||
unique_id = UniqueId(type_value=UniqueIdType2.VALUE_14, id=unique_id_str)
|
||||
|
||||
# TimeSpan
|
||||
time_span = RoomStays.RoomStay.TimeSpan(
|
||||
@@ -744,52 +752,37 @@ def _process_single_reservation(
|
||||
)
|
||||
|
||||
res_id_source = "website"
|
||||
klick_id = None
|
||||
|
||||
if reservation.fbclid != "":
|
||||
klick_id = reservation.fbclid
|
||||
klick_id = str(reservation.fbclid)
|
||||
res_id_source = "meta"
|
||||
elif reservation.gclid != "":
|
||||
klick_id = reservation.gclid
|
||||
klick_id = str(reservation.gclid)
|
||||
res_id_source = "google"
|
||||
|
||||
# explicitly set klick_id to None otherwise an empty string will be sent
|
||||
if klick_id in (None, "", "None"):
|
||||
klick_id = None
|
||||
else: # extract string from Column object
|
||||
klick_id = str(klick_id)
|
||||
# Get utm_medium if available, otherwise use source
|
||||
if reservation.utm_medium is not None and str(reservation.utm_medium) != "":
|
||||
res_id_source = str(reservation.utm_medium)
|
||||
|
||||
hotel_res_id_data = HotelReservationIdData(
|
||||
# Use Pydantic model for automatic validation and truncation
|
||||
# It will automatically:
|
||||
# - Trim whitespace
|
||||
# - Truncate to 64 characters if needed
|
||||
# - Convert empty strings to None
|
||||
hotel_res_id_data_validated = HotelReservationIdDataValidated(
|
||||
res_id_type="13",
|
||||
res_id_value=klick_id,
|
||||
res_id_source=res_id_source,
|
||||
res_id_source_context="99tales",
|
||||
)
|
||||
|
||||
# explicitly set klick_id to None otherwise an empty string will be sent
|
||||
if klick_id in (None, "", "None"):
|
||||
klick_id = None
|
||||
else: # extract string from Column object
|
||||
klick_id = str(klick_id)
|
||||
|
||||
utm_medium = (
|
||||
str(reservation.utm_medium)
|
||||
if reservation.utm_medium is not None and str(reservation.utm_medium) != ""
|
||||
else "website"
|
||||
)
|
||||
|
||||
# shorten klick_id if longer than 64 characters
|
||||
# TODO MAGIC SHORTENING
|
||||
if klick_id is not None and len(klick_id) > 64:
|
||||
klick_id = klick_id[:64]
|
||||
|
||||
if klick_id == "":
|
||||
klick_id = None
|
||||
|
||||
# Convert back to dataclass for the factory
|
||||
hotel_res_id_data = HotelReservationIdData(
|
||||
res_id_type="13",
|
||||
res_id_value=klick_id,
|
||||
res_id_source=utm_medium,
|
||||
res_id_source_context="99tales",
|
||||
res_id_type=hotel_res_id_data_validated.res_id_type,
|
||||
res_id_value=hotel_res_id_data_validated.res_id_value,
|
||||
res_id_source=hotel_res_id_data_validated.res_id_source,
|
||||
res_id_source_context=hotel_res_id_data_validated.res_id_source_context,
|
||||
)
|
||||
|
||||
hotel_res_id = alpine_bits_factory.create(hotel_res_id_data, message_type)
|
||||
|
||||
@@ -405,11 +405,6 @@ async def process_wix_form_submission(request: Request, data: dict[str, Any], db
|
||||
|
||||
unique_id = data.get("submissionId", generate_unique_id())
|
||||
|
||||
# TODO MAGIC shortening
|
||||
if len(unique_id) > 32:
|
||||
# strip to first 35 chars
|
||||
unique_id = unique_id[:32]
|
||||
|
||||
# use database session
|
||||
|
||||
# Save all relevant data to DB (including new fields)
|
||||
|
||||
247
src/alpine_bits_python/schemas.py
Normal file
247
src/alpine_bits_python/schemas.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Pydantic models for data validation in AlpineBits.
|
||||
|
||||
These models provide validation for data before it's passed to:
|
||||
- SQLAlchemy database models
|
||||
- AlpineBits XML generation
|
||||
- API endpoints
|
||||
|
||||
Separating validation (Pydantic) from persistence (SQLAlchemy) and
|
||||
from XML generation (xsdata) follows clean architecture principles.
|
||||
"""
|
||||
|
||||
from datetime import date
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""Phone number with optional type."""
|
||||
|
||||
number: str = Field(..., min_length=1, max_length=50, pattern=r"^\+?[0-9\s\-()]+$")
|
||||
tech_type: str | None = Field(None, pattern="^[135]$") # 1=voice, 3=fax, 5=mobile
|
||||
|
||||
@field_validator("number")
|
||||
@classmethod
|
||||
def clean_phone_number(cls, v: str) -> str:
|
||||
"""Remove extra spaces from phone number."""
|
||||
return " ".join(v.split())
|
||||
|
||||
|
||||
class CustomerData(BaseModel):
|
||||
"""Validated customer data for creating reservations and guests."""
|
||||
|
||||
given_name: str = Field(..., min_length=1, max_length=100)
|
||||
surname: str = Field(..., min_length=1, max_length=100)
|
||||
name_prefix: str | None = Field(None, max_length=20)
|
||||
name_title: str | None = Field(None, max_length=20)
|
||||
phone_numbers: list[PhoneNumber] = Field(default_factory=list)
|
||||
email_address: EmailStr | None = None
|
||||
email_newsletter: bool | None = None
|
||||
address_line: str | None = Field(None, max_length=255)
|
||||
city_name: str | None = Field(None, max_length=100)
|
||||
postal_code: str | None = Field(None, max_length=20)
|
||||
country_code: str | None = Field(
|
||||
None, min_length=2, max_length=2, pattern="^[A-Z]{2}$"
|
||||
)
|
||||
address_catalog: bool | None = None
|
||||
gender: str | None = Field(None, pattern="^(Male|Female|Unknown)$")
|
||||
birth_date: str | None = Field(None, pattern=r"^\d{4}-\d{2}-\d{2}$") # ISO format
|
||||
language: str | None = Field(None, min_length=2, max_length=2, pattern="^[a-z]{2}$")
|
||||
|
||||
@field_validator("given_name", "surname")
|
||||
@classmethod
|
||||
def name_must_not_be_empty(cls, v: str) -> str:
|
||||
"""Ensure names are not just whitespace."""
|
||||
if not v.strip():
|
||||
raise ValueError("Name cannot be empty or whitespace")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("country_code")
|
||||
@classmethod
|
||||
def normalize_country_code(cls, v: str | None) -> str | None:
|
||||
"""Normalize country code to uppercase."""
|
||||
return v.upper() if v else None
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def normalize_language(cls, v: str | None) -> str | None:
|
||||
"""Normalize language code to lowercase."""
|
||||
return v.lower() if v else None
|
||||
|
||||
model_config = {"from_attributes": True} # Allow creation from ORM models
|
||||
|
||||
|
||||
class HotelReservationIdData(BaseModel):
|
||||
"""Validated hotel reservation ID data."""
|
||||
|
||||
res_id_type: str = Field(..., pattern=r"^[0-9]+$") # Must be numeric string
|
||||
res_id_value: str | None = None
|
||||
res_id_source: str | None = None
|
||||
res_id_source_context: str | None = None
|
||||
|
||||
@field_validator(
|
||||
"res_id_value", "res_id_source", "res_id_source_context", mode="before"
|
||||
)
|
||||
@classmethod
|
||||
def trim_and_truncate(cls, v: str | None) -> str | None:
|
||||
"""Trim whitespace and truncate to max length if needed.
|
||||
|
||||
Runs BEFORE field validation to ensure values are cleaned and truncated
|
||||
before max_length constraints are checked.
|
||||
"""
|
||||
if not v:
|
||||
return None
|
||||
# Convert to string if needed
|
||||
v = str(v)
|
||||
# Strip whitespace
|
||||
v = v.strip()
|
||||
# Convert empty strings to None
|
||||
if not v:
|
||||
return None
|
||||
# Truncate to 64 characters if needed
|
||||
if len(v) > 64:
|
||||
v = v[:64]
|
||||
return v
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CommentListItemData(BaseModel):
|
||||
"""Validated comment list item."""
|
||||
|
||||
value: str = Field(..., min_length=1, max_length=1000)
|
||||
list_item: str = Field(..., pattern=r"^[0-9]+$") # Numeric identifier
|
||||
language: str = Field(..., min_length=2, max_length=2, pattern=r"^[a-z]{2}$")
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def normalize_language(cls, v: str) -> str:
|
||||
"""Normalize language to lowercase."""
|
||||
return v.lower()
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CommentData(BaseModel):
|
||||
"""Validated comment data."""
|
||||
|
||||
name: str # Should be validated against CommentName2 enum
|
||||
text: str | None = Field(None, max_length=4000)
|
||||
list_items: list[CommentListItemData] = Field(default_factory=list)
|
||||
|
||||
@field_validator("list_items")
|
||||
@classmethod
|
||||
def validate_list_items(
|
||||
cls, v: list[CommentListItemData]
|
||||
) -> list[CommentListItemData]:
|
||||
"""Ensure list items have unique identifiers."""
|
||||
if v:
|
||||
item_ids = [item.list_item for item in v]
|
||||
if len(item_ids) != len(set(item_ids)):
|
||||
raise ValueError("List items must have unique identifiers")
|
||||
return v
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CommentsData(BaseModel):
|
||||
"""Validated comments collection."""
|
||||
|
||||
comments: list[CommentData] = Field(default_factory=list, max_length=3)
|
||||
|
||||
@field_validator("comments")
|
||||
@classmethod
|
||||
def validate_comment_count(cls, v: list[CommentData]) -> list[CommentData]:
|
||||
"""Ensure maximum 3 comments."""
|
||||
if len(v) > 3:
|
||||
raise ValueError("Maximum 3 comments allowed")
|
||||
return v
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ReservationData(BaseModel):
|
||||
"""Validated reservation data."""
|
||||
|
||||
unique_id: str = Field(..., min_length=1, max_length=35)
|
||||
start_date: date
|
||||
end_date: date
|
||||
num_adults: int = Field(..., ge=1, le=20)
|
||||
num_children: int = Field(0, ge=0, le=10)
|
||||
children_ages: list[int] = Field(default_factory=list)
|
||||
hotel_code: str = Field(..., min_length=1, max_length=50)
|
||||
hotel_name: str | None = Field(None, max_length=200)
|
||||
offer: str | None = Field(None, max_length=500)
|
||||
user_comment: str | None = Field(None, max_length=2000)
|
||||
fbclid: str | None = Field(None, max_length=100)
|
||||
gclid: str | None = Field(None, max_length=100)
|
||||
utm_source: str | None = Field(None, max_length=100)
|
||||
utm_medium: str | None = Field(None, max_length=100)
|
||||
utm_campaign: str | None = Field(None, max_length=100)
|
||||
utm_term: str | None = Field(None, max_length=100)
|
||||
utm_content: str | None = Field(None, max_length=100)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_dates(self) -> "ReservationData":
|
||||
"""Ensure end_date is after start_date."""
|
||||
if self.end_date <= self.start_date:
|
||||
raise ValueError("end_date must be after start_date")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_children_ages(self) -> "ReservationData":
|
||||
"""Ensure children_ages matches num_children."""
|
||||
if len(self.children_ages) != self.num_children:
|
||||
raise ValueError(
|
||||
f"Number of children ages ({len(self.children_ages)}) "
|
||||
f"must match num_children ({self.num_children})"
|
||||
)
|
||||
for age in self.children_ages:
|
||||
if age < 0 or age > 17:
|
||||
raise ValueError(f"Child age {age} must be between 0 and 17")
|
||||
return self
|
||||
|
||||
@field_validator("unique_id")
|
||||
@classmethod
|
||||
def validate_unique_id_length(cls, v: str) -> str:
|
||||
"""Ensure unique_id doesn't exceed max length."""
|
||||
if len(v) > 35:
|
||||
raise ValueError(f"unique_id length {len(v)} exceeds maximum of 35")
|
||||
return v
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# Example usage in a service layer
|
||||
class ReservationService:
|
||||
"""Example service showing how to use Pydantic models with SQLAlchemy."""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.db_session = db_session
|
||||
|
||||
async def create_reservation(
|
||||
self, reservation_data: ReservationData, customer_data: CustomerData
|
||||
):
|
||||
"""Create a reservation with validated data.
|
||||
|
||||
The data has already been validated by Pydantic before reaching here.
|
||||
"""
|
||||
from alpine_bits_python.db import Customer, Reservation
|
||||
|
||||
# Convert validated Pydantic model to SQLAlchemy model
|
||||
db_customer = Customer(**customer_data.model_dump(exclude_none=True))
|
||||
self.db_session.add(db_customer)
|
||||
await self.db_session.flush() # Get the customer ID
|
||||
|
||||
# Create reservation linked to customer
|
||||
db_reservation = Reservation(
|
||||
customer_id=db_customer.id,
|
||||
**reservation_data.model_dump(
|
||||
exclude={"children_ages"}
|
||||
), # Handle separately
|
||||
children_ages=",".join(map(str, reservation_data.children_ages)),
|
||||
)
|
||||
self.db_session.add(db_reservation)
|
||||
await self.db_session.commit()
|
||||
|
||||
return db_reservation, db_customer
|
||||
Reference in New Issue
Block a user