Source code for architxt.nlp.entity_resolver

from __future__ import annotations

import abc
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING

from aiostream import pipe, stream
from typing_extensions import Self

if TYPE_CHECKING:
    from collections.abc import AsyncIterable, AsyncIterator, Iterable
    from types import TracebackType

    from architxt.nlp.model import AnnotatedSentence, Entity

__all__ = ['EntityResolver']


[docs] class EntityResolver(AbstractAsyncContextManager): @property def name(self) -> str: return self.__class__.__name__ @abc.abstractmethod async def __call__(self, entity: Entity) -> Entity: ... async def __aenter__(self) -> Self: return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> None: pass
[docs] async def batch( self, entities: Iterable[Entity] | AsyncIterable[Entity], *, batch_size: int = 16, ) -> AsyncIterator[Entity]: entity_stream = stream.iterate(entities) | pipe.amap(self.__call__, task_limit=batch_size) async with entity_stream.stream() as streamer: async for entity in streamer: yield entity
[docs] async def batch_sentences( self, sentences: Iterable[AnnotatedSentence] | AsyncIterable[AnnotatedSentence], *, batch_size: int = 16, ) -> AsyncIterator[AnnotatedSentence]: async def _resolve(sentence: AnnotatedSentence) -> AnnotatedSentence: sentence.entities = [entity async for entity in self.batch(sentence.entities, batch_size=batch_size)] return sentence sentence_stream = stream.iterate(sentences) | pipe.amap(_resolve, task_limit=1) async with sentence_stream.stream() as streamer: async for sent in streamer: yield sent