Refactored db logic. Can now specify schema in config

This commit is contained in:
Jonas Linter
2025-11-04 09:20:02 +01:00
parent e7b789fcac
commit eb10e070b1
6 changed files with 107 additions and 15 deletions

View File

@@ -1,3 +1,5 @@
This python project is managed by uv. Use uv run to execute app and tests. This python project is managed by uv. Use uv run to execute app and tests.
The Configuration is handled in a config.yaml file. The annotatedyaml library is used to load secrets. !secret SOME_SECRET in the yaml file refers to a secret definition in a secrets.yaml file The Configuration is handled in a config.yaml file. The annotatedyaml library is used to load secrets. !secret SOME_SECRET in the yaml file refers to a secret definition in a secrets.yaml file
When adding something to the config make sure to also add it to the voluptuos schema in config. If the config changes and there is an easy way to migrate an old config file do so. If its an addition then don't.

View File

@@ -3,7 +3,8 @@
database: database:
url: "sqlite+aiosqlite:///alpinebits.db" # For local dev, use SQLite. For prod, override with PostgreSQL URL. url: "sqlite+aiosqlite:///alpinebits.db" # For local dev, use SQLite. For prod, override with PostgreSQL URL.
# url: "postgresql://user:password@host:port/dbname" # Example for Postgres # url: "postgresql+asyncpg://user:password@host:port/dbname" # Example for Postgres
# schema: "alpinebits" # Optional: PostgreSQL schema name (default: public)
# AlpineBits Python config # AlpineBits Python config
# Use annotatedyaml for secrets and environment-specific overrides # Use annotatedyaml for secrets and environment-specific overrides

View File

@@ -5,10 +5,12 @@
database: database:
url: "postgresql+asyncpg://username:password@hostname:5432/database_name" url: "postgresql+asyncpg://username:password@hostname:5432/database_name"
# Example: "postgresql+asyncpg://alpinebits_user:your_password@localhost:5432/alpinebits" # Example: "postgresql+asyncpg://alpinebits_user:your_password@localhost:5432/alpinebits"
schema: "alpinebits" # Optional: PostgreSQL schema name (default: public)
# If using annotatedyaml secrets: # If using annotatedyaml secrets:
# database: # database:
# url: !secret POSTGRES_URL # url: !secret POSTGRES_URL
# schema: "alpinebits" # Optional: PostgreSQL schema name
# #
# Then in secrets.yaml: # Then in secrets.yaml:
# POSTGRES_URL: "postgresql+asyncpg://username:password@hostname:5432/database_name" # POSTGRES_URL: "postgresql+asyncpg://username:password@hostname:5432/database_name"

View File

@@ -219,9 +219,9 @@ class ServerCapabilities:
def _is_action_implemented(self, action_class: type[AlpineBitsAction]) -> bool: def _is_action_implemented(self, action_class: type[AlpineBitsAction]) -> bool:
"""Check if an action is actually implemented or just uses the default behavior. """Check if an action is actually implemented or just uses the default behavior.
This is a simple check - in practice, you might want more sophisticated detection. This is a simple check - in practice, you might want more sophisticated detection.
""" """
# Check if the class has overridden the handle method
return "handle" in action_class.__dict__ return "handle" in action_class.__dict__
def create_capabilities_dict(self) -> None: def create_capabilities_dict(self) -> None:

View File

@@ -27,7 +27,7 @@ from fastapi.security import (
from pydantic import BaseModel from pydantic import BaseModel
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker
from alpine_bits_python.schemas import ReservationData from alpine_bits_python.schemas import ReservationData
@@ -42,7 +42,7 @@ from .config_loader import load_config
from .const import CONF_GOOGLE_ACCOUNT, CONF_HOTEL_ID, CONF_META_ACCOUNT, HttpStatusCode from .const import CONF_GOOGLE_ACCOUNT, CONF_HOTEL_ID, CONF_META_ACCOUNT, HttpStatusCode
from .conversion_service import ConversionService from .conversion_service import ConversionService
from .customer_service import CustomerService from .customer_service import CustomerService
from .db import Base, get_database_url from .db import Base, create_database_engine
from .db import Customer as DBCustomer from .db import Customer as DBCustomer
from .db import Reservation as DBReservation from .db import Reservation as DBReservation
from .email_monitoring import ReservationStatsCollector from .email_monitoring import ReservationStatsCollector
@@ -287,8 +287,8 @@ async def lifespan(app: FastAPI):
) )
_LOGGER.info("Application startup initiated (primary_worker=%s)", is_primary) _LOGGER.info("Application startup initiated (primary_worker=%s)", is_primary)
DATABASE_URL = get_database_url(config) # Create database engine with schema support
engine = create_async_engine(DATABASE_URL, echo=False) engine = create_database_engine(config=config, echo=False)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
app.state.engine = engine app.state.engine = engine

View File

@@ -2,8 +2,13 @@ import hashlib
import os import os
from sqlalchemy import Boolean, Column, Date, DateTime, ForeignKey, Integer, String from sqlalchemy import Boolean, Column, Date, DateTime, ForeignKey, Integer, String
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.orm import declarative_base, relationship from sqlalchemy.orm import declarative_base, relationship
from .logging_config import get_logger
_LOGGER = get_logger(__name__)
Base = declarative_base() Base = declarative_base()
@@ -19,6 +24,74 @@ def get_database_url(config=None):
return db_url return db_url
def get_database_schema(config=None):
"""Get the PostgreSQL schema name from config.
Args:
config: Configuration dictionary
Returns:
Schema name string, or None if not configured
"""
if config and "database" in config and "schema" in config["database"]:
return config["database"]["schema"]
return os.environ.get("DATABASE_SCHEMA")
def configure_schema(schema_name=None):
"""Configure the database schema for all models.
This should be called before creating tables or running migrations.
For PostgreSQL, this sets the schema for all tables.
For other databases, this is a no-op.
Args:
schema_name: Name of the schema to use (e.g., "alpinebits")
"""
if schema_name:
# Update the schema for all tables in Base metadata
for table in Base.metadata.tables.values():
table.schema = schema_name
def create_database_engine(config=None, echo=False) -> AsyncEngine:
"""Create a configured database engine with schema support.
This function:
1. Gets the database URL from config
2. Gets the schema name (if configured)
3. Configures all models to use the schema
4. Creates the async engine with appropriate connect_args for PostgreSQL
Args:
config: Configuration dictionary
echo: Whether to echo SQL statements (default: False)
Returns:
Configured AsyncEngine instance
"""
database_url = get_database_url(config)
schema_name = get_database_schema(config)
# Configure schema for all models if specified
if schema_name:
configure_schema(schema_name)
_LOGGER.info("Configured database schema: %s", schema_name)
# Create engine with connect_args to set search_path for PostgreSQL
connect_args = {}
if schema_name and "postgresql" in database_url:
connect_args = {
"server_settings": {"search_path": f"{schema_name},public"}
}
_LOGGER.info("Setting PostgreSQL search_path to: %s,public", schema_name)
return create_async_engine(database_url, echo=echo, connect_args=connect_args)
class Customer(Base): class Customer(Base):
__tablename__ = "customers" __tablename__ = "customers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
@@ -48,9 +121,10 @@ class Customer(Base):
# Normalize: lowercase, strip whitespace # Normalize: lowercase, strip whitespace
normalized = str(value).lower().strip() normalized = str(value).lower().strip()
# Remove spaces for phone numbers # Remove spaces for phone numbers
is_phone = normalized.startswith("+") or normalized.replace( is_phone = (
"-", "" normalized.startswith("+")
).replace(" ", "").isdigit() or normalized.replace("-", "").replace(" ", "").isdigit()
)
if is_phone: if is_phone:
chars_to_remove = [" ", "-", "(", ")"] chars_to_remove = [" ", "-", "(", ")"]
for char in chars_to_remove: for char in chars_to_remove:
@@ -155,13 +229,18 @@ class Conversion(Base):
of a reservation stay. Linked to reservations via advertising tracking data of a reservation stay. Linked to reservations via advertising tracking data
(fbclid, gclid, etc) stored in advertisingCampagne field. (fbclid, gclid, etc) stored in advertisingCampagne field.
""" """
__tablename__ = "conversions" __tablename__ = "conversions"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
# Link to reservation (nullable since matching may not always work) # Link to reservation (nullable since matching may not always work)
reservation_id = Column(Integer, ForeignKey("reservations.id"), nullable=True, index=True) reservation_id = Column(
Integer, ForeignKey("reservations.id"), nullable=True, index=True
)
customer_id = Column(Integer, ForeignKey("customers.id"), nullable=True, index=True) customer_id = Column(Integer, ForeignKey("customers.id"), nullable=True, index=True)
hashed_customer_id = Column(Integer, ForeignKey("hashed_customers.id"), nullable=True, index=True) hashed_customer_id = Column(
Integer, ForeignKey("hashed_customers.id"), nullable=True, index=True
)
# Reservation metadata from XML # Reservation metadata from XML
hotel_id = Column(String, index=True) # hotelID attribute hotel_id = Column(String, index=True) # hotelID attribute
@@ -173,9 +252,15 @@ class Conversion(Base):
booking_channel = Column(String) # bookingChannel attribute booking_channel = Column(String) # bookingChannel attribute
# Advertising/tracking data - used for matching to existing reservations # Advertising/tracking data - used for matching to existing reservations
advertising_medium = Column(String, index=True) # advertisingMedium (e.g., "99TALES") advertising_medium = Column(
advertising_partner = Column(String, index=True) # advertisingPartner (e.g., "cpc", "website") String, index=True
advertising_campagne = Column(String, index=True) # advertisingCampagne (contains fbclid/gclid) ) # advertisingMedium (e.g., "99TALES")
advertising_partner = Column(
String, index=True
) # advertisingPartner (e.g., "cpc", "website")
advertising_campagne = Column(
String, index=True
) # advertisingCampagne (contains fbclid/gclid)
# Room reservation details # Room reservation details
arrival_date = Column(Date) arrival_date = Column(Date)
@@ -188,7 +273,9 @@ class Conversion(Base):
# Daily sales data (one row per day) # Daily sales data (one row per day)
sale_date = Column(Date, index=True) # date attribute from dailySale sale_date = Column(Date, index=True) # date attribute from dailySale
revenue_total = Column(String) # revenueTotal - keeping as string to preserve decimals revenue_total = Column(
String
) # revenueTotal - keeping as string to preserve decimals
revenue_logis = Column(String) # revenueLogis (accommodation) revenue_logis = Column(String) # revenueLogis (accommodation)
revenue_board = Column(String) # revenueBoard (meal plan) revenue_board = Column(String) # revenueBoard (meal plan)
revenue_fb = Column(String) # revenueFB (food & beverage) revenue_fb = Column(String) # revenueFB (food & beverage)