Source code for architxt.ui.page.importer

from collections.abc import Generator
from functools import partial

import anyio
import pandas as pd
import streamlit as st
from spacy.util import get_installed_models
from streamlit_tags import st_tags

from architxt.bucket import TreeBucket
from architxt.cli.loader import ENTITIES_FILTER, ENTITIES_MAPPING, RELATIONS_FILTER
from architxt.database import loader
from architxt.database.loader import read_document
from architxt.forest import import_forest_from_jsonl
from architxt.nlp import raw_load_corpus
from architxt.nlp.entity_resolver import ScispacyResolver
from architxt.nlp.parser.benepar import BeneparParser
from architxt.nlp.parser.corenlp import CoreNLPParser
from architxt.tree import Tree
from architxt.ui.utils import get_forest, get_neo4j_driver, get_sql_engine, update_metrics

RESOLVER_NAMES = {
    None: 'No resolution',
    'umls': 'Unified Medical Language System (UMLS)',
    'mesh': 'Medical Subject Headings (MeSH)',
    'rxnorm': 'RxNorm',
    'go': 'Gene Ontology (GO)',
    'hpo': 'Human Phenotype Ontology (HPO)',
}


@st.cache_resource
def get_corenlp_parser(corenlp_url: str) -> CoreNLPParser:
    return CoreNLPParser(corenlp_url=corenlp_url)


@st.cache_resource
def get_benepar_parser(language: str, spacy_model: str) -> BeneparParser:
    return BeneparParser(spacy_models={language: spacy_model})


def _render_text_corpus_import(forest: TreeBucket) -> None:
    uploaded_files = st.file_uploader('Corpora', ['.tar.gz', '.tar.xz', '.txt'], accept_multiple_files=True)

    col1, col2, col3 = st.columns(3)
    language = col1.selectbox("Language", ['English', 'French', 'German'])
    parser_type = col2.selectbox("Parser", ["Benepar", "CoreNLP"])

    if parser_type == "Benepar":
        spacy_model = col3.selectbox("Spacy Model", get_installed_models())
        parser = get_benepar_parser(language, spacy_model)
    else:
        corenlp_url = col3.text_input('CoreNLP URL', value='http://localhost:9000')
        parser = get_corenlp_parser(corenlp_url)

    col1, col2 = st.columns(2)
    resolver_name = col1.selectbox('Entity Resolver', options=RESOLVER_NAMES.keys(), format_func=RESOLVER_NAMES.get)
    sample = col2.number_input('Sample', min_value=1, value=None)

    with st.expander("Advanced NLP Settings"):
        entities_filter = st_tags(label='Excluded entities', value=list(ENTITIES_FILTER))
        relations_filter = st_tags(label='Excluded relations', value=list(RELATIONS_FILTER))
        st.text('Entity mapping')

        # Convert dict to DataFrame for editing
        mapping_df = pd.DataFrame(ENTITIES_MAPPING.items(), columns=['From', 'To'])
        edited_mapping_df = st.data_editor(mapping_df, width='stretch', num_rows="dynamic").dropna()
        # Convert back to dict
        entity_mapping = dict(zip(edited_mapping_df['From'], edited_mapping_df['To']))

    if st.button("Load Corpus", disabled=not uploaded_files):
        resolver = None
        if resolver_name:
            try:
                resolver = ScispacyResolver(kb_name=resolver_name)
            except Exception as e:
                st.warning(f"Could not initialize resolver: {e}")

        trees = raw_load_corpus(
            uploaded_files,
            [language] * len(uploaded_files),
            entities_filter=set(entities_filter),
            relations_filter=set(relations_filter),
            entities_mapping=entity_mapping,
            parser=parser,
            resolver=resolver,
            sample=sample,
            cache=st.session_state.get("cache", True),
        )

        with st.spinner("Parsing and Loading Corpus..."):
            forest_update = partial(forest.async_update, commit=True)
            anyio.run(forest_update, trees)

        update_metrics()


def _render_sql_import(forest: TreeBucket) -> None:
    sql_uri = st.text_input("SQL URI", value="sqlite:///example.db", help="e.g., postgresql://user:pass@localhost/db")

    if st.button("Load SQL"):
        engine = get_sql_engine(sql_uri)

        with (
            st.spinner("Loading from SQL..."),
            engine.connect() as connection,
        ):
            trees = loader.read_sql(connection)
            forest.update(trees, commit=True)

        update_metrics()


def _render_graph_import(forest: TreeBucket) -> None:
    col1, col2, col3 = st.columns(3)
    graph_uri = col1.text_input("Graph URI", value="bolt://localhost:7687")
    username = col2.text_input("Username", value="neo4j")
    password = col3.text_input("Password", type="password")

    if st.button("Load Graph"):
        driver = get_neo4j_driver(graph_uri, username=username, password=password)

        with (
            st.spinner("Loading from Graph..."),
            driver.session() as session,
        ):
            trees = loader.read_cypher(session)
            forest.update(trees, commit=True)

        update_metrics()


def _render_document_import(forest: TreeBucket) -> None:
    uploaded_documents = st.file_uploader(
        "Document File",
        [".json", ".toml", ".yml", ".yaml", ".xml", ".csv", ".xls", ".xlsx"],
        accept_multiple_files=True,
    )

    if st.button("Load document", disabled=not uploaded_documents):
        with st.spinner("Loading from document..."):

            def load_trees() -> Generator[Tree, None, None]:
                for document in uploaded_documents:
                    yield from read_document(document, root_name=document.name)

            forest.update(load_trees(), commit=True)

        update_metrics()


def _render_jsonl_import(forest: TreeBucket) -> None:
    uploaded_jsonls = st.file_uploader("JSONL File", ["jsonl"], accept_multiple_files=True)

    if st.button("Load JSONL", disabled=not uploaded_jsonls):
        with st.spinner("Loading from JSONL..."):

            def load_trees() -> Generator[Tree, None, None]:
                for jsonl in uploaded_jsonls:
                    yield from import_forest_from_jsonl(jsonl)

            forest.update(load_trees(), commit=True)

        update_metrics()


IMPORTER = {
    "JSONL File": _render_jsonl_import,
    "Document File": _render_document_import,
    "SQL Database": _render_sql_import,
    "Graph Database": _render_graph_import,
    "Text Corpus": _render_text_corpus_import,
}


[docs] @st.fragment def importer() -> None: st.header("Import Data") forest = get_forest() source_type = st.selectbox("Source Type", IMPORTER.keys()) IMPORTER[source_type](forest)