255 lines
9.0 KiB
Python
255 lines
9.0 KiB
Python
"""Pydantic models for data validation in AlpineBits.
|
|
|
|
These models provide validation for data before it's passed to:
|
|
- SQLAlchemy database models
|
|
- AlpineBits XML generation
|
|
- API endpoints
|
|
|
|
Separating validation (Pydantic) from persistence (SQLAlchemy) and
|
|
from XML generation (xsdata) follows clean architecture principles.
|
|
"""
|
|
|
|
import hashlib
|
|
from datetime import date, datetime
|
|
from enum import Enum
|
|
|
|
from pydantic import BaseModel, EmailStr, Field, field_validator, model_validator
|
|
|
|
|
|
# phonetechtype enum 1,3,5 voice, fax, mobile
|
|
class PhoneTechType(Enum):
|
|
VOICE = "1"
|
|
FAX = "3"
|
|
MOBILE = "5"
|
|
|
|
|
|
class PhoneNumber(BaseModel):
|
|
"""Phone number with optional type."""
|
|
|
|
number: str = Field(..., min_length=1, max_length=50, pattern=r"^\+?[0-9\s\-()]+$")
|
|
tech_type: str | None = Field(None, pattern="^[135]$") # 1=voice, 3=fax, 5=mobile
|
|
|
|
@field_validator("number")
|
|
@classmethod
|
|
def clean_phone_number(cls, v: str) -> str:
|
|
"""Remove extra spaces from phone number."""
|
|
return " ".join(v.split())
|
|
|
|
|
|
class ReservationData(BaseModel):
|
|
"""Validated reservation data."""
|
|
|
|
unique_id: str = Field(..., min_length=1, max_length=200)
|
|
md5_unique_id: str | None = Field(None, min_length=1, max_length=32)
|
|
start_date: date
|
|
end_date: date
|
|
created_at: datetime = Field(default_factory=datetime.now)
|
|
num_adults: int = Field(..., ge=1)
|
|
num_children: int = Field(0, ge=0, le=10)
|
|
children_ages: list[int] = Field(default_factory=list)
|
|
hotel_code: str = Field(..., min_length=1, max_length=50)
|
|
hotel_name: str | None = Field(None, max_length=200)
|
|
offer: str | None = Field(None, max_length=500)
|
|
user_comment: str | None = Field(None, max_length=2000)
|
|
fbclid: str | None = Field(None, max_length=300)
|
|
gclid: str | None = Field(None, max_length=300)
|
|
utm_source: str | None = Field(None, max_length=150)
|
|
utm_medium: str | None = Field(None, max_length=150)
|
|
utm_campaign: str | None = Field(None, max_length=150)
|
|
utm_term: str | None = Field(None, max_length=150)
|
|
utm_content: str | None = Field(None, max_length=150)
|
|
|
|
@model_validator(mode="after")
|
|
def ensure_md5(self) -> "ReservationData":
|
|
"""Ensure md5_unique_id is set after model validation.
|
|
|
|
Using a model_validator in 'after' mode lets us access all fields via
|
|
the instance and set md5_unique_id in-place when it wasn't provided.
|
|
"""
|
|
if not getattr(self, "md5_unique_id", None) and getattr(
|
|
self, "unique_id", None
|
|
):
|
|
self.md5_unique_id = hashlib.md5(self.unique_id.encode("utf-8")).hexdigest()
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def validate_children_ages(self) -> "ReservationData":
|
|
"""Ensure children_ages matches num_children."""
|
|
if len(self.children_ages) != self.num_children:
|
|
raise ValueError(
|
|
f"Number of children ages ({len(self.children_ages)}) "
|
|
f"must match num_children ({self.num_children})"
|
|
)
|
|
for age in self.children_ages:
|
|
if age < 0 or age > 17:
|
|
raise ValueError(f"Child age {age} must be between 0 and 17")
|
|
return self
|
|
|
|
|
|
class CustomerData(BaseModel):
|
|
"""Validated customer data for creating reservations and guests."""
|
|
|
|
given_name: str = Field(..., min_length=1, max_length=100)
|
|
surname: str = Field(..., min_length=1, max_length=100)
|
|
name_prefix: str | None = Field(None, max_length=20)
|
|
name_title: str | None = Field(None, max_length=20)
|
|
phone_numbers: list[tuple[str, None | PhoneTechType]] = Field(default_factory=list)
|
|
email_address: EmailStr | None = None
|
|
email_newsletter: bool | None = None
|
|
address_line: str | None = Field(None, max_length=255)
|
|
city_name: str | None = Field(None, max_length=100)
|
|
postal_code: str | None = Field(None, max_length=20)
|
|
country_code: str | None = Field(
|
|
None, min_length=2, max_length=2, pattern="^[A-Z]{2}$"
|
|
)
|
|
address_catalog: bool | None = None
|
|
gender: str | None = Field(None, pattern="^(Male|Female|Unknown)$")
|
|
birth_date: str | None = Field(None, pattern=r"^\d{4}-\d{2}-\d{2}$") # ISO format
|
|
language: str | None = Field(None, min_length=2, max_length=2, pattern="^[a-z]{2}$")
|
|
|
|
@field_validator("given_name", "surname")
|
|
@classmethod
|
|
def name_must_not_be_empty(cls, v: str) -> str:
|
|
"""Ensure names are not just whitespace."""
|
|
if not v.strip():
|
|
raise ValueError("Name cannot be empty or whitespace")
|
|
return v.strip()
|
|
|
|
@field_validator("country_code")
|
|
@classmethod
|
|
def normalize_country_code(cls, v: str | None) -> str | None:
|
|
"""Normalize country code to uppercase."""
|
|
return v.upper() if v else None
|
|
|
|
@field_validator("language")
|
|
@classmethod
|
|
def normalize_language(cls, v: str | None) -> str | None:
|
|
"""Normalize language code to lowercase."""
|
|
return v.lower() if v else None
|
|
|
|
model_config = {"from_attributes": True} # Allow creation from ORM models
|
|
|
|
|
|
class HotelReservationIdData(BaseModel):
|
|
"""Validated hotel reservation ID data."""
|
|
|
|
res_id_type: str = Field(..., pattern=r"^[0-9]+$") # Must be numeric string
|
|
res_id_value: str | None = Field(None, min_length=1, max_length=64)
|
|
res_id_source: str | None = Field(None, min_length=1, max_length=64)
|
|
res_id_source_context: str | None = Field(None, min_length=1, max_length=64)
|
|
|
|
@field_validator(
|
|
"res_id_value", "res_id_source", "res_id_source_context", mode="before"
|
|
)
|
|
@classmethod
|
|
def trim_and_truncate(cls, v: str | None) -> str | None:
|
|
"""Trim whitespace and truncate to max length if needed.
|
|
|
|
Runs BEFORE field validation to ensure values are cleaned and truncated
|
|
before max_length constraints are checked.
|
|
"""
|
|
if not v:
|
|
return None
|
|
# Convert to string if needed
|
|
v = str(v)
|
|
# Strip whitespace
|
|
v = v.strip()
|
|
# Convert empty strings to None
|
|
if not v:
|
|
return None
|
|
# Truncate to 64 characters if needed
|
|
if len(v) > 64:
|
|
v = v[:64]
|
|
return v
|
|
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class CommentListItemData(BaseModel):
|
|
"""Validated comment list item."""
|
|
|
|
value: str = Field(..., min_length=1, max_length=1000)
|
|
list_item: str = Field(..., pattern=r"^[0-9]+$") # Numeric identifier
|
|
language: str = Field(..., min_length=2, max_length=2, pattern=r"^[a-z]{2}$")
|
|
|
|
@field_validator("language")
|
|
@classmethod
|
|
def normalize_language(cls, v: str) -> str:
|
|
"""Normalize language to lowercase."""
|
|
return v.lower()
|
|
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class CommentData(BaseModel):
|
|
"""Validated comment data."""
|
|
|
|
name: str # Should be validated against CommentName2 enum
|
|
text: str | None = Field(None, max_length=4000)
|
|
list_items: list[CommentListItemData] = Field(default_factory=list)
|
|
|
|
@field_validator("list_items")
|
|
@classmethod
|
|
def validate_list_items(
|
|
cls, v: list[CommentListItemData]
|
|
) -> list[CommentListItemData]:
|
|
"""Ensure list items have unique identifiers."""
|
|
if v:
|
|
item_ids = [item.list_item for item in v]
|
|
if len(item_ids) != len(set(item_ids)):
|
|
raise ValueError("List items must have unique identifiers")
|
|
return v
|
|
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class CommentsData(BaseModel):
|
|
"""Validated comments collection."""
|
|
|
|
comments: list[CommentData] = Field(default_factory=list, max_length=3)
|
|
|
|
@field_validator("comments")
|
|
@classmethod
|
|
def validate_comment_count(cls, v: list[CommentData]) -> list[CommentData]:
|
|
"""Ensure maximum 3 comments."""
|
|
if len(v) > 3:
|
|
raise ValueError("Maximum 3 comments allowed")
|
|
return v
|
|
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
# Example usage in a service layer
|
|
class ReservationService:
|
|
"""Example service showing how to use Pydantic models with SQLAlchemy."""
|
|
|
|
def __init__(self, db_session):
|
|
self.db_session = db_session
|
|
|
|
async def create_reservation(
|
|
self, reservation_data: ReservationData, customer_data: CustomerData
|
|
):
|
|
"""Create a reservation with validated data.
|
|
|
|
The data has already been validated by Pydantic before reaching here.
|
|
"""
|
|
from alpine_bits_python.db import Customer, Reservation
|
|
|
|
# Convert validated Pydantic model to SQLAlchemy model
|
|
db_customer = Customer(**customer_data.model_dump(exclude_none=True))
|
|
self.db_session.add(db_customer)
|
|
await self.db_session.flush() # Get the customer ID
|
|
|
|
# Create reservation linked to customer
|
|
db_reservation = Reservation(
|
|
customer_id=db_customer.id,
|
|
**reservation_data.model_dump(
|
|
exclude={"children_ages"}
|
|
), # Handle separately
|
|
children_ages=",".join(map(str, reservation_data.children_ages)),
|
|
)
|
|
self.db_session.add(db_reservation)
|
|
await self.db_session.commit()
|
|
|
|
return db_reservation, db_customer
|