Updated fix sequences scripts

This commit is contained in:
Jonas Linter
2025-11-04 09:36:22 +01:00
parent eb10e070b1
commit 1f7649fffe

View File

@@ -11,19 +11,24 @@ with explicit IDs, which doesn't automatically advance PostgreSQL sequences.
The datetime migration ensures proper handling of timezone-aware datetimes, The datetime migration ensures proper handling of timezone-aware datetimes,
which is required by the application code. which is required by the application code.
Schema Support:
The script automatically detects and uses the schema configured in your config file.
If you have database.schema: "alpinebits" in your config, it will work with that schema.
Usage: Usage:
# Using default config.yaml # Using default config.yaml (includes schema if configured)
uv run python -m alpine_bits_python.util.fix_postgres_sequences uv run python -m alpine_bits_python.util.fix_postgres_sequences
# Using a specific config file # Using a specific config file (with schema support)
uv run python -m alpine_bits_python.util.fix_postgres_sequences \ uv run python -m alpine_bits_python.util.fix_postgres_sequences \
--config config/postgres.yaml --config config/postgres.yaml
# Using DATABASE_URL environment variable # Using DATABASE_URL environment variable (schema from config or DATABASE_SCHEMA env var)
DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \ DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \
DATABASE_SCHEMA="alpinebits" \
uv run python -m alpine_bits_python.util.fix_postgres_sequences uv run python -m alpine_bits_python.util.fix_postgres_sequences
# Using command line argument # Using command line argument (schema from config)
uv run python -m alpine_bits_python.util.fix_postgres_sequences \ uv run python -m alpine_bits_python.util.fix_postgres_sequences \
--database-url postgresql+asyncpg://user:pass@host/db --database-url postgresql+asyncpg://user:pass@host/db
""" """
@@ -42,16 +47,21 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from alpine_bits_python.config_loader import load_config from alpine_bits_python.config_loader import load_config
from alpine_bits_python.db import get_database_url from alpine_bits_python.db import get_database_schema, get_database_url
from alpine_bits_python.logging_config import get_logger, setup_logging from alpine_bits_python.logging_config import get_logger, setup_logging
_LOGGER = get_logger(__name__) _LOGGER = get_logger(__name__)
async def migrate_datetime_columns(session) -> None: async def migrate_datetime_columns(session, schema_prefix: str = "") -> None:
"""Migrate DateTime columns to TIMESTAMP WITH TIME ZONE. """Migrate DateTime columns to TIMESTAMP WITH TIME ZONE.
This updates the columns to properly handle timezone-aware datetimes. This updates the columns to properly handle timezone-aware datetimes.
Args:
session: Database session
schema_prefix: Schema prefix (e.g., "alpinebits." or "")
""" """
_LOGGER.info("\nMigrating DateTime columns to timezone-aware...") _LOGGER.info("\nMigrating DateTime columns to timezone-aware...")
@@ -62,10 +72,11 @@ async def migrate_datetime_columns(session) -> None:
] ]
for table_name, column_name in datetime_columns: for table_name, column_name in datetime_columns:
_LOGGER.info(f" {table_name}.{column_name}: Converting to TIMESTAMPTZ") full_table = f"{schema_prefix}{table_name}"
_LOGGER.info(f" {full_table}.{column_name}: Converting to TIMESTAMPTZ")
await session.execute( await session.execute(
text( text(
f"ALTER TABLE {table_name} " f"ALTER TABLE {full_table} "
f"ALTER COLUMN {column_name} TYPE TIMESTAMP WITH TIME ZONE" f"ALTER COLUMN {column_name} TYPE TIMESTAMP WITH TIME ZONE"
) )
) )
@@ -74,11 +85,12 @@ async def migrate_datetime_columns(session) -> None:
_LOGGER.info("✓ DateTime columns migrated to timezone-aware") _LOGGER.info("✓ DateTime columns migrated to timezone-aware")
async def fix_sequences(database_url: str) -> None: async def fix_sequences(database_url: str, schema_name: str = None) -> None:
"""Fix PostgreSQL sequences to match current max IDs and migrate datetime columns. """Fix PostgreSQL sequences to match current max IDs and migrate datetime columns.
Args: Args:
database_url: PostgreSQL database URL database_url: PostgreSQL database URL
schema_name: Schema name (e.g., "alpinebits") or None for public
""" """
_LOGGER.info("=" * 70) _LOGGER.info("=" * 70)
@@ -88,16 +100,27 @@ async def fix_sequences(database_url: str) -> None:
"Database: %s", "Database: %s",
database_url.split("@")[-1] if "@" in database_url else database_url, database_url.split("@")[-1] if "@" in database_url else database_url,
) )
if schema_name:
_LOGGER.info("Schema: %s", schema_name)
_LOGGER.info("=" * 70) _LOGGER.info("=" * 70)
# Create engine and session # Create engine and session with schema support
engine = create_async_engine(database_url, echo=False) connect_args = {}
if schema_name:
connect_args = {
"server_settings": {"search_path": f"{schema_name},public"}
}
engine = create_async_engine(database_url, echo=False, connect_args=connect_args)
SessionMaker = async_sessionmaker(engine, expire_on_commit=False) SessionMaker = async_sessionmaker(engine, expire_on_commit=False)
# Determine schema prefix for SQL statements
schema_prefix = f"{schema_name}." if schema_name else ""
try: try:
# Migrate datetime columns first # Migrate datetime columns first
async with SessionMaker() as session: async with SessionMaker() as session:
await migrate_datetime_columns(session) await migrate_datetime_columns(session, schema_prefix)
# Then fix sequences # Then fix sequences
async with SessionMaker() as session: async with SessionMaker() as session:
@@ -107,39 +130,43 @@ async def fix_sequences(database_url: str) -> None:
("hashed_customers", "hashed_customers_id_seq"), ("hashed_customers", "hashed_customers_id_seq"),
("reservations", "reservations_id_seq"), ("reservations", "reservations_id_seq"),
("acked_requests", "acked_requests_id_seq"), ("acked_requests", "acked_requests_id_seq"),
("conversions", "conversions_id_seq"),
] ]
_LOGGER.info("\nResetting sequences...") _LOGGER.info("\nResetting sequences...")
for table_name, sequence_name in tables: for table_name, sequence_name in tables:
full_table = f"{schema_prefix}{table_name}"
full_sequence = f"{schema_prefix}{sequence_name}"
# Get current max ID # Get current max ID
result = await session.execute( result = await session.execute(
text(f"SELECT MAX(id) FROM {table_name}") text(f"SELECT MAX(id) FROM {full_table}")
) )
max_id = result.scalar() max_id = result.scalar()
# Get current sequence value # Get current sequence value
result = await session.execute( result = await session.execute(
text(f"SELECT last_value FROM {sequence_name}") text(f"SELECT last_value FROM {full_sequence}")
) )
current_seq = result.scalar() current_seq = result.scalar()
if max_id is None: if max_id is None:
_LOGGER.info(f" {table_name}: empty table, setting sequence to 1") _LOGGER.info(f" {full_table}: empty table, setting sequence to 1")
await session.execute( await session.execute(
text(f"SELECT setval('{sequence_name}', 1, false)") text(f"SELECT setval('{full_sequence}', 1, false)")
) )
elif current_seq <= max_id: elif current_seq <= max_id:
new_seq = max_id + 1 new_seq = max_id + 1
_LOGGER.info( _LOGGER.info(
f" {table_name}: max_id={max_id}, " f" {full_table}: max_id={max_id}, "
f"old_seq={current_seq}, new_seq={new_seq}" f"old_seq={current_seq}, new_seq={new_seq}"
) )
await session.execute( await session.execute(
text(f"SELECT setval('{sequence_name}', {new_seq}, false)") text(f"SELECT setval('{full_sequence}', {new_seq}, false)")
) )
else: else:
_LOGGER.info( _LOGGER.info(
f" {table_name}: sequence already correct " f" {full_table}: sequence already correct "
f"(max_id={max_id}, seq={current_seq})" f"(max_id={max_id}, seq={current_seq})"
) )
@@ -191,8 +218,11 @@ async def main():
config = {} config = {}
# Determine database URL (same logic as migrate_sqlite_to_postgres) # Determine database URL (same logic as migrate_sqlite_to_postgres)
schema_name = None
if args.database_url: if args.database_url:
database_url = args.database_url database_url = args.database_url
# Get schema from default config if available
schema_name = get_database_schema(config)
elif args.config: elif args.config:
# Load config file manually (simpler YAML without secrets) # Load config file manually (simpler YAML without secrets)
_LOGGER.info("Loading database config from: %s", args.config) _LOGGER.info("Loading database config from: %s", args.config)
@@ -201,6 +231,7 @@ async def main():
config_text = config_path.read_text() config_text = config_path.read_text()
target_config = yaml.safe_load(config_text) target_config = yaml.safe_load(config_text)
database_url = target_config["database"]["url"] database_url = target_config["database"]["url"]
schema_name = target_config.get("database", {}).get("schema")
_LOGGER.info("Successfully loaded config") _LOGGER.info("Successfully loaded config")
except (FileNotFoundError, ValueError, KeyError): except (FileNotFoundError, ValueError, KeyError):
_LOGGER.exception("Failed to load config") _LOGGER.exception("Failed to load config")
@@ -213,6 +244,8 @@ async def main():
if not database_url: if not database_url:
# Try from default config # Try from default config
database_url = get_database_url(config) database_url = get_database_url(config)
# Get schema from config or environment
schema_name = get_database_schema(config)
if "postgresql" not in database_url and "postgres" not in database_url: if "postgresql" not in database_url and "postgres" not in database_url:
_LOGGER.error("This script only works with PostgreSQL databases.") _LOGGER.error("This script only works with PostgreSQL databases.")
@@ -225,7 +258,7 @@ async def main():
sys.exit(1) sys.exit(1)
# Run the fix # Run the fix
await fix_sequences(database_url) await fix_sequences(database_url, schema_name)
if __name__ == "__main__": if __name__ == "__main__":