Merging schema_extension #9

Merged
jonas merged 15 commits from schema_extension into main 2025-10-20 07:19:26 +00:00
2 changed files with 184 additions and 1 deletions
Showing only changes of commit a5006b2faf - Show all commits

View File

@@ -0,0 +1,152 @@
#!/usr/bin/env python3
"""Fix PostgreSQL sequence values after migration from SQLite.
This script resets all ID sequence values to match the current maximum ID
in each table. This is necessary because the migration script inserts records
with explicit IDs, which doesn't automatically advance PostgreSQL sequences.
Usage:
# Using config file
uv run python -m alpine_bits_python.util.fix_postgres_sequences
# Using DATABASE_URL environment variable
DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \
uv run python -m alpine_bits_python.util.fix_postgres_sequences
# Using command line argument
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
--database-url postgresql+asyncpg://user:pass@host/db
"""
import argparse
import asyncio
import sys
from pathlib import Path
# Add parent directory to path so we can import alpine_bits_python
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from alpine_bits_python.config_loader import load_config
from alpine_bits_python.db import get_database_url
from alpine_bits_python.logging_config import get_logger, setup_logging
_LOGGER = get_logger(__name__)
async def fix_sequences(database_url: str) -> None:
"""Fix PostgreSQL sequences to match current max IDs.
Args:
database_url: PostgreSQL database URL
"""
_LOGGER.info("=" * 70)
_LOGGER.info("PostgreSQL Sequence Fix")
_LOGGER.info("=" * 70)
_LOGGER.info("Database: %s", database_url.split("@")[-1] if "@" in database_url else database_url)
_LOGGER.info("=" * 70)
# Create engine and session
engine = create_async_engine(database_url, echo=False)
SessionMaker = async_sessionmaker(engine, expire_on_commit=False)
try:
async with SessionMaker() as session:
# List of tables and their sequence names
tables = [
("customers", "customers_id_seq"),
("hashed_customers", "hashed_customers_id_seq"),
("reservations", "reservations_id_seq"),
("acked_requests", "acked_requests_id_seq"),
]
_LOGGER.info("\nResetting sequences...")
for table_name, sequence_name in tables:
# Get current max ID
result = await session.execute(
text(f"SELECT MAX(id) FROM {table_name}")
)
max_id = result.scalar()
# Get current sequence value
result = await session.execute(
text(f"SELECT last_value FROM {sequence_name}")
)
current_seq = result.scalar()
if max_id is None:
_LOGGER.info(f" {table_name}: empty table, setting sequence to 1")
await session.execute(
text(f"SELECT setval('{sequence_name}', 1, false)")
)
elif current_seq <= max_id:
new_seq = max_id + 1
_LOGGER.info(
f" {table_name}: max_id={max_id}, "
f"old_seq={current_seq}, new_seq={new_seq}"
)
await session.execute(
text(f"SELECT setval('{sequence_name}', {new_seq}, false)")
)
else:
_LOGGER.info(
f" {table_name}: sequence already correct "
f"(max_id={max_id}, seq={current_seq})"
)
await session.commit()
_LOGGER.info("\n" + "=" * 70)
_LOGGER.info("✓ Sequences fixed successfully!")
_LOGGER.info("=" * 70)
_LOGGER.info("\nYou can now insert new records without ID conflicts.")
except Exception as e:
_LOGGER.exception("Failed to fix sequences: %s", e)
raise
finally:
await engine.dispose()
async def main():
"""Run the sequence fix."""
parser = argparse.ArgumentParser(
description="Fix PostgreSQL sequences after SQLite migration",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--database-url",
help="PostgreSQL database URL (default: from config or DATABASE_URL env var)",
)
args = parser.parse_args()
try:
# Load config
config = load_config()
setup_logging(config)
except Exception as e:
_LOGGER.warning("Failed to load config: %s. Using defaults.", e)
config = {}
# Determine database URL
if args.database_url:
database_url = args.database_url
else:
database_url = get_database_url(config)
if "postgresql" not in database_url and "postgres" not in database_url:
_LOGGER.error("This script only works with PostgreSQL databases.")
_LOGGER.error("Current database URL: %s", database_url)
sys.exit(1)
# Run the fix
await fix_sequences(database_url)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -43,7 +43,7 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
import yaml
from sqlalchemy import select
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from alpine_bits_python.config_loader import load_config
@@ -94,6 +94,31 @@ async def get_table_counts(session: AsyncSession) -> dict[str, int]:
return counts
async def reset_sequences(session: AsyncSession) -> None:
"""Reset PostgreSQL sequences to match the current max ID values.
This is necessary after migrating data with explicit IDs from SQLite,
as PostgreSQL sequences won't automatically advance when IDs are set explicitly.
"""
tables = [
("customers", "customers_id_seq"),
("hashed_customers", "hashed_customers_id_seq"),
("reservations", "reservations_id_seq"),
("acked_requests", "acked_requests_id_seq"),
]
for table_name, sequence_name in tables:
# Set sequence to max(id) + 1, or 1 if table is empty
query = text(f"""
SELECT setval('{sequence_name}',
COALESCE((SELECT MAX(id) FROM {table_name}), 0) + 1,
false)
""")
await session.execute(query)
await session.commit()
async def migrate_data(
source_url: str,
target_url: str,
@@ -320,6 +345,12 @@ async def migrate_data(
_LOGGER.info("✓ Migrated %d acked requests", len(acked_requests))
# Reset PostgreSQL sequences
_LOGGER.info("\n[5/5] Resetting PostgreSQL sequences...")
async with TargetSession() as target_session:
await reset_sequences(target_session)
_LOGGER.info("✓ Sequences reset to match current max IDs")
# Verify migration
_LOGGER.info("\n" + "=" * 70)
_LOGGER.info("Verifying migration...")