diff --git a/.gitignore b/.gitignore index 7fc53e2..e05a15e 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,9 @@ logs/* # ignore secrets secrets.yaml +# ignore PostgreSQL config (contains credentials) +config/postgres.yaml + # ignore db alpinebits.db diff --git a/alpinebits.log b/alpinebits.log index 94a6c66..46220ea 100644 --- a/alpinebits.log +++ b/alpinebits.log @@ -14073,3 +14073,41 @@ IndexError: list index out of range 2025-10-15 08:52:56 - root - INFO - Logging to file: alpinebits.log 2025-10-15 08:52:56 - root - INFO - Logging configured at INFO level 2025-10-15 08:52:58 - alpine_bits_python.email_service - INFO - Email service initialized: smtp.titan.email:465 +2025-10-16 16:15:42 - root - INFO - Logging to file: alpinebits.log +2025-10-16 16:15:42 - root - INFO - Logging configured at INFO level +2025-10-16 16:15:42 - alpine_bits_python.email_monitoring - INFO - DailyReportScheduler initialized: send_time=08:00, recipients=[] +2025-10-16 16:15:42 - root - INFO - Daily report scheduler configured for Pushover (primary worker) +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Application startup initiated (primary_worker=True) +2025-10-16 16:15:42 - alpine_bits_python.alpinebits_server - INFO - Initializing action instance for AlpineBitsActionName.OTA_HOTEL_NOTIF_REPORT +2025-10-16 16:15:42 - alpine_bits_python.alpinebits_server - INFO - Initializing action instance for AlpineBitsActionName.OTA_PING +2025-10-16 16:15:42 - alpine_bits_python.alpinebits_server - INFO - Initializing action instance for AlpineBitsActionName.OTA_HOTEL_RES_NOTIF_GUEST_REQUESTS +2025-10-16 16:15:42 - alpine_bits_python.alpinebits_server - INFO - Initializing action instance for AlpineBitsActionName.OTA_READ +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Hotel 39054_001 has no push_endpoint configured +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Hotel 135 has no push_endpoint configured +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Hotel 39052_001 has no push_endpoint configured +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Hotel 39040_001 has no push_endpoint configured +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Starting database migrations... +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Running migration: add_room_types +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Adding column reservations.room_type_code (VARCHAR) +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Successfully added column reservations.room_type_code +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Adding column reservations.room_classification_code (VARCHAR) +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Successfully added column reservations.room_classification_code +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Adding column reservations.room_type (VARCHAR) +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Successfully added column reservations.room_type +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Migration add_room_types: Added 3 columns +2025-10-16 16:15:42 - alpine_bits_python.migrations - INFO - Database migrations completed successfully +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Database tables checked/created at startup. +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - All existing customers already have hashed data +2025-10-16 16:15:42 - alpine_bits_python.email_monitoring - INFO - ReservationStatsCollector initialized with 4 hotels +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Stats collector initialized and hooked up to report scheduler +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Sending test daily report on startup (last 24 hours) +2025-10-16 16:15:42 - alpine_bits_python.email_monitoring - INFO - Collecting reservation stats from 2025-10-15 16:15:42 to 2025-10-16 16:15:42 +2025-10-16 16:15:42 - alpine_bits_python.email_monitoring - INFO - Collected stats: 9 total reservations across 1 hotels +2025-10-16 16:15:42 - alpine_bits_python.email_service - WARNING - No recipients specified for email: AlpineBits Daily Report - 2025-10-16 +2025-10-16 16:15:42 - alpine_bits_python.api - ERROR - Failed to send test daily report via email on startup +2025-10-16 16:15:42 - alpine_bits_python.pushover_service - INFO - Pushover notification sent successfully: AlpineBits Daily Report - 2025-10-16 +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Test daily report sent via Pushover successfully on startup +2025-10-16 16:15:42 - alpine_bits_python.email_monitoring - INFO - Daily report scheduler started +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Daily report scheduler started +2025-10-16 16:15:42 - alpine_bits_python.api - INFO - Application startup complete +2025-10-16 16:15:42 - alpine_bits_python.email_monitoring - INFO - Next daily report scheduled for 2025-10-17 08:00:00 (in 15.7 hours) diff --git a/config/config.yaml b/config/config.yaml index 7dbc4c9..24af8d7 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -41,7 +41,7 @@ alpine_bits_auth: api_tokens: - tLTI8wXF1OVEvUX7kdZRhSW3Qr5feBCz0mHo-kbnEp0 -# Email configuration for monitoring and alerts +# Email configuration (SMTP service config - kept for when port is unblocked) email: # SMTP server configuration smtp: @@ -56,62 +56,41 @@ email: from_address: "info@99tales.net" # Sender address from_name: "AlpineBits Monitor" # Sender display name - # Monitoring and alerting - monitoring: - # Daily report configuration - daily_report: - enabled: false # Set to true to enable daily reports - recipients: - - "jonas@vaius.ai" - #- "dev@99tales.com" - send_time: "08:00" # Time to send daily report (24h format, local time) - include_stats: true # Include reservation/customer stats - include_errors: true # Include error summary - - # Error alert configuration (hybrid approach) - error_alerts: - enabled: false # Set to true to enable error alerts - recipients: - - "jonas@vaius.ai" - #- "oncall@99tales.com" - # Alert is sent immediately if threshold is reached - error_threshold: 5 # Send immediate alert after N errors - # Otherwise, alert is sent after buffer time expires - buffer_minutes: 15 # Wait N minutes before sending buffered errors - # Cooldown period to prevent alert spam - cooldown_minutes: 15 # Wait N min before sending another alert - # Error severity levels to monitor - log_levels: - - "ERROR" - - "CRITICAL" - -# Pushover configuration for push notifications (alternative to email) +# Pushover configuration (push notification service config) pushover: # Pushover API credentials (get from https://pushover.net) user_key: !secret PUSHOVER_USER_KEY # Your user/group key api_token: !secret PUSHOVER_API_TOKEN # Your application API token - # Monitoring and alerting (same structure as email) - monitoring: - # Daily report configuration - daily_report: - enabled: true # Set to true to enable daily reports - send_time: "08:00" # Time to send daily report (24h format, local time) - include_stats: true # Include reservation/customer stats - include_errors: true # Include error summary - priority: 0 # Pushover priority: -2=lowest, -1=low, 0=normal, 1=high, 2=emergency +# Unified notification system - recipient-based routing +notifications: + # Recipients and their preferred notification methods + recipients: + - name: "jonas" + methods: + # Uncomment email when port is unblocked + #- type: "email" + # address: "jonas@vaius.ai" + - type: "pushover" + priority: 1 # Pushover priority: -2=lowest, -1=low, 0=normal, 1=high, 2=emergency - # Error alert configuration (hybrid approach) - error_alerts: - enabled: true # Set to true to enable error alerts - # Alert is sent immediately if threshold is reached - error_threshold: 5 # Send immediate alert after N errors - # Otherwise, alert is sent after buffer time expires - buffer_minutes: 15 # Wait N minutes before sending buffered errors - # Cooldown period to prevent alert spam - cooldown_minutes: 15 # Wait N min before sending another alert - # Error severity levels to monitor - log_levels: - - "ERROR" - - "CRITICAL" - priority: 1 # Pushover priority: -2=lowest, -1=low, 0=normal, 1=high, 2=emergency + # Daily report configuration (applies to all recipients) + daily_report: + enabled: true # Set to true to enable daily reports + send_time: "08:00" # Time to send daily report (24h format, local time) + include_stats: true # Include reservation/customer stats + include_errors: true # Include error summary + + # Error alert configuration (applies to all recipients) + error_alerts: + enabled: true # Set to true to enable error alerts + # Alert is sent immediately if threshold is reached + error_threshold: 5 # Send immediate alert after N errors + # Otherwise, alert is sent after buffer time expires + buffer_minutes: 15 # Wait N minutes before sending buffered errors + # Cooldown period to prevent alert spam + cooldown_minutes: 15 # Wait N min before sending another alert + # Error severity levels to monitor + log_levels: + - "ERROR" + - "CRITICAL" diff --git a/config/postgres.yaml.example b/config/postgres.yaml.example new file mode 100644 index 0000000..d746353 --- /dev/null +++ b/config/postgres.yaml.example @@ -0,0 +1,14 @@ +# PostgreSQL configuration for migration +# Copy this file to postgres.yaml and fill in your PostgreSQL credentials +# This file should NOT be committed to git (add postgres.yaml to .gitignore) + +database: + url: "postgresql+asyncpg://username:password@hostname:5432/database_name" + # Example: "postgresql+asyncpg://alpinebits_user:your_password@localhost:5432/alpinebits" + +# If using annotatedyaml secrets: +# database: +# url: !secret POSTGRES_URL +# +# Then in secrets.yaml: +# POSTGRES_URL: "postgresql+asyncpg://username:password@hostname:5432/database_name" diff --git a/pyproject.toml b/pyproject.toml index 19b289e..acb50c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ requires-python = ">=3.13" dependencies = [ "aiosqlite>=0.21.0", "annotatedyaml>=1.0.0", + "asyncpg>=0.30.0", "dotenv>=0.9.9", "fast-langdetect>=1.0.0", "fastapi>=0.117.1", diff --git a/src/alpine_bits_python/alpine_bits_helpers.py b/src/alpine_bits_python/alpine_bits_helpers.py index 0a15b04..cc32d6f 100644 --- a/src/alpine_bits_python/alpine_bits_helpers.py +++ b/src/alpine_bits_python/alpine_bits_helpers.py @@ -25,6 +25,7 @@ from .generated.alpinebits import ( OtaHotelResNotifRq, OtaResRetrieveRs, ProfileProfileType, + RoomTypeRoomType, UniqueIdType2, ) @@ -76,6 +77,13 @@ RetrieveRoomStays = OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays NotifHotelReservation = OtaHotelResNotifRq.HotelReservations.HotelReservation RetrieveHotelReservation = OtaResRetrieveRs.ReservationsList.HotelReservation +NotifRoomTypes = ( + OtaHotelResNotifRq.HotelReservations.HotelReservation.RoomStays.RoomStay.RoomTypes +) +RetrieveRoomTypes = ( + OtaResRetrieveRs.ReservationsList.HotelReservation.RoomStays.RoomStay.RoomTypes +) + from .const import RESERVATION_ID_TYPE @@ -697,9 +705,29 @@ def _process_single_reservation( start=reservation.start_date.isoformat() if reservation.start_date else None, end=reservation.end_date.isoformat() if reservation.end_date else None, ) + + # RoomTypes (optional) - only create if at least one field is present + room_types = None + if any([reservation.room_type_code, reservation.room_classification_code, reservation.room_type]): + # Convert room_type string to enum if present + room_type_enum = None + if reservation.room_type: + room_type_enum = RoomTypeRoomType(reservation.room_type) + + # Create RoomType instance + room_type_obj = RoomStays.RoomStay.RoomTypes.RoomType( + room_type_code=reservation.room_type_code, + room_classification_code=reservation.room_classification_code, + room_type=room_type_enum, + ) + + # Create RoomTypes container + room_types = RoomStays.RoomStay.RoomTypes(room_type=room_type_obj) + room_stay = RoomStays.RoomStay( time_span=time_span, guest_counts=guest_counts, + room_types=room_types, ) room_stays = RoomStays( room_stay=[room_stay], diff --git a/src/alpine_bits_python/alpinebits_server.py b/src/alpine_bits_python/alpinebits_server.py index 80c07e3..5fe494a 100644 --- a/src/alpine_bits_python/alpinebits_server.py +++ b/src/alpine_bits_python/alpinebits_server.py @@ -11,7 +11,7 @@ import re from abc import ABC from dataclasses import dataclass from datetime import datetime -from enum import Enum, IntEnum +from enum import Enum from typing import Any, Optional, override from xsdata.formats.dataclass.serializers.config import SerializerConfig @@ -23,6 +23,7 @@ from alpine_bits_python.alpine_bits_helpers import ( ) from alpine_bits_python.logging_config import get_logger +from .const import HttpStatusCode from .db import Customer, Reservation from .generated.alpinebits import ( OtaNotifReportRq, @@ -38,15 +39,6 @@ from .reservation_service import ReservationService _LOGGER = get_logger(__name__) -class HttpStatusCode(IntEnum): - """Allowed HTTP status codes for AlpineBits responses.""" - - OK = 200 - BAD_REQUEST = 400 - UNAUTHORIZED = 401 - INTERNAL_SERVER_ERROR = 500 - - def dump_json_for_xml(json_content: Any) -> str: """Dump JSON content as a pretty-printed string for embedding in XML. diff --git a/src/alpine_bits_python/api.py b/src/alpine_bits_python/api.py index f210d7d..4f84fce 100644 --- a/src/alpine_bits_python/api.py +++ b/src/alpine_bits_python/api.py @@ -1,3 +1,5 @@ +"""API endpoints for the form-data and the alpinebits server.""" + import asyncio import gzip import json @@ -36,6 +38,7 @@ from .alpinebits_server import ( ) from .auth import generate_unique_id, validate_api_key from .config_loader import load_config +from .const import HttpStatusCode from .customer_service import CustomerService from .db import Base, get_database_url from .db import Customer as DBCustomer @@ -43,8 +46,7 @@ from .db import Reservation as DBReservation from .email_monitoring import ReservationStatsCollector from .email_service import create_email_service from .logging_config import get_logger, setup_logging -from .notification_adapters import EmailNotificationAdapter, PushoverNotificationAdapter -from .notification_service import NotificationService +from .migrations import run_all_migrations from .pushover_service import create_pushover_service from .rate_limit import ( BURST_RATE_LIMIT, @@ -81,6 +83,8 @@ class LanguageDetectionResponse(BaseModel): # --- Enhanced event dispatcher with hotel-specific routing --- class EventDispatcher: + """Simple event dispatcher for AlpineBits push requests.""" + def __init__(self): self.listeners = defaultdict(list) self.hotel_listeners = defaultdict(list) # hotel_code -> list of listeners @@ -148,7 +152,7 @@ async def push_listener(customer: DBCustomer, reservation: DBReservation, hotel) version=Version.V2024_10, ) - if request.status_code != 200: + if request.status_code != HttpStatusCode.OK: _LOGGER.error( "Failed to generate push request for hotel %s, reservation %s: %s", hotel_id, @@ -235,9 +239,9 @@ async def lifespan(app: FastAPI): # Initialize pushover service pushover_service = create_pushover_service(config) - # Setup logging from config with email and pushover monitoring + # Setup logging from config with unified notification monitoring # Only primary worker should have the report scheduler running - email_handler, report_scheduler = setup_logging( + alert_handler, report_scheduler = setup_logging( config, email_service, pushover_service, loop, enable_scheduler=is_primary ) _LOGGER.info("Application startup initiated (primary_worker=%s)", is_primary) @@ -253,7 +257,7 @@ async def lifespan(app: FastAPI): app.state.event_dispatcher = event_dispatcher app.state.email_service = email_service app.state.pushover_service = pushover_service - app.state.email_handler = email_handler + app.state.alert_handler = alert_handler app.state.report_scheduler = report_scheduler # Register push listeners for hotels with push_endpoint @@ -276,11 +280,18 @@ async def lifespan(app: FastAPI): elif hotel_id and not push_endpoint: _LOGGER.info("Hotel %s has no push_endpoint configured", hotel_id) - # Create tables + # Create tables first (all workers) + # This ensures tables exist before migrations try to alter them async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) _LOGGER.info("Database tables checked/created at startup.") + # Run migrations after tables exist (only primary worker for race conditions) + if is_primary: + await run_all_migrations(engine) + else: + _LOGGER.info("Skipping migrations (non-primary worker)") + # Hash any existing customers (only in primary worker to avoid race conditions) if is_primary: async with AsyncSessionLocal() as session: @@ -306,44 +317,6 @@ async def lifespan(app: FastAPI): report_scheduler.set_stats_collector(stats_collector.collect_stats) _LOGGER.info("Stats collector initialized and hooked up to report scheduler") - # Send a test daily report on startup for testing (with 24-hour lookback) - _LOGGER.info("Sending test daily report on startup (last 24 hours)") - try: - # Use lookback_hours=24 to get stats from last 24 hours - stats = await stats_collector.collect_stats(lookback_hours=24) - - # Send via email (if configured) - if email_service: - success = await email_service.send_daily_report( - recipients=report_scheduler.recipients, - stats=stats, - errors=None, - ) - if success: - _LOGGER.info("Test daily report sent via email successfully on startup") - else: - _LOGGER.error("Failed to send test daily report via email on startup") - - # Send via Pushover (if configured) - if pushover_service: - pushover_config = config.get("pushover", {}) - pushover_monitoring = pushover_config.get("monitoring", {}) - pushover_daily_report = pushover_monitoring.get("daily_report", {}) - priority = pushover_daily_report.get("priority", 0) - - success = await pushover_service.send_daily_report( - stats=stats, - errors=None, - priority=priority, - ) - if success: - _LOGGER.info("Test daily report sent via Pushover successfully on startup") - else: - _LOGGER.error("Failed to send test daily report via Pushover on startup") - - except Exception: - _LOGGER.exception("Error sending test daily report on startup") - # Start daily report scheduler report_scheduler.start() _LOGGER.info("Daily report scheduler started") @@ -360,10 +333,10 @@ async def lifespan(app: FastAPI): report_scheduler.stop() _LOGGER.info("Daily report scheduler stopped") - # Close email alert handler (flush any remaining errors) - if email_handler: - email_handler.close() - _LOGGER.info("Email alert handler closed") + # Close alert handler (flush any remaining errors) + if alert_handler: + alert_handler.close() + _LOGGER.info("Alert handler closed") # Shutdown email service thread pool if email_service: diff --git a/src/alpine_bits_python/config_loader.py b/src/alpine_bits_python/config_loader.py index 6a79854..3f7e1da 100644 --- a/src/alpine_bits_python/config_loader.py +++ b/src/alpine_bits_python/config_loader.py @@ -192,14 +192,69 @@ pushover_schema = Schema( extra=PREVENT_EXTRA, ) +# Unified notification method schema +notification_method_schema = Schema( + { + Required("type"): In(["email", "pushover"]), + Optional("address"): str, # For email + Optional("priority"): Range(min=-2, max=2), # For pushover + }, + extra=PREVENT_EXTRA, +) + +# Unified notification recipient schema +notification_recipient_schema = Schema( + { + Required("name"): str, + Required("methods"): [notification_method_schema], + }, + extra=PREVENT_EXTRA, +) + +# Unified daily report configuration schema (without recipients) +unified_daily_report_schema = Schema( + { + Required("enabled", default=False): Boolean(), + Required("send_time", default="08:00"): str, + Required("include_stats", default=True): Boolean(), + Required("include_errors", default=True): Boolean(), + }, + extra=PREVENT_EXTRA, +) + +# Unified error alerts configuration schema (without recipients) +unified_error_alerts_schema = Schema( + { + Required("enabled", default=False): Boolean(), + Required("error_threshold", default=5): Range(min=1), + Required("buffer_minutes", default=15): Range(min=1), + Required("cooldown_minutes", default=15): Range(min=0), + Required("log_levels", default=["ERROR", "CRITICAL"]): [ + In(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) + ], + }, + extra=PREVENT_EXTRA, +) + +# Unified notifications configuration schema +notifications_schema = Schema( + { + Required("recipients", default=[]): [notification_recipient_schema], + Optional("daily_report", default={}): unified_daily_report_schema, + Optional("error_alerts", default={}): unified_error_alerts_schema, + }, + extra=PREVENT_EXTRA, +) + config_schema = Schema( { Required(CONF_DATABASE): database_schema, Required(CONF_ALPINE_BITS_AUTH): basic_auth_schema, Required(CONF_SERVER): server_info, Required(CONF_LOGGING): logger_schema, - Optional("email"): email_schema, # Email is optional - Optional("pushover"): pushover_schema, # Pushover is optional + Optional("email"): email_schema, # Email is optional (service config only) + Optional("pushover"): pushover_schema, # Pushover is optional (service config only) + Optional("notifications"): notifications_schema, # Unified notification config Optional("api_tokens", default=[]): [str], # API tokens for bearer auth }, extra=PREVENT_EXTRA, diff --git a/src/alpine_bits_python/const.py b/src/alpine_bits_python/const.py index 7e4d2ce..cf939d6 100644 --- a/src/alpine_bits_python/const.py +++ b/src/alpine_bits_python/const.py @@ -1,5 +1,16 @@ +from enum import IntEnum from typing import Final + +class HttpStatusCode(IntEnum): + """Allowed HTTP status codes for AlpineBits responses.""" + + OK = 200 + BAD_REQUEST = 400 + UNAUTHORIZED = 401 + INTERNAL_SERVER_ERROR = 500 + + RESERVATION_ID_TYPE: str = ( "13" # Default reservation ID type for Reservation. 14 would be cancellation ) diff --git a/src/alpine_bits_python/db.py b/src/alpine_bits_python/db.py index 0bc9a8e..9b504c4 100644 --- a/src/alpine_bits_python/db.py +++ b/src/alpine_bits_python/db.py @@ -97,7 +97,7 @@ class HashedCustomer(Base): hashed_country_code = Column(String(64)) hashed_gender = Column(String(64)) hashed_birth_date = Column(String(64)) - created_at = Column(DateTime) + created_at = Column(DateTime(timezone=True)) customer = relationship("Customer", backref="hashed_version") @@ -114,7 +114,7 @@ class Reservation(Base): num_children = Column(Integer) children_ages = Column(String) # comma-separated offer = Column(String) - created_at = Column(DateTime) + created_at = Column(DateTime(timezone=True)) # Add all UTM fields and user comment for XML utm_source = Column(String) utm_medium = Column(String) @@ -127,6 +127,10 @@ class Reservation(Base): # Add hotel_code and hotel_name for XML hotel_code = Column(String) hotel_name = Column(String) + # RoomTypes fields (optional) + room_type_code = Column(String) + room_classification_code = Column(String) + room_type = Column(String) customer = relationship("Customer", back_populates="reservations") @@ -138,4 +142,4 @@ class AckedRequest(Base): unique_id = Column( String, index=True ) # Should match Reservation.form_id or another unique field - timestamp = Column(DateTime) + timestamp = Column(DateTime(timezone=True)) diff --git a/src/alpine_bits_python/logging_config.py b/src/alpine_bits_python/logging_config.py index 565a9a2..4ecfc31 100644 --- a/src/alpine_bits_python/logging_config.py +++ b/src/alpine_bits_python/logging_config.py @@ -25,7 +25,7 @@ def setup_logging( pushover_service: "PushoverService | None" = None, loop: asyncio.AbstractEventLoop | None = None, enable_scheduler: bool = True, -) -> tuple["EmailAlertHandler | None", "DailyReportScheduler | None"]: +) -> tuple[logging.Handler | None, object | None]: """Configure logging based on application config. Args: @@ -37,7 +37,7 @@ def setup_logging( (should be False for non-primary workers) Returns: - Tuple of (email_alert_handler, daily_report_scheduler) if monitoring + Tuple of (alert_handler, daily_report_scheduler) if monitoring is enabled, otherwise (None, None) Logger config format: @@ -92,88 +92,67 @@ def setup_logging( root_logger.info("Logging configured at %s level", level) - # Setup notification monitoring if configured - email_handler = None + # Setup unified notification monitoring if configured + alert_handler = None report_scheduler = None - # Setup email monitoring if configured - if email_service: - email_config = config.get("email", {}) - monitoring_config = email_config.get("monitoring", {}) - - # Setup error alert handler - error_alerts_config = monitoring_config.get("error_alerts", {}) - if error_alerts_config.get("enabled", False): - try: - # Import here to avoid circular dependencies - from alpine_bits_python.email_monitoring import EmailAlertHandler - - email_handler = EmailAlertHandler( - email_service=email_service, - config=error_alerts_config, - loop=loop, - ) - email_handler.setLevel(logging.ERROR) - root_logger.addHandler(email_handler) - root_logger.info("Email alert handler enabled for error monitoring") - except Exception: - root_logger.exception("Failed to setup email alert handler") - - # Setup daily report scheduler (only if enabled and this is primary worker) - daily_report_config = monitoring_config.get("daily_report", {}) - if daily_report_config.get("enabled", False) and enable_scheduler: - try: - # Import here to avoid circular dependencies - from alpine_bits_python.email_monitoring import DailyReportScheduler - - report_scheduler = DailyReportScheduler( - email_service=email_service, - config=daily_report_config, - ) - root_logger.info("Daily report scheduler configured (primary worker)") - except Exception: - root_logger.exception("Failed to setup daily report scheduler") - elif daily_report_config.get("enabled", False) and not enable_scheduler: - root_logger.info( - "Daily report scheduler disabled (non-primary worker)" + # Check if unified notifications are configured + notifications_config = config.get("notifications", {}) + if notifications_config and (email_service or pushover_service): + try: + # Import here to avoid circular dependencies + from alpine_bits_python.notification_manager import ( + get_notification_config, + setup_notification_service, + ) + from alpine_bits_python.unified_monitoring import ( + UnifiedAlertHandler, + UnifiedDailyReportScheduler, ) - # Check if Pushover daily reports are enabled - # If so and no report_scheduler exists yet, create one - if pushover_service and not report_scheduler: - pushover_config = config.get("pushover", {}) - pushover_monitoring = pushover_config.get("monitoring", {}) - pushover_daily_report = pushover_monitoring.get("daily_report", {}) - - if pushover_daily_report.get("enabled", False) and enable_scheduler: - try: - # Import here to avoid circular dependencies - from alpine_bits_python.email_monitoring import DailyReportScheduler - - # Create a dummy config for the scheduler - # (it doesn't need email-specific fields if email is disabled) - scheduler_config = { - "send_time": pushover_daily_report.get("send_time", "08:00"), - "include_stats": pushover_daily_report.get("include_stats", True), - "include_errors": pushover_daily_report.get("include_errors", True), - "recipients": [], # Not used for Pushover - } - - report_scheduler = DailyReportScheduler( - email_service=email_service, # Can be None - config=scheduler_config, - ) - root_logger.info( - "Daily report scheduler configured for Pushover (primary worker)" - ) - except Exception: - root_logger.exception("Failed to setup Pushover daily report scheduler") - elif pushover_daily_report.get("enabled", False) and not enable_scheduler: - root_logger.info( - "Pushover daily report scheduler disabled (non-primary worker)" + # Setup unified notification service + notification_service = setup_notification_service( + config=config, + email_service=email_service, + pushover_service=pushover_service, ) - return email_handler, report_scheduler + if notification_service: + # Setup error alert handler + error_alerts_config = get_notification_config("error_alerts", config) + if error_alerts_config.get("enabled", False): + try: + alert_handler = UnifiedAlertHandler( + notification_service=notification_service, + config=error_alerts_config, + loop=loop, + ) + alert_handler.setLevel(logging.ERROR) + root_logger.addHandler(alert_handler) + root_logger.info("Unified alert handler enabled for error monitoring") + except Exception: + root_logger.exception("Failed to setup unified alert handler") + + # Setup daily report scheduler (only if enabled and this is primary worker) + daily_report_config = get_notification_config("daily_report", config) + if daily_report_config.get("enabled", False) and enable_scheduler: + try: + report_scheduler = UnifiedDailyReportScheduler( + notification_service=notification_service, + config=daily_report_config, + ) + root_logger.info("Unified daily report scheduler configured (primary worker)") + except Exception: + root_logger.exception("Failed to setup unified daily report scheduler") + elif daily_report_config.get("enabled", False) and not enable_scheduler: + root_logger.info( + "Unified daily report scheduler disabled (non-primary worker)" + ) + + except Exception: + root_logger.exception("Failed to setup unified notification monitoring") + + return alert_handler, report_scheduler def get_logger(name: str) -> logging.Logger: diff --git a/src/alpine_bits_python/migrations.py b/src/alpine_bits_python/migrations.py new file mode 100644 index 0000000..5702e17 --- /dev/null +++ b/src/alpine_bits_python/migrations.py @@ -0,0 +1,115 @@ +"""Database migrations for AlpineBits. + +This module contains migration functions that are automatically run at app startup +to update existing database schemas without losing data. +""" + +from sqlalchemy import inspect, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .logging_config import get_logger + +_LOGGER = get_logger(__name__) + + +async def check_column_exists(engine: AsyncEngine, table_name: str, column_name: str) -> bool: + """Check if a column exists in a table. + + Args: + engine: SQLAlchemy async engine + table_name: Name of the table to check + column_name: Name of the column to check + + Returns: + True if column exists, False otherwise + """ + async with engine.connect() as conn: + def _check(connection): + inspector = inspect(connection) + columns = [col['name'] for col in inspector.get_columns(table_name)] + return column_name in columns + + result = await conn.run_sync(_check) + return result + + +async def add_column_if_not_exists( + engine: AsyncEngine, + table_name: str, + column_name: str, + column_type: str = "VARCHAR" +) -> bool: + """Add a column to a table if it doesn't already exist. + + Args: + engine: SQLAlchemy async engine + table_name: Name of the table + column_name: Name of the column to add + column_type: SQL type of the column (default: VARCHAR) + + Returns: + True if column was added, False if it already existed + """ + exists = await check_column_exists(engine, table_name, column_name) + + if exists: + _LOGGER.debug("Column %s.%s already exists, skipping", table_name, column_name) + return False + + _LOGGER.info("Adding column %s.%s (%s)", table_name, column_name, column_type) + + async with engine.begin() as conn: + sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" + await conn.execute(text(sql)) + + _LOGGER.info("Successfully added column %s.%s", table_name, column_name) + return True + + +async def migrate_add_room_types(engine: AsyncEngine) -> None: + """Migration: Add RoomTypes fields to reservations table. + + This migration adds three optional fields: + - room_type_code: String (max 8 chars) + - room_classification_code: String (numeric pattern) + - room_type: String (enum: 1-5) + + Safe to run multiple times - will skip if columns already exist. + """ + _LOGGER.info("Running migration: add_room_types") + + added_count = 0 + + # Add each column if it doesn't exist + if await add_column_if_not_exists(engine, "reservations", "room_type_code", "VARCHAR"): + added_count += 1 + + if await add_column_if_not_exists(engine, "reservations", "room_classification_code", "VARCHAR"): + added_count += 1 + + if await add_column_if_not_exists(engine, "reservations", "room_type", "VARCHAR"): + added_count += 1 + + if added_count > 0: + _LOGGER.info("Migration add_room_types: Added %d columns", added_count) + else: + _LOGGER.info("Migration add_room_types: No changes needed (already applied)") + + +async def run_all_migrations(engine: AsyncEngine) -> None: + """Run all pending migrations. + + This function should be called at app startup, after Base.metadata.create_all. + Each migration function should be idempotent (safe to run multiple times). + """ + _LOGGER.info("Starting database migrations...") + + try: + # Add new migrations here in chronological order + await migrate_add_room_types(engine) + + _LOGGER.info("Database migrations completed successfully") + + except Exception as e: + _LOGGER.exception("Migration failed: %s", e) + raise diff --git a/src/alpine_bits_python/notification_manager.py b/src/alpine_bits_python/notification_manager.py new file mode 100644 index 0000000..090328f --- /dev/null +++ b/src/alpine_bits_python/notification_manager.py @@ -0,0 +1,156 @@ +"""Unified notification manager for setting up recipient-based notification routing. + +This module provides helpers to initialize the unified notification system +based on the recipients configuration. +""" + +from typing import Any + +from .email_service import EmailService +from .logging_config import get_logger +from .notification_adapters import EmailNotificationAdapter, PushoverNotificationAdapter +from .notification_service import NotificationService +from .pushover_service import PushoverService + +_LOGGER = get_logger(__name__) + + +def setup_notification_service( + config: dict[str, Any], + email_service: EmailService | None = None, + pushover_service: PushoverService | None = None, +) -> NotificationService | None: + """Set up unified notification service from config. + + Args: + config: Full configuration dictionary + email_service: Optional EmailService instance + pushover_service: Optional PushoverService instance + + Returns: + NotificationService instance, or None if no recipients configured + + """ + notifications_config = config.get("notifications", {}) + recipients = notifications_config.get("recipients", []) + + if not recipients: + _LOGGER.info("No notification recipients configured") + return None + + notification_service = NotificationService() + + # Process each recipient and their methods + for recipient in recipients: + recipient_name = recipient.get("name", "unknown") + methods = recipient.get("methods", []) + + for method in methods: + method_type = method.get("type") + + if method_type == "email": + if not email_service: + _LOGGER.warning( + "Email method configured for %s but email service not available", + recipient_name, + ) + continue + + email_address = method.get("address") + if not email_address: + _LOGGER.warning( + "Email method for %s missing address", recipient_name + ) + continue + + # Create a unique backend name for this recipient's email + backend_name = f"email_{recipient_name}" + + # Check if we already have an email backend + if not notification_service.has_backend("email"): + # Create email adapter with all email recipients + email_recipients = [] + for r in recipients: + for m in r.get("methods", []): + if m.get("type") == "email" and m.get("address"): + email_recipients.append(m.get("address")) + + if email_recipients: + email_adapter = EmailNotificationAdapter( + email_service, email_recipients + ) + notification_service.register_backend("email", email_adapter) + _LOGGER.info( + "Registered email backend with %d recipient(s)", + len(email_recipients), + ) + + elif method_type == "pushover": + if not pushover_service: + _LOGGER.warning( + "Pushover method configured for %s but pushover service not available", + recipient_name, + ) + continue + + priority = method.get("priority", 0) + + # Check if we already have a pushover backend + if not notification_service.has_backend("pushover"): + # Pushover sends to user_key configured in pushover service + pushover_adapter = PushoverNotificationAdapter( + pushover_service, priority + ) + notification_service.register_backend("pushover", pushover_adapter) + _LOGGER.info("Registered pushover backend with priority %d", priority) + + if not notification_service.backends: + _LOGGER.warning("No notification backends could be configured") + return None + + _LOGGER.info( + "Notification service configured with backends: %s", + list(notification_service.backends.keys()), + ) + return notification_service + + +def get_enabled_backends( + notification_type: str, config: dict[str, Any] +) -> list[str] | None: + """Get list of enabled backends for a notification type. + + Args: + notification_type: "daily_report" or "error_alerts" + config: Full configuration dictionary + + Returns: + List of backend names to use, or None for all backends + + """ + notifications_config = config.get("notifications", {}) + notification_config = notifications_config.get(notification_type, {}) + + if not notification_config.get("enabled", False): + return [] + + # Return None to indicate all backends should be used + # The NotificationService will send to all registered backends + return None + + +def get_notification_config( + notification_type: str, config: dict[str, Any] +) -> dict[str, Any]: + """Get configuration for a specific notification type. + + Args: + notification_type: "daily_report" or "error_alerts" + config: Full configuration dictionary + + Returns: + Configuration dictionary for the notification type + + """ + notifications_config = config.get("notifications", {}) + return notifications_config.get(notification_type, {}) diff --git a/src/alpine_bits_python/schemas.py b/src/alpine_bits_python/schemas.py index 4affec1..0ee730c 100644 --- a/src/alpine_bits_python/schemas.py +++ b/src/alpine_bits_python/schemas.py @@ -58,6 +58,10 @@ class ReservationData(BaseModel): utm_campaign: str | None = Field(None, max_length=150) utm_term: str | None = Field(None, max_length=150) utm_content: str | None = Field(None, max_length=150) + # RoomTypes fields (optional) + room_type_code: str | None = Field(None, min_length=1, max_length=8) + room_classification_code: str | None = Field(None, pattern=r"[0-9]+") + room_type: str | None = Field(None, pattern=r"^[1-5]$") @model_validator(mode="after") def ensure_md5(self) -> "ReservationData": diff --git a/src/alpine_bits_python/unified_monitoring.py b/src/alpine_bits_python/unified_monitoring.py new file mode 100644 index 0000000..7add00a --- /dev/null +++ b/src/alpine_bits_python/unified_monitoring.py @@ -0,0 +1,390 @@ +"""Unified monitoring with support for multiple notification backends. + +This module provides alert handlers and schedulers that work with the +unified notification service to send alerts through multiple channels. +""" + +import asyncio +import logging +import threading +from collections import deque +from datetime import datetime, timedelta +from typing import Any + +from .email_monitoring import ErrorRecord, ReservationStatsCollector +from .logging_config import get_logger +from .notification_service import NotificationService + +_LOGGER = get_logger(__name__) + + +class UnifiedAlertHandler(logging.Handler): + """Custom logging handler that sends alerts through unified notification service. + + This handler uses a hybrid approach: + - Accumulates errors in a buffer + - Sends immediately if error threshold is reached + - Otherwise sends after buffer duration expires + - Always sends buffered errors (no minimum threshold for time-based flush) + - Implements cooldown to prevent alert spam + + The handler is thread-safe and works with asyncio event loops. + """ + + def __init__( + self, + notification_service: NotificationService, + config: dict[str, Any], + loop: asyncio.AbstractEventLoop | None = None, + ): + """Initialize the unified alert handler. + + Args: + notification_service: Unified notification service + config: Configuration dictionary for error alerts + loop: Asyncio event loop (will use current loop if not provided) + + """ + super().__init__() + self.notification_service = notification_service + self.config = config + self.loop = loop # Will be set when first error occurs if not provided + + # Configuration + self.error_threshold = config.get("error_threshold", 5) + self.buffer_minutes = config.get("buffer_minutes", 15) + self.cooldown_minutes = config.get("cooldown_minutes", 15) + self.log_levels = config.get("log_levels", ["ERROR", "CRITICAL"]) + + # State + self.error_buffer: deque[ErrorRecord] = deque() + self.last_sent = datetime.min # Last time we sent an alert + self._flush_task: asyncio.Task | None = None + self._lock = threading.Lock() # Thread-safe for multi-threaded logging + + _LOGGER.info( + "UnifiedAlertHandler initialized: threshold=%d, buffer=%dmin, cooldown=%dmin", + self.error_threshold, + self.buffer_minutes, + self.cooldown_minutes, + ) + + def emit(self, record: logging.LogRecord) -> None: + """Handle a log record. + + This is called automatically by the logging system when an error is logged. + It's important that this method is fast and doesn't block. + + Args: + record: The log record to handle + + """ + # Only handle configured log levels + if record.levelname not in self.log_levels: + return + + try: + # Ensure we have an event loop + if self.loop is None: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + # No running loop, we'll need to handle this differently + _LOGGER.warning("No asyncio event loop available for alerts") + return + + # Add error to buffer (thread-safe) + with self._lock: + error_record = ErrorRecord(record) + self.error_buffer.append(error_record) + buffer_size = len(self.error_buffer) + + # Determine if we should send immediately + should_send_immediately = buffer_size >= self.error_threshold + + if should_send_immediately: + # Cancel any pending flush task + if self._flush_task and not self._flush_task.done(): + self._flush_task.cancel() + + # Schedule immediate flush + self._flush_task = asyncio.run_coroutine_threadsafe( + self._flush_buffer(immediate=True), + self.loop, + ) + # Schedule delayed flush if not already scheduled + elif not self._flush_task or self._flush_task.done(): + self._flush_task = asyncio.run_coroutine_threadsafe( + self._schedule_delayed_flush(), + self.loop, + ) + + except Exception: + # Never let the handler crash - just log and continue + _LOGGER.exception("Error in UnifiedAlertHandler.emit") + + async def _schedule_delayed_flush(self) -> None: + """Schedule a delayed buffer flush after buffer duration.""" + await asyncio.sleep(self.buffer_minutes * 60) + await self._flush_buffer(immediate=False) + + async def _flush_buffer(self, *, immediate: bool) -> None: + """Flush the error buffer and send alert. + + Args: + immediate: Whether this is an immediate flush (threshold hit) + + """ + # Check cooldown period + now = datetime.now() + time_since_last = (now - self.last_sent).total_seconds() / 60 + + if time_since_last < self.cooldown_minutes: + _LOGGER.info( + "Alert cooldown active (%.1f min remaining), buffering errors", + self.cooldown_minutes - time_since_last, + ) + # Don't clear buffer - let errors accumulate until cooldown expires + return + + # Get all buffered errors (thread-safe) + with self._lock: + if not self.error_buffer: + return + + errors = list(self.error_buffer) + self.error_buffer.clear() + + # Update last sent time + self.last_sent = now + + # Format alert + error_count = len(errors) + time_range = ( + f"{errors[0].timestamp.strftime('%H:%M:%S')} to " + f"{errors[-1].timestamp.strftime('%H:%M:%S')}" + ) + + # Determine alert type + alert_type = "Immediate Alert" if immediate else "Scheduled Alert" + if immediate: + reason = f"(threshold of {self.error_threshold} exceeded)" + else: + reason = f"({self.buffer_minutes} minute buffer)" + + title = f"AlpineBits Error {alert_type}: {error_count} errors {reason}" + + # Build message + message = f"Error Alert - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" + message += "=" * 70 + "\n\n" + message += f"Alert Type: {alert_type}\n" + message += f"Error Count: {error_count}\n" + message += f"Time Range: {time_range}\n" + message += f"Reason: {reason}\n" + message += "\n" + "=" * 70 + "\n\n" + + # Add individual errors + message += "Errors:\n" + message += "-" * 70 + "\n\n" + for error in errors: + message += error.format_plain_text() + message += "\n" + + message += "-" * 70 + "\n" + message += f"Generated by AlpineBits Monitoring at {now.strftime('%Y-%m-%d %H:%M:%S')}\n" + + # Send through unified notification service + try: + results = await self.notification_service.send_alert( + title=title, + message=message, + backends=None, # Send to all backends + ) + + success_count = sum(1 for success in results.values() if success) + if success_count > 0: + _LOGGER.info( + "Alert sent successfully through %d/%d backend(s): %d errors", + success_count, + len(results), + error_count, + ) + else: + _LOGGER.error("Failed to send alert through any backend: %d errors", error_count) + + except Exception: + _LOGGER.exception("Exception while sending alert") + + def close(self) -> None: + """Close the handler and flush any remaining errors. + + This is called when the logging system shuts down. + """ + # Cancel any pending flush tasks + if self._flush_task and not self._flush_task.done(): + self._flush_task.cancel() + + # Flush any remaining errors immediately + if self.error_buffer and self.loop: + try: + # Check if the loop is still running + if not self.loop.is_closed(): + future = asyncio.run_coroutine_threadsafe( + self._flush_buffer(immediate=False), + self.loop, + ) + future.result(timeout=5) + else: + _LOGGER.warning( + "Event loop closed, cannot flush %d remaining errors", + len(self.error_buffer), + ) + except Exception: + _LOGGER.exception("Error flushing buffer on close") + + super().close() + + +class UnifiedDailyReportScheduler: + """Scheduler for sending daily reports through unified notification service. + + This runs as a background task and sends daily reports containing + statistics and error summaries through all configured notification backends. + """ + + def __init__( + self, + notification_service: NotificationService, + config: dict[str, Any], + ): + """Initialize the unified daily report scheduler. + + Args: + notification_service: Unified notification service + config: Configuration for daily reports + + """ + self.notification_service = notification_service + self.config = config + self.send_time = config.get("send_time", "08:00") # Default 8 AM + self.include_stats = config.get("include_stats", True) + self.include_errors = config.get("include_errors", True) + + self._task: asyncio.Task | None = None + self._stats_collector = None # Will be set by application + self._error_log: list[dict[str, Any]] = [] + + _LOGGER.info( + "UnifiedDailyReportScheduler initialized: send_time=%s", + self.send_time, + ) + + def start(self) -> None: + """Start the daily report scheduler.""" + if self._task is None or self._task.done(): + self._task = asyncio.create_task(self._run()) + _LOGGER.info("Daily report scheduler started") + + def stop(self) -> None: + """Stop the daily report scheduler.""" + if self._task and not self._task.done(): + self._task.cancel() + _LOGGER.info("Daily report scheduler stopped") + + def log_error(self, error: dict[str, Any]) -> None: + """Log an error for inclusion in daily report. + + Args: + error: Error information dictionary + + """ + self._error_log.append(error) + + async def _run(self) -> None: + """Run the daily report scheduler loop.""" + while True: + try: + # Calculate time until next report + now = datetime.now() + target_hour, target_minute = map(int, self.send_time.split(":")) + + # Calculate next send time + next_send = now.replace( + hour=target_hour, + minute=target_minute, + second=0, + microsecond=0, + ) + + # If time has passed today, schedule for tomorrow + if next_send <= now: + next_send += timedelta(days=1) + + # Calculate sleep duration + sleep_seconds = (next_send - now).total_seconds() + + _LOGGER.info( + "Next daily report scheduled for %s (in %.1f hours)", + next_send.strftime("%Y-%m-%d %H:%M:%S"), + sleep_seconds / 3600, + ) + + # Wait until send time + await asyncio.sleep(sleep_seconds) + + # Send report + await self._send_report() + + except asyncio.CancelledError: + _LOGGER.info("Daily report scheduler cancelled") + break + except Exception: + _LOGGER.exception("Error in daily report scheduler") + # Sleep a bit before retrying + await asyncio.sleep(60) + + async def _send_report(self) -> None: + """Send the daily report.""" + stats = {} + + # Collect statistics if enabled + if self.include_stats and self._stats_collector: + try: + stats = await self._stats_collector() + except Exception: + _LOGGER.exception("Error collecting statistics for daily report") + + # Get errors if enabled + errors = self._error_log.copy() if self.include_errors else None + + # Send report through unified notification service + try: + results = await self.notification_service.send_daily_report( + stats=stats, + errors=errors, + backends=None, # Send to all backends + ) + + success_count = sum(1 for success in results.values() if success) + if success_count > 0: + _LOGGER.info( + "Daily report sent successfully through %d/%d backend(s)", + success_count, + len(results), + ) + # Clear error log after successful send + self._error_log.clear() + else: + _LOGGER.error("Failed to send daily report through any backend") + + except Exception: + _LOGGER.exception("Exception while sending daily report") + + def set_stats_collector(self, collector) -> None: + """Set the statistics collector function. + + Args: + collector: Async function that returns statistics dictionary + + """ + self._stats_collector = collector diff --git a/src/alpine_bits_python/util/fix_postgres_sequences.py b/src/alpine_bits_python/util/fix_postgres_sequences.py new file mode 100644 index 0000000..d04f2f9 --- /dev/null +++ b/src/alpine_bits_python/util/fix_postgres_sequences.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +"""Fix PostgreSQL sequences and migrate datetime columns after SQLite migration. + +This script performs two operations: +1. Migrates DateTime columns to TIMESTAMP WITH TIME ZONE for timezone-aware support +2. Resets all ID sequence values to match the current maximum ID in each table + +The sequence reset is necessary because the migration script inserts records +with explicit IDs, which doesn't automatically advance PostgreSQL sequences. + +The datetime migration ensures proper handling of timezone-aware datetimes, +which is required by the application code. + +Usage: + # Using default config.yaml + uv run python -m alpine_bits_python.util.fix_postgres_sequences + + # Using a specific config file + uv run python -m alpine_bits_python.util.fix_postgres_sequences \ + --config config/postgres.yaml + + # Using DATABASE_URL environment variable + DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \ + uv run python -m alpine_bits_python.util.fix_postgres_sequences + + # Using command line argument + uv run python -m alpine_bits_python.util.fix_postgres_sequences \ + --database-url postgresql+asyncpg://user:pass@host/db +""" + +import argparse +import asyncio +import os +import sys +from pathlib import Path + +# Add parent directory to path so we can import alpine_bits_python +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import yaml +from sqlalchemy import text +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from alpine_bits_python.config_loader import load_config +from alpine_bits_python.db import get_database_url +from alpine_bits_python.logging_config import get_logger, setup_logging + +_LOGGER = get_logger(__name__) + + +async def migrate_datetime_columns(session) -> None: + """Migrate DateTime columns to TIMESTAMP WITH TIME ZONE. + + This updates the columns to properly handle timezone-aware datetimes. + """ + _LOGGER.info("\nMigrating DateTime columns to timezone-aware...") + + datetime_columns = [ + ("hashed_customers", "created_at"), + ("reservations", "created_at"), + ("acked_requests", "timestamp"), + ] + + for table_name, column_name in datetime_columns: + _LOGGER.info(f" {table_name}.{column_name}: Converting to TIMESTAMPTZ") + await session.execute( + text( + f"ALTER TABLE {table_name} " + f"ALTER COLUMN {column_name} TYPE TIMESTAMP WITH TIME ZONE" + ) + ) + + await session.commit() + _LOGGER.info("✓ DateTime columns migrated to timezone-aware") + + +async def fix_sequences(database_url: str) -> None: + """Fix PostgreSQL sequences to match current max IDs and migrate datetime columns. + + Args: + database_url: PostgreSQL database URL + + """ + _LOGGER.info("=" * 70) + _LOGGER.info("PostgreSQL Migration & Sequence Fix") + _LOGGER.info("=" * 70) + _LOGGER.info( + "Database: %s", + database_url.split("@")[-1] if "@" in database_url else database_url, + ) + _LOGGER.info("=" * 70) + + # Create engine and session + engine = create_async_engine(database_url, echo=False) + SessionMaker = async_sessionmaker(engine, expire_on_commit=False) + + try: + # Migrate datetime columns first + async with SessionMaker() as session: + await migrate_datetime_columns(session) + + # Then fix sequences + async with SessionMaker() as session: + # List of tables and their sequence names + tables = [ + ("customers", "customers_id_seq"), + ("hashed_customers", "hashed_customers_id_seq"), + ("reservations", "reservations_id_seq"), + ("acked_requests", "acked_requests_id_seq"), + ] + + _LOGGER.info("\nResetting sequences...") + for table_name, sequence_name in tables: + # Get current max ID + result = await session.execute( + text(f"SELECT MAX(id) FROM {table_name}") + ) + max_id = result.scalar() + + # Get current sequence value + result = await session.execute( + text(f"SELECT last_value FROM {sequence_name}") + ) + current_seq = result.scalar() + + if max_id is None: + _LOGGER.info(f" {table_name}: empty table, setting sequence to 1") + await session.execute( + text(f"SELECT setval('{sequence_name}', 1, false)") + ) + elif current_seq <= max_id: + new_seq = max_id + 1 + _LOGGER.info( + f" {table_name}: max_id={max_id}, " + f"old_seq={current_seq}, new_seq={new_seq}" + ) + await session.execute( + text(f"SELECT setval('{sequence_name}', {new_seq}, false)") + ) + else: + _LOGGER.info( + f" {table_name}: sequence already correct " + f"(max_id={max_id}, seq={current_seq})" + ) + + await session.commit() + + _LOGGER.info("\n" + "=" * 70) + _LOGGER.info("✓ Migration completed successfully!") + _LOGGER.info("=" * 70) + _LOGGER.info("\nChanges applied:") + _LOGGER.info(" 1. DateTime columns are now timezone-aware (TIMESTAMPTZ)") + _LOGGER.info(" 2. Sequences are reset to match current max IDs") + _LOGGER.info("\nYou can now insert new records without conflicts.") + + except Exception as e: + _LOGGER.exception("Failed to fix sequences: %s", e) + raise + + finally: + await engine.dispose() + + +async def main(): + """Run the sequence fix.""" + parser = argparse.ArgumentParser( + description="Fix PostgreSQL sequences after SQLite migration", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--database-url", + help="PostgreSQL database URL (default: from config or DATABASE_URL env var)", + ) + parser.add_argument( + "--config", + help=( + "Path to config file containing PostgreSQL database URL " + "(keeps password out of bash history)" + ), + ) + + args = parser.parse_args() + + try: + # Load config + config = load_config() + setup_logging(config) + except Exception as e: + _LOGGER.warning("Failed to load config: %s. Using defaults.", e) + config = {} + + # Determine database URL (same logic as migrate_sqlite_to_postgres) + if args.database_url: + database_url = args.database_url + elif args.config: + # Load config file manually (simpler YAML without secrets) + _LOGGER.info("Loading database config from: %s", args.config) + try: + config_path = Path(args.config) + config_text = config_path.read_text() + target_config = yaml.safe_load(config_text) + database_url = target_config["database"]["url"] + _LOGGER.info("Successfully loaded config") + except (FileNotFoundError, ValueError, KeyError): + _LOGGER.exception("Failed to load config") + _LOGGER.info( + "Config file should contain: database.url with PostgreSQL connection" + ) + sys.exit(1) + else: + database_url = os.environ.get("DATABASE_URL") + if not database_url: + # Try from default config + database_url = get_database_url(config) + + if "postgresql" not in database_url and "postgres" not in database_url: + _LOGGER.error("This script only works with PostgreSQL databases.") + url_type = database_url.split("+")[0] if "+" in database_url else "unknown" + _LOGGER.error("Current database URL type detected: %s", url_type) + _LOGGER.error("\nSpecify PostgreSQL database using one of:") + _LOGGER.error(" - --config config/postgres.yaml") + _LOGGER.error(" - DATABASE_URL environment variable") + _LOGGER.error(" - --database-url postgresql+asyncpg://user:pass@host/db") + sys.exit(1) + + # Run the fix + await fix_sequences(database_url) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/alpine_bits_python/util/migrate_add_room_types.py b/src/alpine_bits_python/util/migrate_add_room_types.py new file mode 100644 index 0000000..e9de879 --- /dev/null +++ b/src/alpine_bits_python/util/migrate_add_room_types.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +"""Migration script to add RoomTypes fields to Reservation table. + +This migration adds three optional fields to the reservations table: +- room_type_code: String (max 8 chars) +- room_classification_code: String (numeric pattern) +- room_type: String (enum: 1-5) + +This script can be run manually before starting the server, or the changes +will be applied automatically when the server starts via Base.metadata.create_all. +""" + +import asyncio +import sys +from pathlib import Path + +# Add parent directory to path so we can import alpine_bits_python +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from sqlalchemy import inspect, text +from sqlalchemy.ext.asyncio import create_async_engine + +from alpine_bits_python.config_loader import load_config +from alpine_bits_python.db import get_database_url +from alpine_bits_python.logging_config import get_logger, setup_logging + +_LOGGER = get_logger(__name__) + + +async def check_columns_exist(engine, table_name: str, columns: list[str]) -> dict[str, bool]: + """Check which columns exist in the table. + + Returns a dict mapping column name to whether it exists. + """ + async with engine.connect() as conn: + def _check(connection): + inspector = inspect(connection) + existing_cols = [col['name'] for col in inspector.get_columns(table_name)] + return {col: col in existing_cols for col in columns} + + result = await conn.run_sync(_check) + return result + + +async def add_room_types_columns(engine): + """Add RoomTypes columns to reservations table if they don't exist.""" + from alpine_bits_python.db import Base + + table_name = "reservations" + columns_to_add = ["room_type_code", "room_classification_code", "room_type"] + + # First, ensure the table exists by creating all tables if needed + _LOGGER.info("Ensuring database tables exist...") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + _LOGGER.info("Database tables checked/created.") + + _LOGGER.info("Checking which columns need to be added to %s table...", table_name) + + # Check which columns already exist + columns_exist = await check_columns_exist(engine, table_name, columns_to_add) + + columns_to_create = [col for col, exists in columns_exist.items() if not exists] + + if not columns_to_create: + _LOGGER.info("All RoomTypes columns already exist in %s table. No migration needed.", table_name) + return + + _LOGGER.info("Adding columns to %s table: %s", table_name, ", ".join(columns_to_create)) + + # Build ALTER TABLE statements for missing columns + # Note: SQLite supports ALTER TABLE ADD COLUMN but not ADD MULTIPLE COLUMNS + async with engine.begin() as conn: + for column in columns_to_create: + sql = f"ALTER TABLE {table_name} ADD COLUMN {column} VARCHAR" + _LOGGER.info("Executing: %s", sql) + await conn.execute(text(sql)) + + _LOGGER.info("Successfully added %d columns to %s table", len(columns_to_create), table_name) + + +async def main(): + """Run the migration.""" + try: + # Load config + config = load_config() + setup_logging(config) + except Exception as e: + _LOGGER.warning("Failed to load config: %s. Using defaults.", e) + config = {} + + _LOGGER.info("=" * 60) + _LOGGER.info("Starting RoomTypes Migration") + _LOGGER.info("=" * 60) + + # Get database URL + database_url = get_database_url(config) + _LOGGER.info("Database URL: %s", database_url.replace("://", "://***:***@").split("@")[-1]) + + # Create engine + engine = create_async_engine(database_url, echo=False) + + try: + # Run migration + await add_room_types_columns(engine) + + _LOGGER.info("=" * 60) + _LOGGER.info("Migration completed successfully!") + _LOGGER.info("=" * 60) + + except Exception as e: + _LOGGER.exception("Migration failed: %s", e) + sys.exit(1) + finally: + await engine.dispose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/alpine_bits_python/util/migrate_sqlite_to_postgres.py b/src/alpine_bits_python/util/migrate_sqlite_to_postgres.py new file mode 100644 index 0000000..5eefc1b --- /dev/null +++ b/src/alpine_bits_python/util/migrate_sqlite_to_postgres.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 +"""Migration script to copy data from SQLite to PostgreSQL. + +This script: +1. Connects to both SQLite and PostgreSQL databases +2. Reads all data from SQLite using SQLAlchemy models +3. Writes data to PostgreSQL using the same models +4. Ensures data integrity and provides progress feedback + +Prerequisites: +- PostgreSQL database must be created and empty (or you can use --drop-tables flag) +- asyncpg must be installed: uv pip install asyncpg +- Configure target PostgreSQL URL in config.yaml or via DATABASE_URL env var + +Usage: + # Dry run (preview what will be migrated) + uv run python -m alpine_bits_python.util.migrate_sqlite_to_postgres --dry-run + + # Actual migration using target config file + uv run python -m alpine_bits_python.util.migrate_sqlite_to_postgres \ + --target-config config/postgres.yaml + + # Drop existing tables first (careful!) + uv run python -m alpine_bits_python.util.migrate_sqlite_to_postgres \ + --target-config config/postgres.yaml --drop-tables + + # Alternative: use DATABASE_URL environment variable + DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \ + uv run python -m alpine_bits_python.util.migrate_sqlite_to_postgres + + # Alternative: specify URLs directly + uv run python -m alpine_bits_python.util.migrate_sqlite_to_postgres \ + --source sqlite+aiosqlite:///old.db \ + --target postgresql+asyncpg://user:pass@localhost/dbname +""" + +import argparse +import asyncio +import sys +from pathlib import Path + +# Add parent directory to path so we can import alpine_bits_python +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import yaml +from sqlalchemy import select, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from alpine_bits_python.config_loader import load_config +from alpine_bits_python.db import ( + AckedRequest, + Base, + Customer, + HashedCustomer, + Reservation, + get_database_url, +) +from alpine_bits_python.logging_config import get_logger, setup_logging + +_LOGGER = get_logger(__name__) + + +def mask_db_url(url: str) -> str: + """Mask sensitive parts of database URL for logging.""" + if "://" not in url: + return url + protocol, rest = url.split("://", 1) + if "@" in rest: + credentials, location = rest.split("@", 1) + return f"{protocol}://***:***@{location}" + return url + + +async def get_table_counts(session: AsyncSession) -> dict[str, int]: + """Get row counts for all tables.""" + counts = {} + + # Count customers + result = await session.execute(select(Customer)) + counts["customers"] = len(result.scalars().all()) + + # Count hashed_customers + result = await session.execute(select(HashedCustomer)) + counts["hashed_customers"] = len(result.scalars().all()) + + # Count reservations + result = await session.execute(select(Reservation)) + counts["reservations"] = len(result.scalars().all()) + + # Count acked_requests + result = await session.execute(select(AckedRequest)) + counts["acked_requests"] = len(result.scalars().all()) + + return counts + + +async def reset_sequences(session: AsyncSession) -> None: + """Reset PostgreSQL sequences to match the current max ID values. + + This is necessary after migrating data with explicit IDs from SQLite, + as PostgreSQL sequences won't automatically advance when IDs are set explicitly. + """ + tables = [ + ("customers", "customers_id_seq"), + ("hashed_customers", "hashed_customers_id_seq"), + ("reservations", "reservations_id_seq"), + ("acked_requests", "acked_requests_id_seq"), + ] + + for table_name, sequence_name in tables: + # Set sequence to max(id) + 1, or 1 if table is empty + query = text(f""" + SELECT setval('{sequence_name}', + COALESCE((SELECT MAX(id) FROM {table_name}), 0) + 1, + false) + """) + await session.execute(query) + + await session.commit() + + +async def migrate_data( + source_url: str, + target_url: str, + dry_run: bool = False, + drop_tables: bool = False, +) -> None: + """Migrate data from source database to target database. + + Args: + source_url: Source database URL (SQLite) + target_url: Target database URL (PostgreSQL) + dry_run: If True, only preview what would be migrated + drop_tables: If True, drop existing tables in target before creating + """ + _LOGGER.info("=" * 70) + _LOGGER.info("SQLite to PostgreSQL Migration") + _LOGGER.info("=" * 70) + _LOGGER.info("Source: %s", mask_db_url(source_url)) + _LOGGER.info("Target: %s", mask_db_url(target_url)) + _LOGGER.info("Mode: %s", "DRY RUN" if dry_run else "LIVE MIGRATION") + _LOGGER.info("=" * 70) + + # Create engines + _LOGGER.info("Creating database connections...") + source_engine = create_async_engine(source_url, echo=False) + target_engine = create_async_engine(target_url, echo=False) + + # Create session makers + SourceSession = async_sessionmaker(source_engine, expire_on_commit=False) + TargetSession = async_sessionmaker(target_engine, expire_on_commit=False) + + try: + # Check source database + _LOGGER.info("\nChecking source database...") + async with SourceSession() as source_session: + source_counts = await get_table_counts(source_session) + + _LOGGER.info("Source database contains:") + for table, count in source_counts.items(): + _LOGGER.info(" - %s: %d rows", table, count) + + total_rows = sum(source_counts.values()) + if total_rows == 0: + _LOGGER.warning("Source database is empty. Nothing to migrate.") + return + + if dry_run: + _LOGGER.info("\n" + "=" * 70) + _LOGGER.info("DRY RUN: Would migrate %d total rows", total_rows) + _LOGGER.info("=" * 70) + return + + # Prepare target database + _LOGGER.info("\nPreparing target database...") + + if drop_tables: + _LOGGER.warning("Dropping existing tables in target database...") + async with target_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + _LOGGER.info("Tables dropped.") + + _LOGGER.info("Creating tables in target database...") + async with target_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + _LOGGER.info("Tables created.") + + # Check if target already has data + _LOGGER.info("\nChecking target database...") + async with TargetSession() as target_session: + target_counts = await get_table_counts(target_session) + + if sum(target_counts.values()) > 0: + _LOGGER.warning("Target database is not empty:") + for table, count in target_counts.items(): + if count > 0: + _LOGGER.warning(" - %s: %d rows", table, count) + + response = input("\nContinue anyway? This may cause conflicts. (yes/no): ") + if response.lower() != "yes": + _LOGGER.info("Migration cancelled.") + return + + # Migrate data table by table + _LOGGER.info("\n" + "=" * 70) + _LOGGER.info("Starting data migration...") + _LOGGER.info("=" * 70) + + # 1. Migrate Customers first (no dependencies) + _LOGGER.info("\n[1/4] Migrating Customers...") + async with SourceSession() as source_session: + result = await source_session.execute(select(Customer)) + customers = result.scalars().all() + + if customers: + async with TargetSession() as target_session: + for i, customer in enumerate(customers, 1): + # Create new instance with same data + new_customer = Customer( + id=customer.id, + given_name=customer.given_name, + contact_id=customer.contact_id, + surname=customer.surname, + name_prefix=customer.name_prefix, + email_address=customer.email_address, + phone=customer.phone, + email_newsletter=customer.email_newsletter, + address_line=customer.address_line, + city_name=customer.city_name, + postal_code=customer.postal_code, + country_code=customer.country_code, + gender=customer.gender, + birth_date=customer.birth_date, + language=customer.language, + address_catalog=customer.address_catalog, + name_title=customer.name_title, + ) + target_session.add(new_customer) + + if i % 100 == 0: + _LOGGER.info(" Progress: %d/%d customers", i, len(customers)) + + await target_session.commit() + + _LOGGER.info("✓ Migrated %d customers", len(customers)) + + # 2. Migrate HashedCustomers (depends on Customers) + _LOGGER.info("\n[2/4] Migrating HashedCustomers...") + async with SourceSession() as source_session: + result = await source_session.execute(select(HashedCustomer)) + hashed_customers = result.scalars().all() + + if hashed_customers: + async with TargetSession() as target_session: + for i, hashed in enumerate(hashed_customers, 1): + new_hashed = HashedCustomer( + id=hashed.id, + customer_id=hashed.customer_id, + contact_id=hashed.contact_id, + hashed_email=hashed.hashed_email, + hashed_phone=hashed.hashed_phone, + hashed_given_name=hashed.hashed_given_name, + hashed_surname=hashed.hashed_surname, + hashed_city=hashed.hashed_city, + hashed_postal_code=hashed.hashed_postal_code, + hashed_country_code=hashed.hashed_country_code, + hashed_gender=hashed.hashed_gender, + hashed_birth_date=hashed.hashed_birth_date, + created_at=hashed.created_at, + ) + target_session.add(new_hashed) + + if i % 100 == 0: + _LOGGER.info(" Progress: %d/%d hashed customers", i, len(hashed_customers)) + + await target_session.commit() + + _LOGGER.info("✓ Migrated %d hashed customers", len(hashed_customers)) + + # 3. Migrate Reservations (depends on Customers) + _LOGGER.info("\n[3/4] Migrating Reservations...") + async with SourceSession() as source_session: + result = await source_session.execute(select(Reservation)) + reservations = result.scalars().all() + + if reservations: + async with TargetSession() as target_session: + for i, reservation in enumerate(reservations, 1): + new_reservation = Reservation( + id=reservation.id, + customer_id=reservation.customer_id, + unique_id=reservation.unique_id, + md5_unique_id=reservation.md5_unique_id, + start_date=reservation.start_date, + end_date=reservation.end_date, + num_adults=reservation.num_adults, + num_children=reservation.num_children, + children_ages=reservation.children_ages, + offer=reservation.offer, + created_at=reservation.created_at, + utm_source=reservation.utm_source, + utm_medium=reservation.utm_medium, + utm_campaign=reservation.utm_campaign, + utm_term=reservation.utm_term, + utm_content=reservation.utm_content, + user_comment=reservation.user_comment, + fbclid=reservation.fbclid, + gclid=reservation.gclid, + hotel_code=reservation.hotel_code, + hotel_name=reservation.hotel_name, + room_type_code=reservation.room_type_code, + room_classification_code=reservation.room_classification_code, + room_type=reservation.room_type, + ) + target_session.add(new_reservation) + + if i % 100 == 0: + _LOGGER.info(" Progress: %d/%d reservations", i, len(reservations)) + + await target_session.commit() + + _LOGGER.info("✓ Migrated %d reservations", len(reservations)) + + # 4. Migrate AckedRequests (no dependencies) + _LOGGER.info("\n[4/4] Migrating AckedRequests...") + async with SourceSession() as source_session: + result = await source_session.execute(select(AckedRequest)) + acked_requests = result.scalars().all() + + if acked_requests: + async with TargetSession() as target_session: + for i, acked in enumerate(acked_requests, 1): + new_acked = AckedRequest( + id=acked.id, + client_id=acked.client_id, + unique_id=acked.unique_id, + timestamp=acked.timestamp, + ) + target_session.add(new_acked) + + if i % 100 == 0: + _LOGGER.info(" Progress: %d/%d acked requests", i, len(acked_requests)) + + await target_session.commit() + + _LOGGER.info("✓ Migrated %d acked requests", len(acked_requests)) + + # Migrate datetime columns to timezone-aware + _LOGGER.info("\n[5/6] Converting DateTime columns to timezone-aware...") + async with target_engine.begin() as conn: + await conn.execute( + text( + "ALTER TABLE hashed_customers " + "ALTER COLUMN created_at TYPE TIMESTAMP WITH TIME ZONE" + ) + ) + await conn.execute( + text( + "ALTER TABLE reservations " + "ALTER COLUMN created_at TYPE TIMESTAMP WITH TIME ZONE" + ) + ) + await conn.execute( + text( + "ALTER TABLE acked_requests " + "ALTER COLUMN timestamp TYPE TIMESTAMP WITH TIME ZONE" + ) + ) + _LOGGER.info("✓ DateTime columns converted to timezone-aware") + + # Reset PostgreSQL sequences + _LOGGER.info("\n[6/6] Resetting PostgreSQL sequences...") + async with TargetSession() as target_session: + await reset_sequences(target_session) + _LOGGER.info("✓ Sequences reset to match current max IDs") + + # Verify migration + _LOGGER.info("\n" + "=" * 70) + _LOGGER.info("Verifying migration...") + _LOGGER.info("=" * 70) + + async with TargetSession() as target_session: + final_counts = await get_table_counts(target_session) + + _LOGGER.info("Target database now contains:") + all_match = True + for table, count in final_counts.items(): + source_count = source_counts[table] + match = "✓" if count == source_count else "✗" + _LOGGER.info(" %s %s: %d rows (source: %d)", match, table, count, source_count) + if count != source_count: + all_match = False + + if all_match: + _LOGGER.info("\n" + "=" * 70) + _LOGGER.info("✓ Migration completed successfully!") + _LOGGER.info("=" * 70) + _LOGGER.info("\nNext steps:") + _LOGGER.info("1. Test your application with PostgreSQL") + _LOGGER.info("2. Update config.yaml or DATABASE_URL to use PostgreSQL") + _LOGGER.info("3. Keep SQLite backup until you're confident everything works") + else: + _LOGGER.error("\n" + "=" * 70) + _LOGGER.error("✗ Migration completed with mismatches!") + _LOGGER.error("=" * 70) + _LOGGER.error("Please review the counts above and investigate.") + + except Exception as e: + _LOGGER.exception("Migration failed: %s", e) + raise + + finally: + await source_engine.dispose() + await target_engine.dispose() + + +async def main(): + """Run the migration.""" + parser = argparse.ArgumentParser( + description="Migrate data from SQLite to PostgreSQL", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--source", + help="Source database URL (default: from config or sqlite+aiosqlite:///alpinebits.db)", + ) + parser.add_argument( + "--target", + help=( + "Target database URL " + "(default: from DATABASE_URL env var or --target-config)" + ), + ) + parser.add_argument( + "--target-config", + help=( + "Path to config file containing target PostgreSQL database URL " + "(keeps password out of bash history)" + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Preview migration without making changes", + ) + parser.add_argument( + "--drop-tables", + action="store_true", + help="Drop existing tables in target database before migration", + ) + + args = parser.parse_args() + + try: + # Load config + config = load_config() + setup_logging(config) + except Exception as e: + _LOGGER.warning("Failed to load config: %s. Using defaults.", e) + config = {} + + # Determine source URL (default to SQLite) + if args.source: + source_url = args.source + else: + source_url = get_database_url(config) + if "sqlite" not in source_url: + _LOGGER.error("Source database must be SQLite. Use --source to specify.") + sys.exit(1) + + # Determine target URL (must be PostgreSQL) + if args.target: + target_url = args.target + elif args.target_config: + # Load target config file manually (simpler YAML without secrets) + _LOGGER.info("Loading target database config from: %s", args.target_config) + try: + config_path = Path(args.target_config) + with config_path.open() as f: + target_config = yaml.safe_load(f) + target_url = target_config["database"]["url"] + _LOGGER.info("Successfully loaded target config") + except (FileNotFoundError, ValueError, KeyError): + _LOGGER.exception("Failed to load target config") + _LOGGER.info( + "Config file should contain: database.url with PostgreSQL connection" + ) + sys.exit(1) + else: + import os + target_url = os.environ.get("DATABASE_URL") + if not target_url: + _LOGGER.error("Target database URL not specified.") + _LOGGER.error("Specify target database using one of:") + _LOGGER.error(" - --target-config config/postgres.yaml") + _LOGGER.error(" - DATABASE_URL environment variable") + _LOGGER.error(" - --target postgresql+asyncpg://user:pass@host/db") + sys.exit(1) + + if "postgresql" not in target_url and "postgres" not in target_url: + _LOGGER.error("Target database must be PostgreSQL.") + sys.exit(1) + + # Run migration + await migrate_data( + source_url=source_url, + target_url=target_url, + dry_run=args.dry_run, + drop_tables=args.drop_tables, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/alpine_bits_python/worker_coordination.py b/src/alpine_bits_python/worker_coordination.py index a30b4b1..199aa7d 100644 --- a/src/alpine_bits_python/worker_coordination.py +++ b/src/alpine_bits_python/worker_coordination.py @@ -21,12 +21,32 @@ class WorkerLock: services like schedulers only run on one worker. """ - def __init__(self, lock_file: str = "/tmp/alpinebits_primary_worker.lock"): + def __init__(self, lock_file: str | None = None): """Initialize the worker lock. Args: - lock_file: Path to the lock file + lock_file: Path to the lock file. If None, will try /var/run first, + falling back to /tmp if /var/run is not writable. """ + if lock_file is None: + # Try /var/run first (more persistent), fall back to /tmp + for candidate in ["/var/run/alpinebits_primary_worker.lock", + "/tmp/alpinebits_primary_worker.lock"]: + try: + candidate_path = Path(candidate) + candidate_path.parent.mkdir(parents=True, exist_ok=True) + # Test if we can write to this location + test_file = candidate_path.parent / ".alpinebits_test" + test_file.touch() + test_file.unlink() + lock_file = candidate + break + except (PermissionError, OSError): + continue + else: + # If all fail, default to /tmp + lock_file = "/tmp/alpinebits_primary_worker.lock" + self.lock_file = Path(lock_file) self.lock_fd = None self.is_primary = False @@ -107,6 +127,7 @@ def is_primary_worker() -> tuple[bool, WorkerLock | None]: """Determine if this worker should run singleton services. Uses file-based locking to coordinate between workers. + Includes stale lock detection and cleanup. Returns: Tuple of (is_primary, lock_object) @@ -114,6 +135,31 @@ def is_primary_worker() -> tuple[bool, WorkerLock | None]: - lock_object: WorkerLock instance (must be kept alive) """ lock = WorkerLock() + + # Check for stale locks from dead processes + if lock.lock_file.exists(): + try: + with open(lock.lock_file, 'r') as f: + old_pid_str = f.read().strip() + if old_pid_str: + old_pid = int(old_pid_str) + # Check if the process with this PID still exists + try: + os.kill(old_pid, 0) # Signal 0 just checks existence + _LOGGER.debug("Lock held by active process pid=%d", old_pid) + except ProcessLookupError: + # Process is dead, remove stale lock + _LOGGER.warning( + "Removing stale lock file from dead process pid=%d", + old_pid + ) + try: + lock.lock_file.unlink() + except Exception as e: + _LOGGER.warning("Failed to remove stale lock: %s", e) + except (ValueError, FileNotFoundError, PermissionError) as e: + _LOGGER.warning("Error checking lock file: %s", e) + is_primary = lock.acquire() return is_primary, lock diff --git a/test_migration.db b/test_migration.db new file mode 100644 index 0000000..512a004 Binary files /dev/null and b/test_migration.db differ diff --git a/tests/test_alpine_bits_server_read.py b/tests/test_alpine_bits_server_read.py index 4f193ed..6b36243 100644 --- a/tests/test_alpine_bits_server_read.py +++ b/tests/test_alpine_bits_server_read.py @@ -16,14 +16,12 @@ from xsdata_pydantic.bindings import XmlParser, XmlSerializer from alpine_bits_python.alpine_bits_helpers import create_res_retrieve_response from alpine_bits_python.alpinebits_server import AlpineBitsClientInfo, AlpineBitsServer +from alpine_bits_python.const import HttpStatusCode from alpine_bits_python.db import AckedRequest, Base, Customer, Reservation from alpine_bits_python.generated import OtaReadRq from alpine_bits_python.generated.alpinebits import OtaResRetrieveRs from alpine_bits_python.schemas import ReservationData -# HTTP status code constants -HTTP_OK = 200 - @pytest_asyncio.fixture async def test_db_engine(): @@ -558,7 +556,7 @@ class TestAcknowledgments: ) assert response is not None - assert response.status_code == HTTP_OK + assert response.status_code == HttpStatusCode.OK assert response.xml_content is not None # Verify response contains reservation data @@ -609,7 +607,7 @@ class TestAcknowledgments: ) assert ack_response is not None - assert ack_response.status_code == HTTP_OK + assert ack_response.status_code == HttpStatusCode.OK assert "OTA_NotifReportRS" in ack_response.xml_content @pytest.mark.asyncio @@ -920,7 +918,7 @@ class TestAcknowledgments: ) assert response is not None - assert response.status_code == HTTP_OK + assert response.status_code == HttpStatusCode.OK # Parse response to verify both reservations are returned parser = XmlParser() diff --git a/tests/test_alpinebits_server_ping.py b/tests/test_alpinebits_server_ping.py index 3ad8009..6d339aa 100644 --- a/tests/test_alpinebits_server_ping.py +++ b/tests/test_alpinebits_server_ping.py @@ -4,6 +4,7 @@ import pytest from xsdata_pydantic.bindings import XmlParser from alpine_bits_python.alpinebits_server import AlpineBitsClientInfo, AlpineBitsServer +from alpine_bits_python.const import HttpStatusCode from alpine_bits_python.generated.alpinebits import OtaPingRs @@ -60,7 +61,7 @@ async def test_ping_action_response_success(): client_info=client_info, version="2024-10", ) - assert response.status_code == 200 + assert response.status_code == HttpStatusCode.OK assert "", ) - assert response.status_code == 401 + assert response.status_code == HttpStatusCode.UNAUTHORIZED def test_xml_upload_invalid_path(self, client, basic_auth_headers): """Test XML upload with path traversal attempt. @@ -805,7 +803,7 @@ class TestAuthentication: ) # Should not be 401 - assert response.status_code != 401 + assert response.status_code != HttpStatusCode.UNAUTHORIZED def test_basic_auth_missing_credentials(self, client): """Test basic auth with missing credentials.""" @@ -814,7 +812,7 @@ class TestAuthentication: data={"action": "OTA_Ping:Handshaking"}, ) - assert response.status_code == 401 + assert response.status_code == HttpStatusCode.UNAUTHORIZED def test_basic_auth_malformed_header(self, client): """Test basic auth with malformed Authorization header.""" @@ -839,7 +837,7 @@ class TestEventDispatcher: # The async task runs in background and doesn't affect response response = client.post("/api/webhook/wix-form", json=sample_wix_form_data) - assert response.status_code == 200 + assert response.status_code == HttpStatusCode.OK # Event dispatcher is tested separately in its own test suite @@ -902,7 +900,7 @@ class TestCORS: # TestClient returns 400 for OPTIONS requests # In production, CORS middleware handles preflight correctly - assert response.status_code in [200, 400, 405] + assert response.status_code in [HttpStatusCode.OK, 400, 405] class TestRateLimiting: @@ -917,7 +915,7 @@ class TestRateLimiting: responses.append(response.status_code) # All should succeed if under limit - assert all(status == 200 for status in responses) + assert all(status == HttpStatusCode.OK for status in responses) if __name__ == "__main__": diff --git a/uv.lock b/uv.lock index 5f02cd1..8ced167 100644 --- a/uv.lock +++ b/uv.lock @@ -21,6 +21,7 @@ source = { editable = "." } dependencies = [ { name = "aiosqlite" }, { name = "annotatedyaml" }, + { name = "asyncpg" }, { name = "dotenv" }, { name = "fast-langdetect" }, { name = "fastapi" }, @@ -50,6 +51,7 @@ dev = [ requires-dist = [ { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "annotatedyaml", specifier = ">=1.0.0" }, + { name = "asyncpg", specifier = ">=0.30.0" }, { name = "dotenv", specifier = ">=0.9.9" }, { name = "fast-langdetect", specifier = ">=1.0.0" }, { name = "fastapi", specifier = ">=0.117.1" }, @@ -135,6 +137,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, ] +[[package]] +name = "asyncpg" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746, upload-time = "2024-10-20T00:30:41.127Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/22/e20602e1218dc07692acf70d5b902be820168d6282e69ef0d3cb920dc36f/asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70", size = 670373, upload-time = "2024-10-20T00:29:55.165Z" }, + { url = "https://files.pythonhosted.org/packages/3d/b3/0cf269a9d647852a95c06eb00b815d0b95a4eb4b55aa2d6ba680971733b9/asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3", size = 634745, upload-time = "2024-10-20T00:29:57.14Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6d/a4f31bf358ce8491d2a31bfe0d7bcf25269e80481e49de4d8616c4295a34/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33", size = 3512103, upload-time = "2024-10-20T00:29:58.499Z" }, + { url = "https://files.pythonhosted.org/packages/96/19/139227a6e67f407b9c386cb594d9628c6c78c9024f26df87c912fabd4368/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4", size = 3592471, upload-time = "2024-10-20T00:30:00.354Z" }, + { url = "https://files.pythonhosted.org/packages/67/e4/ab3ca38f628f53f0fd28d3ff20edff1c975dd1cb22482e0061916b4b9a74/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4", size = 3496253, upload-time = "2024-10-20T00:30:02.794Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5f/0bf65511d4eeac3a1f41c54034a492515a707c6edbc642174ae79034d3ba/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba", size = 3662720, upload-time = "2024-10-20T00:30:04.501Z" }, + { url = "https://files.pythonhosted.org/packages/e7/31/1513d5a6412b98052c3ed9158d783b1e09d0910f51fbe0e05f56cc370bc4/asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590", size = 560404, upload-time = "2024-10-20T00:30:06.537Z" }, + { url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623, upload-time = "2024-10-20T00:30:09.024Z" }, +] + [[package]] name = "certifi" version = "2025.8.3"