Source code for architxt.simplification.tree_rewriting.operations.relations

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

import more_itertools

from architxt.tree import NodeLabel, NodeType, Tree, has_type

from .operation import Operation

if TYPE_CHECKING:
    from architxt.similarity import TREE_CLUSTER

__all__ = [
    'FindRelationsOperation',
]


def _is_valid_relation(tree: Tree) -> bool:
    return len(tree) == 2 and has_type(tree[0], NodeType.GROUP) and has_type(tree[1], NodeType.GROUP)


[docs] class FindRelationsOperation(Operation): """ Identifies and establishes hierarchical relationships between `GROUP` nodes within a tree structure. The function scans for subtrees that contain at least two distinct elements. When a `GROUP` node is found to have a relationship with a collection, that relationship is distributed between the `GROUP` node itself and each member of the collection. The operation can operate in two modes: 1. Naming-only mode: Simply assigns labels to valid relations without altering the tree's structure. 2. Structural modification mode: restructures the tree by creating relation nodes between groups and collections. """ def __init__(self, *args: Any, naming_only: bool = False, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.naming_only = naming_only
[docs] def apply(self, tree: Tree, *, equiv_subtrees: TREE_CLUSTER) -> bool: # noqa: ARG002 simplified = False # Traverse subtrees, starting with the deepest, containing exactly 2 children for subtree in sorted( tree.subtrees( lambda x: len(x) == 2 and not has_type(x, {NodeType.ENT, NodeType.GROUP}) and all(has_type(y, {NodeType.GROUP, NodeType.COLL}) for y in x) ), key=lambda x: x.depth, reverse=True, ): if _is_valid_relation(subtree): # Group <-> Group if self._create_group_group_relation(subtree): simplified = True elif not self.naming_only and self._create_group_collection_relation(subtree): # Group <-> Collection simplified = True return simplified
def _create_group_group_relation(self, tree: Tree) -> bool: modified = False if not _is_valid_relation(tree): return False if tree[0].label.name == tree[1].label.name: return False # Create and set the relationship label label = sorted([tree[0].label.name, tree[1].label.name]) rel_label = NodeLabel(NodeType.REL, f'{label[0]}<->{label[1]}') # Log relation creation in MLFlow, if active if not has_type(tree, NodeType.REL): modified = True self._log_to_mlflow({'name': rel_label}) tree.label = rel_label return modified def _create_group_collection_relation(self, tree: Tree) -> bool: if len(tree) != 2: return False if has_type(tree[0], NodeType.GROUP) and has_type(tree[1], NodeType.COLL): group, collection = tree[0], tree[1] elif has_type(tree[0], NodeType.COLL) and has_type(tree[1], NodeType.GROUP): collection, group = tree[0], tree[1] else: return False # If a valid Group-Collection pair is found, create relationships for each if ( len(collection) == 0 or not all(has_type(x, NodeType.GROUP) for x in collection) or not more_itertools.all_equal(collection, lambda x: x.label.name) ): warnings.warn("Collection is empty or does not contain homogeneous GROUP nodes.") return False group_label = group.label.name collection_label = collection[0].label.name if group_label == collection_label: return False label1, label2 = sorted((group_label, collection_label)) rel_label = NodeLabel(NodeType.REL, f'{label1}<->{label2}') # Create relationship nodes for each element in the collection for coll_group in collection[:]: rel_tree = Tree(rel_label, children=[group.copy(), coll_group.detach()]) tree.append(rel_tree) # Add new relationship to subtree # Log relation creation in MLFlow, if active self._log_to_mlflow({'name': rel_label.name}) tree.remove(group) tree.remove(collection) return True