From 1f7649fffe8771bd16d29089d5700b3e936d3a54 Mon Sep 17 00:00:00 2001 From: Jonas Linter <{email_address}> Date: Tue, 4 Nov 2025 09:36:22 +0100 Subject: [PATCH] Updated fix sequences scripts --- .../util/fix_postgres_sequences.py | 73 ++++++++++++++----- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/src/alpine_bits_python/util/fix_postgres_sequences.py b/src/alpine_bits_python/util/fix_postgres_sequences.py index d04f2f9..a864deb 100644 --- a/src/alpine_bits_python/util/fix_postgres_sequences.py +++ b/src/alpine_bits_python/util/fix_postgres_sequences.py @@ -11,19 +11,24 @@ with explicit IDs, which doesn't automatically advance PostgreSQL sequences. The datetime migration ensures proper handling of timezone-aware datetimes, which is required by the application code. +Schema Support: + The script automatically detects and uses the schema configured in your config file. + If you have database.schema: "alpinebits" in your config, it will work with that schema. + Usage: - # Using default config.yaml + # Using default config.yaml (includes schema if configured) uv run python -m alpine_bits_python.util.fix_postgres_sequences - # Using a specific config file + # Using a specific config file (with schema support) uv run python -m alpine_bits_python.util.fix_postgres_sequences \ --config config/postgres.yaml - # Using DATABASE_URL environment variable + # Using DATABASE_URL environment variable (schema from config or DATABASE_SCHEMA env var) DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \ + DATABASE_SCHEMA="alpinebits" \ uv run python -m alpine_bits_python.util.fix_postgres_sequences - # Using command line argument + # Using command line argument (schema from config) uv run python -m alpine_bits_python.util.fix_postgres_sequences \ --database-url postgresql+asyncpg://user:pass@host/db """ @@ -42,16 +47,21 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from alpine_bits_python.config_loader import load_config -from alpine_bits_python.db import get_database_url +from alpine_bits_python.db import get_database_schema, get_database_url from alpine_bits_python.logging_config import get_logger, setup_logging _LOGGER = get_logger(__name__) -async def migrate_datetime_columns(session) -> None: +async def migrate_datetime_columns(session, schema_prefix: str = "") -> None: """Migrate DateTime columns to TIMESTAMP WITH TIME ZONE. This updates the columns to properly handle timezone-aware datetimes. + + Args: + session: Database session + schema_prefix: Schema prefix (e.g., "alpinebits." or "") + """ _LOGGER.info("\nMigrating DateTime columns to timezone-aware...") @@ -62,10 +72,11 @@ async def migrate_datetime_columns(session) -> None: ] for table_name, column_name in datetime_columns: - _LOGGER.info(f" {table_name}.{column_name}: Converting to TIMESTAMPTZ") + full_table = f"{schema_prefix}{table_name}" + _LOGGER.info(f" {full_table}.{column_name}: Converting to TIMESTAMPTZ") await session.execute( text( - f"ALTER TABLE {table_name} " + f"ALTER TABLE {full_table} " f"ALTER COLUMN {column_name} TYPE TIMESTAMP WITH TIME ZONE" ) ) @@ -74,11 +85,12 @@ async def migrate_datetime_columns(session) -> None: _LOGGER.info("✓ DateTime columns migrated to timezone-aware") -async def fix_sequences(database_url: str) -> None: +async def fix_sequences(database_url: str, schema_name: str = None) -> None: """Fix PostgreSQL sequences to match current max IDs and migrate datetime columns. Args: database_url: PostgreSQL database URL + schema_name: Schema name (e.g., "alpinebits") or None for public """ _LOGGER.info("=" * 70) @@ -88,16 +100,27 @@ async def fix_sequences(database_url: str) -> None: "Database: %s", database_url.split("@")[-1] if "@" in database_url else database_url, ) + if schema_name: + _LOGGER.info("Schema: %s", schema_name) _LOGGER.info("=" * 70) - # Create engine and session - engine = create_async_engine(database_url, echo=False) + # Create engine and session with schema support + connect_args = {} + if schema_name: + connect_args = { + "server_settings": {"search_path": f"{schema_name},public"} + } + + engine = create_async_engine(database_url, echo=False, connect_args=connect_args) SessionMaker = async_sessionmaker(engine, expire_on_commit=False) + # Determine schema prefix for SQL statements + schema_prefix = f"{schema_name}." if schema_name else "" + try: # Migrate datetime columns first async with SessionMaker() as session: - await migrate_datetime_columns(session) + await migrate_datetime_columns(session, schema_prefix) # Then fix sequences async with SessionMaker() as session: @@ -107,39 +130,43 @@ async def fix_sequences(database_url: str) -> None: ("hashed_customers", "hashed_customers_id_seq"), ("reservations", "reservations_id_seq"), ("acked_requests", "acked_requests_id_seq"), + ("conversions", "conversions_id_seq"), ] _LOGGER.info("\nResetting sequences...") for table_name, sequence_name in tables: + full_table = f"{schema_prefix}{table_name}" + full_sequence = f"{schema_prefix}{sequence_name}" + # Get current max ID result = await session.execute( - text(f"SELECT MAX(id) FROM {table_name}") + text(f"SELECT MAX(id) FROM {full_table}") ) max_id = result.scalar() # Get current sequence value result = await session.execute( - text(f"SELECT last_value FROM {sequence_name}") + text(f"SELECT last_value FROM {full_sequence}") ) current_seq = result.scalar() if max_id is None: - _LOGGER.info(f" {table_name}: empty table, setting sequence to 1") + _LOGGER.info(f" {full_table}: empty table, setting sequence to 1") await session.execute( - text(f"SELECT setval('{sequence_name}', 1, false)") + text(f"SELECT setval('{full_sequence}', 1, false)") ) elif current_seq <= max_id: new_seq = max_id + 1 _LOGGER.info( - f" {table_name}: max_id={max_id}, " + f" {full_table}: max_id={max_id}, " f"old_seq={current_seq}, new_seq={new_seq}" ) await session.execute( - text(f"SELECT setval('{sequence_name}', {new_seq}, false)") + text(f"SELECT setval('{full_sequence}', {new_seq}, false)") ) else: _LOGGER.info( - f" {table_name}: sequence already correct " + f" {full_table}: sequence already correct " f"(max_id={max_id}, seq={current_seq})" ) @@ -191,8 +218,11 @@ async def main(): config = {} # Determine database URL (same logic as migrate_sqlite_to_postgres) + schema_name = None if args.database_url: database_url = args.database_url + # Get schema from default config if available + schema_name = get_database_schema(config) elif args.config: # Load config file manually (simpler YAML without secrets) _LOGGER.info("Loading database config from: %s", args.config) @@ -201,6 +231,7 @@ async def main(): config_text = config_path.read_text() target_config = yaml.safe_load(config_text) database_url = target_config["database"]["url"] + schema_name = target_config.get("database", {}).get("schema") _LOGGER.info("Successfully loaded config") except (FileNotFoundError, ValueError, KeyError): _LOGGER.exception("Failed to load config") @@ -213,6 +244,8 @@ async def main(): if not database_url: # Try from default config database_url = get_database_url(config) + # Get schema from config or environment + schema_name = get_database_schema(config) if "postgresql" not in database_url and "postgres" not in database_url: _LOGGER.error("This script only works with PostgreSQL databases.") @@ -225,7 +258,7 @@ async def main(): sys.exit(1) # Run the fix - await fix_sequences(database_url) + await fix_sequences(database_url, schema_name) if __name__ == "__main__":