Files
DosVault/src/migrate.py
2025-09-06 18:51:10 -04:00

147 lines
5.0 KiB
Python
Executable File

#!/usr/bin/env python
"""
Database migration management script.
"""
import sys
import argparse
from pathlib import Path
from alembic.config import Config
from alembic import command
from sqlalchemy import create_engine, inspect
# Add current directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from libs.config import Config as AppConfig
from libs.database import Base
def get_alembic_config():
"""Get Alembic configuration object."""
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
app_config = AppConfig()
alembic_cfg.set_main_option("sqlalchemy.url", f"sqlite:///{app_config.database_path}")
return alembic_cfg
def init_database():
"""Initialize database tables without Alembic for first-time setup."""
app_config = AppConfig()
engine = create_engine(f"sqlite:///{app_config.database_path}")
Base.metadata.create_all(engine)
print(f"Database initialized at {app_config.database_path}")
def create_migration(message: str):
"""Create a new migration file."""
alembic_cfg = get_alembic_config()
command.revision(alembic_cfg, message=message, autogenerate=True)
print(f"Created migration: {message}")
def upgrade_database(revision: str = "head"):
"""Upgrade database to a specific revision."""
alembic_cfg = get_alembic_config()
command.upgrade(alembic_cfg, revision)
print(f"Database upgraded to {revision}")
def downgrade_database(revision: str):
"""Downgrade database to a specific revision."""
alembic_cfg = get_alembic_config()
command.downgrade(alembic_cfg, revision)
print(f"Database downgraded to {revision}")
def show_history():
"""Show migration history."""
alembic_cfg = get_alembic_config()
command.history(alembic_cfg)
def show_current():
"""Show current database revision."""
alembic_cfg = get_alembic_config()
command.current(alembic_cfg)
def stamp_database(revision: str = "head"):
"""Mark the database as being at a specific revision without running migrations."""
alembic_cfg = get_alembic_config()
command.stamp(alembic_cfg, revision)
print(f"Database stamped at {revision}")
def check_database_exists():
"""Check if database and migration table exist."""
app_config = AppConfig()
db_path = Path(app_config.database_path)
if not db_path.exists():
print("Database does not exist.")
return False
# Check if alembic_version table exists
engine = create_engine(f"sqlite:///{app_config.database_path}")
inspector = inspect(engine)
tables = inspector.get_table_names()
if "alembic_version" not in tables:
print("Database exists but is not under Alembic control.")
return False
print("Database exists and is under Alembic control.")
return True
def main():
parser = argparse.ArgumentParser(description="Database migration management")
subparsers = parser.add_subparsers(dest='command', help='Available commands')
# Init command
subparsers.add_parser('init', help='Initialize database (for first-time setup)')
# Stamp command
stamp_parser = subparsers.add_parser('stamp', help='Mark database as being at a specific revision')
stamp_parser.add_argument('revision', nargs='?', default='head', help='Revision to stamp (default: head)')
# Create migration command
create_parser = subparsers.add_parser('create', help='Create a new migration')
create_parser.add_argument('message', help='Migration message')
# Upgrade command
upgrade_parser = subparsers.add_parser('upgrade', help='Upgrade database')
upgrade_parser.add_argument('revision', nargs='?', default='head', help='Target revision (default: head)')
# Downgrade command
downgrade_parser = subparsers.add_parser('downgrade', help='Downgrade database')
downgrade_parser.add_argument('revision', help='Target revision')
# History command
subparsers.add_parser('history', help='Show migration history')
# Current command
subparsers.add_parser('current', help='Show current database revision')
# Check command
subparsers.add_parser('check', help='Check database status')
args = parser.parse_args()
if not args.command:
parser.print_help()
return
try:
if args.command == 'init':
init_database()
elif args.command == 'stamp':
stamp_database(args.revision)
elif args.command == 'create':
create_migration(args.message)
elif args.command == 'upgrade':
upgrade_database(args.revision)
elif args.command == 'downgrade':
downgrade_database(args.revision)
elif args.command == 'history':
show_history()
elif args.command == 'current':
show_current()
elif args.command == 'check':
check_database_exists()
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()