import contextlib
from collections import Counter
from collections.abc import Callable, Collection, Iterable
from copy import deepcopy
from enum import Enum
from functools import cache
from typing import Any, TypeAlias, overload
import pandas as pd
from nltk.grammar import Nonterminal, Production
from nltk.tree import ParentedTree
__all__ = [
'TREE_POS',
'Forest',
'NodeLabel',
'NodeType',
'Tree',
'has_type',
'reduce',
'reduce_all',
]
TREE_POS = tuple[int, ...]
[docs]
class NodeType(str, Enum):
ENT = 'ENT'
GROUP = 'GROUP'
REL = 'REL'
COLL = 'COLL'
[docs]
class NodeLabel(str):
type: NodeType
name: str
data: dict[str, Any] | None
__slots__ = ('data', 'name', 'type')
def __new__(cls, label_type: NodeType, label: str = '', _data: dict[str, Any] | None = None) -> 'NodeLabel':
string_value = f'{label_type.value}::{label}' if label else label_type.value
return super().__new__(cls, string_value) # type: ignore
def __init__(self, label_type: NodeType, label: str = '', data: dict[str, Any] | None = None) -> None:
self.name = label
self.type = label_type
self.data = data
def __reduce__(self) -> tuple[Callable[..., 'NodeLabel'], tuple[Any, ...]]:
return NodeLabel, (self.type, self.name, self.data)
[docs]
class Tree(ParentedTree):
slots = ('_parent', '_label')
_parent: 'Tree | None'
_label: NodeLabel | str
def __init__(self, node: NodeLabel | str, children: Iterable['Tree | str'] | None = None) -> None:
super().__init__(node, children)
if isinstance(node, NodeLabel):
return
if '::' in self._label:
node_type, _, name = self._label.partition('::')
with contextlib.suppress(ValueError):
self._label = NodeLabel(NodeType(node_type), name)
else:
with contextlib.suppress(ValueError):
self._label = NodeLabel(NodeType(self._label))
def __hash__(self) -> int:
return id(self)
def __repr__(self) -> str:
return f'{type(self)}(len={len(self)})'
def __reduce__(self) -> tuple[Callable[..., 'Tree'], tuple[Any, ...]]:
return type(self), (self._label, tuple(self))
[docs]
@cache
def height(self) -> int:
"""
Get the height of the tree.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.height()
4
>>> t[0].height()
3
>>> t[0, 0].height()
2
"""
return super().height()
[docs]
@cache
def depth(self) -> int:
"""
Get the depth of the tree.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.depth()
1
>>> t[0].depth()
2
>>> t[0, 0].depth()
3
"""
return len(self.treeposition()) + 1
[docs]
@cache
def groups(self) -> set[str]:
"""
Get the set of group names present within the tree.
:return: A set of unique group names within the tree.
>>> t = Tree.fromstring('(S (GROUP::A x) (GROUP::B y) (X (GROUP::C z)))')
>>> sorted(t.groups())
['A', 'B', 'C']
>>> sorted(t[0].groups())
['A']
"""
result = set()
if has_type(self, NodeType.GROUP):
result.add(self.label().name)
for child in self:
if isinstance(child, Tree):
result.update(child.groups())
return result
[docs]
@cache
def group_instances(self, group_name: str) -> pd.DataFrame:
"""
Get a DataFrame containing all instances of a specified group within the tree.
Each row in the DataFrame represents an instance of the group, and each column represents an entity in that
group, with the value being a concatenated string of that entity's leaves.
:param group_name: The name of the group to search for.
:return: A pandas DataFrame containing instances of the specified group.
>>> t = Tree.fromstring('(S (GROUP::A (ENT::person Alice) (ENT::fruit apple)) '
... '(GROUP::A (ENT::person Bob) (ENT::fruit banana)) '
... '(GROUP::B (ENT::person Charlie) (ENT::animal dog)))')
>>> t.group_instances("A")
person fruit
0 Alice apple
1 Bob banana
>>> t.group_instances("B")
person animal
0 Charlie dog
>>> t.group_instances("C")
Empty DataFrame
Columns: []
Index: []
>>> t[0].group_instances("A")
person fruit
0 Alice apple
"""
dataframes = [child.group_instances(group_name) for child in self if isinstance(child, Tree)]
if has_type(self, NodeType.GROUP) and self.label().name == group_name:
root_dataframe = pd.DataFrame(
[
{
sub_child.label().name: ' '.join(sub_child.leaves())
for sub_child in self
if has_type(sub_child, NodeType.ENT)
}
]
)
dataframes.append(root_dataframe)
if not dataframes:
return pd.DataFrame()
return pd.concat(dataframes, ignore_index=True).drop_duplicates()
[docs]
@cache
def entities(self) -> tuple['Tree', ...]:
"""
Get a tuple of subtrees that are entities.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> list(t.entities()) == [t[0, 0], t[0, 1], t[1, 0], t[1, 1]]
True
>>> del t[0]
>>> list(t.entities()) == [t[0, 0], t[0, 1]]
True
>>> list(t[0, 0].entities()) == [t[0, 0]]
True
"""
result = []
if has_type(self, NodeType.ENT):
result.append(self)
for child in self:
if isinstance(child, Tree):
result.extend(child.entities())
return tuple(result)
[docs]
@cache
def entity_labels(self) -> set[str]:
"""
Get the set of entity labels present in the tree.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> sorted(t.entity_labels())
['animal', 'fruit', 'person']
>>> sorted(t[0].entity_labels())
['fruit', 'person']
>>> del t[0]
>>> sorted(t.entity_labels())
['animal', 'person']
"""
return {node.label().name for node in self.entities()}
[docs]
@cache
def entity_label_count(self) -> Counter[NodeLabel]:
"""
Return a Counter object that counts the labels of entity subtrees.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.entity_label_count()
Counter({'person': 2, 'fruit': 1, 'animal': 1})
"""
return Counter(ent.label().name for ent in self.entities())
[docs]
@cache
def has_duplicate_entity(self) -> bool:
"""
Check if there are duplicate entity labels.
>>> from architxt.tree import Tree
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.has_duplicate_entity()
True
>>> t[0].has_duplicate_entity()
False
"""
return any(v > 1 for v in self.entity_label_count().values())
[docs]
@cache
def has_entity_child(self) -> bool:
"""
Check if there is at least one entity as direct children.
>>> from architxt.tree import Tree
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.has_entity_child()
False
>>> t[0].has_entity_child()
True
"""
return any(has_type(child, NodeType.ENT) for child in self)
[docs]
def has_unlabelled_nodes(self) -> bool:
return any(not has_type(subtree) for subtree in self)
[docs]
def merge(self, tree: 'Tree') -> 'Tree':
"""
Merge two trees into one.
The root of both trees becomes one while maintaining the level of each subtree.
"""
return type(self)('SENT', deepcopy([*self, *tree]))
def __reset_cache(self) -> None:
"""Reset cached properties."""
self.height.cache_clear()
self.depth.cache_clear()
self.groups.cache_clear()
self.group_instances.cache_clear()
self.entities.cache_clear()
self.entity_labels.cache_clear()
self.entity_label_count.cache_clear()
self.has_duplicate_entity.cache_clear()
self.has_entity_child.cache_clear()
# Remove cache recursively
if parent := self.parent():
parent.__reset_cache()
@overload
def __setitem__(self, pos: TREE_POS | int, subtree: 'Tree | str') -> None: ...
@overload
def __setitem__(self, pos: slice, subtree: 'list[Tree | str]') -> None: ...
def __setitem__(self, pos: TREE_POS | int | slice, subtree: 'list[Tree | str] | Tree | str') -> None:
super().__setitem__(pos, subtree)
self.__reset_cache()
def __delitem__(self, pos: TREE_POS | int | slice) -> None:
super().__delitem__(pos)
self.__reset_cache()
[docs]
def set_label(self, label: NodeLabel | str) -> None:
super().set_label(label)
# Do not need to reset our own cache as it does not change our structure
if parent := self.parent():
parent.__reset_cache()
[docs]
def append(self, child: 'Tree | str') -> None:
super().append(child)
self.__reset_cache()
[docs]
def extend(self, children: 'Iterable[Tree | str]') -> None:
super().extend(children)
self.__reset_cache()
[docs]
def remove(self, child: 'Tree | str', *, recursive: bool = True) -> None:
super().remove(child)
self.__reset_cache()
if recursive and len(self) == 0 and (parent := self._parent) is not None:
parent.remove(self)
[docs]
def insert(self, pos: int, child: 'Tree | str') -> None:
super().insert(pos, child)
self.__reset_cache()
[docs]
def pop(self, pos: int = -1, *, recursive: bool = True) -> 'Tree | str':
"""
Delete an element from the tree at the specified position `pos`.
If the parent tree becomes empty after the deletion, recursively deletes the parent node.
:param pos: The position (index) of the element to delete in the tree.
:param recursive: If an empty tree should be removed from the parent.
:return: The element at the position. The function modifies the tree in place.
>>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))")
>>> print(t[(1, 1)].pformat(margin=255))
(NP (NNS apples))
>>> subtree = t[1, 1].pop(0)
>>> print(t.pformat(margin=255))
(S (NP Alice) (VP (VB like)))
>>> subtree = t.pop(0)
>>> print(t.pformat(margin=255))
(S (VP (VB like)))
>>> subtree = t[0].pop(0, recursive=False)
>>> print(t.pformat(margin=255))
(S (VP ))
"""
child = super().pop(pos)
self.__reset_cache()
if recursive and len(self) == 0 and (parent := self._parent) is not None:
parent.remove(self)
return child
Forest: TypeAlias = Collection[Tree]
[docs]
def has_type(t: Any, types: set[NodeType | str] | NodeType | str | None = None) -> bool:
"""
Check if the given tree object has the specified type(s).
:param t: The object to check type for (can be a Tree, Production, or NodeLabel).
:param types: The types to check for (can be a set of strings, a string, or None).
:return: True if the object has the specified type(s), False otherwise.
>>> tree = Tree.fromstring('(S (ENT Alice) (REL Bob))')
>>> has_type(tree, NodeType.ENT) # Check if the tree is of type 'S'
False
>>> has_type(tree[0], NodeType.ENT)
True
>>> has_type(tree[0], 'ENT')
True
>>> has_type(tree[1], NodeType.ENT)
False
>>> has_type(tree[1], {NodeType.ENT, NodeType.REL})
True
"""
assert t is not None
# Normalize type input
if types is None:
types = set(NodeType)
elif not isinstance(types, set):
types = {types}
types = {t.value if isinstance(t, NodeType) else str(t) for t in types}
# Check for the type in the respective object
if isinstance(t, NodeLabel):
label = t
elif isinstance(t, Tree):
label = t.label()
elif isinstance(t, Production):
label = t.lhs().symbol()
elif isinstance(t, Nonterminal):
label = t.symbol()
else:
return False
return isinstance(label, NodeLabel) and label.type.value in types
[docs]
def reduce(tree: Tree, pos: int, types: set[str | NodeType] | None = None) -> bool:
"""
Reduces a subtree within a tree at the specified position `pos`.
The reduction occurs only if the subtree at `pos` has exactly one child,
or if it does not match a specific set of node types.
If the subtree can be reduced, its children are lifted into the parent node at `pos`.
:param tree: The tree in which the reduction will take place.
:param pos: The index of the subtree to attempt to reduce.
:param types: A set of `NodeType` or string labels that should be kept, or `None` to reduce based on length.
:return: `True` if the subtree was reduced, `False` otherwise.
>>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))")
>>> reduce(t[1], 1)
True
>>> print(t.pformat(margin=255))
(S (NP Alice) (VP (VB like) (NNS apples)))
>>> reduce(t, 0)
True
>>> print(t.pformat(margin=255))
(S Alice (VP (VB like) (NNS apples)))
"""
assert tree is not None
# Check if the tree at the specified position can be reduced
if (
not isinstance(tree[pos], Tree) # Ensure the subtree at `pos` is a Tree
or (types and has_type(tree[pos], types)) # Check if it matches the specified types
or (len(tree[pos]) > 1) # If no types, only reduce if it has one child
):
return False
# Replace the original subtree by its children into the parent at `pos`
tree[pos : pos + 1] = [deepcopy(child) for child in tree[pos]]
return True
[docs]
def reduce_all(tree: Tree, skip_types: set[str | NodeType] | None = None) -> None:
"""
Recursively attempts to reduce all eligible subtrees in a tree.
The reduction process continues until no further reductions are possible.
Subtrees can be skipped if their types are listed in `skip_types`.
:param tree: The tree in which reductions will be attempted.
:param skip_types: A set of `NodeType` or string labels that should be kept, or `None` to reduce based on length.
:return: None. The tree is modified in place.
>>> t = Tree.fromstring("(S (X (Y (Z (NP Alice)))) (VP (VB likes) (NP (NNS apples))))")
>>> reduce_all(t)
>>> print(t.pformat(margin=255))
(S Alice (VP likes apples))
"""
assert tree is not None
reduced = True
while reduced:
reduced = False
for subtree in tree.subtrees(lambda st: isinstance(st, Tree) and st.parent() is not None):
if reduce(subtree.parent(), subtree.parent_index(), types=skip_types):
reduced = True
break