Source code for architxt.ui.page.simplification

from functools import partial
from typing import TYPE_CHECKING

import anyio
import streamlit as st

from architxt.bucket import TreeBucket
from architxt.simplification.llm import llm_rewrite
from architxt.simplification.tree_rewriting import rewrite
from architxt.ui.utils import get_forest, update_metrics

if TYPE_CHECKING:
    from langchain_core.language_models import BaseChatModel


def _render_rule_based_simplification(forest: TreeBucket) -> None:
    c1, c2 = st.columns(2)
    tau = c1.number_input('Tau (Similarity)', 0.0, 1.0, 0.8)
    decay = c2.number_input('Decay', min_value=0.0, value=2.0)
    min_support = c1.number_input('Min Support', min_value=1, value=10)
    epoch = c2.number_input('Epochs', min_value=1, value=50)

    if st.button("Apply Rule-Based Simplification"):
        with st.spinner("Simplifying..."):
            rewrite(forest, tau=tau, decay=decay, epoch=epoch, min_support=min_support)

        update_metrics()


@st.cache_resource(scope="session")
def get_llm(
    provider: str,
    model_name: str,
    max_tokens: int,
    temperature: float,
    openvino: bool = False,
) -> 'BaseChatModel':
    from langchain.chat_models import init_chat_model

    if provider == "Local (HuggingFace)":
        from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline

        pipeline = HuggingFacePipeline.from_model_id(
            model_id=model_name,
            task='text-generation',
            device_map=None if openvino else 'auto',
            backend='openvino' if openvino else 'pt',
            model_kwargs={'export': True} if openvino else {'torch_dtype': 'auto'},
            pipeline_kwargs={
                'use_cache': True,
                'do_sample': True,
                'return_full_text': False,
                'max_new_tokens': max_tokens,
                'temperature': temperature,
                'repetition_penalty': 1.1,
                'num_return_sequences': 1,
                'pad_token_id': 0,
            },
        )
        return ChatHuggingFace(llm=pipeline)

    return init_chat_model(
        model=model_name,
        temperature=temperature,
        max_tokens=max_tokens,
    )


def _render_llm_based_simplification(forest: TreeBucket) -> None:
    st.warning("LLM Simplification requires extra dependencies and an LLM provider.")

    openvino = False
    provider = st.selectbox("Provider", ["Local (HuggingFace)", "API"])
    model_name = st.text_input(
        "Model Name", value="deepseek-ai/DeepSeek-V3" if provider == "Local (HuggingFace)" else "openai:o3-mini"
    )

    with st.expander("Advanced LLM Settings"):
        c1, c2, c3 = st.columns(3)
        temperature = c1.number_input("Temperature", 0.0, 1.0, 0.2)
        max_tokens = c2.number_input("Max Tokens", min_value=256, value=4096, step=128)
        refining_steps = c3.number_input("Refining Steps", min_value=0, value=3)

        tau = c1.number_input('Tau (Similarity)', 0.0, 1.0, 0.8)
        decay = c2.number_input('Decay', min_value=0.0, value=2.0)
        min_support = c3.number_input('Min Support', min_value=1, value=10)

        if provider == "Local (HuggingFace)":
            openvino = st.checkbox("Use OpenVINO (if available)", value=False)

    if st.button("Apply LLM Simplification"):
        llm = get_llm(provider, model_name, max_tokens, temperature, openvino)
        transform = partial(
            llm_rewrite,
            llm=llm,
            max_tokens=max_tokens,
            tau=tau,
            decay=decay,
            min_support=min_support,
            refining_steps=refining_steps,
        )

        with st.spinner("Simplifying with LLM..."):
            anyio.run(transform, forest)

        update_metrics()


TRANSFORMER = {
    "Default (rule-based)": _render_rule_based_simplification,
    "LLM": _render_llm_based_simplification,
}


[docs] @st.fragment def simplification() -> None: st.header("Simplify") forest = get_forest() if len(forest) == 0: st.info("No data loaded. Please import data using the top menu.") else: method = st.selectbox("Method", TRANSFORMER.keys()) TRANSFORMER[method](forest)