Merging schema_extension #9
152
src/alpine_bits_python/util/fix_postgres_sequences.py
Normal file
152
src/alpine_bits_python/util/fix_postgres_sequences.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Fix PostgreSQL sequence values after migration from SQLite.
|
||||||
|
|
||||||
|
This script resets all ID sequence values to match the current maximum ID
|
||||||
|
in each table. This is necessary because the migration script inserts records
|
||||||
|
with explicit IDs, which doesn't automatically advance PostgreSQL sequences.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Using config file
|
||||||
|
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
||||||
|
|
||||||
|
# Using DATABASE_URL environment variable
|
||||||
|
DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \
|
||||||
|
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
||||||
|
|
||||||
|
# Using command line argument
|
||||||
|
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
|
||||||
|
--database-url postgresql+asyncpg://user:pass@host/db
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path so we can import alpine_bits_python
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from alpine_bits_python.config_loader import load_config
|
||||||
|
from alpine_bits_python.db import get_database_url
|
||||||
|
from alpine_bits_python.logging_config import get_logger, setup_logging
|
||||||
|
|
||||||
|
_LOGGER = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def fix_sequences(database_url: str) -> None:
|
||||||
|
"""Fix PostgreSQL sequences to match current max IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: PostgreSQL database URL
|
||||||
|
"""
|
||||||
|
_LOGGER.info("=" * 70)
|
||||||
|
_LOGGER.info("PostgreSQL Sequence Fix")
|
||||||
|
_LOGGER.info("=" * 70)
|
||||||
|
_LOGGER.info("Database: %s", database_url.split("@")[-1] if "@" in database_url else database_url)
|
||||||
|
_LOGGER.info("=" * 70)
|
||||||
|
|
||||||
|
# Create engine and session
|
||||||
|
engine = create_async_engine(database_url, echo=False)
|
||||||
|
SessionMaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with SessionMaker() as session:
|
||||||
|
# List of tables and their sequence names
|
||||||
|
tables = [
|
||||||
|
("customers", "customers_id_seq"),
|
||||||
|
("hashed_customers", "hashed_customers_id_seq"),
|
||||||
|
("reservations", "reservations_id_seq"),
|
||||||
|
("acked_requests", "acked_requests_id_seq"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_LOGGER.info("\nResetting sequences...")
|
||||||
|
for table_name, sequence_name in tables:
|
||||||
|
# Get current max ID
|
||||||
|
result = await session.execute(
|
||||||
|
text(f"SELECT MAX(id) FROM {table_name}")
|
||||||
|
)
|
||||||
|
max_id = result.scalar()
|
||||||
|
|
||||||
|
# Get current sequence value
|
||||||
|
result = await session.execute(
|
||||||
|
text(f"SELECT last_value FROM {sequence_name}")
|
||||||
|
)
|
||||||
|
current_seq = result.scalar()
|
||||||
|
|
||||||
|
if max_id is None:
|
||||||
|
_LOGGER.info(f" {table_name}: empty table, setting sequence to 1")
|
||||||
|
await session.execute(
|
||||||
|
text(f"SELECT setval('{sequence_name}', 1, false)")
|
||||||
|
)
|
||||||
|
elif current_seq <= max_id:
|
||||||
|
new_seq = max_id + 1
|
||||||
|
_LOGGER.info(
|
||||||
|
f" {table_name}: max_id={max_id}, "
|
||||||
|
f"old_seq={current_seq}, new_seq={new_seq}"
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(f"SELECT setval('{sequence_name}', {new_seq}, false)")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_LOGGER.info(
|
||||||
|
f" {table_name}: sequence already correct "
|
||||||
|
f"(max_id={max_id}, seq={current_seq})"
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
_LOGGER.info("\n" + "=" * 70)
|
||||||
|
_LOGGER.info("✓ Sequences fixed successfully!")
|
||||||
|
_LOGGER.info("=" * 70)
|
||||||
|
_LOGGER.info("\nYou can now insert new records without ID conflicts.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
_LOGGER.exception("Failed to fix sequences: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run the sequence fix."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Fix PostgreSQL sequences after SQLite migration",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=__doc__,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--database-url",
|
||||||
|
help="PostgreSQL database URL (default: from config or DATABASE_URL env var)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load config
|
||||||
|
config = load_config()
|
||||||
|
setup_logging(config)
|
||||||
|
except Exception as e:
|
||||||
|
_LOGGER.warning("Failed to load config: %s. Using defaults.", e)
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# Determine database URL
|
||||||
|
if args.database_url:
|
||||||
|
database_url = args.database_url
|
||||||
|
else:
|
||||||
|
database_url = get_database_url(config)
|
||||||
|
|
||||||
|
if "postgresql" not in database_url and "postgres" not in database_url:
|
||||||
|
_LOGGER.error("This script only works with PostgreSQL databases.")
|
||||||
|
_LOGGER.error("Current database URL: %s", database_url)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Run the fix
|
||||||
|
await fix_sequences(database_url)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -43,7 +43,7 @@ from pathlib import Path
|
|||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from alpine_bits_python.config_loader import load_config
|
from alpine_bits_python.config_loader import load_config
|
||||||
@@ -94,6 +94,31 @@ async def get_table_counts(session: AsyncSession) -> dict[str, int]:
|
|||||||
return counts
|
return counts
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_sequences(session: AsyncSession) -> None:
|
||||||
|
"""Reset PostgreSQL sequences to match the current max ID values.
|
||||||
|
|
||||||
|
This is necessary after migrating data with explicit IDs from SQLite,
|
||||||
|
as PostgreSQL sequences won't automatically advance when IDs are set explicitly.
|
||||||
|
"""
|
||||||
|
tables = [
|
||||||
|
("customers", "customers_id_seq"),
|
||||||
|
("hashed_customers", "hashed_customers_id_seq"),
|
||||||
|
("reservations", "reservations_id_seq"),
|
||||||
|
("acked_requests", "acked_requests_id_seq"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for table_name, sequence_name in tables:
|
||||||
|
# Set sequence to max(id) + 1, or 1 if table is empty
|
||||||
|
query = text(f"""
|
||||||
|
SELECT setval('{sequence_name}',
|
||||||
|
COALESCE((SELECT MAX(id) FROM {table_name}), 0) + 1,
|
||||||
|
false)
|
||||||
|
""")
|
||||||
|
await session.execute(query)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
async def migrate_data(
|
async def migrate_data(
|
||||||
source_url: str,
|
source_url: str,
|
||||||
target_url: str,
|
target_url: str,
|
||||||
@@ -320,6 +345,12 @@ async def migrate_data(
|
|||||||
|
|
||||||
_LOGGER.info("✓ Migrated %d acked requests", len(acked_requests))
|
_LOGGER.info("✓ Migrated %d acked requests", len(acked_requests))
|
||||||
|
|
||||||
|
# Reset PostgreSQL sequences
|
||||||
|
_LOGGER.info("\n[5/5] Resetting PostgreSQL sequences...")
|
||||||
|
async with TargetSession() as target_session:
|
||||||
|
await reset_sequences(target_session)
|
||||||
|
_LOGGER.info("✓ Sequences reset to match current max IDs")
|
||||||
|
|
||||||
# Verify migration
|
# Verify migration
|
||||||
_LOGGER.info("\n" + "=" * 70)
|
_LOGGER.info("\n" + "=" * 70)
|
||||||
_LOGGER.info("Verifying migration...")
|
_LOGGER.info("Verifying migration...")
|
||||||
|
|||||||
Reference in New Issue
Block a user