diff --git a/src/alpine_bits_python/api.py b/src/alpine_bits_python/api.py index 78a1381..eb124cd 100644 --- a/src/alpine_bits_python/api.py +++ b/src/alpine_bits_python/api.py @@ -42,7 +42,7 @@ from .config_loader import load_config from .const import CONF_GOOGLE_ACCOUNT, CONF_HOTEL_ID, CONF_META_ACCOUNT, HttpStatusCode from .conversion_service import ConversionService from .customer_service import CustomerService -from .db import Base, create_database_engine +from .db import Base, ResilientAsyncSession, create_database_engine from .db import Customer as DBCustomer from .db import Reservation as DBReservation from .email_monitoring import ReservationStatsCollector @@ -291,8 +291,12 @@ async def lifespan(app: FastAPI): engine = create_database_engine(config=config, echo=False) AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) + # Create resilient session wrapper for automatic connection recovery + resilient_session = ResilientAsyncSession(AsyncSessionLocal, engine) + app.state.engine = engine app.state.async_sessionmaker = AsyncSessionLocal + app.state.resilient_session = resilient_session app.state.config = config app.state.alpine_bits_server = AlpineBitsServer(config) app.state.event_dispatcher = event_dispatcher @@ -394,11 +398,25 @@ async def lifespan(app: FastAPI): async def get_async_session(request: Request): + """Get a database session with automatic connection recovery. + + This dependency provides an async session that will automatically + retry on connection errors, disposing the pool and reconnecting. + """ async_sessionmaker = request.app.state.async_sessionmaker async with async_sessionmaker() as session: yield session +def get_resilient_session(request: Request) -> ResilientAsyncSession: + """Get the resilient session manager from app state. + + This provides access to the ResilientAsyncSession for use in handlers + that need retry capability on connection errors. + """ + return request.app.state.resilient_session + + app = FastAPI( title="Wix Form Handler API", description="Secure API endpoint to receive and process Wix form submissions with authentication and rate limiting", diff --git a/src/alpine_bits_python/db.py b/src/alpine_bits_python/db.py index a1415de..b9767fa 100644 --- a/src/alpine_bits_python/db.py +++ b/src/alpine_bits_python/db.py @@ -1,8 +1,11 @@ +import asyncio import hashlib import os +from typing import Any, AsyncGenerator, Callable, TypeVar from sqlalchemy import Boolean, Column, Date, DateTime, ForeignKey, Integer, String -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.exc import DBAPIError, InternalServerError +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import declarative_base, relationship from .logging_config import get_logger @@ -11,6 +14,14 @@ _LOGGER = get_logger(__name__) Base = declarative_base() +# Type variable for async functions +T = TypeVar("T") + +# Maximum number of retries for session operations +MAX_RETRIES = 3 +# Delay between retries in seconds +RETRY_DELAY = 0.5 + # Async SQLAlchemy setup def get_database_url(config=None): @@ -92,6 +103,108 @@ def create_database_engine(config=None, echo=False) -> AsyncEngine: return create_async_engine(database_url, echo=echo, connect_args=connect_args) +class ResilientAsyncSession: + """Wrapper around AsyncSession that handles connection recovery. + + This wrapper automatically retries operations on connection loss or OID errors, + disposing the connection pool and creating a fresh session on failure. + """ + + def __init__( + self, + async_sessionmaker_: async_sessionmaker[AsyncSession], + engine: AsyncEngine, + ): + """Initialize the resilient session wrapper. + + Args: + async_sessionmaker_: Factory for creating async sessions + engine: The SQLAlchemy async engine for connection recovery + """ + self.async_sessionmaker = async_sessionmaker_ + self.engine = engine + + async def execute_with_retry( + self, func: Callable[..., T], *args, **kwargs + ) -> T: + """Execute a function with automatic retry on connection errors. + + Args: + func: Async function that takes a session as first argument + *args: Positional arguments to pass to func (first arg should be session) + **kwargs: Keyword arguments to pass to func + + Returns: + Result of the function call + + Raises: + The original exception if all retries are exhausted + """ + last_error = None + + for attempt in range(MAX_RETRIES): + try: + async with self.async_sessionmaker() as session: + return await func(session, *args, **kwargs) + except (InternalServerError, DBAPIError) as e: + last_error = e + error_msg = str(e).lower() + + # Check if this is an OID error or connection loss + if ( + "could not open relation" in error_msg + or "lost connection" in error_msg + or "connection closed" in error_msg + or "connection refused" in error_msg + ): + _LOGGER.warning( + "Connection error on attempt %d/%d: %s. Disposing pool and retrying...", + attempt + 1, + MAX_RETRIES, + e.__class__.__name__, + ) + + # Dispose the entire connection pool to force new connections + await self.engine.dispose() + + # Wait before retry (exponential backoff) + if attempt < MAX_RETRIES - 1: + wait_time = RETRY_DELAY * (2 ** attempt) + await asyncio.sleep(wait_time) + else: + # Not a connection-related error, re-raise immediately + raise + except Exception: + # Any other exception, re-raise immediately + raise + + # All retries exhausted + _LOGGER.error( + "Failed to execute query after %d retries: %s", + MAX_RETRIES, + last_error.__class__.__name__, + ) + raise last_error + + +async def get_resilient_session( + resilient_session: "ResilientAsyncSession", +) -> AsyncGenerator[AsyncSession, None]: + """Dependency for FastAPI that provides a resilient async session. + + This generator creates a new session with automatic retry capability + on connection errors. Used as a dependency in FastAPI endpoints. + + Args: + resilient_session: ResilientAsyncSession instance from app state + + Yields: + AsyncSession instance for database operations + """ + async with resilient_session.async_sessionmaker() as session: + yield session + + class Customer(Base): __tablename__ = "customers" id = Column(Integer, primary_key=True)