Getting closer

This commit is contained in:
Jonas Linter
2025-11-18 13:32:29 +01:00
parent 5a660507d2
commit e7757c8c51
4 changed files with 175 additions and 359 deletions

View File

@@ -27,7 +27,31 @@ from .logging_config import get_logger
_LOGGER = get_logger(__name__)
Base = declarative_base()
# Load schema from config at module level
# This happens once when the module is imported
try:
from .config_loader import load_config
_app_config = load_config()
_SCHEMA = _app_config.get("database", {}).get("schema")
except (FileNotFoundError, KeyError, ValueError, ImportError):
_SCHEMA = None
# If schema isn't in config, try environment variable
if not _SCHEMA:
_SCHEMA = os.environ.get("DATABASE_SCHEMA")
class Base:
"""Base class that applies schema to all tables."""
# # Set schema on all tables if configured
# if _SCHEMA:
# __table_args__ = {"schema": _SCHEMA}
Base = declarative_base(cls=Base)
# Type variable for async functions
T = TypeVar("T")
@@ -60,26 +84,30 @@ def get_database_schema(config=None):
Schema name string, or None if not configured
"""
# Check environment variable first (takes precedence)
schema = os.environ.get("DATABASE_SCHEMA")
if schema:
return schema
# Fall back to config file
if config and "database" in config and "schema" in config["database"]:
return config["database"]["schema"]
return os.environ.get("DATABASE_SCHEMA")
return None
def configure_schema(schema_name=None):
def configure_schema(schema_name):
"""Configure the database schema for all models.
This should be called before creating tables or running migrations.
For PostgreSQL, this sets the schema for all tables.
For other databases, this is a no-op.
IMPORTANT: This must be called BEFORE any models are imported/defined.
It modifies the Base class to apply schema to all tables.
Args:
schema_name: Name of the schema to use (e.g., "alpinebits")
"""
if schema_name:
# Update the schema for all tables in Base metadata
for table in Base.metadata.tables.values():
table.schema = schema_name
# Set __table_args__ on the Base class to apply schema to all tables
Base.__table_args__ = {"schema": _SCHEMA}
def create_database_engine(config=None, echo=False) -> AsyncEngine:
@@ -102,7 +130,7 @@ def create_database_engine(config=None, echo=False) -> AsyncEngine:
database_url = get_database_url(config)
schema_name = get_database_schema(config)
# Configure schema for all models if specified
# # Configure schema for all models if specified
if schema_name:
configure_schema(schema_name)
_LOGGER.info("Configured database schema: %s", schema_name)