From a5006b2fafb94d61ebd52c587b038b49e6248d49 Mon Sep 17 00:00:00 2001 From: Jonas Linter <{email_address}> Date: Fri, 17 Oct 2025 22:27:10 +0200 Subject: [PATCH] Fix autoincrement --- .../util/fix_postgres_sequences.py | 152 ++++++++++++++++++ .../util/migrate_sqlite_to_postgres.py | 33 +++- 2 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 src/alpine_bits_python/util/fix_postgres_sequences.py diff --git a/src/alpine_bits_python/util/fix_postgres_sequences.py b/src/alpine_bits_python/util/fix_postgres_sequences.py new file mode 100644 index 0000000..f1fdba6 --- /dev/null +++ b/src/alpine_bits_python/util/fix_postgres_sequences.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +"""Fix PostgreSQL sequence values after migration from SQLite. + +This script resets all ID sequence values to match the current maximum ID +in each table. This is necessary because the migration script inserts records +with explicit IDs, which doesn't automatically advance PostgreSQL sequences. + +Usage: + # Using config file + uv run python -m alpine_bits_python.util.fix_postgres_sequences + + # Using DATABASE_URL environment variable + DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \ + uv run python -m alpine_bits_python.util.fix_postgres_sequences + + # Using command line argument + uv run python -m alpine_bits_python.util.fix_postgres_sequences \ + --database-url postgresql+asyncpg://user:pass@host/db +""" + +import argparse +import asyncio +import sys +from pathlib import Path + +# Add parent directory to path so we can import alpine_bits_python +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from alpine_bits_python.config_loader import load_config +from alpine_bits_python.db import get_database_url +from alpine_bits_python.logging_config import get_logger, setup_logging + +_LOGGER = get_logger(__name__) + + +async def fix_sequences(database_url: str) -> None: + """Fix PostgreSQL sequences to match current max IDs. + + Args: + database_url: PostgreSQL database URL + """ + _LOGGER.info("=" * 70) + _LOGGER.info("PostgreSQL Sequence Fix") + _LOGGER.info("=" * 70) + _LOGGER.info("Database: %s", database_url.split("@")[-1] if "@" in database_url else database_url) + _LOGGER.info("=" * 70) + + # Create engine and session + engine = create_async_engine(database_url, echo=False) + SessionMaker = async_sessionmaker(engine, expire_on_commit=False) + + try: + async with SessionMaker() as session: + # List of tables and their sequence names + tables = [ + ("customers", "customers_id_seq"), + ("hashed_customers", "hashed_customers_id_seq"), + ("reservations", "reservations_id_seq"), + ("acked_requests", "acked_requests_id_seq"), + ] + + _LOGGER.info("\nResetting sequences...") + for table_name, sequence_name in tables: + # Get current max ID + result = await session.execute( + text(f"SELECT MAX(id) FROM {table_name}") + ) + max_id = result.scalar() + + # Get current sequence value + result = await session.execute( + text(f"SELECT last_value FROM {sequence_name}") + ) + current_seq = result.scalar() + + if max_id is None: + _LOGGER.info(f" {table_name}: empty table, setting sequence to 1") + await session.execute( + text(f"SELECT setval('{sequence_name}', 1, false)") + ) + elif current_seq <= max_id: + new_seq = max_id + 1 + _LOGGER.info( + f" {table_name}: max_id={max_id}, " + f"old_seq={current_seq}, new_seq={new_seq}" + ) + await session.execute( + text(f"SELECT setval('{sequence_name}', {new_seq}, false)") + ) + else: + _LOGGER.info( + f" {table_name}: sequence already correct " + f"(max_id={max_id}, seq={current_seq})" + ) + + await session.commit() + + _LOGGER.info("\n" + "=" * 70) + _LOGGER.info("✓ Sequences fixed successfully!") + _LOGGER.info("=" * 70) + _LOGGER.info("\nYou can now insert new records without ID conflicts.") + + except Exception as e: + _LOGGER.exception("Failed to fix sequences: %s", e) + raise + + finally: + await engine.dispose() + + +async def main(): + """Run the sequence fix.""" + parser = argparse.ArgumentParser( + description="Fix PostgreSQL sequences after SQLite migration", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--database-url", + help="PostgreSQL database URL (default: from config or DATABASE_URL env var)", + ) + + args = parser.parse_args() + + try: + # Load config + config = load_config() + setup_logging(config) + except Exception as e: + _LOGGER.warning("Failed to load config: %s. Using defaults.", e) + config = {} + + # Determine database URL + if args.database_url: + database_url = args.database_url + else: + database_url = get_database_url(config) + + if "postgresql" not in database_url and "postgres" not in database_url: + _LOGGER.error("This script only works with PostgreSQL databases.") + _LOGGER.error("Current database URL: %s", database_url) + sys.exit(1) + + # Run the fix + await fix_sequences(database_url) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/alpine_bits_python/util/migrate_sqlite_to_postgres.py b/src/alpine_bits_python/util/migrate_sqlite_to_postgres.py index 84ff22d..07f21d6 100644 --- a/src/alpine_bits_python/util/migrate_sqlite_to_postgres.py +++ b/src/alpine_bits_python/util/migrate_sqlite_to_postgres.py @@ -43,7 +43,7 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import yaml -from sqlalchemy import select +from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from alpine_bits_python.config_loader import load_config @@ -94,6 +94,31 @@ async def get_table_counts(session: AsyncSession) -> dict[str, int]: return counts +async def reset_sequences(session: AsyncSession) -> None: + """Reset PostgreSQL sequences to match the current max ID values. + + This is necessary after migrating data with explicit IDs from SQLite, + as PostgreSQL sequences won't automatically advance when IDs are set explicitly. + """ + tables = [ + ("customers", "customers_id_seq"), + ("hashed_customers", "hashed_customers_id_seq"), + ("reservations", "reservations_id_seq"), + ("acked_requests", "acked_requests_id_seq"), + ] + + for table_name, sequence_name in tables: + # Set sequence to max(id) + 1, or 1 if table is empty + query = text(f""" + SELECT setval('{sequence_name}', + COALESCE((SELECT MAX(id) FROM {table_name}), 0) + 1, + false) + """) + await session.execute(query) + + await session.commit() + + async def migrate_data( source_url: str, target_url: str, @@ -320,6 +345,12 @@ async def migrate_data( _LOGGER.info("✓ Migrated %d acked requests", len(acked_requests)) + # Reset PostgreSQL sequences + _LOGGER.info("\n[5/5] Resetting PostgreSQL sequences...") + async with TargetSession() as target_session: + await reset_sequences(target_session) + _LOGGER.info("✓ Sequences reset to match current max IDs") + # Verify migration _LOGGER.info("\n" + "=" * 70) _LOGGER.info("Verifying migration...")