Getting closer
This commit is contained in:
@@ -27,7 +27,31 @@ from .logging_config import get_logger
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Load schema from config at module level
|
||||
# This happens once when the module is imported
|
||||
try:
|
||||
from .config_loader import load_config
|
||||
|
||||
_app_config = load_config()
|
||||
_SCHEMA = _app_config.get("database", {}).get("schema")
|
||||
except (FileNotFoundError, KeyError, ValueError, ImportError):
|
||||
_SCHEMA = None
|
||||
|
||||
# If schema isn't in config, try environment variable
|
||||
if not _SCHEMA:
|
||||
_SCHEMA = os.environ.get("DATABASE_SCHEMA")
|
||||
|
||||
|
||||
class Base:
|
||||
"""Base class that applies schema to all tables."""
|
||||
|
||||
# # Set schema on all tables if configured
|
||||
# if _SCHEMA:
|
||||
# __table_args__ = {"schema": _SCHEMA}
|
||||
|
||||
|
||||
Base = declarative_base(cls=Base)
|
||||
|
||||
# Type variable for async functions
|
||||
T = TypeVar("T")
|
||||
@@ -60,26 +84,30 @@ def get_database_schema(config=None):
|
||||
Schema name string, or None if not configured
|
||||
|
||||
"""
|
||||
# Check environment variable first (takes precedence)
|
||||
schema = os.environ.get("DATABASE_SCHEMA")
|
||||
if schema:
|
||||
return schema
|
||||
# Fall back to config file
|
||||
if config and "database" in config and "schema" in config["database"]:
|
||||
return config["database"]["schema"]
|
||||
return os.environ.get("DATABASE_SCHEMA")
|
||||
return None
|
||||
|
||||
|
||||
def configure_schema(schema_name=None):
|
||||
def configure_schema(schema_name):
|
||||
"""Configure the database schema for all models.
|
||||
|
||||
This should be called before creating tables or running migrations.
|
||||
For PostgreSQL, this sets the schema for all tables.
|
||||
For other databases, this is a no-op.
|
||||
IMPORTANT: This must be called BEFORE any models are imported/defined.
|
||||
It modifies the Base class to apply schema to all tables.
|
||||
|
||||
Args:
|
||||
schema_name: Name of the schema to use (e.g., "alpinebits")
|
||||
|
||||
"""
|
||||
if schema_name:
|
||||
# Update the schema for all tables in Base metadata
|
||||
for table in Base.metadata.tables.values():
|
||||
table.schema = schema_name
|
||||
# Set __table_args__ on the Base class to apply schema to all tables
|
||||
|
||||
Base.__table_args__ = {"schema": _SCHEMA}
|
||||
|
||||
|
||||
def create_database_engine(config=None, echo=False) -> AsyncEngine:
|
||||
@@ -102,7 +130,7 @@ def create_database_engine(config=None, echo=False) -> AsyncEngine:
|
||||
database_url = get_database_url(config)
|
||||
schema_name = get_database_schema(config)
|
||||
|
||||
# Configure schema for all models if specified
|
||||
# # Configure schema for all models if specified
|
||||
if schema_name:
|
||||
configure_schema(schema_name)
|
||||
_LOGGER.info("Configured database schema: %s", schema_name)
|
||||
|
||||
Reference in New Issue
Block a user