Updated fix sequences scripts
This commit is contained in:
@@ -11,19 +11,24 @@ with explicit IDs, which doesn't automatically advance PostgreSQL sequences.
|
||||
The datetime migration ensures proper handling of timezone-aware datetimes,
|
||||
which is required by the application code.
|
||||
|
||||
Schema Support:
|
||||
The script automatically detects and uses the schema configured in your config file.
|
||||
If you have database.schema: "alpinebits" in your config, it will work with that schema.
|
||||
|
||||
Usage:
|
||||
# Using default config.yaml
|
||||
# Using default config.yaml (includes schema if configured)
|
||||
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
||||
|
||||
# Using a specific config file
|
||||
# Using a specific config file (with schema support)
|
||||
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
|
||||
--config config/postgres.yaml
|
||||
|
||||
# Using DATABASE_URL environment variable
|
||||
# Using DATABASE_URL environment variable (schema from config or DATABASE_SCHEMA env var)
|
||||
DATABASE_URL="postgresql+asyncpg://user:pass@host/db" \
|
||||
DATABASE_SCHEMA="alpinebits" \
|
||||
uv run python -m alpine_bits_python.util.fix_postgres_sequences
|
||||
|
||||
# Using command line argument
|
||||
# Using command line argument (schema from config)
|
||||
uv run python -m alpine_bits_python.util.fix_postgres_sequences \
|
||||
--database-url postgresql+asyncpg://user:pass@host/db
|
||||
"""
|
||||
@@ -42,16 +47,21 @@ from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from alpine_bits_python.config_loader import load_config
|
||||
from alpine_bits_python.db import get_database_url
|
||||
from alpine_bits_python.db import get_database_schema, get_database_url
|
||||
from alpine_bits_python.logging_config import get_logger, setup_logging
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
async def migrate_datetime_columns(session) -> None:
|
||||
async def migrate_datetime_columns(session, schema_prefix: str = "") -> None:
|
||||
"""Migrate DateTime columns to TIMESTAMP WITH TIME ZONE.
|
||||
|
||||
This updates the columns to properly handle timezone-aware datetimes.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
schema_prefix: Schema prefix (e.g., "alpinebits." or "")
|
||||
|
||||
"""
|
||||
_LOGGER.info("\nMigrating DateTime columns to timezone-aware...")
|
||||
|
||||
@@ -62,10 +72,11 @@ async def migrate_datetime_columns(session) -> None:
|
||||
]
|
||||
|
||||
for table_name, column_name in datetime_columns:
|
||||
_LOGGER.info(f" {table_name}.{column_name}: Converting to TIMESTAMPTZ")
|
||||
full_table = f"{schema_prefix}{table_name}"
|
||||
_LOGGER.info(f" {full_table}.{column_name}: Converting to TIMESTAMPTZ")
|
||||
await session.execute(
|
||||
text(
|
||||
f"ALTER TABLE {table_name} "
|
||||
f"ALTER TABLE {full_table} "
|
||||
f"ALTER COLUMN {column_name} TYPE TIMESTAMP WITH TIME ZONE"
|
||||
)
|
||||
)
|
||||
@@ -74,11 +85,12 @@ async def migrate_datetime_columns(session) -> None:
|
||||
_LOGGER.info("✓ DateTime columns migrated to timezone-aware")
|
||||
|
||||
|
||||
async def fix_sequences(database_url: str) -> None:
|
||||
async def fix_sequences(database_url: str, schema_name: str = None) -> None:
|
||||
"""Fix PostgreSQL sequences to match current max IDs and migrate datetime columns.
|
||||
|
||||
Args:
|
||||
database_url: PostgreSQL database URL
|
||||
schema_name: Schema name (e.g., "alpinebits") or None for public
|
||||
|
||||
"""
|
||||
_LOGGER.info("=" * 70)
|
||||
@@ -88,16 +100,27 @@ async def fix_sequences(database_url: str) -> None:
|
||||
"Database: %s",
|
||||
database_url.split("@")[-1] if "@" in database_url else database_url,
|
||||
)
|
||||
if schema_name:
|
||||
_LOGGER.info("Schema: %s", schema_name)
|
||||
_LOGGER.info("=" * 70)
|
||||
|
||||
# Create engine and session
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
# Create engine and session with schema support
|
||||
connect_args = {}
|
||||
if schema_name:
|
||||
connect_args = {
|
||||
"server_settings": {"search_path": f"{schema_name},public"}
|
||||
}
|
||||
|
||||
engine = create_async_engine(database_url, echo=False, connect_args=connect_args)
|
||||
SessionMaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
# Determine schema prefix for SQL statements
|
||||
schema_prefix = f"{schema_name}." if schema_name else ""
|
||||
|
||||
try:
|
||||
# Migrate datetime columns first
|
||||
async with SessionMaker() as session:
|
||||
await migrate_datetime_columns(session)
|
||||
await migrate_datetime_columns(session, schema_prefix)
|
||||
|
||||
# Then fix sequences
|
||||
async with SessionMaker() as session:
|
||||
@@ -107,39 +130,43 @@ async def fix_sequences(database_url: str) -> None:
|
||||
("hashed_customers", "hashed_customers_id_seq"),
|
||||
("reservations", "reservations_id_seq"),
|
||||
("acked_requests", "acked_requests_id_seq"),
|
||||
("conversions", "conversions_id_seq"),
|
||||
]
|
||||
|
||||
_LOGGER.info("\nResetting sequences...")
|
||||
for table_name, sequence_name in tables:
|
||||
full_table = f"{schema_prefix}{table_name}"
|
||||
full_sequence = f"{schema_prefix}{sequence_name}"
|
||||
|
||||
# Get current max ID
|
||||
result = await session.execute(
|
||||
text(f"SELECT MAX(id) FROM {table_name}")
|
||||
text(f"SELECT MAX(id) FROM {full_table}")
|
||||
)
|
||||
max_id = result.scalar()
|
||||
|
||||
# Get current sequence value
|
||||
result = await session.execute(
|
||||
text(f"SELECT last_value FROM {sequence_name}")
|
||||
text(f"SELECT last_value FROM {full_sequence}")
|
||||
)
|
||||
current_seq = result.scalar()
|
||||
|
||||
if max_id is None:
|
||||
_LOGGER.info(f" {table_name}: empty table, setting sequence to 1")
|
||||
_LOGGER.info(f" {full_table}: empty table, setting sequence to 1")
|
||||
await session.execute(
|
||||
text(f"SELECT setval('{sequence_name}', 1, false)")
|
||||
text(f"SELECT setval('{full_sequence}', 1, false)")
|
||||
)
|
||||
elif current_seq <= max_id:
|
||||
new_seq = max_id + 1
|
||||
_LOGGER.info(
|
||||
f" {table_name}: max_id={max_id}, "
|
||||
f" {full_table}: max_id={max_id}, "
|
||||
f"old_seq={current_seq}, new_seq={new_seq}"
|
||||
)
|
||||
await session.execute(
|
||||
text(f"SELECT setval('{sequence_name}', {new_seq}, false)")
|
||||
text(f"SELECT setval('{full_sequence}', {new_seq}, false)")
|
||||
)
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f" {table_name}: sequence already correct "
|
||||
f" {full_table}: sequence already correct "
|
||||
f"(max_id={max_id}, seq={current_seq})"
|
||||
)
|
||||
|
||||
@@ -191,8 +218,11 @@ async def main():
|
||||
config = {}
|
||||
|
||||
# Determine database URL (same logic as migrate_sqlite_to_postgres)
|
||||
schema_name = None
|
||||
if args.database_url:
|
||||
database_url = args.database_url
|
||||
# Get schema from default config if available
|
||||
schema_name = get_database_schema(config)
|
||||
elif args.config:
|
||||
# Load config file manually (simpler YAML without secrets)
|
||||
_LOGGER.info("Loading database config from: %s", args.config)
|
||||
@@ -201,6 +231,7 @@ async def main():
|
||||
config_text = config_path.read_text()
|
||||
target_config = yaml.safe_load(config_text)
|
||||
database_url = target_config["database"]["url"]
|
||||
schema_name = target_config.get("database", {}).get("schema")
|
||||
_LOGGER.info("Successfully loaded config")
|
||||
except (FileNotFoundError, ValueError, KeyError):
|
||||
_LOGGER.exception("Failed to load config")
|
||||
@@ -213,6 +244,8 @@ async def main():
|
||||
if not database_url:
|
||||
# Try from default config
|
||||
database_url = get_database_url(config)
|
||||
# Get schema from config or environment
|
||||
schema_name = get_database_schema(config)
|
||||
|
||||
if "postgresql" not in database_url and "postgres" not in database_url:
|
||||
_LOGGER.error("This script only works with PostgreSQL databases.")
|
||||
@@ -225,7 +258,7 @@ async def main():
|
||||
sys.exit(1)
|
||||
|
||||
# Run the fix
|
||||
await fix_sequences(database_url)
|
||||
await fix_sequences(database_url, schema_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user