import contextlib
import re
import uuid
import weakref
from collections import Counter
from collections.abc import (
Callable,
Collection,
Generator,
Hashable,
Iterable,
MutableMapping,
Sequence,
)
from copy import copy, deepcopy
from enum import Enum
from functools import partial, total_ordering
from typing import TYPE_CHECKING, Any, Literal, TextIO, TypeAlias, TypeGuard, overload
from urllib.parse import quote, unquote
import pandas as pd
from cachetools import cachedmethod, keys
from nltk import slice_bounds
from nltk.grammar import Nonterminal, Production
from persistent.list import PersistentList
from persistent.mapping import PersistentMapping
from architxt.utils import ExceptionGroup
__all__ = [
'Forest',
'NodeLabel',
'NodeType',
'Tree',
'TreeOID',
'TreePosition',
'has_type',
]
TreePosition: TypeAlias = tuple[int, ...]
TreeOID: TypeAlias = uuid.UUID
TREE_PARSER_RE = re.compile(r"\(\s*[^\s()]+|[()]|[^\s()]+")
[docs]
class NodeType(str, Enum):
ENT = 'ENT'
GROUP = 'GROUP'
REL = 'REL'
COLL = 'COLL'
[docs]
class NodeLabel(str):
type: NodeType
name: str
__slots__ = ('name', 'type')
def __new__(cls, label_type: NodeType, label: str = '') -> 'NodeLabel':
string_value = f'{label_type.value}::{label}' if label else label_type.value
return super().__new__(cls, string_value)
def __init__(self, label_type: NodeType, label: str = '') -> None:
self.name = label
self.type = label_type
def __reduce__(self) -> tuple[Callable[..., 'NodeLabel'], tuple[Any, ...]]:
return NodeLabel, (self.type, self.name)
[docs]
@classmethod
def fromstring(cls, label: 'NodeLabel | str') -> 'NodeLabel | str':
if isinstance(label, NodeLabel):
return label
if '::' in label:
node_type, _, name = label.partition('::')
with contextlib.suppress(ValueError):
label = NodeLabel(NodeType(node_type), name)
else:
with contextlib.suppress(ValueError):
label = NodeLabel(NodeType(label))
return label
[docs]
@total_ordering
class Tree(PersistentList['_SubTree | str']):
_label: NodeLabel | str
_metadata: MutableMapping[str, Any]
_oid: TreeOID
_v_parent: weakref.ReferenceType['Tree'] | None
_v_cache: MutableMapping[Hashable, Any]
__slots__ = ('_label', '_metadata', '_oid', '_v_cache', '_v_parent')
def __init__(
self,
label: NodeLabel | str,
children: Iterable['Tree | str'] | None = None,
metadata: MutableMapping[str, Any] | None = None,
oid: TreeOID | None = None,
) -> None:
super().__init__(children)
self._label = NodeLabel.fromstring(label)
self._metadata = PersistentMapping(metadata or {})
self._oid = oid or uuid.uuid4()
self._v_parent = None
self._v_cache = {}
self._check_children(self)
for child in self:
if isinstance(child, Tree):
child._v_parent = weakref.ref(self)
def _check_children(self, children: 'Iterable[Tree | str]') -> None:
errors = []
for index, child in enumerate(children):
if not isinstance(child, Tree):
continue
if self in child.subtrees():
msg = f'Child at index {index} creates a cyclic reference: a tree cannot contain itself.'
errors.append(ValueError(msg))
if is_sub_tree(child) and child.parent is not self:
msg = f'Child at index {index} is already attached to another parent: {child.parent}.'
errors.append(ValueError(msg))
if errors:
msg = 'Invalid tree children detected'
raise ExceptionGroup(msg, errors)
def _invalidate_cache(self) -> None:
self._v_cache.clear()
if parent := self.parent:
parent._invalidate_cache()
def __eq__(self, other: object) -> bool:
"""
Compare two subtree objects for equality.
Two subtrees are considered equal if they have identical labels and identical children (compared recursively).
The parent reference is not considered in the comparison.
:param other: The other object to compare against.
:return: True if the two subtrees are identical in terms of label and children, False otherwise.
"""
return isinstance(other, Tree) and self.label == other.label and super().__eq__(other)
def __hash__(self) -> int:
return self._oid.int
def __repr__(self) -> str:
return f'{type(self)}(len={len(self)})'
def __str__(self) -> str:
return self.pformat()
def __setstate__(self, state: object) -> None:
super().__setstate__(state)
self._v_parent = None
self._v_cache = {}
for child in self:
if isinstance(child, Tree):
child._v_parent = weakref.ref(self)
def __copy__(self) -> 'Tree':
"""Support for the copy.copy() interface."""
return self.copy()
def __deepcopy__(self, _memo: dict[int, Any]) -> 'Tree':
"""Support for the copy.deepcopy() interface."""
return self.copy()
def __lt__(self, other: Any) -> bool:
if isinstance(other, Tree):
return (len(self), self.label) < (len(other), other.label)
return str(self.label) < str(other)
@property
def oid(self) -> TreeOID:
return self._oid
@property
def metadata(self) -> MutableMapping[str, Any]:
return self._metadata
@property
def parent(self) -> 'Tree | None':
"""
The parent of this tree, or None if it has no parent.
>>> t = Tree.fromstring('(S (A xxx) (A xxx))')
>>> t.parent
>>> t[0].parent is t
True
>>> t[1].parent is t
True
"""
return self._v_parent() if self._v_parent else None
@property
def parent_index(self) -> int | None:
"""
The index of this tree in its parent.
I.e., ``tree.parent[tree.parent_index] is tree``.
Note that ``tree.parent_index`` is not necessarily equal to ``tree.parent.index(tree)``,
since the ``index()`` method returns the first child that is equal to its argument.
>>> t = Tree.fromstring('(S (A xxx) (A xxx))')
>>> t.parent_index
>>> t[0].parent_index
0
>>> t[1].parent_index
1
"""
if self.parent is None:
return None
for i, child in enumerate(self.parent):
if child is self:
return i
msg = "The tree is not in it's parent"
raise ValueError(msg)
@property
def label(self) -> NodeLabel | str:
"""The label of this tree."""
return self._label
@label.setter
def label(self, label: NodeLabel | str) -> None:
self._label = label
self._invalidate_cache()
@property
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='root'))
def root(self) -> 'Tree':
"""
The root of this tree.
I.e., the unique ancestor of this tree whose parent is None.
If ``tree.parent()`` is None, then ``tree`` is its own root.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t[0, 0].root is t
True
"""
node = self
while (parent := node.parent) is not None:
node = parent
return node
@property
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='height'))
def height(self) -> int:
"""
Get the height of the tree.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.height
4
>>> t[0].height
3
>>> t[0, 0].height
2
"""
return 1 + max((child.height if isinstance(child, Tree) else 1 for child in self), default=0)
@property
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='depth'))
def depth(self) -> int:
"""
Get the depth of the tree.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.depth
1
>>> t[0].depth
2
>>> t[0, 0].depth
3
"""
return len(self.position) + 1
@property
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='position'))
def position(self) -> TreePosition:
"""
The tree position of this tree, relative to the root of the tree.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.position
()
>>> t[1, 0].position
(1, 0)
"""
if not is_sub_tree(self):
return ()
return *self.parent.position, self.parent_index
[docs]
def positions(
self, *, order: Literal['preorder', 'postorder', 'bothorder', 'leaves'] = 'preorder'
) -> Generator[TreePosition, None, None]:
"""
Get all the positions in the tree.
>>> t = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
>>> list(t.positions()) # doctest: +ELLIPSIS
[(), (0,), (0, 0), (0, 0, 0), (0, 1), (0, 1, 0), (1,), (1, 0), (1, 0, 0), ...]
>>> for pos in t.positions(order='leaves'):
... t[pos] = t[pos][::-1].upper()
>>> print(t)
(S (NP (D EHT) (N GOD)) (VP (V DESAHC) (NP (D EHT) (N TAC))))
:param order: One of: ``preorder``, ``postorder``, ``bothorder``, ``leaves``.
:yield: All positions in the tree in the given order
"""
if order in ('preorder', 'bothorder'):
yield ()
for i, child in enumerate(self):
if isinstance(child, Tree):
yield from ((i, *pos) for pos in child.positions(order=order))
else:
yield (i,)
if order in ('postorder', 'bothorder'):
yield ()
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='leaves'))
def leaves(self) -> list[str]:
"""
Return the leaves of the tree.
>>> t = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
>>> list(t.leaves())
['the', 'dog', 'chased', 'the', 'cat']
:yield: A list containing this tree's leaves.
The order reflects the order of the leaves in the tree's hierarchical structure.
"""
leaves = []
for child in self:
if isinstance(child, Tree):
leaves.extend(child.leaves())
else:
leaves.append(child)
return leaves
[docs]
def subtrees(self, filter_fn: Callable[['Tree'], bool] | None = None) -> Generator['Tree', None, None]:
"""
Get all the subtrees of this tree, optionally restricted to trees matching the filter function.
:param filter_fn: The function to filter all local trees
>>> t = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
>>> for s in t.subtrees(lambda t: t.height == 2):
... print(s)
(D the)
(N dog)
(V chased)
(D the)
(N cat)
"""
if not filter_fn or filter_fn(self):
yield self
for child in self:
if isinstance(child, Tree):
yield from child.subtrees(filter_fn)
[docs]
def productions(self) -> list[Production]:
"""
Generate the productions that correspond to the non-terminal nodes of the tree.
For each subtree of the form (P: C1 C2 ... Cn) this produces a production of the form P -> C1 C2 ... Cn.
>>> t = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
>>> t.productions() # doctest: +NORMALIZE_WHITESPACE
[S -> NP VP, NP -> D N, D -> 'the', N -> 'dog', VP -> V NP, V -> 'chased',
NP -> D N, D -> 'the', N -> 'cat']
"""
child_names = [Nonterminal(child.label) if isinstance(child, Tree) else child for child in self]
productions = [Production(Nonterminal(self.label), child_names)]
for child in self:
if isinstance(child, Tree):
productions.extend(child.productions())
return productions
[docs]
def leaf_position(self, index: int) -> TreePosition:
"""
Return the tree position of the `index`-th leaf in this tree.
The tree position is a tuple of indices that corresponds to the
location of the `index`-th leaf in the tree structure.
If `tp = self.leaf_position(i)`, then `self[tp]` should be
the same as `self.leaves()[i]`.
:param index: The index of the leaf for which to find the tree position.
:returns: A tuple representing the tree position of the `index`-th leaf.
:raise IndexError: If `index` is negative or if there are fewer than `index + 1` leaves in the tree.
>>> t = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
>>> t.leaf_position(0)
(0, 0, 0)
>>> t.leaf_position(4)
(1, 1, 1, 0)
"""
if index < 0:
msg = 'index must be non-negative'
raise IndexError(msg)
stack: list[tuple[Tree | str, tuple[int, ...]]] = [(self, ())]
count = 0
while stack:
node, pos = stack.pop()
if isinstance(node, Tree):
# Add children in reverse to simulate left-to-right traversal
stack.extend((node[i], (*pos, i)) for i in reversed(range(len(node))))
else:
if count == index:
return pos
count += 1
msg = "index must be less than or equal to len(self)"
raise IndexError(msg)
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='groups'))
def groups(self) -> set[str]:
"""
Get the set of group names present within the tree.
:return: A set of unique group names within the tree.
>>> t = Tree.fromstring('(S (GROUP::A x) (GROUP::B y) (X (GROUP::C z)))')
>>> sorted(t.groups())
['A', 'B', 'C']
>>> sorted(t[0].groups())
['A']
"""
result = set()
if isinstance(self.label, NodeLabel) and has_type(self, NodeType.GROUP):
result.add(self.label.name)
for child in self:
if isinstance(child, Tree):
result.update(child.groups())
return result
[docs]
def group_instances(self, group_name: str) -> pd.DataFrame:
"""
Get a DataFrame containing all instances of a specified group within the tree.
Each row in the DataFrame represents an instance of the group, and each column represents an entity in that
group, with the value being a concatenated string of that entity's leaves.
:param group_name: The name of the group to search for.
:return: A pandas DataFrame containing instances of the specified group.
>>> t = Tree.fromstring('(S (GROUP::A (ENT::person Alice) (ENT::fruit apple)) '
... '(GROUP::A (ENT::person Bob) (ENT::fruit banana)) '
... '(GROUP::B (ENT::person Charlie) (ENT::animal dog)))')
>>> t.group_instances("A")
person fruit
0 Alice apple
1 Bob banana
>>> t.group_instances("B")
person animal
0 Charlie dog
>>> t.group_instances("C")
Empty DataFrame
Columns: []
Index: []
>>> t[0].group_instances("A")
person fruit
0 Alice apple
"""
dataframes = [child.group_instances(group_name) for child in self if isinstance(child, Tree)]
if isinstance(self.label, NodeLabel) and has_type(self, NodeType.GROUP) and self.label.name == group_name:
root_dataframe = pd.DataFrame(
[
{
sub_child.label.name: ' '.join(sub_child.leaves())
for sub_child in self
if has_type(sub_child, NodeType.ENT)
}
]
)
dataframes.append(root_dataframe)
if not dataframes:
return pd.DataFrame()
return pd.concat(dataframes, ignore_index=True).drop_duplicates()
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='entities'))
def entities(self) -> tuple['_TypedTree', ...]:
"""
Get a tuple of subtrees that are entities.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> list(t.entities()) == [t[0, 0], t[0, 1], t[1, 0], t[1, 1]]
True
>>> del t[0]
>>> list(t.entities()) == [t[0, 0], t[0, 1]]
True
>>> list(t[0, 0].entities()) == [t[0, 0]]
True
"""
return tuple(ent for ent in self.subtrees() if has_type(ent, NodeType.ENT))
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='entity_labels'))
def entity_labels(self) -> set[str]:
"""
Get the set of entity labels present in the tree.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> sorted(t.entity_labels())
['animal', 'fruit', 'person']
>>> sorted(t[0].entity_labels())
['fruit', 'person']
>>> del t[0]
>>> sorted(t.entity_labels())
['animal', 'person']
"""
return {ent.label.name for ent in self.entities()}
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='entity_label_count'))
def entity_label_count(self) -> Counter[str]:
"""
Return a Counter object that counts the labels of entity subtrees.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.entity_label_count()
Counter({'person': 2, 'fruit': 1, 'animal': 1})
"""
return Counter(ent.label.name for ent in self.entities())
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='has_duplicate_entity'))
def has_duplicate_entity(self) -> bool:
"""
Check if there are duplicate entity labels.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.has_duplicate_entity()
True
>>> t[0].has_duplicate_entity()
False
"""
return any(v > 1 for v in self.entity_label_count().values())
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='has_entity_child'))
def has_entity_child(self) -> bool:
"""
Check if there is at least one entity as direct children.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> t.has_entity_child()
False
>>> t[0].has_entity_child()
True
"""
return any(has_type(child, NodeType.ENT) for child in self)
[docs]
@cachedmethod(lambda self: self._v_cache, key=partial(keys.methodkey, method='has_unlabelled_nodes'))
def has_unlabelled_nodes(self) -> bool:
"""
Check if any child has a non-typed label.
:return: A boolean indicating if the node contains any non-typed label.
>>> t = Tree.fromstring('(S (X xxx) (Y yyy) (Z zzz))')
>>> t.has_unlabelled_nodes()
True
>>> t = Tree.fromstring('(S (ENT::X xxx) (REL::Y yyy) (COLL::Z zzz))')
>>> t.has_unlabelled_nodes()
False
"""
return any(not has_type(subtree) for subtree in self)
[docs]
def merge(self, tree: 'Tree') -> 'Tree':
"""
Merge two trees into one.
The root of both trees becomes one while maintaining the level of each subtree.
"""
children: list[Tree] = []
if self.label == 'ROOT':
children.extend(self)
else:
children.append(self)
if tree.label == 'ROOT':
children.extend(tree)
else:
children.append(tree)
return type(self)('ROOT', deepcopy(children))
[docs]
def reduce(self, skip_types: set[str | NodeType] | None = None) -> bool:
"""
Attempt to reduce this subtree, lifting the children into the parent node, replacing the subtree.
Reduction happens if:
- The tree has exactly one child, AND
- Its label is not in `types` (if `types` is given)
:param skip_types: A set of node types that should be kept, or `None` to reduce reduces all single-child nodes.
:return: `True` if the subtree was reduced, `False` otherwise.
>>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))")
>>> t[1, 1].reduce()
True
>>> print(t)
(S (NP Alice) (VP (VB like) (NNS apples)))
>>> t[0].reduce()
True
>>> print(t)
(S Alice (VP (VB like) (NNS apples)))
"""
if not is_sub_tree(self) or len(self) != 1 or (skip_types and has_type(self, skip_types)):
return False
# Replace the original subtree by its children into the parent at `pos`
parent_index = self.parent_index
self.parent[parent_index : parent_index + 1] = [
child.detach() if isinstance(child, Tree) else child for child in self
]
return True
[docs]
def reduce_all(self, skip_types: set[str | NodeType] | None = None) -> None:
"""
Recursively reduces all reducible subtrees in the tree.
The reduction process continues until no further reductions are possible.
Subtrees can be skipped if their types are listed in `skip_types`.
:param skip_types: A set of node types that should be kept, or `None` to reduce reduces all single-child nodes.
>>> t = Tree.fromstring("(S (X (Y (Z (NP Alice)))) (VP (VB likes) (NP (NNS apples))))")
>>> t.reduce_all()
>>> print(t)
(S Alice (VP likes apples))
"""
reduced = True
while reduced:
reduced = False
for subtree in self.subtrees():
if subtree.reduce(skip_types=skip_types):
reduced = True
break
@overload
def __getitem__(self, pos: TreePosition | int) -> '_SubTree | str': ...
@overload
def __getitem__(self, pos: slice) -> 'list[_SubTree | str]': ...
def __getitem__(self, pos: TreePosition | int | slice) -> '_SubTree | str | list[_SubTree | str]':
"""
Retrieve a child or subtree using an index, a slice, or a tree position.
>>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))')
>>> print(t[0])
(X (ENT::person Alice) (ENT::fruit apple))
>>> print(t[0, 1])
(ENT::fruit apple)
>>> print(t[1:][0])
(Y (ENT::person Bob) (ENT::animal rabbit))
"""
if isinstance(pos, int | slice):
# We access `data` directly instead of using `super()` because `UserList` casts slice outputs
# to the parent class, which would return a Tree instead of a plain list.
return self.data[pos]
if isinstance(pos, tuple):
if len(pos) == 0:
return self
if len(pos) == 1:
return self[pos[0]]
return self[pos[0]][pos[1:]]
msg = f'{type(self).__name__} indices must be integers, not {type(pos).__name__}'
raise TypeError(msg)
@overload
def __setitem__(self, pos: TreePosition | int, subtree: 'Tree | str') -> None: ...
@overload
def __setitem__(self, pos: slice, subtree: 'Iterable[Tree | str]') -> None: ...
def __setitem__(self, pos: TreePosition | int | slice, subtree: 'Tree | str | Iterable[Tree | str]') -> None: # noqa: C901
# ptree[start:stop] = subtree
if isinstance(pos, slice):
start, stop, step = slice_bounds(self, pos, allow_step=True)
# make a copy of subtree, in case it's an iterator
if not isinstance(subtree, list | tuple):
subtree = list(subtree)
# Check for any error conditions, so we can avoid ending
# up in an inconsistent state if an error does occur.
self._check_children(subtree)
# clear the child pointers of all parents we're removing
for i in range(start, stop, step):
if isinstance((child := self[i]), Tree):
child._v_parent = None
# set the child pointers of the new children. We do this
# after clearing *all* child pointers, in case we're e.g.
# reversing the elements in a tree.
for i, child in enumerate(subtree):
if isinstance(child, Tree):
child._v_parent = weakref.ref(self)
# finally, update the content of the child list itself.
super().__setitem__(pos, subtree)
# ptree[i] = subtree
elif isinstance(pos, int):
if pos < 0:
pos += len(self)
if pos < 0:
msg = 'pos out of range'
raise IndexError(msg)
# if the subtree is not changing, do nothing.
if subtree is self[pos]:
return
# Set the new child's parent pointer.
if isinstance(subtree, Tree):
subtree._v_parent = weakref.ref(self)
# Remove the old child's parent pointer
if isinstance((child := self[pos]), Tree):
child._v_parent = None
# Update our child list.
super().__setitem__(pos, subtree)
elif isinstance(pos, tuple):
# ptree[()] = subtree
if len(pos) == 0:
msg = 'The tree position () may not be assigned to.'
raise IndexError(msg)
# ptree[(i,)] = subtree
if len(pos) == 1:
self[pos[0]] = subtree
# ptree[i1, i2, i3] = subtree
else:
self[pos[0]][pos[1:]] = subtree
else:
msg = f'{type(self).__name__} indices must be integers, not {type(pos).__name__}'
raise TypeError(msg)
self._invalidate_cache()
def __delitem__(self, pos: TreePosition | int | slice) -> None: # noqa: C901
# del ptree[start:stop]
if isinstance(pos, slice):
start, stop, step = slice_bounds(self, pos, allow_step=True)
# Clear all the children pointers.
for i in range(start, stop, step):
if isinstance((child := self[i]), Tree):
child._v_parent = None
# Delete the children from our child list.
super().__delitem__(pos)
# del ptree[i]
elif isinstance(pos, int):
if pos < 0:
pos += len(self)
if pos < 0:
msg = 'pos out of range'
raise IndexError(msg)
# Clear the child's parent pointer.
if isinstance((child := self[pos]), Tree):
child._v_parent = None
# Remove the child from our child list.
super().__delitem__(pos)
elif isinstance(pos, tuple):
# del ptree[()]
if len(pos) == 0:
msg = 'The tree position () may not be deleted.'
raise IndexError(msg)
# del ptree[(i,)]
if len(pos) == 1:
del self[pos[0]]
# del ptree[i1, i2, i3]
else:
del self[pos[0]][pos[1:]]
else:
msg = f'{type(self).__name__} indices must be integers, not {type(pos).__name__}'
raise TypeError(msg)
self._invalidate_cache()
[docs]
def clear(self) -> None:
super().clear()
self._invalidate_cache()
[docs]
def append(self, child: 'Tree | str') -> None:
if isinstance(child, Tree):
self._check_children([child])
child._v_parent = weakref.ref(self)
super().append(child)
self._invalidate_cache()
[docs]
def extend(self, children: 'Iterable[Tree | str]') -> None:
self._check_children(children)
for child in children:
if isinstance(child, Tree):
child._v_parent = weakref.ref(self)
super().extend(children)
self._invalidate_cache()
[docs]
def remove(self, child: 'Tree | str', *, recursive: bool = True) -> None:
super().remove(child)
if isinstance(child, Tree):
child._v_parent = None
if recursive and len(self) == 0 and (parent := self.parent) is not None:
parent.remove(self)
self._invalidate_cache()
[docs]
def insert(self, pos: int, child: 'Tree | str') -> None:
# Set the child's parent and update our child list.
if isinstance(child, Tree):
self._check_children([child])
child._v_parent = weakref.ref(self)
super().insert(pos, child)
self._invalidate_cache()
[docs]
def pop(self, pos: int = -1, *, recursive: bool = True) -> 'Tree | str':
"""
Delete an element from the tree at the specified position `pos`.
If the parent tree becomes empty after the deletion, parent nodes are recursively deleted.
:param pos: The position (index) of the element to delete in the tree.
:param recursive: If an empty tree should be removed from the parent.
:return: The element at the position. The function modifies the tree in place.
>>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))")
>>> print(t[(1, 1)])
(NP (NNS apples))
>>> subtree = t[1, 1].pop(0)
>>> print(t)
(S (NP Alice) (VP (VB like)))
>>> subtree = t.pop(0)
>>> print(t)
(S (VP (VB like)))
>>> subtree = t[0].pop(0, recursive=False)
>>> print(t)
(S (VP ))
"""
child = super().pop(pos)
if isinstance(child, Tree):
child._v_parent = None
if recursive and len(self) == 0 and (parent := self.parent) is not None:
parent.remove(self)
self._invalidate_cache()
return child
[docs]
def detach(self) -> 'Tree':
"""
Detach a subtree from its parent.
:return: The detached tree.
>>> t = Tree.fromstring('(S (A xxx) (B yyy))')
>>> detached = t[0].detach()
>>> print(detached.root)
(A xxx)
>>> print(t)
(S (B yyy))
"""
if is_sub_tree(self):
self.parent.remove(self, recursive=False)
return self
[docs]
def copy(self) -> 'Tree':
"""
Copy an entire tree.
:return: A new copy of the tree.
"""
return Tree(
label=copy(self.label),
children=[child.copy() if isinstance(child, Tree) else child for child in self],
metadata=self.metadata,
oid=self.oid,
)
[docs]
@classmethod
def fromstring(cls, text: str) -> 'Tree':
"""
Read a tree from a LISP-style notation.
Trees are represented as nested brackettings, such as:
(S (NP (NNP John)) (VP (V runs)))
:param text: The string to read
:return: A tree corresponding to the string representation ``text``.
>>> t = Tree.fromstring('(S (X xxx) (Y yyy))')
>>> print(t)
(S (X xxx) (Y yyy))
"""
# Walk through each token, updating a stack of trees.
stack: list[tuple[str | None, list[Any]]] = [(None, [])] # list of (node, children) tuples
label: str | None
for match in TREE_PARSER_RE.finditer(text):
token = match.group()
# Beginning of a tree/subtree
if token.startswith('('):
if len(stack) == 1 and len(stack[0][1]) > 0:
cls._parse_error(text, 'end-of-string', match)
label = unquote(token[1:].lstrip())
stack.append((label, []))
# End of a tree/subtree
elif token == ')':
if len(stack) == 1:
if len(stack[0][1]) == 0:
cls._parse_error(text, '(', match)
else:
cls._parse_error(text, 'end-of-string', match)
label, children = stack.pop()
stack[-1][1].append(cls(label, children))
# Leaf node
else:
if len(stack) == 1:
cls._parse_error(text, '(', match)
stack[-1][1].append(unquote(token))
# check that we got exactly one complete tree.
if len(stack) > 1:
cls._parse_error(text, ')')
elif len(stack[0][1]) == 0:
cls._parse_error(text, '(')
else:
assert stack[0][0] is None
assert len(stack[0][1]) == 1
return stack[0][1][0]
@classmethod
def _parse_error(cls, text: str, expecting: str, match: re.Match[str] | None = None) -> None:
"""
Display a friendly error message when parsing a tree string fails.
:param text: The string we're parsing.
:param expecting: What we expected to see instead.
:param match: Regexp match of the problem token or `None` if end-of-string.
"""
# Construct a basic error message
if match:
pos, token = match.start(), match.group()
else:
pos, token = len(text), 'end-of-string'
msg = f"{cls.__name__}.read(): expected {expecting!r} but got {token!r}\n{' ' * 12}at index {pos}."
# Add a display showing the error token itself:
text = text.replace("\n", " ").replace("\t", " ")
offset = pos
if len(text) > pos + 10:
text = text[: pos + 10] + "..."
if pos > 10:
text = "..." + text[pos - 10 :]
offset = 13
msg += '\n{}"{}"\n{}^'.format(" " * 16, text, " " * (17 + offset))
raise ValueError(msg)
[docs]
def pretty_print(
self, highlight: Sequence['Tree | int'] = (), stream: TextIO | None = None, maxwidth: int = 32
) -> None:
"""
Pretty-print this tree as ASCII or Unicode art.
It relies on :py:class:`nltk.tree.prettyprinter.TreePrettyPrinter`.
:param stream: The file to print to.
:param highlight: Optionally, a sequence of Tree objects in `tree` which should be highlighted.
Has the effect of only applying colors to nodes in this sequence.
:param maxwidth: maximum number of characters before a label starts to wrap.
Leaf with more than `maxwidth` characters will be truncated.
"""
from nltk.tree import Tree as NLTKTree
from nltk.tree.prettyprinter import TreePrettyPrinter
nltk_tree = NLTKTree.fromstring(str(self))
for tree_position in nltk_tree.treepositions(order='leaves'):
if len(nltk_tree[tree_position]) > maxwidth:
nltk_tree[tree_position] = nltk_tree[tree_position][: maxwidth - 3] + '...'
print(TreePrettyPrinter(nltk_tree, highlight=highlight).text(unicodelines=True, maxwidth=maxwidth), file=stream)
[docs]
def to_svg(self, highlight: Sequence['Tree | int'] = ()) -> str:
"""
Pretty-print this tree as SVG.
It relies on :py:class:`nltk.tree.prettyprinter.TreePrettyPrinter`.
:param highlight: Optionally, a sequence of Tree objects in `tree` which should be highlighted.
Has the effect of only applying colors to nodes in this sequence.
"""
from nltk.tree import Tree as NLTKTree
from nltk.tree.prettyprinter import TreePrettyPrinter
nltk_tree = NLTKTree.fromstring(str(self))
return TreePrettyPrinter(nltk_tree, highlight=highlight).svg()
if TYPE_CHECKING:
class _SubTree(Tree):
parent: Tree
parent_index: int
class _TypedTree(Tree):
label: NodeLabel
class _TypedSubTree(_SubTree, _TypedTree): ...
[docs]
def is_sub_tree(tree: Tree) -> TypeGuard['_SubTree']:
"""
Determine whether the given Tree instance is a subtree.
This helper function serves as a type guard to assist static type checkers
like mypy in refining the type of `tree` when the function returns True.
:param tree: The tree instance to check.
:return: True if `tree` is a subtree (i.e., has a parent), False otherwise.
>>> t = Tree.fromstring('(S (X xxx) (Y yyy))')
>>> is_sub_tree(t)
False
>>> is_sub_tree(t[0])
True
"""
return tree.parent is not None
@overload
def has_type(
t: '_SubTree', types: set[NodeType | str] | NodeType | str | None = None
) -> TypeGuard['_TypedSubTree']: ...
@overload
def has_type(t: Tree, types: set[NodeType | str] | NodeType | str | None = None) -> TypeGuard['_TypedTree']: ...
@overload
def has_type(t: Any, types: set[NodeType | str] | NodeType | str | None = None) -> bool: ...
[docs]
def has_type(t: Any, types: set[NodeType | str] | NodeType | str | None = None) -> bool:
"""
Check if the given tree object has the specified type(s).
:param t: The object to check type for (can be a Tree, Production, or NodeLabel).
:param types: The types to check for (can be a set of strings, a string, or None).
:return: True if the object has the specified type(s), False otherwise.
>>> tree = Tree.fromstring('(S (ENT Alice) (REL Bob))')
>>> has_type(tree, NodeType.ENT)
False
>>> has_type(tree[0], NodeType.ENT)
True
>>> has_type(tree[0], 'ENT')
True
>>> has_type(tree[1], NodeType.ENT)
False
>>> has_type(tree[1], {NodeType.ENT, NodeType.REL})
True
"""
# Normalize type input
if types is None:
types = set(NodeType)
elif not isinstance(types, set):
types = {types}
types = {t.value if isinstance(t, NodeType) else str(t) for t in types}
label: NodeLabel | str
# Check for the type in the respective object
if isinstance(t, NodeLabel):
label = t
elif isinstance(t, Tree):
label = t.label
elif isinstance(t, Production):
label = t.lhs().symbol()
elif isinstance(t, Nonterminal):
label = t.symbol()
else:
return False
return isinstance(label, NodeLabel) and label.type.value in types
Forest: TypeAlias = Collection[Tree]