Refactored db logic. Can now specify schema in config

This commit is contained in:
Jonas Linter
2025-11-04 09:20:02 +01:00
parent e7b789fcac
commit eb10e070b1
6 changed files with 107 additions and 15 deletions

View File

@@ -219,9 +219,9 @@ class ServerCapabilities:
def _is_action_implemented(self, action_class: type[AlpineBitsAction]) -> bool:
"""Check if an action is actually implemented or just uses the default behavior.
This is a simple check - in practice, you might want more sophisticated detection.
"""
# Check if the class has overridden the handle method
return "handle" in action_class.__dict__
def create_capabilities_dict(self) -> None:

View File

@@ -27,7 +27,7 @@ from fastapi.security import (
from pydantic import BaseModel
from slowapi.errors import RateLimitExceeded
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio import async_sessionmaker
from alpine_bits_python.schemas import ReservationData
@@ -42,7 +42,7 @@ from .config_loader import load_config
from .const import CONF_GOOGLE_ACCOUNT, CONF_HOTEL_ID, CONF_META_ACCOUNT, HttpStatusCode
from .conversion_service import ConversionService
from .customer_service import CustomerService
from .db import Base, get_database_url
from .db import Base, create_database_engine
from .db import Customer as DBCustomer
from .db import Reservation as DBReservation
from .email_monitoring import ReservationStatsCollector
@@ -287,8 +287,8 @@ async def lifespan(app: FastAPI):
)
_LOGGER.info("Application startup initiated (primary_worker=%s)", is_primary)
DATABASE_URL = get_database_url(config)
engine = create_async_engine(DATABASE_URL, echo=False)
# Create database engine with schema support
engine = create_database_engine(config=config, echo=False)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
app.state.engine = engine

View File

@@ -2,8 +2,13 @@ import hashlib
import os
from sqlalchemy import Boolean, Column, Date, DateTime, ForeignKey, Integer, String
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.orm import declarative_base, relationship
from .logging_config import get_logger
_LOGGER = get_logger(__name__)
Base = declarative_base()
@@ -19,6 +24,74 @@ def get_database_url(config=None):
return db_url
def get_database_schema(config=None):
"""Get the PostgreSQL schema name from config.
Args:
config: Configuration dictionary
Returns:
Schema name string, or None if not configured
"""
if config and "database" in config and "schema" in config["database"]:
return config["database"]["schema"]
return os.environ.get("DATABASE_SCHEMA")
def configure_schema(schema_name=None):
"""Configure the database schema for all models.
This should be called before creating tables or running migrations.
For PostgreSQL, this sets the schema for all tables.
For other databases, this is a no-op.
Args:
schema_name: Name of the schema to use (e.g., "alpinebits")
"""
if schema_name:
# Update the schema for all tables in Base metadata
for table in Base.metadata.tables.values():
table.schema = schema_name
def create_database_engine(config=None, echo=False) -> AsyncEngine:
"""Create a configured database engine with schema support.
This function:
1. Gets the database URL from config
2. Gets the schema name (if configured)
3. Configures all models to use the schema
4. Creates the async engine with appropriate connect_args for PostgreSQL
Args:
config: Configuration dictionary
echo: Whether to echo SQL statements (default: False)
Returns:
Configured AsyncEngine instance
"""
database_url = get_database_url(config)
schema_name = get_database_schema(config)
# Configure schema for all models if specified
if schema_name:
configure_schema(schema_name)
_LOGGER.info("Configured database schema: %s", schema_name)
# Create engine with connect_args to set search_path for PostgreSQL
connect_args = {}
if schema_name and "postgresql" in database_url:
connect_args = {
"server_settings": {"search_path": f"{schema_name},public"}
}
_LOGGER.info("Setting PostgreSQL search_path to: %s,public", schema_name)
return create_async_engine(database_url, echo=echo, connect_args=connect_args)
class Customer(Base):
__tablename__ = "customers"
id = Column(Integer, primary_key=True)
@@ -48,9 +121,10 @@ class Customer(Base):
# Normalize: lowercase, strip whitespace
normalized = str(value).lower().strip()
# Remove spaces for phone numbers
is_phone = normalized.startswith("+") or normalized.replace(
"-", ""
).replace(" ", "").isdigit()
is_phone = (
normalized.startswith("+")
or normalized.replace("-", "").replace(" ", "").isdigit()
)
if is_phone:
chars_to_remove = [" ", "-", "(", ")"]
for char in chars_to_remove:
@@ -155,13 +229,18 @@ class Conversion(Base):
of a reservation stay. Linked to reservations via advertising tracking data
(fbclid, gclid, etc) stored in advertisingCampagne field.
"""
__tablename__ = "conversions"
id = Column(Integer, primary_key=True)
# Link to reservation (nullable since matching may not always work)
reservation_id = Column(Integer, ForeignKey("reservations.id"), nullable=True, index=True)
reservation_id = Column(
Integer, ForeignKey("reservations.id"), nullable=True, index=True
)
customer_id = Column(Integer, ForeignKey("customers.id"), nullable=True, index=True)
hashed_customer_id = Column(Integer, ForeignKey("hashed_customers.id"), nullable=True, index=True)
hashed_customer_id = Column(
Integer, ForeignKey("hashed_customers.id"), nullable=True, index=True
)
# Reservation metadata from XML
hotel_id = Column(String, index=True) # hotelID attribute
@@ -173,9 +252,15 @@ class Conversion(Base):
booking_channel = Column(String) # bookingChannel attribute
# Advertising/tracking data - used for matching to existing reservations
advertising_medium = Column(String, index=True) # advertisingMedium (e.g., "99TALES")
advertising_partner = Column(String, index=True) # advertisingPartner (e.g., "cpc", "website")
advertising_campagne = Column(String, index=True) # advertisingCampagne (contains fbclid/gclid)
advertising_medium = Column(
String, index=True
) # advertisingMedium (e.g., "99TALES")
advertising_partner = Column(
String, index=True
) # advertisingPartner (e.g., "cpc", "website")
advertising_campagne = Column(
String, index=True
) # advertisingCampagne (contains fbclid/gclid)
# Room reservation details
arrival_date = Column(Date)
@@ -188,7 +273,9 @@ class Conversion(Base):
# Daily sales data (one row per day)
sale_date = Column(Date, index=True) # date attribute from dailySale
revenue_total = Column(String) # revenueTotal - keeping as string to preserve decimals
revenue_total = Column(
String
) # revenueTotal - keeping as string to preserve decimals
revenue_logis = Column(String) # revenueLogis (accommodation)
revenue_board = Column(String) # revenueBoard (meal plan)
revenue_fb = Column(String) # revenueFB (food & beverage)