Updated fix sequences scripts

This commit is contained in:
Jonas Linter
2025-11-04 09:36:22 +01:00
parent eb10e070b1
commit 1f7649fffe

View File

@@ -11,19 +11,24 @@ with explicit IDs, which doesn't automatically advance PostgreSQL sequences.
The datetime migration ensures proper handling of timezone-aware datetimes,
which is required by the application code.
Schema Support:
The script automatically detects and uses the schema configured in your config file.
If you have database.schema: "alpinebits" in your config, it will work with that schema.
Usage:
# Using default config.yaml
# Using default config.yaml (includes schema if configured)
uv run python -m alpine_bits_python.util.fix_postgres_sequences
# Using a specific config file
# Using a specific config file (with schema support)
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
--config config/postgres.yaml
# Using DATABASE_URL environment variable
# Using DATABASE_URL environment variable (schema from config or DATABASE_SCHEMA env var)
DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \
DATABASE_SCHEMA="alpinebits" \
uv run python -m alpine_bits_python.util.fix_postgres_sequences
# Using command line argument
# Using command line argument (schema from config)
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
--database-url postgresql+asyncpg://user:pass@host/db
"""
@@ -42,16 +47,21 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from alpine_bits_python.config_loader import load_config
from alpine_bits_python.db import get_database_url
from alpine_bits_python.db import get_database_schema, get_database_url
from alpine_bits_python.logging_config import get_logger, setup_logging
_LOGGER = get_logger(__name__)
async def migrate_datetime_columns(session) -> None:
async def migrate_datetime_columns(session, schema_prefix: str = "") -> None:
"""Migrate DateTime columns to TIMESTAMP WITH TIME ZONE.
This updates the columns to properly handle timezone-aware datetimes.
Args:
session: Database session
schema_prefix: Schema prefix (e.g., "alpinebits." or "")
"""
_LOGGER.info("\nMigrating DateTime columns to timezone-aware...")
@@ -62,10 +72,11 @@ async def migrate_datetime_columns(session) -> None:
]
for table_name, column_name in datetime_columns:
_LOGGER.info(f" {table_name}.{column_name}: Converting to TIMESTAMPTZ")
full_table = f"{schema_prefix}{table_name}"
_LOGGER.info(f" {full_table}.{column_name}: Converting to TIMESTAMPTZ")
await session.execute(
text(
f"ALTER TABLE {table_name} "
f"ALTER TABLE {full_table} "
f"ALTER COLUMN {column_name} TYPE TIMESTAMP WITH TIME ZONE"
)
)
@@ -74,11 +85,12 @@ async def migrate_datetime_columns(session) -> None:
_LOGGER.info("✓ DateTime columns migrated to timezone-aware")
async def fix_sequences(database_url: str) -> None:
async def fix_sequences(database_url: str, schema_name: str = None) -> None:
"""Fix PostgreSQL sequences to match current max IDs and migrate datetime columns.
Args:
database_url: PostgreSQL database URL
schema_name: Schema name (e.g., "alpinebits") or None for public
"""
_LOGGER.info("=" * 70)
@@ -88,16 +100,27 @@ async def fix_sequences(database_url: str) -> None:
"Database: %s",
database_url.split("@")[-1] if "@" in database_url else database_url,
)
if schema_name:
_LOGGER.info("Schema: %s", schema_name)
_LOGGER.info("=" * 70)
# Create engine and session
engine = create_async_engine(database_url, echo=False)
# Create engine and session with schema support
connect_args = {}
if schema_name:
connect_args = {
"server_settings": {"search_path": f"{schema_name},public"}
}
engine = create_async_engine(database_url, echo=False, connect_args=connect_args)
SessionMaker = async_sessionmaker(engine, expire_on_commit=False)
# Determine schema prefix for SQL statements
schema_prefix = f"{schema_name}." if schema_name else ""
try:
# Migrate datetime columns first
async with SessionMaker() as session:
await migrate_datetime_columns(session)
await migrate_datetime_columns(session, schema_prefix)
# Then fix sequences
async with SessionMaker() as session:
@@ -107,39 +130,43 @@ async def fix_sequences(database_url: str) -> None:
("hashed_customers", "hashed_customers_id_seq"),
("reservations", "reservations_id_seq"),
("acked_requests", "acked_requests_id_seq"),
("conversions", "conversions_id_seq"),
]
_LOGGER.info("\nResetting sequences...")
for table_name, sequence_name in tables:
full_table = f"{schema_prefix}{table_name}"
full_sequence = f"{schema_prefix}{sequence_name}"
# Get current max ID
result = await session.execute(
text(f"SELECT MAX(id) FROM {table_name}")
text(f"SELECT MAX(id) FROM {full_table}")
)
max_id = result.scalar()
# Get current sequence value
result = await session.execute(
text(f"SELECT last_value FROM {sequence_name}")
text(f"SELECT last_value FROM {full_sequence}")
)
current_seq = result.scalar()
if max_id is None:
_LOGGER.info(f" {table_name}: empty table, setting sequence to 1")
_LOGGER.info(f" {full_table}: empty table, setting sequence to 1")
await session.execute(
text(f"SELECT setval('{sequence_name}', 1, false)")
text(f"SELECT setval('{full_sequence}', 1, false)")
)
elif current_seq <= max_id:
new_seq = max_id + 1
_LOGGER.info(
f" {table_name}: max_id={max_id}, "
f" {full_table}: max_id={max_id}, "
f"old_seq={current_seq}, new_seq={new_seq}"
)
await session.execute(
text(f"SELECT setval('{sequence_name}', {new_seq}, false)")
text(f"SELECT setval('{full_sequence}', {new_seq}, false)")
)
else:
_LOGGER.info(
f" {table_name}: sequence already correct "
f" {full_table}: sequence already correct "
f"(max_id={max_id}, seq={current_seq})"
)
@@ -191,8 +218,11 @@ async def main():
config = {}
# Determine database URL (same logic as migrate_sqlite_to_postgres)
schema_name = None
if args.database_url:
database_url = args.database_url
# Get schema from default config if available
schema_name = get_database_schema(config)
elif args.config:
# Load config file manually (simpler YAML without secrets)
_LOGGER.info("Loading database config from: %s", args.config)
@@ -201,6 +231,7 @@ async def main():
config_text = config_path.read_text()
target_config = yaml.safe_load(config_text)
database_url = target_config["database"]["url"]
schema_name = target_config.get("database", {}).get("schema")
_LOGGER.info("Successfully loaded config")
except (FileNotFoundError, ValueError, KeyError):
_LOGGER.exception("Failed to load config")
@@ -213,6 +244,8 @@ async def main():
if not database_url:
# Try from default config
database_url = get_database_url(config)
# Get schema from config or environment
schema_name = get_database_schema(config)
if "postgresql" not in database_url and "postgres" not in database_url:
_LOGGER.error("This script only works with PostgreSQL databases.")
@@ -225,7 +258,7 @@ async def main():
sys.exit(1)
# Run the fix
await fix_sequences(database_url)
await fix_sequences(database_url, schema_name)
if __name__ == "__main__":