Source code for architxt.cli

import shutil
import subprocess
import warnings
from contextlib import AbstractContextManager, nullcontext
from enum import Enum
from functools import partial
from pathlib import Path

import anyio
import mlflow
import more_itertools
import typer
from mlflow.data.code_dataset_source import CodeDatasetSource
from mlflow.data.meta_dataset import MetaDataset
from platformdirs import user_cache_path
from rich.columns import Columns
from rich.panel import Panel
from rich.table import Table
from typer.main import get_command

from architxt.bucket import TreeBucket
from architxt.bucket.zodb import ZODBTreeBucket
from architxt.generator import gen_instance
from architxt.inspector import ForestInspector
from architxt.llm import get_chat_model
from architxt.metrics import Metrics, redundancy_score
from architxt.schema import Group, Relation, Schema
from architxt.similarity import DECAY
from architxt.simplification.tree_rewriting import rewrite

from .export import app as export_app
from .loader import app as loader_app
from .utils import (
    console,
    get_schema_metrics,
    init_forest,
    load_forest,
    show_metrics,
    show_schema,
    show_valid_trees_metrics,
)

app = typer.Typer(
    help="ArchiTXT is a tool for structuring textual data into a valid database model. "
    "It is guided by a meta-grammar and uses an iterative process of tree rewriting.",
    no_args_is_help=True,
)

app.add_typer(loader_app, name="load")
app.add_typer(export_app, name="export")


[docs] @app.callback() def mlflow_setup() -> None: mlflow.set_experiment('ArchiTXT')
[docs] @app.command( help="Launch the web-based UI.", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, ) def ui(ctx: typer.Context) -> None: """Launch the web-based UI using Streamlit.""" try: from architxt import ui subprocess.run(['streamlit', 'run', ui.__file__, *ctx.args], check=True) except FileNotFoundError as error: console.print( "[red]Streamlit is not installed or not found. Please install it with `pip install architxt[ui]` to use the UI.[/]" ) raise typer.Exit(code=1) from error
[docs] @app.command(help="Cleanup a forest retaining only the valid tree structure") def cleanup( files: list[Path] = typer.Argument(..., exists=True, readable=True, help="Path of the data files to load."), *, tau: float = typer.Option(0.7, help="The similarity threshold.", min=0, max=1), decay: float = typer.Option(DECAY, help="The similarity decay factor.", min=0.001), output: Path | None = typer.Option(None, help="Path to save the result."), metrics: bool = typer.Option(False, help="Show metrics of the simplification."), in_memory: bool = typer.Option(False, help="Perform the cleanup in memory."), ) -> None: if in_memory: warnings.warn( "Performing cleanup in memory. This is not recommended for large datasets as it may lead to high memory usage.", UserWarning, ) storage_ctx = nullcontext([]) else: storage_ctx = ZODBTreeBucket() with ( storage_ctx as tmp_forest, ZODBTreeBucket(storage_path=output) as output_forest, ): init_forest(tmp_forest, files) schema = Schema.from_forest(tmp_forest, keep_unlabelled=False) show_schema(schema) if metrics: result_metrics = Metrics(tmp_forest, tau=tau, decay=decay) trees = schema.extract_valid_trees(tmp_forest) output_forest.update(trees, commit=True) if metrics: result_metrics.update(output_forest) show_metrics(result_metrics)
[docs] class SimilarityMode(str, Enum): instance = "instance" schema = "schema"
[docs] @app.command(help="Simplify a bunch of databased together.") def simplify( files: list[Path] = typer.Argument(..., exists=True, readable=True, help="Path of the data files to load."), *, tau: float = typer.Option(0.7, help="The similarity threshold.", min=0, max=1), decay: float = typer.Option(DECAY, help="The similarity decay factor.", min=0.001), epoch: int = typer.Option(100, help="Number of iteration for tree rewriting.", min=1), min_support: int = typer.Option(20, help="Minimum support for tree patterns.", min=1), similarity: SimilarityMode = typer.Option( SimilarityMode.instance, help="Mode for similarity calculation. 'instance' uses instance-level similarity, 'schema' uses schema-level similarity.", ), workers: int | None = typer.Option( None, help="Number of parallel worker processes to use. Defaults to the number of available CPU cores.", min=1 ), output: Path | None = typer.Option(None, help="Path to save the result."), debug: bool = typer.Option(False, help="Enable debug mode for more verbose output."), metrics: bool = typer.Option(False, help="Show metrics of the simplification."), log: bool = typer.Option(False, help="Enable logging to MLFlow."), log_system_metrics: bool = typer.Option(False, help="Enable logging of system metrics to MLFlow."), in_memory: bool = typer.Option(False, help="Perform the simplification in memory."), ) -> None: if in_memory: warnings.warn( "Performing simplification in memory. This is not recommended for large datasets as it may lead to high memory usage.", UserWarning, ) storage_ctx = nullcontext([]) else: storage_ctx = ZODBTreeBucket(storage_path=output) run_ctx: AbstractContextManager = nullcontext() if log: console.print(f'[green]MLFlow logging enabled. Logs will be send to {mlflow.get_tracking_uri()}[/]') run_ctx = mlflow.start_run(description='simplification', log_system_metrics=log_system_metrics) for file in files: mlflow.log_input(MetaDataset(CodeDatasetSource({}), name=file.name)) with run_ctx, storage_ctx as forest: init_forest(forest, files) console.print( f'[blue]Rewriting {len(forest)} trees with tau={tau}, decay={decay}, epoch={epoch}, min_support={min_support}[/]' ) result_metrics = rewrite( forest, tau=tau, decay=decay, epoch=epoch, min_support=min_support, schema_similarity=(similarity == SimilarityMode.schema), debug=debug, max_workers=workers, ) # Generate schema schema = Schema.from_forest(forest, keep_unlabelled=False) show_schema(schema) if metrics: show_metrics(result_metrics) show_valid_trees_metrics(result_metrics, schema, forest, epoch + 1, log) if output and not isinstance(forest, TreeBucket): with ZODBTreeBucket(storage_path=output) as output_forest: output_forest.update(forest, commit=True)
[docs] @app.command(help="Simplify a bunch of databased together.") def simplify_llm( files: list[Path] = typer.Argument(..., exists=True, readable=True, help="Path of the data files to load."), *, tau: float = typer.Option(0.7, help="The similarity threshold.", min=0, max=1), decay: float = typer.Option(DECAY, help="The similarity decay factor.", min=0.001), min_support: int = typer.Option(20, help="Minimum support for vocab.", min=1), vocab_similarity: float = typer.Option(0.6, help="The vocabulary similarity threshold.", min=0, max=1), refining_steps: int = typer.Option(0, help="Number of refining steps."), output: Path | None = typer.Option(None, help="Path to save the result."), intermediate_output: Path | None = typer.Option(None, help="Path to save intermediate results."), debug: bool = typer.Option(False, help="Enable debug mode for more verbose output."), metrics: bool = typer.Option(False, help="Show metrics of the simplification."), log: bool = typer.Option(False, help="Enable logging to MLFlow."), log_system_metrics: bool = typer.Option(False, help="Enable logging of system metrics to MLFlow."), model_provider: str = typer.Option('huggingface', help="Provider of the model."), model: str = typer.Option('HuggingFaceTB/SmolLM2-135M-Instruct', help="Model to use for the LLM."), max_tokens: int = typer.Option(2048, help="Maximum number of tokens to generate."), local: bool = typer.Option(True, help="Use local model."), openvino: bool = typer.Option(False, help="Enable Intel OpenVINO optimizations."), rate_limit: float | None = typer.Option(None, help="Rate limit for the LLM."), estimate: bool = typer.Option(False, help="Estimate the number of tokens to generate."), temperature: float = typer.Option(0.2, help="Temperature for the LLM."), in_memory: bool = typer.Option(False, help="Perform the simplification in memory."), ) -> None: try: from langchain_core.rate_limiters import InMemoryRateLimiter from architxt.simplification.llm import estimate_tokens, llm_rewrite except ImportError: typer.secho( "LLM simplification is unavailable because optional dependencies are missing.\n" "Install them with: `pip install architxt[llm]`\n" "If using an external provider, also install the appropriate bridge, e.g. `pip install langchain-openai`", fg="yellow", err=True, ) raise typer.Exit(code=2) if in_memory: warnings.warn( "Performing simplification in memory. This is not recommended for large datasets as it may lead to high memory usage.", UserWarning, ) storage_ctx = nullcontext([]) else: storage_ctx = ZODBTreeBucket(storage_path=output) run_ctx: AbstractContextManager = nullcontext() if log: console.print(f'[green]MLFlow logging enabled. Logs will be send to {mlflow.get_tracking_uri()}[/]') run_ctx = mlflow.start_run(description='llm simplification', log_system_metrics=log_system_metrics) mlflow.langchain.autolog() mlflow.log_params( { 'model_provider': model_provider, 'model': model, 'max_tokens': max_tokens, 'local': local, 'openvino': openvino, 'rate_limit': rate_limit, 'temperature': temperature, } ) for file in files: mlflow.log_input(MetaDataset(CodeDatasetSource({}), name=file.name)) rate_limiter = InMemoryRateLimiter(requests_per_second=rate_limit) if rate_limit else None llm = get_chat_model( model_provider, model, max_tokens=max_tokens, temperature=temperature, rate_limiter=rate_limiter, local=local, openvino=openvino, ) if estimate: num_input_tokens, num_output_tokens, num_queries = estimate_tokens( load_forest(files), llm=llm, max_tokens=max_tokens, refining_steps=refining_steps, ) console.print(f'[blue]Estimated number of tokens: input={num_input_tokens}, output={num_output_tokens}[/]') if rate_limit: console.print( f'[blue]Estimated number of queries: {num_queries} queries (~{num_queries / rate_limit:.2f}s)[/]' ) else: console.print(f'[blue]Estimated number of queries: {num_queries} queries[/]') return with run_ctx, storage_ctx as forest: init_forest(forest, files) console.print(f'[blue]Rewriting {len(forest)} trees with model={model}[/]') transform = partial( llm_rewrite, llm=llm, max_tokens=max_tokens, tau=tau, decay=decay, min_support=min_support, vocab_similarity=vocab_similarity, refining_steps=refining_steps, debug=debug, intermediate_output=intermediate_output, ) result_metrics = anyio.run(transform, forest) # Generate schema schema = Schema.from_forest(forest, keep_unlabelled=False) show_schema(schema) if metrics: show_metrics(result_metrics) show_valid_trees_metrics(result_metrics, schema, forest, refining_steps + 1, log) if output and not isinstance(forest, TreeBucket): with ZODBTreeBucket(storage_path=output) as output_forest: output_forest.update(forest, commit=True)
[docs] @app.command(help="Display statistics of a dataset.") def inspect( files: list[Path] = typer.Argument(..., exists=True, readable=True, help="Path of the data files to load."), redundancy: bool = typer.Option(False, help="Compute redundancy metrics."), ) -> None: """Display overall statistics.""" inspector = ForestInspector() with ZODBTreeBucket() as forest: trees = inspector(load_forest(files)) forest.update(trees, commit=True) schema = Schema.from_forest(inspector(forest), keep_unlabelled=False) # Display the schema show_schema(schema) # Display the largest tree console.print(Panel(str(inspector.largest_tree), title="Largest Tree")) # Entity Count tables = [] for chunk in more_itertools.chunked_even(inspector.entity_count.most_common(), 10): entity_table = Table(title='Entity Counts') entity_table.add_column("Entity", style="cyan", no_wrap=True) entity_table.add_column("Count", style="magenta") for entity, count in chunk: entity_table.add_row(entity, str(count)) tables.append(entity_table) # Display statistics stats_table = Table(title='Statistics') stats_table.add_column("Metric", style="cyan", no_wrap=True) stats_table.add_column("Value", style="magenta") stats_table.add_row("Total Trees", str(inspector.total_trees)) stats_table.add_row("Total Entities", str(inspector.total_entities)) stats_table.add_row("Total Groups", str(inspector.total_groups)) stats_table.add_row("Total Relations", str(inspector.total_relations)) stats_table.add_row("Average Tree Height", f"{inspector.avg_height:.3f}") stats_table.add_row("Maximum Tree Height", str(inspector.max_height)) stats_table.add_row("Average Tree size", f"{inspector.avg_size:.3f}") stats_table.add_row("Maximum Tree size", str(inspector.max_size)) stats_table.add_row("Average Branching", f"{inspector.avg_branching:.3f}") stats_table.add_row("Maximum Branching", str(inspector.max_children)) if redundancy: datasets = schema.extract_datasets(forest) for tau in (1.0, 0.7, 0.5): redundancy = sum(redundancy_score(ds, tau=tau) for ds in datasets.values()) / len(datasets) stats_table.add_row(f"Redundant Trees ({tau}:.1f)", f"{redundancy:.3f}") console.print(Columns([*tables, stats_table, get_schema_metrics(schema)], equal=True))
[docs] @app.command(help="Simplify a bunch of databased together.") def compare( src: Path = typer.Argument(..., exists=True, readable=True, help="Path of the data file to compare to."), dst: Path = typer.Argument(..., exists=True, readable=True, help="Path of the data file to compare."), *, tau: float = typer.Option(0.7, help="The similarity threshold.", min=0, max=1), decay: float = typer.Option(DECAY, help="The similarity decay factor.", min=0.001), ) -> None: # Metrics inspector_src = ForestInspector() inspector_dst = ForestInspector() with ( ZODBTreeBucket() as forest_src, ZODBTreeBucket() as forest_dst, ): trees_src = inspector_src(load_forest([src])) forest_src.update(trees_src, commit=True) metrics = Metrics(forest_src, tau=tau, decay=decay) trees_dst = inspector_dst(load_forest([dst])) forest_dst.update(trees_dst, commit=True) metrics.update(forest_dst) schema = Schema.from_forest(forest_dst, keep_unlabelled=False) show_metrics(metrics) show_valid_trees_metrics(metrics, schema, forest_dst, 0, False) # Entity Count tables = [] entities = inspector_src.entity_count.keys() | inspector_dst.entity_count.keys() for chunk in more_itertools.chunked_even(entities, 10): entity_table = Table() entity_table.add_column("Entity", style="cyan", no_wrap=True) entity_table.add_column("Count source", style="magenta") entity_table.add_column("Count destination", style="magenta") for entity in chunk: entity_table.add_row( entity, str(inspector_src.entity_count[entity]), str(inspector_dst.entity_count[entity]), ) tables.append(entity_table) # Display statistics stats_table = Table() stats_table.add_column("Metric", style="cyan", no_wrap=True) stats_table.add_column("Value source", style="magenta") stats_table.add_column("Value destination", style="magenta") stats_table.add_row("Total Trees", str(inspector_src.total_trees), str(inspector_dst.total_trees)) stats_table.add_row("Total Entities", str(inspector_src.total_entities), str(inspector_dst.total_entities)) stats_table.add_row("Total Groups", str(inspector_src.total_groups), str(inspector_dst.total_groups)) stats_table.add_row("Total Relations", str(inspector_src.total_relations), str(inspector_dst.total_relations)) stats_table.add_row("Average Tree Height", f"{inspector_src.avg_height:.3f}", f"{inspector_dst.avg_height:.3f}") stats_table.add_row("Maximum Tree Height", str(inspector_src.max_height), str(inspector_dst.max_height)) stats_table.add_row("Average Tree size", f"{inspector_src.avg_size:.3f}", f"{inspector_dst.avg_size:.3f}") stats_table.add_row("Maximum Tree size", str(inspector_src.max_size), str(inspector_dst.max_size)) stats_table.add_row("Average Branching", f"{inspector_src.avg_branching:.3f}", f"{inspector_dst.avg_branching:.3f}") stats_table.add_row("Maximum Branching", str(inspector_src.max_children), str(inspector_dst.max_children)) console.print(Columns([*tables, stats_table], equal=True))
[docs] @app.command(name='generate', help="Generate synthetic instance.") def instance_generator( *, sample: int = typer.Option(100, help="Number of sentences to sample from the corpus.", min=1), output: Path | None = typer.Option(None, help="Path to save the result."), ) -> None: """Generate synthetic database instances.""" schema = Schema.from_description( groups={ Group(name='SOSY', entities={'SOSY', 'ANATOMIE', 'SUBSTANCE'}), Group(name='TREATMENT', entities={'SUBSTANCE', 'DOSAGE', 'ADMINISTRATION', 'FREQUENCY'}), Group(name='EXAM', entities={'DIAGNOSTIC_PROCEDURE', 'ANATOMIE'}), }, relations={ Relation(name='PRESCRIPTION', left='SOSY', right='TREATMENT'), Relation(name='EXAM_RESULT', left='EXAM', right='SOSY'), }, ) show_schema(schema) with ( ZODBTreeBucket(storage_path=output) as forest, console.status("[cyan]Generating synthetic instances..."), ): trees = gen_instance(schema, size=sample, generate_collections=False) forest.update(trees, commit=True) console.print(f'[green]Generated {len(forest)} synthetic instances.[/]')
[docs] @app.command(name='cache-clear', help='Clear all the cache of ArchiTXT') def clear_cache( *, force: bool = typer.Option(False, help="Force the deletion of the cache without asking."), ) -> None: cache_path = user_cache_path('architxt') if not cache_path.exists(): console.print("[yellow]Cache is already empty or does not exist. Doing nothing.[/]") return if not force and not typer.confirm('All the cache data will be deleted. Are you sure?'): typer.Abort() shutil.rmtree(cache_path) console.print("[green]Cache cleared.[/]")
# Click command used for Sphinx documentation _click_command = get_command(app)