diff --git a/src/alpine_bits_python/api.py b/src/alpine_bits_python/api.py index 4f84fce..b9d409c 100644 --- a/src/alpine_bits_python/api.py +++ b/src/alpine_bits_python/api.py @@ -70,6 +70,86 @@ security_bearer = HTTPBearer() # Constants for token sanitization TOKEN_LOG_LENGTH = 10 +# Country name to ISO 3166-1 alpha-2 code mapping +COUNTRY_NAME_TO_CODE = { + # English names + "germany": "DE", + "italy": "IT", + "austria": "AT", + "switzerland": "CH", + "france": "FR", + "netherlands": "NL", + "belgium": "BE", + "spain": "ES", + "portugal": "PT", + "united kingdom": "GB", + "uk": "GB", + "czech republic": "CZ", + "poland": "PL", + "hungary": "HU", + "croatia": "HR", + "slovenia": "SI", + # German names + "deutschland": "DE", + "italien": "IT", + "österreich": "AT", + "schweiz": "CH", + "frankreich": "FR", + "niederlande": "NL", + "belgien": "BE", + "spanien": "ES", + "vereinigtes königreich": "GB", + "tschechien": "CZ", + "polen": "PL", + "ungarn": "HU", + "kroatien": "HR", + "slowenien": "SI", + # Italian names + "germania": "DE", + "italia": "IT", + "svizzera": "CH", + "francia": "FR", + "paesi bassi": "NL", + "belgio": "BE", + "spagna": "ES", + "portogallo": "PT", + "regno unito": "GB", + "repubblica ceca": "CZ", + "polonia": "PL", + "ungheria": "HU", + "croazia": "HR", +} + + +def normalize_country_input(country_input: str | None) -> str | None: + """Normalize country input to ISO 3166-1 alpha-2 code. + + Handles: + - Country names in English, German, and Italian + - Already valid 2-letter codes (case-insensitive) + - None/empty values + + Args: + country_input: Country name or code (case-insensitive) + + Returns: + 2-letter ISO country code (uppercase) or None if input is None/empty + + """ + if not country_input: + return None + + country_input = country_input.strip() + + # If already 2 letters, assume it's a country code (ISO 3166-1 alpha-2) + iso_country_code_length = 2 + if len(country_input) == iso_country_code_length and country_input.isalpha(): + return country_input.upper() + + # Try to match as country name (case-insensitive) + country_lower = country_input.lower() + return COUNTRY_NAME_TO_CODE.get(country_lower, country_input) + # Pydantic models for language detection class LanguageDetectionRequest(BaseModel): @@ -738,6 +818,9 @@ async def process_generic_webhook_submission( city = form_data.get("stadt", "") country = form_data.get("land", "") + # Normalize country input (convert names to codes, handle case) + country = normalize_country_input(country) + # Parse dates - handle DD.MM.YYYY format start_date_str = form_data.get("anreise") end_date_str = form_data.get("abreise") diff --git a/src/alpine_bits_python/customer_service.py b/src/alpine_bits_python/customer_service.py index 03243b7..7614e75 100644 --- a/src/alpine_bits_python/customer_service.py +++ b/src/alpine_bits_python/customer_service.py @@ -1,12 +1,16 @@ """Customer service layer for handling customer and hashed customer operations.""" from datetime import UTC, datetime -from typing import Optional +from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from .db import Customer, HashedCustomer +from .logging_config import get_logger +from .schemas import CustomerData + +_LOGGER = get_logger(__name__) class CustomerService: @@ -27,9 +31,27 @@ class CustomerService: Returns: The created Customer instance (with hashed_version relationship populated) + + Raises: + ValidationError: If customer_data fails validation + (e.g., invalid country code) + """ - # Create the customer - customer = Customer(**customer_data) + # Validate customer data through Pydantic model + validated_data = CustomerData(**customer_data) + + # Create the customer with validated data + # Exclude 'phone_numbers' as Customer model uses 'phone' field + customer = Customer( + **validated_data.model_dump(exclude_none=True, exclude={"phone_numbers"}) + ) + + # Set fields not in CustomerData model separately + if "contact_id" in customer_data: + customer.contact_id = customer_data["contact_id"] + if "phone" in customer_data: + customer.phone = customer_data["phone"] + self.session.add(customer) await self.session.flush() # Flush to get the customer.id @@ -43,9 +65,7 @@ class CustomerService: return customer - async def update_customer( - self, customer: Customer, update_data: dict - ) -> Customer: + async def update_customer(self, customer: Customer, update_data: dict) -> Customer: """Update an existing customer and sync its hashed version. Args: @@ -54,17 +74,62 @@ class CustomerService: Returns: The updated Customer instance + + Raises: + ValidationError: If update_data fails validation + (e.g., invalid country code) + """ - # Update customer fields - for key, value in update_data.items(): + # Validate update data through Pydantic model + # We need to merge with existing data for validation + existing_data = { + "given_name": customer.given_name, + "surname": customer.surname, + "name_prefix": customer.name_prefix, + "email_address": customer.email_address, + "phone": customer.phone, + "email_newsletter": customer.email_newsletter, + "address_line": customer.address_line, + "city_name": customer.city_name, + "postal_code": customer.postal_code, + "country_code": customer.country_code, + "gender": customer.gender, + "birth_date": customer.birth_date, + "language": customer.language, + "address_catalog": customer.address_catalog, + "name_title": customer.name_title, + } + # Merge update_data into existing_data (only CustomerData fields) + # Filter to include only fields that exist in CustomerData model + customer_data_fields = set(CustomerData.model_fields.keys()) + # Include 'phone' field (maps to CustomerData) + existing_data.update( + { + k: v + for k, v in update_data.items() + if k in customer_data_fields or k == "phone" + } + ) + + # Validate merged data + validated_data = CustomerData(**existing_data) + + # Update customer fields with validated data + # Exclude 'phone_numbers' as Customer model uses 'phone' field + # Note: We don't use exclude_none=True to allow setting fields to None + for key, value in validated_data.model_dump(exclude={"phone_numbers"}).items(): if hasattr(customer, key): setattr(customer, key, value) + # Update fields not in CustomerData model separately + if "contact_id" in update_data: + customer.contact_id = update_data["contact_id"] + if "phone" in update_data: + customer.phone = update_data["phone"] + # Update or create hashed version result = await self.session.execute( - select(HashedCustomer).where( - HashedCustomer.customer_id == customer.id - ) + select(HashedCustomer).where(HashedCustomer.customer_id == customer.id) ) hashed_customer = result.scalar_one_or_none() @@ -91,9 +156,7 @@ class CustomerService: return customer - async def get_customer_by_contact_id( - self, contact_id: str - ) -> Optional[Customer]: + async def get_customer_by_contact_id(self, contact_id: str) -> Customer | None: """Get a customer by contact_id. Args: @@ -101,6 +164,7 @@ class CustomerService: Returns: Customer instance if found, None otherwise + """ result = await self.session.execute( select(Customer).where(Customer.contact_id == contact_id) @@ -118,6 +182,7 @@ class CustomerService: Returns: Existing or newly created Customer instance + """ contact_id = customer_data.get("contact_id") @@ -130,9 +195,7 @@ class CustomerService: # Create new customer (either no contact_id or customer doesn't exist) return await self.create_customer(customer_data) - async def get_hashed_customer( - self, customer_id: int - ) -> Optional[HashedCustomer]: + async def get_hashed_customer(self, customer_id: int) -> HashedCustomer | None: """Get the hashed version of a customer. Args: @@ -140,11 +203,10 @@ class CustomerService: Returns: HashedCustomer instance if found, None otherwise + """ result = await self.session.execute( - select(HashedCustomer).where( - HashedCustomer.customer_id == customer_id - ) + select(HashedCustomer).where(HashedCustomer.customer_id == customer_id) ) return result.scalar_one_or_none() @@ -154,25 +216,79 @@ class CustomerService: This is useful for backfilling hashed data for customers created before the hashing system was implemented. + Also validates and sanitizes customer data (e.g., normalizes country + codes to uppercase). Customers with invalid data that cannot be fixed + will be skipped and logged. + Returns: Number of customers that were hashed + """ # Get all customers result = await self.session.execute(select(Customer)) customers = result.scalars().all() hashed_count = 0 + skipped_count = 0 + for customer in customers: # Check if this customer already has a hashed version existing_hashed = await self.get_hashed_customer(customer.id) if not existing_hashed: - # Create hashed version - hashed_customer = customer.create_hashed_customer() - hashed_customer.created_at = datetime.now(UTC) - self.session.add(hashed_customer) - hashed_count += 1 + # Validate and sanitize customer data before hashing + customer_dict = { + "given_name": customer.given_name, + "surname": customer.surname, + "name_prefix": customer.name_prefix, + "email_address": customer.email_address, + "phone": customer.phone, + "email_newsletter": customer.email_newsletter, + "address_line": customer.address_line, + "city_name": customer.city_name, + "postal_code": customer.postal_code, + "country_code": customer.country_code, + "gender": customer.gender, + "birth_date": customer.birth_date, + "language": customer.language, + "address_catalog": customer.address_catalog, + "name_title": customer.name_title, + } + + try: + # Validate through Pydantic (normalizes country code) + validated = CustomerData(**customer_dict) + + # Update customer with sanitized data + # Exclude 'phone_numbers' as Customer model uses 'phone' field + for key, value in validated.model_dump( + exclude_none=True, exclude={"phone_numbers"} + ).items(): + if hasattr(customer, key): + setattr(customer, key, value) + + # Create hashed version with sanitized data + hashed_customer = customer.create_hashed_customer() + hashed_customer.created_at = datetime.now(UTC) + self.session.add(hashed_customer) + hashed_count += 1 + + except ValidationError as e: + # Skip customers with invalid data and log + skipped_count += 1 + _LOGGER.warning( + "Skipping customer ID %s due to validation error: %s", + customer.id, + e, + ) if hashed_count > 0: await self.session.commit() + if skipped_count > 0: + _LOGGER.warning( + "Skipped %d customers with invalid data. " + "Please fix these customers manually.", + skipped_count, + ) + return hashed_count diff --git a/src/alpine_bits_python/schemas.py b/src/alpine_bits_python/schemas.py index 0ee730c..d3eebe2 100644 --- a/src/alpine_bits_python/schemas.py +++ b/src/alpine_bits_python/schemas.py @@ -103,9 +103,7 @@ class CustomerData(BaseModel): address_line: str | None = Field(None, max_length=255) city_name: str | None = Field(None, max_length=100) postal_code: str | None = Field(None, max_length=20) - country_code: str | None = Field( - None, min_length=2, max_length=2, pattern="^[A-Z]{2}$" - ) + country_code: str | None = Field(None, min_length=2, max_length=2) address_catalog: bool | None = None gender: str | None = Field(None, pattern="^(Male|Female|Unknown)$") birth_date: str | None = Field(None, pattern=r"^\d{4}-\d{2}-\d{2}$") # ISO format @@ -119,11 +117,29 @@ class CustomerData(BaseModel): raise ValueError("Name cannot be empty or whitespace") return v.strip() - @field_validator("country_code") + @field_validator("country_code", mode="before") @classmethod def normalize_country_code(cls, v: str | None) -> str | None: - """Normalize country code to uppercase.""" - return v.upper() if v else None + """Normalize country code to uppercase and validate format. + + Runs in 'before' mode to normalize before other validations. + Accepts 2-letter country codes (case-insensitive) and normalizes + to uppercase ISO 3166-1 alpha-2 format. + """ + if v is None or v == "": + return None + + # Convert to string and strip whitespace + v = str(v).strip().upper() + + # Validate it's exactly 2 letters + if len(v) != 2 or not v.isalpha(): + raise ValueError( + f"Country code must be exactly 2 letters (ISO 3166-1 alpha-2), " + f"got '{v}'" + ) + + return v @field_validator("language") @classmethod