from __future__ import annotations
import itertools
import json
import re
import unicodedata
import warnings
from collections import Counter
from difflib import get_close_matches
from typing import TYPE_CHECKING
import json_repair
import mlflow
import more_itertools
from aiostream import Stream, pipe, stream
from json_repair import JSONReturnType
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import NumberedListOutputParser
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
)
from mlflow.entities import SpanEvent, SpanType
from tqdm.auto import tqdm, trange
from architxt.bucket import TreeBucket
from architxt.forest import export_forest_to_jsonl
from architxt.metrics import Metrics
from architxt.schema import Schema
from architxt.similarity import DECAY, DEFAULT_METRIC, METRIC_FUNC
from architxt.tree import Forest, NodeLabel, NodeType, Tree, TreeOID, has_type
from architxt.utils import windowed_shuffle
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Collection, Iterable, Sequence
from pathlib import Path
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
__all__ = ['estimate_tokens', 'llm_rewrite']
DEFAULT_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template("""
You are a data-engineer agent whose task is deterministic JSON tree normalization and schema induction for noisy JSON trees.
Goal: produce one simplified, canonical JSON tree per input tree.
You can restructure JSON trees by adding, removing, renaming, or moving nodes.
ENT = property, GROUP = table, REL = relation.
Keep existing groups and relations when possible.
All trees should share the same vocabulary, rename groups and relations according to the relevant vocabulary.
{vocab}
Node format:
{{"oid":<str|null>,"name":<str>,"type":"GROUP"|"REL"|"ENT"|null,"metadata":<obj|null>,"children":[...]}}
Rules:
- Do **NOT** modify or rename ENT nodes.
- You can duplicates ENT nodes if needed.
- Return one simplified tree per input. No notes or explanations.
- Each output tree must start with root:
{{"oid":null,"name":"ROOT","type":null,"metadata":{{}},"children":[...]}}
- Create meaningful GROUP nodes to collect related ENT nodes.
- Link GROUPs with REL nodes where appropriate.
- Preserve original oids; any new node gets "oid":null.
- Keep the tree structure as close as possible to the original one.
- Create generic semantic group names (eg. Person). Avoid dataset- or domain-specific names (eg. prefer Exam over EGC).
Your response should be a numbered list with each item on a new line (do not put linebreak in the resulting json).
For example:
1. {{...}}
2. {{...}}
3. {{...}}
"""),
HumanMessage("""
1. {"oid":"1","name":"UNDEF","type":null,"children":[{"oid":"2","name":"FruitName","type":"ENT","children":["banana"]},{"oid":"3","name":"Color","type":"ENT","children":["yellow"]}]}
2. {"oid":"4","name":"UNDEF","type":null,"children":[{"oid":"5","name":"FruitName","type":"ENT","children":["orange"]},{"oid":"6","name":"PersonName","type":"ENT","children":["Alice"]},{"oid":"7","name":"Age","type":"ENT","children":["30"]}]}
"""),
AIMessage("""
1. {"oid":null,"name":"ROOT","type":null,"children":[{"oid":"1","name":"Fruit","type":"GROUP","children":[{"oid":"2","name":"FruitName","type":"ENT","children":["banana"]},{"oid":"3","name":"Color","type":"ENT","children":["yellow"]}]}]}
2. {"oid":null,"name":"ROOT","type":null,"children":[{"oid":null,"name":"Eat","type":"REL","children":[{"oid":null,"name":"Fruit","type":"GROUP","children":[{"oid":"5","name":"FruitName","type":"ENT","children":["orange"]}]},{"oid":null,"name":"Person","type":"GROUP","children":[{"oid":"6","name":"PersonName","type":"ENT","children":["Alice"]},{"oid":"7","name":"Age","type":"ENT","children":["30"]}]}]}]}
"""),
HumanMessagePromptTemplate.from_template("{trees}"),
]
)
def _trees_to_markdown_list(trees: Iterable[Tree]) -> str:
"""
Create a numbered Markdown list where each line is a JSON representation of a :py:class:`~architxt.tree.Tree`.
:param trees: An Iterable of trees to format
:return: A string with one line per tree in the form "N. <json>", using compact separators and stable key ordering.
"""
return '\n\n'.join(
f'{i}. {json.dumps(tree.to_json(), ensure_ascii=False, separators=(",", ":"), sort_keys=True)}'
for i, tree in enumerate(trees, start=1)
if isinstance(tree, Tree)
)
def _parse_tree(json_data: JSONReturnType) -> Tree:
"""
Parse a JSON object into a Tree.
:param json_data: The JSON object to parse.
:raise ValueError: If the JSON object is not a valid tree.
:raise TypeError: If the JSON object is of an invalid type.
:return: The parsed Tree.
"""
if not json_data:
msg = 'Empty JSON data cannot be parsed into a tree.'
raise TypeError(msg)
if isinstance(json_data, dict):
tree = Tree.from_json(json_data)
elif isinstance(json_data, list):
children = [Tree.from_json(sub_tree) for sub_tree in json_data if isinstance(sub_tree, dict)]
if children:
tree = Tree('ROOT', children)
else:
msg = 'No valid tree objects found in JSON list data.'
raise ValueError(msg)
else:
msg = f'Invalid JSON data type for tree parsing: {type(json_data)}.'
raise TypeError(msg)
return tree
def _sanitize(tree: Tree, oid: TreeOID) -> Tree:
"""
Sanitize a :py:class:`~architxt.tree.Tree` in-place by renaming invalid nodes with a `UNDEF_<oid>` label.
:param tree: The tree to sanitize.
:param oid: The Tree OID to use.
:return: The sanitized tree.
"""
# ensure ROOT and assign old oid to avoid duplicates
children = [tree] if has_type(tree) else [child.detach() for child in tree]
tree = Tree('ROOT', children, oid=oid)
# ensure groups and relations are valid
for st in tree.subtrees(reverse=True):
if (has_type(st, NodeType.GROUP) and not all(has_type(c, NodeType.ENT) for c in st)) or (
has_type(st, NodeType.REL) and (len(st) != 2 or not all(has_type(c, NodeType.GROUP) for c in st))
):
st.label = f'UNDEF_{st.oid.hex}'
return tree
def _fix_vocab(tree: Tree, vocab: Collection[str], vocab_similarity: float = 0.6) -> Tree:
"""
Fix the vocabulary in the tree by updating GROUP and REL labels in-place to match canonical forms.
:param tree: Trees to fix.
:param vocab: Collection of canonical labels.
:param vocab_similarity: Similarity threshold in [0, 1] for merging labels.
:return: An updated tree with fixed vocabulary.
"""
for subtree in tree.subtrees():
if (
has_type(subtree, {NodeType.GROUP, NodeType.REL})
and (label := _normalize(subtree.label.name))
and (matches := get_close_matches(label, vocab, n=1, cutoff=vocab_similarity))
):
subtree.label = NodeLabel(subtree.label.type, matches[0])
return tree
def _parse_tree_output(
raw_output: str | None,
*,
fallback: Tree,
vocab: Collection[str] | None = None,
vocab_similarity: float = 0.6,
debug: bool = False,
) -> tuple[Tree, bool]:
"""
Parse a raw LLM output string into a Tree, returning the provided fallback when parsing fails or output is empty.
Attempts to repair and load JSON from raw_output, convert the object into a :py:class:`~architxt.tree.Tree`,
and wrap the parsed content under a ROOT node that reuses the fallback's oid before validating the result.
If parsing fails or the JSON does not contain a suitable object,
the original fallback :py:class:`~architxt.tree.Tree` is returned.
:param raw_output: The raw LLM output string to parse.
:param fallback: The fallback original :py:class:`~architxt.tree.Tree` to return when parsing fails.
:param vocab: Collection of canonical labels.
:param vocab_similarity: Similarity threshold in [0, 1] for merging labels.
:param debug: If True, emit warnings on parse errors and log JSON repair/parse metadata to MLflow.
:return: The parsed :py:class:`~architxt.tree.Tree`, or the original fallback if parsing is unsuccessful.
"""
if not raw_output:
return fallback, False
try:
raw_output = raw_output.strip()
json_data = json_repair.loads(raw_output, skip_json_loads=True, logging=debug)
if isinstance(json_data, tuple):
json_data, fixes = json_data
if fixes and (span := mlflow.get_current_active_span()):
event = SpanEvent(
name='JSON fixes', attributes={'json_fixes': [fix['text'] for fix in fixes if 'text' in fix]}
)
span.add_event(event)
tree = _parse_tree(json_data)
tree = _sanitize(tree, oid=fallback.oid)
if vocab:
tree = _fix_vocab(tree, vocab=vocab, vocab_similarity=vocab_similarity)
except (ValueError, TypeError) as error:
if debug:
warnings.warn(str(error), RuntimeWarning)
if span := mlflow.get_current_active_span():
span.record_exception(error)
else:
return tree, tree != fallback
return fallback, False
def _build_simplify_langchain_graph(
llm: BaseChatModel,
prompt: ChatPromptTemplate,
*,
vocab: Collection[str] | None = None,
vocab_similarity: float = 0.6,
debug: bool = False,
) -> Runnable[Sequence[Tree], Sequence[tuple[Tree, bool]]]:
"""
Build a LangChain graph that simplifies :py:class:`~architxt.tree.Tree` using the provided model and prompt.
:param llm: The LLM model to use for simplification.
:param prompt: The prompt template to use for simplification.
:param debug: If True, emit warnings on parse errors and log JSON repair/parse metadata to MLflow.
:return: A Runnable LangChain graph that simplifies :py:class:`~architxt.tree.Tree`.
"""
to_json = RunnableLambda(lambda trees: {"trees": _trees_to_markdown_list(trees)})
llm_chain = to_json | prompt | llm.with_retry(stop_after_attempt=10) | NumberedListOutputParser()
parallel = RunnableParallel(origin=RunnablePassthrough(), simplified=llm_chain)
tree_parser = RunnableLambda(
lambda result: tuple(
_parse_tree_output(simplified, fallback=origin, vocab=vocab, vocab_similarity=vocab_similarity, debug=debug)
for origin, simplified in itertools.zip_longest(
result['origin'], result['simplified'][: len(result['origin'])]
)
)
)
return parallel | tree_parser
[docs]
def count_tokens(llm: BaseLanguageModel, trees: Iterable[Tree]) -> int:
"""
Count the number of tokens in the prompt for a set of trees.
:param llm: LLM model to use.
:param trees: Sequence of trees to simplify.
:return: Number of tokens in the formatted prompt.
"""
json_trees = _trees_to_markdown_list(trees)
return llm.get_num_tokens(json_trees)
[docs]
def estimate_tokens(
trees: Iterable[Tree],
llm: BaseLanguageModel,
max_tokens: int,
*,
prompt: BasePromptTemplate = DEFAULT_PROMPT,
refining_steps: int = 0,
error_adjustment: float = 1.2,
) -> tuple[int, int, int]:
"""
Estimate the total number of tokens (input/output) and queries required for a rewrite.
:param trees: Sequence of trees to simplify.
:param llm: LM model to use.
:param max_tokens: Maximum number of tokens to allow per prompt.
:param prompt: Prompt template to use.
:param refining_steps: Number of refining steps to perform after the initial rewrite.
:param error_adjustment: Factor to adjust the estimated number of tokens for error.
:return: The total number of tokens (input/output) and the number of queries estimated for a rewrite.
"""
prompt_tokens = llm.get_num_tokens(prompt.format(trees='', vocab=''))
batches = more_itertools.constrained_batches(
trees,
max_size=max_tokens - prompt_tokens,
get_len=lambda x: count_tokens(llm, [x]),
strict=False,
)
queries = 0
input_tokens = 0
output_tokens = 0
for batch in batches:
queries += 1
tokens = count_tokens(llm, batch)
input_tokens += prompt_tokens + tokens
output_tokens += tokens
return (
int(input_tokens * (refining_steps + 1) * error_adjustment),
int(output_tokens * (refining_steps + 1) * error_adjustment),
queries * (refining_steps + 1),
)
[docs]
async def llm_simplify(
llm: BaseChatModel,
max_tokens: int,
prompt: ChatPromptTemplate,
trees: Iterable[Tree],
*,
vocab: Collection[str] | None = None,
vocab_similarity: float = 0.6,
task_limit: int = 4,
debug: bool = False,
) -> AsyncGenerator[tuple[Tree, bool], None]:
"""
Simplify parse trees using an LLM.
It uses the following flow where the tree parser falls back to the original tree in case of parsing errors:
.. mermaid::
:alt: ArchiTXT Schema
:align: center
---
config:
theme: neutral
---
flowchart LR
A[Trees] --> B[Convert to JSON] --> C[LLM]
A & C --> E[Tree parser]
E --> F[Simplified trees]
:param llm: LLM model to use.
:param max_tokens: Maximum number of tokens to allow per prompt.
:param prompt: Prompt template to use.
:param trees: Sequence of trees to simplify.
:param vocab: Optional list of vocabulary words to use in the prompt.
:param vocab_similarity: Similarity threshold in [0, 1] for merging labels.
:param task_limit: Maximum number of concurrent requests to make.
:param debug: Whether to enable debug logging.
:yield: Simplified trees objects with the same oid as input.
"""
vocab_str = f"Prefer these labels : {', '.join(vocab)}." if vocab else ""
prompt = prompt.partial(vocab=vocab_str)
chain = _build_simplify_langchain_graph(llm, prompt, vocab=vocab, vocab_similarity=vocab_similarity, debug=debug)
prompt_tokens = llm.get_num_tokens(prompt.format(trees=''))
# Group trees respecting the maximum number of tokens per prompt
batches = more_itertools.constrained_batches(
trees,
max_size=max_tokens - prompt_tokens,
get_len=lambda x: count_tokens(llm, [x]),
strict=False,
)
@mlflow.trace(name='llm-invoke', span_type=SpanType.CHAIN)
async def _safe_traced_invoke(tree_batch: Sequence[Tree]) -> Sequence[tuple[Tree, bool]]:
try:
return await chain.ainvoke(tree_batch)
except Exception as error:
warnings.warn(str(error), RuntimeWarning)
if span := mlflow.get_current_active_span():
span.record_exception(error)
return [(orig_tree, False) for orig_tree in tree_batch]
# Run queries concurrently
tree_stream: Stream[Sequence[tuple[Tree, bool]]] = stream.iterate(batches) | pipe.amap(
_safe_traced_invoke, ordered=False, task_limit=task_limit
)
async with tree_stream.stream() as streamer:
async for batch in streamer:
for tree, simplified in batch:
yield tree, simplified
def _normalize(s: str) -> str:
"""
Normalize a string for vocabulary extraction.
Applies Unicode NFKC normalization, removes non-alphanumeric characters,
and converts to upper snake_case (e.g., "hello, world" -> "HELLO_WORLD").
:param s: String to normalize.
:return: Normalized upper snake_case string, or empty string if no alphanumeric characters.
"""
# unicode normalize
s = unicodedata.normalize('NFKC', s)
# keep alnum and spaces
s = ''.join(ch if ch.isalnum() else ' ' for ch in s)
# convert to upper case
s = s.strip().upper()
# convert to snake_case
return re.sub(r'\s+', '_', s)
def _get_mlflow_schema(forest: Forest) -> dict:
schema = Schema.from_forest(forest)
return {
'forest.size': len(forest),
'schema.size': len(schema.productions()),
'schema.entities': sorted(schema.entities),
'schema.groups': sorted({group.name for group in schema.groups}),
'schema.relations': sorted({relation.name for relation in schema.relations}),
}
[docs]
async def llm_rewrite(
forest: Forest,
llm: BaseChatModel,
max_tokens: int,
tau: float = 0.7,
decay: float = DECAY,
min_support: int | None = None,
vocab_similarity: float = 0.6,
refining_steps: int = 0,
debug: bool = False,
intermediate_output_path: Path | None = None,
task_limit: int = 1,
metric: METRIC_FUNC = DEFAULT_METRIC,
prompt: ChatPromptTemplate = DEFAULT_PROMPT,
commit: bool | int = True,
) -> Metrics:
"""
Rewrite a forest into a valid schema using a LLM agent.
:param forest: A forest to be rewritten in place.
:param llm: The LLM model to interact with for rewriting and simplification tasks.
:param max_tokens: The token limit of the prompt.
:param tau: Threshold for subtree similarity when clustering.
:param decay: The similarity decay factor.
The higher the value, the more the weight of context decreases with distance.
:param min_support: Minimum support for vocab.
:param vocab_similarity: Similarity threshold in [0, 1] for merging vocabulary labels.
:param refining_steps: Number of refining steps to perform after the initial rewrite.
:param debug: Whether to enable debug logging.
:param intermediate_output_path: Optional path to save intermediate results after each iteration.
:param task_limit: Maximum number of concurrent requests to make.
:param metric: The metric function used to compute similarity between subtrees.
:param prompt: The prompt template to use for the LLM during the simplification.
:param commit: Commit automatically if using TreeBucket. If already in a transaction, no commit is applied.
- If False, no commits are made, it relies on the current transaction.
- If True (default), commits in batch.
- If an integer, commits every N tree.
To avoid memory issues, we recommend using incremental commit with large iterables.
:return: A `Metrics` object encapsulating the results and metrics calculated for the LLM rewrite process.
"""
metrics = Metrics(forest, tau=tau, decay=decay, metric=metric)
min_support = min_support or max((len(forest) // 20), 2)
if mlflow.active_run():
mlflow.log_params(
{
'nb_sentences': len(forest),
'tau': tau,
'decay': decay,
'min_support': min_support,
'vocab_similarity': vocab_similarity,
'metric': metric.__name__,
'refining_steps': refining_steps,
}
)
metrics.log_to_mlflow(0, debug=debug)
mlflow_schema = _get_mlflow_schema(forest)
for iteration in trange(refining_steps + 1, leave=False, desc='rewriting iterations'):
with mlflow.start_span(
'llm-rewriting',
span_type=SpanType.CHAIN,
attributes={
'step': iteration,
},
) as iteration_span:
iteration_span.set_inputs(mlflow_schema)
vocab = extract_vocab(forest, min_support, vocab_similarity)
shuffled_forest = tqdm(windowed_shuffle(forest), leave=False, total=len(forest), desc='simplifying')
simplification = llm_simplify(
llm,
max_tokens,
prompt,
shuffled_forest,
vocab=vocab,
vocab_similarity=vocab_similarity,
task_limit=task_limit,
debug=debug,
)
# Track if any tree was modified
any_modified = False
async def _simplification_wrap() -> AsyncGenerator[Tree, None]:
nonlocal any_modified
async for tree, simplified in simplification:
if simplified:
any_modified = True
yield tree
if isinstance(forest, TreeBucket):
await forest.async_update(_simplification_wrap(), commit=commit)
else:
forest[:] = [tree async for tree in _simplification_wrap()]
mlflow_schema = _get_mlflow_schema(forest)
iteration_span.set_outputs(mlflow_schema)
iteration_span.set_attribute('simplified', any_modified)
# Save intermediate results
if intermediate_output_path:
intermediate_output_path.mkdir(parents=True, exist_ok=True)
intermediate_file = intermediate_output_path / f'intermediate_{iteration}.jsonl'
export_forest_to_jsonl(intermediate_file, forest)
# Log metrics to MLflow
if mlflow.active_run():
metrics.update()
metrics.log_to_mlflow(iteration + 1, debug=debug)
# Early stopping if no tree was modified
if not any_modified:
break
metrics.update()
return metrics