Fix autoincrement
This commit is contained in:
152
src/alpine_bits_python/util/fix_postgres_sequences.py
Normal file
152
src/alpine_bits_python/util/fix_postgres_sequences.py
Normal file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Fix PostgreSQL sequence values after migration from SQLite.
|
||||
|
||||
This script resets all ID sequence values to match the current maximum ID
|
||||
in each table. This is necessary because the migration script inserts records
|
||||
with explicit IDs, which doesn't automatically advance PostgreSQL sequences.
|
||||
|
||||
Usage:
|
||||
# Using config file
|
||||
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
||||
|
||||
# Using DATABASE_URL environment variable
|
||||
DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \
|
||||
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
||||
|
||||
# Using command line argument
|
||||
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
|
||||
--database-url postgresql+asyncpg://user:pass@host/db
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path so we can import alpine_bits_python
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from alpine_bits_python.config_loader import load_config
|
||||
from alpine_bits_python.db import get_database_url
|
||||
from alpine_bits_python.logging_config import get_logger, setup_logging
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
async def fix_sequences(database_url: str) -> None:
|
||||
"""Fix PostgreSQL sequences to match current max IDs.
|
||||
|
||||
Args:
|
||||
database_url: PostgreSQL database URL
|
||||
"""
|
||||
_LOGGER.info("=" * 70)
|
||||
_LOGGER.info("PostgreSQL Sequence Fix")
|
||||
_LOGGER.info("=" * 70)
|
||||
_LOGGER.info("Database: %s", database_url.split("@")[-1] if "@" in database_url else database_url)
|
||||
_LOGGER.info("=" * 70)
|
||||
|
||||
# Create engine and session
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
SessionMaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
try:
|
||||
async with SessionMaker() as session:
|
||||
# List of tables and their sequence names
|
||||
tables = [
|
||||
("customers", "customers_id_seq"),
|
||||
("hashed_customers", "hashed_customers_id_seq"),
|
||||
("reservations", "reservations_id_seq"),
|
||||
("acked_requests", "acked_requests_id_seq"),
|
||||
]
|
||||
|
||||
_LOGGER.info("\nResetting sequences...")
|
||||
for table_name, sequence_name in tables:
|
||||
# Get current max ID
|
||||
result = await session.execute(
|
||||
text(f"SELECT MAX(id) FROM {table_name}")
|
||||
)
|
||||
max_id = result.scalar()
|
||||
|
||||
# Get current sequence value
|
||||
result = await session.execute(
|
||||
text(f"SELECT last_value FROM {sequence_name}")
|
||||
)
|
||||
current_seq = result.scalar()
|
||||
|
||||
if max_id is None:
|
||||
_LOGGER.info(f" {table_name}: empty table, setting sequence to 1")
|
||||
await session.execute(
|
||||
text(f"SELECT setval('{sequence_name}', 1, false)")
|
||||
)
|
||||
elif current_seq <= max_id:
|
||||
new_seq = max_id + 1
|
||||
_LOGGER.info(
|
||||
f" {table_name}: max_id={max_id}, "
|
||||
f"old_seq={current_seq}, new_seq={new_seq}"
|
||||
)
|
||||
await session.execute(
|
||||
text(f"SELECT setval('{sequence_name}', {new_seq}, false)")
|
||||
)
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f" {table_name}: sequence already correct "
|
||||
f"(max_id={max_id}, seq={current_seq})"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
_LOGGER.info("\n" + "=" * 70)
|
||||
_LOGGER.info("✓ Sequences fixed successfully!")
|
||||
_LOGGER.info("=" * 70)
|
||||
_LOGGER.info("\nYou can now insert new records without ID conflicts.")
|
||||
|
||||
except Exception as e:
|
||||
_LOGGER.exception("Failed to fix sequences: %s", e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the sequence fix."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fix PostgreSQL sequences after SQLite migration",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--database-url",
|
||||
help="PostgreSQL database URL (default: from config or DATABASE_URL env var)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Load config
|
||||
config = load_config()
|
||||
setup_logging(config)
|
||||
except Exception as e:
|
||||
_LOGGER.warning("Failed to load config: %s. Using defaults.", e)
|
||||
config = {}
|
||||
|
||||
# Determine database URL
|
||||
if args.database_url:
|
||||
database_url = args.database_url
|
||||
else:
|
||||
database_url = get_database_url(config)
|
||||
|
||||
if "postgresql" not in database_url and "postgres" not in database_url:
|
||||
_LOGGER.error("This script only works with PostgreSQL databases.")
|
||||
_LOGGER.error("Current database URL: %s", database_url)
|
||||
sys.exit(1)
|
||||
|
||||
# Run the fix
|
||||
await fix_sequences(database_url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user