Source code for architxt.nlp.contrib.flair

from __future__ import annotations

from typing import TYPE_CHECKING

from aiostream import pipe, stream

from architxt.nlp.entity_extractor import EntityExtractor
from architxt.nlp.model import AnnotatedSentence, Entity

if TYPE_CHECKING:
    from collections.abc import AsyncIterable, AsyncIterator, Iterable


try:
    from flair.data import Sentence
    from flair.models import SequenceTagger
except ImportError as error:
    msg = f"The '{__name__}' contrib module requires Flair. Install it with: pip install architxt[flair]"
    raise ImportError(msg) from error

__all__ = ['FlairEntityExtractor']


[docs] class FlairEntityExtractor(EntityExtractor): def __init__(self, model_name: str = "ner") -> None: self.tagger = SequenceTagger.load(model_name) @staticmethod def _sentence_to_annotated(sentence: Sentence) -> AnnotatedSentence: entities = [ Entity( name=span.tag, start=span.start_position, end=span.end_position, id=f"{span.tag}_{span.start_position}_{span.end_position}", value=span.text, ) for span in sentence.get_spans('ner') ] return AnnotatedSentence(txt=sentence.to_plain_string(), entities=entities, rels=[]) def __call__(self, sentence: str) -> AnnotatedSentence: flair_sentence = Sentence(sentence) self.tagger.predict(flair_sentence) return self._sentence_to_annotated(flair_sentence)
[docs] async def batch( self, sentences: Iterable[str] | AsyncIterable[str], *, batch_size: int = 128, ) -> AsyncIterator[AnnotatedSentence]: entity_stream = ( stream.iterate(sentences) | pipe.map(Sentence) | pipe.chunks(batch_size) | pipe.action(self.tagger.predict) | pipe.flatten() | pipe.map(self._sentence_to_annotated) ) async with entity_stream.stream() as streamer: async for doc in streamer: yield doc