Source code for architxt.nlp.parser

import abc
import uuid
import warnings
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator
from copy import deepcopy
from types import TracebackType

from aiostream import pipe, stream
from nltk.tokenize.util import align_tokens

from architxt.nlp.entity_resolver import EntityResolver
from architxt.nlp.model import AnnotatedSentence, Entity, Relation, TreeEntity, TreeRel
from architxt.tree import NodeLabel, NodeType, Tree, has_type, reduce_all

__all__ = ['Parser']


[docs] class Parser(abc.ABC): def __enter__(self) -> 'Parser': return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> None: pass
[docs] async def parse_batch( self, sentences: Iterable[AnnotatedSentence] | AsyncIterable[AnnotatedSentence], language: str, resolver: EntityResolver | None = None, batch_size: int = 128, ) -> AsyncIterator[tuple[AnnotatedSentence, Tree]]: """ Parse a batch of annotated sentences into enriched syntax trees. This function processes an iterable (or asynchronous iterable) of sentences, parses each sentence into a syntax tree, enriches the tree by resolving coordination structures, and applies further enhancements like entity and relation enrichment. Optionally, an external entity resolver can be used to unify entities and relations across sentences. :param sentences: An iterable or asynchronous iterable of `AnnotatedSentence` objects to be parsed. :param language: The language to use for parsing. :param resolver: An optional entity resolver used to resolve entities within the parsed trees. If `None`, no entity resolution is performed. :param batch_size: The maximum number of concurrent parsing tasks that can run at once. It will only load at most `batch_size` element from the input iterable. :yields: A tuple of the original `AnnotatedSentence` and its enriched `Tree`. Each sentence is parsed independently, and results are yielded as they become available. Example: .. code-block:: python with Parser(corenlp_url="http://localhost:9000") as parser: async for sentence, tree in parser.parse_batch(sentences, language="English"): print(sentence) print(tree) """ def parse( batch_sentences: list[AnnotatedSentence], *_: list[AnnotatedSentence] ) -> AsyncIterable[tuple[AnnotatedSentence, Tree]]: texts = (sentence.txt for sentence in batch_sentences) trees = self.raw_parse(texts, language=language, batch_size=len(batch_sentences)) return stream.iterate(zip(batch_sentences, trees)) async def process( sent_tree: tuple[AnnotatedSentence, Tree], *_: tuple[AnnotatedSentence, Tree] | None ) -> tuple[AnnotatedSentence, Tree] | None: sentence, tree = sent_tree if enriched_tree := await process_tree(sentence, tree, resolver=resolver): return sentence, enriched_tree return None tree_stream = ( stream.iterate(sentences) | pipe.chunks(batch_size) | pipe.flatmap(parse) | pipe.amap(process, ordered=False, task_limit=batch_size) | pipe.filter(lambda x: x is not None) ) async with tree_stream.stream() as streamer: async for sentence, tree in streamer: yield sentence, tree
[docs] async def parse( self, sentence: AnnotatedSentence, *, language: str, resolver: EntityResolver | None = None, ) -> Tree | None: """ Parse an annotated sentence into an enriched syntax tree. This function takes an annotated sentence, parses it into a syntax tree, enriches the tree by fixing coordination structures, adding extra information (entities and relations), and applying reductions. An external entity resolver could be used to unify entities and relations. :param sentence: The annotated sentence to parse. :param language: The language to use for parsing. :param resolver: An optional entity resolver used to resolve entities within the parsed trees. If `None`, no entity resolution is performed. :returns: An enriched tree object. Example: .. code-block:: python with Parser(corenlp_url="http://localhost:9000") as parser: tree = parse(sentence, language='English') print(tree) """ for tree in self.raw_parse([sentence.txt], language=language, batch_size=1): if enriched_tree := await process_tree(sentence, tree, resolver=resolver): return enriched_tree return None
[docs] @abc.abstractmethod def raw_parse(self, sentences: Iterable[str], *, language: str, batch_size: int = 64) -> Iterator[Tree]: """ Parse a sentences into syntax trees using CoreNLP server. :param sentences: The sentences to parse. :param language: The language to use for parsing. :param batch_size: The maximum number of concurrent parsing tasks that can run at once. :returns: The parse trees of the sentences. Example: .. code-block:: python with Parser(corenlp_url="http://localhost:9000") as parser: for tree in parser.raw_parse(sentences, language='English'): print(tree) """ raise NotImplementedError
[docs] async def process_tree( sentence: AnnotatedSentence, tree: Tree, *, resolver: EntityResolver | None = None, ) -> Tree | None: # Replace specific parenthesis tokens ('-LRB-' for '(', '-RRB-' for ')') in the leaf nodes for subtree in tree.subtrees(lambda x: x.height() == 2 and len(x) == 1 and x[0] in {'-LRB-', '-RRB-'}): subtree[0] = '(' if subtree[0] == '-LRB-' else ')' # Flatten the coordination in the tree structure fix_all_coord(tree) # Enrich the tree with named entities and relations from the sentence try: enrich_tree(tree, sentence.txt, sentence.entities, sentence.rels) except ValueError as error: # Alignment issue, skip the tree warnings.warn(f'Alignment issue: {error}') return None # Reduce the tree structure removing unneeded nodes reduce_all(tree, set(NodeType)) # Don't yield an empty tree if not len(tree) or any(isinstance(child, str) for child in tree): return None # Do not keep root with only one child if len(tree) == 1 and isinstance(tree[0], Tree): tree = Tree.convert(tree[0]) # Rename nodes to unique undefined names # This is needed when measuring statistics for subtree in tree.subtrees(lambda x: not has_type(x, NodeType.ENT)): subtree.set_label('ROOT' if subtree.parent() is None else f'UNDEF_{uuid.uuid4().hex}') if resolver: await resolve_tree(tree, resolver) return tree
[docs] def enrich_tree(tree: Tree, sentence: str, entities: list[Entity], relations: list[Relation]) -> None: """ Enriches a syntactic tree (tree) by inserting entities and relationships, and removing unused subtrees. The function processes a list of entities and relations, inserting them into the tree, unnesting entities as needed, and finally deleting any subtrees that are not part of the enriched structure. :param tree: A tree representing the syntactic tree to enrich. :param sentence: The original sentence from which the tree is derived. :param entities: A list of `Entity` objects to be inserted into the tree. :param relations: A list of `Relation` objects representing the relationships between entities (currently not used). >>> t = Tree.fromstring("(S (NP Alice) (VP (VB likes) (NP (NNS apples) (CCONJ and) (NNS oranges))))") >>> e1 = Entity(name="person", start=0, end=5, id="E1") >>> e2 = Entity(name="fruit", start=12, end=18, id="E2") >>> e3 = Entity(name="fruit", start=23, end=30, id="E3") >>> enrich_tree(t, "Alice likes apples and oranges", [e1, e2, e3], []) >>> print(t.pformat(margin=255)) (S (ENT::person Alice) (VP (NP (ENT::fruit apples) (ENT::fruit oranges)))) >>> t = Tree.fromstring("(S (NP XXX) (NP YYY))") >>> e1 = Entity(name="nested1", start=0, end=3, id="E1") >>> e2 = Entity(name="nested2", start=4, end=7, id="E2") >>> e3 = Entity(name="overlap", start=0, end=7, id="E3") >>> enrich_tree(t, "XXX YYY", [e1, e2, e3], []) >>> print(t.pformat(margin=255)) (S (REL (ENT::overlap XXX YYY) (nested (ENT::nested1 XXX) (ENT::nested2 YYY)))) """ assert tree is not None assert sentence tokens = align_tokens(tree.leaves(), sentence) entity_tokens = { entity.id: tuple(i for i, token in enumerate(tokens) if entity.start <= token[1] and token[0] < entity.end) for entity in entities } # Insert entities into the tree by length (descending) to handle larger entities first computed_spans: set[tuple[int, ...]] = set() entity_trees: list[Tree] = [] for entity in sorted(entities, key=lambda entity: len(entity_tokens[entity.id]), reverse=True): entity_span = entity_tokens[entity.id] # Check for conflicts and skip problematic entities if is_conflicting_entity(entity, entity_span, computed_spans, tree): continue tree_entity = TreeEntity(entity.name, [tree.leaf_treeposition(i) for i in entity_span]) entity_tree = ins_ent(tree, tree_entity) entity_trees.append(entity_tree) computed_spans.add(entity_span) for et in entity_trees: assert et.parent() is not None, str(et) # Unnest any nested entities in reverse order (to avoid modifying parent indices during the process) for entity_tree in sorted(entity_trees, key=lambda x: x.height()): unnest_ent(entity_tree.parent(), entity_tree.parent_index()) # Currently, the relation part is commented out, but can be enabled when relations are processed. # for relation in relations: # tree_rel = TreeRel((), (), relation.name) # ins_rel(tree, tree_rel) if relations: warnings.warn("Relations are not yet supported and will be skipped.") # Remove subtrees that have no specific entity or relation (i.e., generic subtrees) for subtree in list(tree.subtrees(lambda x: x.height() == 2 and not has_type(x))): subtree.parent().remove(subtree)
[docs] def fix_coord(tree: Tree, pos: int) -> bool: """ Fix the coordination structure in a tree at the specified position `pos`. This function modifies the tree to ensure that the conjunctions are structured correctly according to the grammar rules of coordination. :param tree: The tree in which coordination adjustments will be made. :param pos: The index of the subtree within the parent tree that contains the coordination to fix. :return: `True` if the coordination was successfully fixed, `False` otherwise. >>> t = Tree.fromstring("(S (NP Alice) (VP (VB eats) (NP (NNS apples) (COORD (CCONJ and) (NP (NNS oranges))))))") >>> fix_coord(t[1], 1) True >>> print(t.pformat(margin=255)) (S (NP Alice) (VP (VB eats) (CONJ (NP (NNS apples)) (NP (NNS oranges))))) """ assert tree is not None coord = None # Identify the coordination subtree for child in tree[pos]: if ( isinstance(child, Tree) and child.label() == 'COORD' and len(child) > 0 and isinstance(child[0], Tree) and child[0].label() == 'CCONJ' ): coord = child break if coord is None: return False coord_index = coord.parent_index() # Create the left and right parts of the conjunction left = Tree(tree[pos].label(), children=[Tree.convert(child) for child in tree[pos][:coord_index]]) right = [Tree.convert(child) for child in coord[1:]] # Get all NPs after the conjunction # Create the conjunction tree conj = Tree('CONJ', children=[left, *right]) # CONJ should include the left NP and the conjuncts # Insert the new structure back into the tree # If children remain on the right of the coordination, we keep the existing level new_tree = ( Tree(tree[pos].label(), children=[conj] + [Tree.convert(child) for child in remaining_children]) if (remaining_children := tree[pos][coord_index + 1 :]) else conj ) # Replace the old subtree tree[pos] = new_tree return True
[docs] def fix_conj(tree: Tree, pos: int) -> bool: """ Fix conjunction structures in a tree at the specified position `pos`. If the node at `pos` is labeled 'CONJ', the function flattens any nested conjunctions by replacing the node with a new tree that combines its children. :param tree: The tree in which the conjunction structure will be fixed. :param pos: The index of the 'CONJ' node to be processed. :return: `True` if the conjunction structure was modified, `False` otherwise. >>> t = Tree.fromstring("(S (NP Alice) (VP (VB eats) (CONJ (NP (NNS apples)) (NP (NNS oranges)))))") >>> fix_conj(t[1], 1) False >>> t = Tree.fromstring("(S (NP Alice) (VP (VB eats) (CONJ (NP (NNS apples)) (CONJ (NP (NNS oranges)) (NP (NNS bananas))))))") >>> fix_conj(t[1], 1) True >>> print(t.pformat(margin=255)) (S (NP Alice) (VP (VB eats) (CONJ (NP (NNS apples)) (NP (NNS oranges)) (NP (NNS bananas))))) """ assert tree is not None # Check if the specified position is valid and corresponds to a 'CONJ' node if not isinstance(tree[pos], Tree) or tree[pos].label() != 'CONJ': return False new_children: list[Tree | str] = [] # Collect children, flattening nested conjunctions for child in tree[pos]: if isinstance(child, Tree) and child.label() == 'CONJ': new_children.extend(child) # Extend with children of the nested CONJ else: new_children.append(child) # Append non-CONJ children # If no changes were made, return False if len(new_children) <= len(tree[pos]): return False # Create a new tree for the flattened conjunction new_tree = Tree('CONJ', children=[Tree.convert(t) for t in new_children]) # Replace the original 'CONJ' node with the new tree tree[pos] = new_tree return True
[docs] def fix_all_coord(tree: Tree) -> None: """ Fix all coordination structures in a tree. This function iteratively applies `fix_coord` and `fix_conj` to the tree until no further modifications can be made. It ensures that the tree adheres to the correct syntactic structure for coordination and conjunctions. :param tree: The tree in which coordination structures will be fixed. >>> t = Tree.fromstring("(S (NP Alice) (VP (VB eats) (NP (NNS apples) (COORD (CCONJ and) (NP (NNS oranges))))))") >>> fix_all_coord(t) >>> print(t.pformat(margin=255)) (S (NP Alice) (VP (VB eats) (CONJ (NP (NNS apples)) (NP (NNS oranges))))) >>> t2 = Tree.fromstring("(S (NP Alice) (VP (VB eats) (NP (NNS apples) (COORD (CCONJ and) (NP (NNS oranges) (COORD (CCONJ and) (NP (NNS bananas))))))))") >>> fix_all_coord(t2) >>> print(t2.pformat(margin=255)) (S (NP Alice) (VP (VB eats) (CONJ (NP (NNS apples)) (NP (NNS oranges)) (NP (NNS bananas))))) """ assert tree is not None # Fix coordination coord_fixed = True while coord_fixed: coord_fixed = False for pos in tree.treepositions(): if len(pos) < 1: continue # Attempt to fix coordination if fix_coord(tree[pos[:-1]], pos[-1]): coord_fixed = True break # Fix conjunctions conj_fixed = True while conj_fixed: conj_fixed = False for pos in tree.treepositions(): if len(pos) < 1: continue # Attempt to fix conjunctions if fix_conj(tree[pos[:-1]], pos[-1]): conj_fixed = True break
[docs] def ins_ent(tree: Tree, tree_ent: TreeEntity) -> Tree: """ Insert a tree entity into the appropriate position within a parented tree. The function modifies the tree structure to insert an entity at the correct level based on its positions and root position. :param tree: A tree representing the syntactic tree. :param tree_ent: A `TreeEntity` containing the entity name and its positions in the tree. :return: The updated subtree where the entity was inserted. >>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))") >>> tree_ent1 = TreeEntity(name="person", positions=[(0, 0)]) >>> tree_ent2 = TreeEntity(name="fruit", positions=[(1, 1, 0, 0)]) >>> ent_tree = ins_ent(t, tree_ent1) >>> print(t.pformat(margin=255)) (S (ENT::person Alice) (VP (VB like) (NP (NNS apples)))) >>> ent_tree = ins_ent(t, tree_ent2) >>> print(t.pformat(margin=255)) (S (ENT::person Alice) (VP (VB like) (ENT::fruit apples))) >>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))") >>> t_ent = TreeEntity(name="xxx", positions=[(1, 0, 0), (1, 1, 0, 0)]) >>> ent_tree = ins_ent(t, t_ent) >>> print(t.pformat(margin=255)) (S (NP Alice) (ENT::xxx like apples)) >>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))") >>> t_ent = TreeEntity(name="xxx", positions=[(0, 0), (1, 1, 0, 0)]) >>> ent_tree = ins_ent(t, t_ent) >>> print(t.pformat(margin=255)) (S (ENT::xxx Alice apples) (VP (VB like))) >>> t = Tree.fromstring("(S (NP Alice) (VP (VB like) (NP (NNS apples))))") >>> t_ent = TreeEntity(name="xxx", positions=[(0, 0), (1, 0, 0), (1, 1, 0, 0)]) >>> ent_tree = ins_ent(t, t_ent) >>> print(t.pformat(margin=255)) (S (ENT::xxx Alice like apples)) >>> t_ent = TreeEntity(name="yyy", positions=[(0, 2)]) >>> ent_tree = ins_ent(t, t_ent) >>> print(t.pformat(margin=255)) (S (ENT::xxx Alice like (ENT::yyy apples))) >>> t = Tree.fromstring("(S x y z)") >>> t_ent = TreeEntity(name="XY", positions=[(0,), (1,)]) >>> ent_tree = ins_ent(t, t_ent) >>> print(t.pformat(margin=255)) (S (ENT::XY x y) z) >>> t_ent = TreeEntity(name="YZ", positions=[(0, 1), (1,)]) >>> ent_tree = ins_ent(t, t_ent) >>> print(t.pformat(margin=255)) (S (ENT::XY x y) (ENT::YZ y z)) """ assert tree is not None # Determine the insertion point based on the positions of the entity anchor_pos = tree_ent.root_pos anchor_pos_len = len(anchor_pos) child_pos = tree_ent.positions[0] if sum(child_pos[anchor_pos_len + 1 :]) > 0: # Entity has children; attach to the common parent at the first child index + 1 entity_index = child_pos[anchor_pos_len] + 1 elif ( tree[tree_ent.root_pos].parent() is None or child_pos[anchor_pos_len] > 0 or tree_ent.positions[-1][anchor_pos_len] < (len(tree[tree_ent.root_pos]) - 1) ): # Attach to common parent at the correct index entity_index = child_pos[anchor_pos_len] else: # Attach to the grandparent at the common parent index entity_index = tree_ent.root_pos[-1] anchor_pos = tree_ent.root_pos[:-1] # Adjust anchor position upwards if necessary while len(tree[anchor_pos]) == 1 and tree[anchor_pos].parent(): entity_index = anchor_pos[-1] anchor_pos = anchor_pos[:-1] # Collect and delete children from the original positions children = [] for child_position in reversed(tree_ent.positions): parent_position = child_position[:-1] if not has_type(tree[parent_position], NodeType.ENT): # The entity has no conflict children.append(tree[child_position]) tree[parent_position].pop(child_position[-1], recursive=False) elif len(parent_position) <= len(anchor_pos) and parent_position == anchor_pos[: len(parent_position)]: # The entity is a child of another children.append(tree[child_position]) tree[parent_position].pop(child_position[-1], recursive=False) elif any( leaf_position not in tree_ent.positions for leaf_position in tree[parent_position].treepositions('leaves') ): # The entity overlap with another we need to duplicate overlapping leaves # Else, the entity is a parent entity, so we include only leaves not present in nested entities children.append(tree[child_position]) # Create a new tree node for the entity and insert it into the tree new_tree = Tree(NodeLabel(NodeType.ENT, tree_ent.name), children=reversed(children)) tree[anchor_pos].insert(entity_index, new_tree) # Return the modified subtree where the entity was inserted entity_tree = tree[anchor_pos][entity_index] # Remove empty subtree left in place for subtree in list(tree.subtrees(lambda st: len(st) == 0)): if subtree.parent(): subtree.parent().remove(subtree) return entity_tree
[docs] def unnest_ent(tree: Tree, pos: int) -> None: """ Un-nest an entity in a tree at the specified position `pos`. If the node at `pos` is labeled as an entity (ENT), the function converts the nested structure into a flat structure, creating a relationship (REL) between the entity and its nested entities. :param tree: The tree in which the entity will be un-nested. :param pos: The index of the 'ENT' node to be processed. >>> t = Tree.fromstring('(S (ENT::person Alice (ENT::person Bob) (ENT::person Charlie)))') >>> unnest_ent(t[0], 0) >>> print(t.pformat(margin=255)) (S (ENT::person Alice (ENT::person Bob) (ENT::person Charlie))) >>> unnest_ent(t, 0) >>> print(t.pformat(margin=255)) (S (REL (ENT::person Alice Bob Charlie) (nested (ENT::person Bob) (ENT::person Charlie)))) """ assert tree is not None # Check if the specified position corresponds to an 'ENT' node if not has_type(tree[pos], NodeType.ENT): return # Create the main entity tree and collect nested entities entity_tree = Tree(tree[pos].label(), children=tree[pos].leaves()) # Collect nested entities and ensure they are only from the children of the current entity nested_entities = [deepcopy(child) for child in tree[pos] if has_type(child, NodeType.ENT)] nested_tree = Tree('nested', children=nested_entities) # Construct a new relationship tree with the entity and its nested entities new_tree = Tree(NodeLabel(NodeType.REL), children=[entity_tree, nested_tree]) if nested_entities else entity_tree # Replace the original entity node with the new structure tree[pos] = new_tree
[docs] def ins_rel(tree: Tree, _: TreeRel) -> None: assert tree is not None
[docs] def is_conflicting_entity( entity: Entity, entity_span: tuple[int, ...], computed_spans: set[tuple[int, ...]], tree: Tree ) -> bool: """Check for conflicts with other entities (overlapping or duplicate spans).""" if entity_span in computed_spans: warnings.warn( f"Entity {entity.name} with tokens {entity_span} ('{' '.join(tree.leaves()[i] for i in entity_span)}') " f"conflicts with a previously inserted entity." ) return True for span in computed_spans: if any(token in span for token in entity_span) and not all(token in span for token in entity_span): warnings.warn( f"Entity {entity.name} with tokens {entity_span} ('{' '.join(tree.leaves()[i] for i in entity_span)}') " f"partially overlaps with a previously inserted entity with tokens {span} ('{' '.join(tree.leaves()[i] for i in span)}')." "Overlapping tokens will be duplicated." ) return False return False
[docs] async def resolve_tree(tree: Tree, resolver: EntityResolver) -> None: """Resolve entities in a tree using the provided entity resolver.""" ent_trees = list(tree.subtrees(lambda x: has_type(x, NodeType.ENT))) ent_texts = await resolver(' '.join(ent_tree.leaves()) for ent_tree in ent_trees) for ent_tree, ent_text in zip(ent_trees, ent_texts, strict=True): ent_tree.clear() ent_tree.append(ent_text)