Source code for architxt.simplification.simple_rewrite

from copy import deepcopy
from typing import Any

from tqdm.auto import tqdm

from architxt.tree import Forest, NodeLabel, NodeType, Tree


[docs] def simple_rewrite(forest: Forest, **_: Any) -> Forest: """ Rewrite a given forest into a valid schema, treating each tree as a distinct group. Entities within a tree are grouped together, and duplicate entities are discarded. The function creates a unique group name for each distinct set of entities. :param forest: Input forest consisting of a list of Tree objects. :return: A new forest where each tree is restructured as a valid group. Already valid trees are kept as is. """ new_forest: list[Tree] = [] group_ids: dict[tuple[str, ...], str] = {} for tree in tqdm(forest): if not tree.has_unlabelled_nodes(): new_forest.append(tree) continue entities = tree.entity_labels() group_key = tuple(sorted(entities)) if group_key not in group_ids: group_ids[group_key] = str(len(group_ids) + 1) group_label = NodeLabel(NodeType.GROUP, group_ids[group_key]) group_entities: list[Tree] = [] for entity in tree.entities(): if entity.label().name in entities: group_entities.append(deepcopy(entity)) entities.remove(entity.label().name) group_tree = Tree(group_label, group_entities) tree = Tree('ROOT', [group_tree]) new_forest.append(tree) return new_forest