Source code for architxt.simplification.tree_rewriting.operations.groups

from copy import deepcopy
from itertools import combinations

import more_itertools

from architxt.similarity import TREE_CLUSTER
from architxt.tree import NodeLabel, NodeType, Tree, has_type

from .operation import Operation

__all__ = [
    'FindSubGroupsOperation',
    'MergeGroupsOperation',
]


[docs] class FindSubGroupsOperation(Operation): """ Identifies and create subgroup of entities for each subtree. It creates a group only if the support of the newly created subgroup is greater than the support of the subtree. """ def _create_and_evaluate_subgroup( self, subtree: Tree, sub_group: tuple[Tree, ...], min_support: int, equiv_subtrees: TREE_CLUSTER ) -> tuple[Tree, int] | None: """ Attempt to add a new subtree by creating a new `GROUP` for a given `sub_group` of entities. It also evaluates the new group support within the `equiv_subtrees` equivalence class. :param subtree: The tree structure within which a potential subgroup will be created. :param sub_group: A tuple of `Tree` entities to be grouped into a new `GROUP` node. :param min_support: The support needed for a subgroup to be considered valid. :param equiv_subtrees: The cluster of equivalent subtrees in the forest. :return: A tuple containing the modified subtree and its support count if the modified subtree meets the minimum support threshold; otherwise, `None`. """ new_subtree = deepcopy(subtree) # Create a new GROUP node with the given entities from the sub_group. group_tree = Tree(NodeLabel(NodeType.GROUP), children=[deepcopy(ent) for ent in sub_group]) # Remove the used entities from the original subtree # and insert the new GROUP node at the earliest index of the sub_group. indices = sorted((ent.parent_index() for ent in sub_group), reverse=True) for idx in indices: new_subtree.pop(idx) # Insert the GROUP node at the position of the earliest entity in sub_group insertion_index = min(indices) new_subtree.insert(insertion_index, group_tree) # Reset label if subtree becomes invalid as a group if has_type(subtree, NodeType.GROUP): new_subtree.set_label('') # Compute support for the new subtree. It is a valid subgroup if its support exceeds the given threshold. new_group = new_subtree[insertion_index] equiv = self.get_equiv_of(new_group, equiv_subtrees=equiv_subtrees) support = len(equiv) if support >= min_support: new_group.set_label(equiv[0].label()) return new_subtree, support return None
[docs] def apply(self, tree: Tree, *, equiv_subtrees: TREE_CLUSTER) -> tuple[Tree, bool]: simplified = False # Generate candidate subtrees that do not include ENT, REL, or COLL nodes as their children. candidate_subtrees = sorted( tree.subtrees(lambda sub: all(has_type(child, NodeType.ENT) for child in sub)), key=lambda sub: sub.height(), ) for subtree in candidate_subtrees: parent = subtree.parent() parent_idx = subtree.parent_index() # Compute initial support for the subtree group_support = len(self.get_equiv_of(subtree, equiv_subtrees=equiv_subtrees)) entity_trees = [child for child in subtree if has_type(child, NodeType.ENT)] entity_labels = {ent.label() for ent in entity_trees} # To narrow down the search space, we focus on reducing the entity trees to consider. # We retain only those groups that appear in clusters with higher support than the actual subtree, # and where the entity set intersects with the current subtrees. # This allows us to reduce the set of entity labels to consider only those present in these selected groups. entity_groups = { tuple(sorted(x.label() for x in subtree)) for cluster in equiv_subtrees if len(cluster) > group_support for subtree in cluster if entity_labels.intersection(x.label() for x in subtree) } if not entity_groups: continue available_labels = {label for group in entity_groups for label in group} # In addition to limiting the search to a subset of entity labels, # we can also restrict the size of subgroups to consider. # This helps prevent combinatorial explosion by avoiding the evaluation of excessively large groups. # # In one hane, we know that subgroups should be smaller than the actual subtree. # On another hand, similarity is unlikely when groups differ significantly in size. # We can limit subgroup size to the largest group in the selected clusters # that contain a subset of the available entity labels within the subtree. # Larger groups could then be constructed by the merge_group operation. entity_trees = [entity for entity in entity_trees if entity.label() in available_labels] entity_labels = {ent.label() for ent in entity_trees} k = min( len(entity_trees), len(subtree) - 1, max( (len(ent_group) for ent_group in entity_groups if entity_labels.issuperset(ent_group)), default=len(entity_trees), ), ) # Recursively explore k-sized combinations of entity trees and select the one with the maximum support, # decreasing k if necessary while k > 1: # Evaluate all k-groups evaluated_groups = ( self._create_and_evaluate_subgroup( subtree, sub_group, min_support=max(group_support + 1, self.min_support), equiv_subtrees=equiv_subtrees, ) for sub_group in combinations(entity_trees, k) if more_itertools.all_unique(ent.label() for ent in sub_group) ) # Select the subgroup with maximum support valid_subgroups = filter(None, evaluated_groups) max_subtree, max_support = max(valid_subgroups, key=lambda x: x[1], default=(None, None)) # If no suitable k-group found; decrease k and try again if max_subtree is None: k -= 1 continue # Successfully found a valid k-group, mark the tree as simplified simplified = True self._log_to_mlflow( { 'num_instance': max_support, 'labels': [str(ent.label()) for ent in max_subtree], } ) # Replace subtree with the newly constructed one if parent: subtree = parent[parent_idx] = max_subtree else: subtree.clear() subtree.extend(deepcopy(max_subtree[:])) # Reset entity trees and k entity_trees = [child for child in subtree if has_type(child, NodeType.ENT)] k = min(len(entity_trees), k) return tree, simplified
[docs] class MergeGroupsOperation(Operation): """ Attempt to add `ENT` to existing `GROUP` within a tree. It tries to form a new `GROUP` nodes that does not reduce the support of the given group. """ def _merge_groups_inner( self, subtree: Tree, combined_groups: tuple[Tree, ...], equiv_subtrees: TREE_CLUSTER, ) -> tuple[Tree, int] | None: """ Attempt to merge specified `GROUP` and `ENT` nodes within a subtree. It tries to replace them with a single `GROUP` node, given that it meets minimum support and subtree similarity requirements. :param subtree: The subtree to be modified during the merging process. :param combined_groups: A tuple containing subtrees or groups of subtrees to combine. :param equiv_subtrees: The cluster of equivalent subtrees in the forest. :return: A tuple containing the modified subtree and its support count if the modified subtree meets the minimum support threshold; otherwise, `None`. """ sub_group = [] max_sub_group_support = 1 group_count = 0 for group_entity in combined_groups: # Directly append single `ENT` nodes if has_type(group_entity, NodeType.ENT): sub_group.append(group_entity) # Process `GROUP` nodes, treating single-element groups as entities elif has_type(group_entity, NodeType.GROUP): if len(group_entity) == 1: # Group of sizes 1 are treated as entities sub_group.append(group_entity[0]) else: group_count += 1 group_support = len(self.get_equiv_of(group_entity, equiv_subtrees=equiv_subtrees)) max_sub_group_support = max(max_sub_group_support, group_support) sub_group.extend(group_entity.entities()) # Skip if invalid conditions are met: duplicates entities, empty groups, or no valid subgroups if not sub_group or group_count == 0 or not more_itertools.all_unique(ent.label() for ent in sub_group): return None # Copy the tree new_tree = deepcopy(subtree.root()) new_subtree = new_tree[subtree.treeposition()] # Create new `GROUP` node with selected entities group_tree = Tree(NodeLabel(NodeType.GROUP), children=[deepcopy(ent) for ent in sub_group]) # Removed used entity trees from the subtree for group_ent in sorted(combined_groups, key=lambda x: x.parent_index(), reverse=True): new_subtree.pop(group_ent.parent_index(), recursive=False) # Insert the newly created `GROUP` node at the appropriate position group_position = min(group_entity.parent_index() for group_entity in combined_groups) new_subtree.insert(group_position, group_tree) # Compute support for the newly formed group equiv = self.get_equiv_of(new_subtree[group_position], equiv_subtrees=equiv_subtrees) support = len(equiv) # Return the modified subtree and its support counts if support exceeds the threshold if support >= max_sub_group_support: new_subtree[group_position].set_label(equiv[0].label()) return new_subtree, support return None
[docs] def apply(self, tree: Tree, *, equiv_subtrees: TREE_CLUSTER) -> tuple[Tree, bool]: simplified = False for subtree in sorted( tree.subtrees(lambda x: not has_type(x) and any(has_type(y, NodeType.GROUP) for y in x)), key=lambda x: x.height(), ): # Identify `GROUP` and `ENT` nodes in the subtree that could be merged group_ent_trees = tuple(filter(lambda x: has_type(x, {NodeType.GROUP, NodeType.ENT}), subtree)) parent = subtree.parent() parent_idx = subtree.parent_index() k = len({x.label() for x in group_ent_trees}) # Recursively creating k-sized groups, decreasing k if necessary while k > 1: # Get k-subgroup with maximum support k_groups = combinations(group_ent_trees, k) k_groups_support = ( self._merge_groups_inner(subtree, combined_groups, equiv_subtrees) for combined_groups in k_groups ) # Identify the best possible merge based on maximum support max_subtree: Tree | None max_subtree, max_support = max( filter(None, k_groups_support), key=lambda x: x[1], default=(None, 0), ) # If no valid k-sized group was found, reduce k and continue if max_subtree is None: k -= 1 continue # A group is found, we need to add the new subgroup tree simplified = True self._log_to_mlflow( { 'num_instance': max_support, 'labels': [str(ent.label()) for ent in max_subtree], } ) # Replace subtree with the newly constructed one if parent: subtree = parent[parent_idx] = deepcopy(max_subtree) else: subtree.clear() subtree.extend(deepcopy(max_subtree[:])) # Update entity trees and reset k for remaining entities group_ent_trees = tuple(filter(lambda child: has_type(child, {NodeType.GROUP, NodeType.ENT}), subtree)) k = min(len(group_ent_trees), k) return tree, simplified