Source code for architxt.inspector

from collections import Counter
from collections.abc import Generator, Iterable

from architxt.tree import Tree

__all__ = ['ForestInspector']


[docs] class ForestInspector: def __init__(self) -> None: self.total_trees = 0 self.total_entities = 0 self.total_nodes = 0 self.sum_children = 0 self.max_children = 0 self.sum_height = 0 self.max_height = 0 self.sum_size = 0 self.max_size = 0 self.entity_count = Counter[str]() self.largest_tree: Tree | None = None @property def avg_height(self) -> float: """Get the average height of all trees.""" return self.sum_height / self.total_trees if self.total_trees else 0 @property def avg_size(self) -> float: """Get the average size (number of leaves) of all trees.""" return self.sum_size / self.total_trees if self.total_trees else 0 @property def avg_branching(self) -> float: """Get the average branching factor (children per node) across all trees.""" return self.sum_children / self.total_nodes if self.total_nodes else 0 def __call__(self, forest: Iterable[Tree]) -> Generator[Tree, None, None]: for tree in forest: self.total_trees += 1 # Count and track heights height = tree.height self.sum_height += height if height > self.max_height: self.max_height = height self.largest_tree = tree # Count and track sizes size = len(tree.leaves()) self.sum_size += size if size > self.max_size: self.max_size = size # Count entities entities = [ent.label for ent in tree.entities()] self.total_entities += len(entities) self.entity_count.update(entities) # Calculate branching factor for node in tree.subtrees(): nb_children = len(node) self.total_nodes += 1 self.sum_children += nb_children if nb_children > self.max_children: self.max_children = nb_children yield tree