Source code for architxt.simplification.simple_rewrite
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import more_itertools
from tqdm.auto import tqdm
from architxt.bucket import TreeBucket
from architxt.tree import NodeLabel, NodeType, Tree, has_type
from architxt.utils import get_commit_batch_size
if TYPE_CHECKING:
from collections.abc import Iterable
__all__ = ['simple_rewrite']
def _simple_rewrite_tree(tree: Tree, group_ids: dict[tuple[str, ...], str]) -> None:
"""Rewrite of a single tree."""
if has_type(tree, NodeType.ENT) or not tree.has_unlabelled_nodes():
return
entities = tree.entity_labels()
group_key = tuple(sorted(entities))
if group_key not in group_ids:
group_ids[group_key] = str(len(group_ids) + 1)
group_label = NodeLabel(NodeType.GROUP, group_ids[group_key])
group_entities: list[Tree] = []
for entity in tree.entities():
if entity.label.name in entities:
group_entities.append(entity.copy())
entities.remove(entity.label.name)
group_tree = Tree(group_label, group_entities)
tree[:] = [group_tree]
[docs]
def simple_rewrite(forest: Iterable[Tree], *, commit: bool | int = True) -> None:
"""
Rewrite a forest into a valid schema, treating each tree as a distinct group.
This function processes each tree in the forest, collapsing its entities into a single
group node if the tree contains unlabelled nodes.
Each unique combination of entity labels is assigned a consistent group ID.
Duplicate entities are removed.
:param forest: A forest to be rewritten in place.
:param commit: Commit automatically if using TreeBucket. If already in a transaction not commit is applied.
- If False, no commits are made, it relies on the current transaction.
- If True (default), commits in batch.
- If an integer, commits every N tree.
To avoid memory issues, we recommend using incremental commit with large iterables.
"""
batch_size = get_commit_batch_size(commit)
group_ids: dict[tuple[str, ...], str] = {}
trees = tqdm(forest, desc="Rewriting trees")
for chunk in more_itertools.ichunked(trees, batch_size):
with forest.transaction() if isinstance(forest, TreeBucket) and commit else nullcontext():
for tree in chunk:
_simple_rewrite_tree(tree, group_ids)