233 lines
8.2 KiB
Python
233 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Fix PostgreSQL sequences and migrate datetime columns after SQLite migration.
|
|
|
|
This script performs two operations:
|
|
1. Migrates DateTime columns to TIMESTAMP WITH TIME ZONE for timezone-aware support
|
|
2. Resets all ID sequence values to match the current maximum ID in each table
|
|
|
|
The sequence reset is necessary because the migration script inserts records
|
|
with explicit IDs, which doesn't automatically advance PostgreSQL sequences.
|
|
|
|
The datetime migration ensures proper handling of timezone-aware datetimes,
|
|
which is required by the application code.
|
|
|
|
Usage:
|
|
# Using default config.yaml
|
|
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
|
|
|
# Using a specific config file
|
|
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
|
|
--config config/postgres.yaml
|
|
|
|
# Using DATABASE_URL environment variable
|
|
DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \
|
|
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
|
|
|
# Using command line argument
|
|
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
|
|
--database-url postgresql+asyncpg://user:pass@host/db
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path so we can import alpine_bits_python
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
import yaml
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
|
|
from alpine_bits_python.config_loader import load_config
|
|
from alpine_bits_python.db import get_database_url
|
|
from alpine_bits_python.logging_config import get_logger, setup_logging
|
|
|
|
_LOGGER = get_logger(__name__)
|
|
|
|
|
|
async def migrate_datetime_columns(session) -> None:
|
|
"""Migrate DateTime columns to TIMESTAMP WITH TIME ZONE.
|
|
|
|
This updates the columns to properly handle timezone-aware datetimes.
|
|
"""
|
|
_LOGGER.info("\nMigrating DateTime columns to timezone-aware...")
|
|
|
|
datetime_columns = [
|
|
("hashed_customers", "created_at"),
|
|
("reservations", "created_at"),
|
|
("acked_requests", "timestamp"),
|
|
]
|
|
|
|
for table_name, column_name in datetime_columns:
|
|
_LOGGER.info(f" {table_name}.{column_name}: Converting to TIMESTAMPTZ")
|
|
await session.execute(
|
|
text(
|
|
f"ALTER TABLE {table_name} "
|
|
f"ALTER COLUMN {column_name} TYPE TIMESTAMP WITH TIME ZONE"
|
|
)
|
|
)
|
|
|
|
await session.commit()
|
|
_LOGGER.info("✓ DateTime columns migrated to timezone-aware")
|
|
|
|
|
|
async def fix_sequences(database_url: str) -> None:
|
|
"""Fix PostgreSQL sequences to match current max IDs and migrate datetime columns.
|
|
|
|
Args:
|
|
database_url: PostgreSQL database URL
|
|
|
|
"""
|
|
_LOGGER.info("=" * 70)
|
|
_LOGGER.info("PostgreSQL Migration & Sequence Fix")
|
|
_LOGGER.info("=" * 70)
|
|
_LOGGER.info(
|
|
"Database: %s",
|
|
database_url.split("@")[-1] if "@" in database_url else database_url,
|
|
)
|
|
_LOGGER.info("=" * 70)
|
|
|
|
# Create engine and session
|
|
engine = create_async_engine(database_url, echo=False)
|
|
SessionMaker = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
try:
|
|
# Migrate datetime columns first
|
|
async with SessionMaker() as session:
|
|
await migrate_datetime_columns(session)
|
|
|
|
# Then fix sequences
|
|
async with SessionMaker() as session:
|
|
# List of tables and their sequence names
|
|
tables = [
|
|
("customers", "customers_id_seq"),
|
|
("hashed_customers", "hashed_customers_id_seq"),
|
|
("reservations", "reservations_id_seq"),
|
|
("acked_requests", "acked_requests_id_seq"),
|
|
]
|
|
|
|
_LOGGER.info("\nResetting sequences...")
|
|
for table_name, sequence_name in tables:
|
|
# Get current max ID
|
|
result = await session.execute(
|
|
text(f"SELECT MAX(id) FROM {table_name}")
|
|
)
|
|
max_id = result.scalar()
|
|
|
|
# Get current sequence value
|
|
result = await session.execute(
|
|
text(f"SELECT last_value FROM {sequence_name}")
|
|
)
|
|
current_seq = result.scalar()
|
|
|
|
if max_id is None:
|
|
_LOGGER.info(f" {table_name}: empty table, setting sequence to 1")
|
|
await session.execute(
|
|
text(f"SELECT setval('{sequence_name}', 1, false)")
|
|
)
|
|
elif current_seq <= max_id:
|
|
new_seq = max_id + 1
|
|
_LOGGER.info(
|
|
f" {table_name}: max_id={max_id}, "
|
|
f"old_seq={current_seq}, new_seq={new_seq}"
|
|
)
|
|
await session.execute(
|
|
text(f"SELECT setval('{sequence_name}', {new_seq}, false)")
|
|
)
|
|
else:
|
|
_LOGGER.info(
|
|
f" {table_name}: sequence already correct "
|
|
f"(max_id={max_id}, seq={current_seq})"
|
|
)
|
|
|
|
await session.commit()
|
|
|
|
_LOGGER.info("\n" + "=" * 70)
|
|
_LOGGER.info("✓ Migration completed successfully!")
|
|
_LOGGER.info("=" * 70)
|
|
_LOGGER.info("\nChanges applied:")
|
|
_LOGGER.info(" 1. DateTime columns are now timezone-aware (TIMESTAMPTZ)")
|
|
_LOGGER.info(" 2. Sequences are reset to match current max IDs")
|
|
_LOGGER.info("\nYou can now insert new records without conflicts.")
|
|
|
|
except Exception as e:
|
|
_LOGGER.exception("Failed to fix sequences: %s", e)
|
|
raise
|
|
|
|
finally:
|
|
await engine.dispose()
|
|
|
|
|
|
async def main():
|
|
"""Run the sequence fix."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Fix PostgreSQL sequences after SQLite migration",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__,
|
|
)
|
|
parser.add_argument(
|
|
"--database-url",
|
|
help="PostgreSQL database URL (default: from config or DATABASE_URL env var)",
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
help=(
|
|
"Path to config file containing PostgreSQL database URL "
|
|
"(keeps password out of bash history)"
|
|
),
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
# Load config
|
|
config = load_config()
|
|
setup_logging(config)
|
|
except Exception as e:
|
|
_LOGGER.warning("Failed to load config: %s. Using defaults.", e)
|
|
config = {}
|
|
|
|
# Determine database URL (same logic as migrate_sqlite_to_postgres)
|
|
if args.database_url:
|
|
database_url = args.database_url
|
|
elif args.config:
|
|
# Load config file manually (simpler YAML without secrets)
|
|
_LOGGER.info("Loading database config from: %s", args.config)
|
|
try:
|
|
config_path = Path(args.config)
|
|
config_text = config_path.read_text()
|
|
target_config = yaml.safe_load(config_text)
|
|
database_url = target_config["database"]["url"]
|
|
_LOGGER.info("Successfully loaded config")
|
|
except (FileNotFoundError, ValueError, KeyError):
|
|
_LOGGER.exception("Failed to load config")
|
|
_LOGGER.info(
|
|
"Config file should contain: database.url with PostgreSQL connection"
|
|
)
|
|
sys.exit(1)
|
|
else:
|
|
database_url = os.environ.get("DATABASE_URL")
|
|
if not database_url:
|
|
# Try from default config
|
|
database_url = get_database_url(config)
|
|
|
|
if "postgresql" not in database_url and "postgres" not in database_url:
|
|
_LOGGER.error("This script only works with PostgreSQL databases.")
|
|
url_type = database_url.split("+")[0] if "+" in database_url else "unknown"
|
|
_LOGGER.error("Current database URL type detected: %s", url_type)
|
|
_LOGGER.error("\nSpecify PostgreSQL database using one of:")
|
|
_LOGGER.error(" - --config config/postgres.yaml")
|
|
_LOGGER.error(" - DATABASE_URL environment variable")
|
|
_LOGGER.error(" - --database-url postgresql+asyncpg://user:pass@host/db")
|
|
sys.exit(1)
|
|
|
|
# Run the fix
|
|
await fix_sequences(database_url)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|