import asyncio
import random
from copy import deepcopy
import mlflow
import pandas as pd
import streamlit as st
from streamlit_agraph import Config , agraph
from streamlit_agraph import Edge as _Edge
from streamlit_agraph import Node as _Node
from streamlit_tags import st_tags
from architxt.cli import ENTITIES_FILTER , ENTITIES_MAPPING , RELATIONS_FILTER
from architxt.nlp import raw_load_corpus
from architxt.nlp.parser.corenlp import CoreNLPParser
from architxt.schema import Schema
from architxt.simplification.tree_rewriting import rewrite
from architxt.tree import Forest , Tree
RESOLVER_NAMES = {
None : 'No resolution' ,
'umls' : 'Unified Medical Language System (UMLS)' ,
'mesh' : 'Medical Subject Headings (MeSH)' ,
'rxnorm' : 'RxNorm' ,
'go' : 'Gene Ontology (GO)' ,
'hpo' : 'Human Phenotype Ontology (HPO)' ,
}
[docs]
class Node ( _Node ):
def __eq__ ( self , other : object ) -> bool :
return isinstance ( other , self . __class__ ) and self . id == other . id
def __hash__ ( self ) -> int :
return hash ( self . id )
def __repr__ ( self ) -> str :
return f 'Node( { self . id } )'
[docs]
class Edge ( _Edge ):
def __eq__ ( self , other : object ) -> bool :
return isinstance ( other , self . __class__ ) and self . source == other . source and self . to == other . to
def __hash__ ( self ) -> int :
return hash (( self . source , self . to ))
def __repr__ ( self ) -> str :
return f 'Edge( { self . source } , { self . to } )'
[docs]
@st . fragment ()
def graph ( schema : Schema ) -> None :
"""Render schema graph visualization."""
nodes = set ()
edges = set ()
for entity in schema . entities :
nodes . add ( Node ( id = entity , label = entity ))
for group , entities in schema . groups . items ():
nodes . add ( Node ( id = group , label = group ))
for entity in entities :
edges . add ( Edge ( source = group , target = entity ))
for relation , ( group1 , group2 ) in schema . relations . items ():
edges . add ( Edge ( source = group1 , target = group2 , label = relation ))
agraph ( nodes = nodes , edges = edges , config = Config ( directed = True ))
[docs]
@st . fragment ()
def dataframe ( forest : Forest ) -> None :
"""Render instance DataFrames."""
final_tree = Tree ( 'ROOT' , deepcopy ( forest ))
group_name = st . selectbox ( 'Group' , sorted ( final_tree . groups ()))
st . dataframe ( final_tree . group_instances ( group_name ), use_container_width = True )
st . title ( "ArchiTXT" )
with st . sidebar :
corenlp_url = st . text_input ( 'Corenlp URL' , value = 'http://localhost:9000' )
resolver_name = st . selectbox (
'Entity Resolver' ,
options = RESOLVER_NAMES . keys (),
format_func = RESOLVER_NAMES . get ,
)
input_tab , stats_tab , schema_tab , instance_tab = st . tabs ([ '📖 Corpus' , '📊 Metrics' , '📐 Schema' , '🗄️ Instance' ])
with input_tab :
uploaded_files = st . file_uploader ( 'Corpora' , [ '.tar.gz' , '.tar.xz' ], accept_multiple_files = True )
if uploaded_files :
file_language_table = st . data_editor (
pd . DataFrame ([{ 'Corpora' : file . name , 'Language' : 'English' } for file in uploaded_files ]),
column_config = {
'Corpora' : st . column_config . TextColumn ( disabled = True ),
'Language' : st . column_config . SelectboxColumn ( options = [ 'English' , 'French' ], required = True ),
},
hide_index = True ,
use_container_width = True ,
)
file_language = { row [ 'Corpora' ]: row [ 'Language' ] for _ , row in file_language_table . iterrows ()}
st . divider ()
with st . form ( key = 'corpora' , enter_to_submit = False ):
entities_filter = st_tags ( label = 'Excluded entities' , value = list ( ENTITIES_FILTER ))
relations_filter = st_tags ( label = 'Excluded relations' , value = list ( RELATIONS_FILTER ))
st . text ( 'Entity mapping' )
entity_mapping = st . data_editor ( ENTITIES_MAPPING , use_container_width = True , hide_index = True , num_rows = "dynamic" )
st . divider ()
col1 , col2 , col3 = st . columns ( 3 )
tau = col1 . number_input ( 'Tau' , min_value = 0.05 , max_value = 1.0 , step = 0.05 , value = 0.5 )
epoch = col2 . number_input ( 'Epoch' , min_value = 1 , step = 1 , value = 100 )
min_support = col3 . number_input ( 'Minimum Support' , min_value = 1 , step = 1 , value = 10 )
sample = col1 . number_input ( 'Sample size' , min_value = 0 , step = 1 , value = 0 , help = '0 means no sampling' )
shuffle = col2 . selectbox ( 'Shuffle' , options = [ True , False ])
submitted = st . form_submit_button ( "Start" )
[docs]
async def load_forest () -> list [ Tree ]:
if not uploaded_files :
return []
languages = [ file_language [ file . name ] for file in uploaded_files ]
return await raw_load_corpus (
uploaded_files ,
languages ,
entities_filter = set ( entities_filter ),
relations_filter = set ( relations_filter ),
entities_mapping = entity_mapping ,
parser = CoreNLPParser ( corenlp_url = corenlp_url ),
resolver_name = resolver_name ,
)
if submitted and file_language :
try :
if mlflow . active_run ():
mlflow . end_run ()
with st . spinner ( 'Computing...' ), mlflow . start_run ( description = 'UI run' , log_system_metrics = True ) as mlflow_run :
forest = asyncio . run ( load_forest ())
if sample :
forest = random . sample ( forest , sample )
if shuffle :
random . shuffle ( forest )
rewrite_forest = rewrite (
forest ,
tau = tau ,
epoch = epoch ,
min_support = min_support ,
)
# Display statistics tab
with stats_tab :
run_id = mlflow_run . info . run_id
client = mlflow . tracking . MlflowClient ()
st . line_chart (
pd . DataFrame (
[
metric . to_dictionary ()
for metric_name in [
'coverage' ,
'cluster_ami' ,
'cluster_completeness' ,
'overlap' ,
'balance' ,
]
for metric in client . get_metric_history ( run_id , metric_name )
]
),
x = 'step' ,
y = 'value' ,
color = 'key' ,
)
st . line_chart (
{
metric : [ x . value for x in client . get_metric_history ( run_id , metric )]
for metric in [
'num_productions' ,
'unlabeled_nodes' ,
'group_instance_total' ,
'relation_instance_total' ,
'collection_instance_total' ,
]
}
)
st . bar_chart ([ x . value for x in client . get_metric_history ( run_id , 'edit_op' )])
schema = Schema . from_forest ( rewrite_forest , keep_unlabelled = False )
# Display schema graph
with schema_tab :
graph ( schema )
# Display instance data
with instance_tab :
clean_forest = schema . extract_valid_trees ( rewrite_forest )
dataframe ( clean_forest )
except Exception as e :
st . error ( f "An error occurred: { e !s} " )
Copy to clipboard