Source code for architxt.simplification.tree_rewriting.operations.collections

from copy import deepcopy
from itertools import groupby
from typing import Any

import more_itertools

from architxt.similarity import TREE_CLUSTER
from architxt.tree import NodeLabel, NodeType, Tree, has_type

from .operation import Operation

__all__ = [
    'FindCollectionsOperation',
]


[docs] class FindCollectionsOperation(Operation): """ Identifies and groups nodes into collections. The operation can operate in two modes: 1. Naming-only mode: Simply assigns labels to valid collections without altering the tree's structure. 2. Structural modification mode: Groups nodes into collection sets, updates their labels, and restructures the tree by creating collection nodes. """ def __init__(self, *args: Any, naming_only: bool = False, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.naming_only = naming_only
[docs] def apply(self, tree: Tree, *, equiv_subtrees: TREE_CLUSTER) -> tuple[Tree, bool]: # noqa: ARG002 simplified = False for subtree in sorted( tree.subtrees( lambda x: not has_type(x) and any(has_type(y, {NodeType.GROUP, NodeType.REL, NodeType.COLL}) for y in x) ), key=lambda x: x.depth(), reverse=True, ): # Naming-only mode: apply labels without modifying tree structure if self.naming_only: if has_type(subtree[0], {NodeType.GROUP, NodeType.REL}) and more_itertools.all_equal( subtree, key=lambda x: x.label() ): subtree.set_label(NodeLabel(NodeType.COLL, subtree[0].label().name)) simplified = True continue # Group nodes by shared label and organize them into collection sets for structural modification for coll_tree_set in sorted( filter( lambda x: len(x) > 1, ( sorted(equiv_set, key=lambda x: x.parent_index()) for _, equiv_set in groupby( sorted( filter(lambda x: has_type(x, {NodeType.GROUP, NodeType.REL, NodeType.COLL}), subtree) ), key=lambda x: x.label().name, ) ), ), key=lambda x: x[0].parent_index(), ): # Prepare a new collection of nodes (merging if some nodes are already collections) coll_elements = [] for coll_tree in coll_tree_set: if has_type(coll_tree, NodeType.COLL): coll_elements.extend(coll_tree) # Merge collection elements else: coll_elements.append(coll_tree) # Prepare the collection node label = NodeLabel(NodeType.COLL, coll_tree_set[0].label().name) children = [deepcopy(tree) for tree in coll_elements] # Log the creation of a new collection in MLFlow, if active self._log_to_mlflow( { 'name': label.name, 'size': len(children), } ) simplified = True # If the entire subtree is a single collection, update its label and structure directly if len(subtree) == len(coll_tree_set): subtree.set_label(label) subtree.clear() subtree.extend(children) else: index = coll_tree_set[0].parent_index() # Remove nodes of the current collection set from the subtree for coll_tree in coll_tree_set: subtree.pop(coll_tree.parent_index(), recursive=False) # Insert the new collection node at the appropriate index coll_tree = Tree(label, children=children) subtree.insert(index, coll_tree) return tree, simplified