Source code for advanced_alchemy.cli

import sys
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union, cast

if TYPE_CHECKING:
    from alembic.migration import MigrationContext
    from alembic.operations.ops import MigrationScript, UpgradeOps
    from click import Group

    from advanced_alchemy.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig

__all__ = ("add_migration_commands", "get_alchemy_group")


[docs] def get_alchemy_group() -> "Group": """Get the Advanced Alchemy CLI group. Raises: MissingDependencyError: If the `click` package is not installed. Returns: The Advanced Alchemy CLI group. """ from advanced_alchemy.exceptions import MissingDependencyError try: import rich_click as click except ImportError: try: import click # type: ignore[no-redef] except ImportError as e: raise MissingDependencyError(package="click", install_package="cli") from e @click.group(name="alchemy") @click.option( "--config", help="Dotted path to SQLAlchemy config(s) (e.g. 'myapp.config.alchemy_configs')", required=True, type=str, ) @click.pass_context def alchemy_group(ctx: "click.Context", config: str) -> None: """Advanced Alchemy CLI commands.""" from pathlib import Path from rich import get_console from advanced_alchemy.utils import module_loader console = get_console() ctx.ensure_object(dict) # Add current working directory to sys.path to allow loading local config modules cwd = str(Path.cwd()) if cwd not in sys.path: sys.path.insert(0, cwd) try: config_instance = module_loader.import_string(config) if isinstance(config_instance, Sequence): ctx.obj["configs"] = config_instance else: ctx.obj["configs"] = [config_instance] except ImportError as e: console.print(f"[red]Error loading config: {e}[/]") ctx.exit(1) finally: # Clean up: remove the cwd from sys.path if we added it if cwd in sys.path and sys.path[0] == cwd: sys.path.remove(cwd) return alchemy_group
[docs] def add_migration_commands(database_group: Optional["Group"] = None) -> "Group": # noqa: C901, PLR0915 """Add migration commands to the database group. Args: database_group: The database group to add the commands to. Raises: MissingDependencyError: If the `click` package is not installed. Returns: The database group with the migration commands added. """ from advanced_alchemy.exceptions import MissingDependencyError try: import rich_click as click except ImportError: try: import click # type: ignore[no-redef] except ImportError as e: raise MissingDependencyError(package="click", install_package="cli") from e from rich import get_console console = get_console() if database_group is None: database_group = get_alchemy_group() bind_key_option = click.option( "--bind-key", help="Specify which SQLAlchemy config to use by bind key", type=str, default=None, ) verbose_option = click.option( "--verbose", help="Enable verbose output.", type=bool, default=False, is_flag=True, ) no_prompt_option = click.option( "--no-prompt", help="Do not prompt for confirmation before executing the command.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) def get_config_by_bind_key( ctx: "click.Context", bind_key: Optional[str] ) -> "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]": """Get the SQLAlchemy config for the specified bind key. Args: ctx: The click context. bind_key: The bind key to get the config for. Returns: The SQLAlchemy config for the specified bind key. """ configs = ctx.obj["configs"] if bind_key is None: return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", configs[0]) for config in configs: if config.bind_key == bind_key: return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", config) console.print(f"[red]No config found for bind key: {bind_key}[/]") sys.exit(1) @database_group.command( name="show-current-revision", help="Shows the current revision for the database.", ) @bind_key_option @verbose_option def show_database_revision(bind_key: Optional[str], verbose: bool) -> None: # pyright: ignore[reportUnusedFunction] """Show current database revision.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Listing current revision[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.current(verbose=verbose) @database_group.command( name="downgrade", help="Downgrade database to a specific revision.", ) @bind_key_option @click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @click.option( "--tag", help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.", type=str, default=None, ) @no_prompt_option @click.argument( "revision", type=str, default="-1", ) def downgrade_database( # pyright: ignore[reportUnusedFunction] bind_key: Optional[str], revision: str, sql: bool, tag: Optional[str], no_prompt: bool ) -> None: """Downgrade the database to the latest revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Starting database downgrade process[/]", align="left") input_confirmed = ( True if no_prompt else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?") ) if input_confirmed: sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.downgrade(revision=revision, sql=sql, tag=tag) @database_group.command( name="upgrade", help="Upgrade database to a specific revision.", ) @bind_key_option @click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @click.option( "--tag", help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.", type=str, default=None, ) @no_prompt_option @click.argument( "revision", type=str, default="head", ) def upgrade_database( # pyright: ignore[reportUnusedFunction] bind_key: Optional[str], revision: str, sql: bool, tag: Optional[str], no_prompt: bool ) -> None: """Upgrade the database to the latest revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Starting database upgrade process[/]", align="left") input_confirmed = ( True if no_prompt else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]") ) if input_confirmed: sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.upgrade(revision=revision, sql=sql, tag=tag) @database_group.command( help="Stamp the revision table with the given revision; don't run any migrations", ) @click.argument("revision", type=str) @bind_key_option @click.option("--sql", is_flag=True, default=False, help="Generate SQL output for offline migrations") @click.option( "--tag", type=str, default=None, help="Arbitrary 'tag' that can be intercepted by custom env.py scripts" ) @click.option("--purge", is_flag=True, default=False, help="Delete all entries in version table before stamping") def stamp(bind_key: Optional[str], revision: str, sql: bool, tag: Optional[str], purge: bool) -> None: # pyright: ignore[reportUnusedFunction] """Stamp the revision table with the given revision.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Stamping revision table[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.stamp(revision=revision, sql=sql, tag=tag, purge=purge) @database_group.command( name="check", help="Check if the target database is up to date", ) @bind_key_option def check_revision(bind_key: Optional[str]) -> None: # pyright: ignore[reportUnusedFunction] """Check for pending upgrade operations.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Checking for pending migrations[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.check() @database_group.command( name="edit", help="Edit a revision file using $EDITOR", ) @click.argument("revision", type=str) @bind_key_option def edit_revision(bind_key: Optional[str], revision: str) -> None: # pyright: ignore[reportUnusedFunction] """Edit revision script with system editor.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule(f"[yellow]Opening revision {revision} in editor[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.edit(revision=revision) @database_group.command( name="ensure-version", help="Create the alembic version table if it doesn't exist", ) @bind_key_option @click.option("--sql", is_flag=True, default=False, help="Generate SQL output instead of executing") def ensure_version_table(bind_key: Optional[str], sql: bool) -> None: # pyright: ignore[reportUnusedFunction] """Ensure alembic version table exists.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Ensuring version table exists[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.ensure_version(sql=sql) @database_group.command( name="heads", help="Show current available heads in the script directory", ) @bind_key_option @verbose_option @click.option( "--resolve-dependencies", is_flag=True, default=False, help="Resolve dependencies between heads", ) def show_heads(bind_key: Optional[str], verbose: bool, resolve_dependencies: bool) -> None: # pyright: ignore[reportUnusedFunction] """Show current heads.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Showing current heads[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.heads(verbose=verbose, resolve_dependencies=resolve_dependencies) @database_group.command( name="history", help="List changeset scripts in chronological order", ) @bind_key_option @verbose_option @click.option( "--rev-range", type=str, default=None, help="Revision range (e.g., 'base:head', 'abc:def')", ) @click.option( "--indicate-current", is_flag=True, default=False, help="Indicate the current revision", ) def show_history( # pyright: ignore[reportUnusedFunction] bind_key: Optional[str], verbose: bool, rev_range: Optional[str], indicate_current: bool, ) -> None: """Show revision history.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Showing revision history[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.history( rev_range=rev_range, verbose=verbose, indicate_current=indicate_current, ) @database_group.command( name="merge", help="Merge two revisions together, creating a new migration file", ) @click.argument("revisions", type=str) @bind_key_option @click.option("-m", "--message", type=str, default=None, help="Merge message") @click.option("--branch-label", type=str, default=None, help="Branch label for merge revision") @click.option("--rev-id", type=str, default=None, help="Specify custom revision ID") @no_prompt_option def merge_revisions( # pyright: ignore[reportUnusedFunction] bind_key: Optional[str], revisions: str, message: Optional[str], branch_label: Optional[str], rev_id: Optional[str], no_prompt: bool, ) -> None: """Merge revisions (resolves multiple heads).""" from rich.prompt import Prompt from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Merging revisions[/]", align="left") if message is None: message = "merge revisions" if no_prompt else Prompt.ask("Enter merge message") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.merge( revisions=revisions, message=message, branch_label=branch_label, rev_id=rev_id, ) @database_group.command( name="show", help="Show the revision denoted by the given symbol", ) @click.argument("revision", type=str) @bind_key_option def show_revision(bind_key: Optional[str], revision: str) -> None: # pyright: ignore[reportUnusedFunction] """Show details of a specific revision.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule(f"[yellow]Showing revision {revision}[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.show(rev=revision) @database_group.command( name="branches", help="Show current branch points in the migration history", ) @bind_key_option @verbose_option def show_branches(bind_key: Optional[str], verbose: bool) -> None: # pyright: ignore[reportUnusedFunction] """Show branch points.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Showing branch points[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.branches(verbose=verbose) @database_group.command( name="list-templates", help="List available Alembic migration templates", ) @bind_key_option def list_init_templates(bind_key: Optional[str]) -> None: # pyright: ignore[reportUnusedFunction] """List available initialization templates.""" from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Available templates[/]", align="left") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.list_templates() @database_group.command( name="init", help="Initialize migrations for the project.", ) @bind_key_option @click.argument( "directory", default=None, required=False, ) @click.option("--multidb", is_flag=True, default=False, help="Support multiple databases") @click.option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder") @no_prompt_option def init_alembic( # pyright: ignore[reportUnusedFunction] bind_key: Optional[str], directory: Optional[str], multidb: bool, package: bool, no_prompt: bool ) -> None: """Initialize the database migrations.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Initializing database migrations.", align="left") input_confirmed = ( True if no_prompt else Confirm.ask("[bold]Are you sure you want initialize migrations for the project?[/]") ) if input_confirmed: configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"] for config in configs: directory = config.alembic_config.script_location if directory is None else directory alembic_commands = AlembicCommands(sqlalchemy_config=config) alembic_commands.init(directory=cast("str", directory), multidb=multidb, package=package) @database_group.command( name="make-migrations", help="Create a new migration revision.", ) @bind_key_option @click.option("-m", "--message", default=None, help="Revision message") @click.option( "--autogenerate/--no-autogenerate", default=True, help="Automatically populate revision with detected changes" ) @click.option("--sql", is_flag=True, default=False, help="Export to `.sql` instead of writing to the database.") @click.option("--head", default="head", help="Specify head revision to use as base for new revision.") @click.option( "--splice", is_flag=True, default=False, help='Allow a non-head revision as the "head" to splice onto' ) @click.option("--branch-label", default=None, help="Specify a branch label to apply to the new revision") @click.option("--version-path", default=None, help="Specify specific path from config for version file") @click.option("--rev-id", default=None, help="Specify a ID to use for revision.") @no_prompt_option def create_revision( # pyright: ignore[reportUnusedFunction] bind_key: Optional[str], message: Optional[str], autogenerate: bool, sql: bool, head: str, splice: bool, branch_label: Optional[str], version_path: Optional[str], rev_id: Optional[str], no_prompt: bool, ) -> None: """Create a new database revision.""" from rich.prompt import Prompt from advanced_alchemy.alembic.commands import AlembicCommands def process_revision_directives( context: "MigrationContext", # noqa: ARG001 revision: tuple[str], # noqa: ARG001 directives: list["MigrationScript"], ) -> None: """Handle revision directives.""" if autogenerate and cast("UpgradeOps", directives[0].upgrade_ops).is_empty(): console.rule( "[magenta]The generation of a migration file is being skipped because it would result in an empty file.", style="magenta", align="left", ) console.rule( "[magenta]More information can be found here. https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect", style="magenta", align="left", ) console.rule( "[magenta]If you intend to create an empty migration file, use the --no-autogenerate option.", style="magenta", align="left", ) directives.clear() ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Starting database upgrade process[/]", align="left") if message is None: message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision") sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.revision( message=message, autogenerate=autogenerate, sql=sql, head=head, splice=splice, branch_label=branch_label, version_path=version_path, rev_id=rev_id, process_revision_directives=process_revision_directives, # type: ignore[arg-type] ) @database_group.command(name="drop-all", help="Drop all tables from the database.") @bind_key_option @no_prompt_option def drop_all(bind_key: Optional[str], no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] """Drop all tables from the database.""" from anyio import run from rich.prompt import Confirm from advanced_alchemy.alembic.utils import drop_all from advanced_alchemy.base import metadata_registry ctx = cast("click.Context", click.get_current_context()) console.rule("[yellow]Dropping all tables from the database[/]", align="left") input_confirmed = no_prompt or Confirm.ask( "[bold red]Are you sure you want to drop all tables from the database?" ) async def _drop_all( configs: "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]", ) -> None: for config in configs: engine = config.get_engine() await drop_all(engine, config.alembic_config.version_table_name, metadata_registry.get(config.bind_key)) if input_confirmed: configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"] run(_drop_all, configs) @database_group.command(name="dump-data", help="Dump specified tables from the database to JSON files.") @bind_key_option @click.option( "--table", "table_names", help="Name of the table to dump. Multiple tables can be specified. Use '*' to dump all tables.", type=str, required=True, multiple=True, ) @click.option( "--dir", "dump_dir", help="Directory to save the JSON files. Defaults to WORKDIR/fixtures", type=click.Path(path_type=Path), default=Path.cwd() / "fixtures", required=False, ) def dump_table_data(bind_key: Optional[str], table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction] """Dump table data to JSON files.""" from anyio import run from rich.prompt import Confirm from advanced_alchemy.alembic.utils import dump_tables from advanced_alchemy.base import metadata_registry, orm_registry ctx = cast("click.Context", click.get_current_context()) all_tables = "*" in table_names if all_tables and not Confirm.ask( "[yellow bold]You have specified '*'. Are you sure you want to dump all tables from the database?", ): return console.rule("[red bold]No data was dumped.", style="red", align="left") async def _dump_tables() -> None: configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"] for config in configs: target_tables = set(metadata_registry.get(config.bind_key).tables) if not all_tables: for table_name in set(table_names) - target_tables: console.rule( f"[red bold]Skipping table '{table_name}' because it is not available in the default registry", style="red", align="left", ) target_tables.intersection_update(table_names) else: console.rule("[yellow bold]Dumping all tables", style="yellow", align="left") models = [ mapper.class_ for mapper in orm_registry.mappers if mapper.class_.__table__.name in target_tables ] await dump_tables(dump_dir, config.get_session(), models) console.rule("[green bold]Data dump complete", align="left") return run(_dump_tables) return database_group