Source code for architxt.simplification.simple_rewrite
from collections.abc import Iterable
from contextlib import nullcontext
import more_itertools
from tqdm.auto import tqdm
from architxt.bucket import TreeBucket
from architxt.tree import NodeLabel, NodeType, Tree, has_type
from architxt.utils import BATCH_SIZE
__all__ = ['simple_rewrite']
def _simple_rewrite_tree(tree: Tree, group_ids: dict[tuple[str, ...], str]) -> None:
"""Rewrite of a single tree."""
if has_type(tree, NodeType.ENT) or not tree.has_unlabelled_nodes():
return
entities = tree.entity_labels()
group_key = tuple(sorted(entities))
if group_key not in group_ids:
group_ids[group_key] = str(len(group_ids) + 1)
group_label = NodeLabel(NodeType.GROUP, group_ids[group_key])
group_entities: list[Tree] = []
for entity in tree.entities():
if entity.label.name in entities:
group_entities.append(entity.copy())
entities.remove(entity.label.name)
group_tree = Tree(group_label, group_entities)
tree[:] = [group_tree]
[docs]
def simple_rewrite(forest: Iterable[Tree], *, commit: bool | int = BATCH_SIZE) -> None:
"""
Rewrite a forest into a valid schema, treating each tree as a distinct group.
This function processes each tree in the forest, collapsing its entities into a single
group node if the tree contains unlabelled nodes.
Each unique combination of entity labels is assigned a consistent group ID.
Duplicate entities are removed.
:param forest: A forest to be rewritten in place.
:param commit: When working with a `TreeBucket`, changes can be committed automatically .
- If False, no commits are made. Use this for small forests where you want to commit manually later.
- If True, commits after processing the entire forest in one transaction.
- If an integer, commits after processing every N tree.
To avoid memory issues with large forests, we recommend using batch commit on large forests.
"""
group_ids: dict[tuple[str, ...], str] = {}
if commit and isinstance(forest, TreeBucket) and isinstance(commit, int):
for chunk in more_itertools.ichunked(tqdm(forest, desc="Rewriting trees"), commit):
with forest.transaction():
for tree in chunk:
_simple_rewrite_tree(tree, group_ids)
else:
with forest.transaction() if commit and isinstance(forest, TreeBucket) else nullcontext():
for tree in tqdm(forest, desc="Rewriting trees"):
_simple_rewrite_tree(tree, group_ids)