Source code for architxt.ui.page.labelling

import pandas as pd
import streamlit as st

from architxt.bucket import TreeBucket
from architxt.labelling import Renaming, apply_renaming, llm_group_labelling, llm_relation_labelling
from architxt.llm import get_chat_model
from architxt.schema import Schema
from architxt.tree import NodeType
from architxt.ui.utils import get_forest, get_schema, update_metrics


def _render_editor(forest: TreeBucket, schema: Schema) -> None:
    col1, col2 = st.columns(2)

    with col1:
        st.write("**Groups**")
        st.session_state.group_renames = st.data_editor(
            st.session_state.group_renames,
            key="group_editor",
            hide_index=True,
            width="stretch",
            disabled=["Current Name"],
        )

    with col2:
        st.write("**Relations**")
        st.session_state.relation_renames = st.data_editor(
            st.session_state.relation_renames,
            key="relation_editor",
            hide_index=True,
            width="stretch",
            disabled=["Current Name"],
        )

    renames = []
    for _, row in st.session_state.group_renames.iterrows():
        if row["Current Name"] != row["New Name"]:
            renames.append(Renaming(NodeType.GROUP, row["Current Name"], row["New Name"]))

    for _, row in st.session_state.relation_renames.iterrows():
        if row["Current Name"] != row["New Name"]:
            renames.append(Renaming(NodeType.REL, row["Current Name"], row["New Name"]))

    c1, c2 = st.columns(2)
    if c1.button("Reset", width="stretch"):
        _reset_labelling_tables(schema)
        st.rerun()

    if c2.button("Apply Renaming", width="stretch", disabled=not renames):
        with (
            st.spinner("Applying new labels to forest..."),
            forest.transaction(),
        ):
            apply_renaming(forest, renames)

        st.toast(f"Applied {len(renames)} renaming.")
        update_metrics()


def _render_llm_labelling(forest: TreeBucket, schema: Schema) -> None:
    st.subheader('LLM Auto-labelling')

    st.warning("LLM Labelling requires extra dependencies and an LLM provider.")

    local: bool = st.session_state.get("llm_local", False)
    c1, c2 = st.columns(2)
    model_provider = c1.text_input("Provider", "huggingface", disabled=local)
    model_name = c2.text_input("Model", value="deepseek-ai/DeepSeek-V3")

    with st.expander("Advanced LLM Settings"):
        c1, c2, c3 = st.columns(3)
        temperature = c1.number_input("Temperature", 0.0, 1.0, 0.2)
        max_tokens = c2.number_input("Max Tokens", min_value=256, value=4096, step=128)
        sample_size = c3.number_input("Sample Size", min_value=0, value=5)

    if st.button("Get AI Suggestions"):
        with st.status("Computing new labels..."):
            st.write("Loading LLM...")
            llm = get_chat_model(
                model_provider,
                model_name,
                max_tokens=max_tokens,
                temperature=temperature,
                local=local,
                openvino=st.session_state.get("llm_openvino", False),
            )

            st.write("Generating group names...")
            groups_renames = llm_group_labelling(schema, llm, forest=forest, sample_size=sample_size)

            # Update group dataframe in session state
            new_group_df = st.session_state.group_renames.copy()
            rename_dict = {r.old_name: r.new_name for r in groups_renames}
            new_group_df["New Name"] = new_group_df["Current Name"].map(lambda x: rename_dict.get(x, x))
            st.session_state.group_renames = new_group_df

            st.write("Generating relation names...")
            relations_renames = llm_relation_labelling(schema, llm, group_renames=groups_renames)

            # Update relation dataframe in session state
            new_relation_df = st.session_state.relation_renames.copy()
            rel_rename_dict = {r.old_name: r.new_name for r in relations_renames}
            new_relation_df["New Name"] = new_relation_df["Current Name"].map(lambda x: rel_rename_dict.get(x, x))
            st.session_state.relation_renames = new_relation_df

            st.toast("AI suggestions generated! Review them in the tables above.")
            st.rerun()


def _reset_labelling_tables(schema: Schema) -> None:
    group_names = {g.name for g in schema.groups}
    st.session_state.group_renames = pd.DataFrame([{"Current Name": g, "New Name": g} for g in sorted(group_names)])

    rels_names = {r.name for r in schema.relations}
    st.session_state.relation_renames = pd.DataFrame([{"Current Name": r, "New Name": r} for r in sorted(rels_names)])


[docs] @st.fragment def labelling() -> None: st.header("Labelling") forest = get_forest() schema = get_schema() if len(schema.groups) == 0: st.warning("No groups found in the forest.") return # Initialize session state for renaming if "group_renames" not in st.session_state: _reset_labelling_tables(schema) _render_editor(forest, schema) st.divider() _render_llm_labelling(forest, schema)