Source code for architxt.nlp.entity_extractor
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, AsyncIterator, Iterable
from typing import TYPE_CHECKING
from aiostream import pipe, stream
from architxt.nlp.model import AnnotatedSentence, Entity
if TYPE_CHECKING:
from flair.data import Sentence
from spacy.tokens import Doc
SPACY_DISABLED_PIPELINES = {'parser', 'senter', 'sentencizer', 'textcat', 'lemmatizer', 'tagger'}
[docs]
class EntityExtractor(ABC):
@property
def name(self) -> str:
return self.__class__.__name__
@abstractmethod
def __call__(self, sentence: str) -> AnnotatedSentence: ...
[docs]
async def batch(
self,
sentences: Iterable[str] | AsyncIterable[str],
) -> AsyncIterator[AnnotatedSentence]:
sentence_stream = stream.iterate(sentences) | pipe.map(self.__call__)
async with sentence_stream.stream() as streamer:
async for sentence in streamer:
yield sentence
[docs]
async def enrich(
self,
sentences: Iterable[AnnotatedSentence] | AsyncIterable[AnnotatedSentence],
) -> AsyncIterator[AnnotatedSentence]:
def _enrich_sentence(annotated: AnnotatedSentence) -> AnnotatedSentence:
new_entities = self(annotated.txt).entities
annotated.entities.extend(new_entities)
return annotated
sentence_stream = stream.iterate(sentences) | pipe.map(_enrich_sentence)
async with sentence_stream.stream() as streamer:
async for sentence in streamer:
yield sentence
[docs]
class SpacyEntityExtractor(EntityExtractor):
def __init__(self, model_name: str = "en_core_web_sm") -> None:
import spacy
self.nlp = spacy.load(model_name, disable=SPACY_DISABLED_PIPELINES)
@staticmethod
def _doc_to_annotated(doc: 'Doc') -> AnnotatedSentence:
entities = [
Entity(
name=ent.label_,
start=ent.start_char,
end=ent.end_char,
id=f"{ent.label_}_{ent.start_char}_{ent.end_char}",
value=ent.text,
)
for ent in doc.ents
]
return AnnotatedSentence(txt=doc.text, entities=entities, rels=[])
def __call__(self, sentence: str) -> AnnotatedSentence:
doc = self.nlp(sentence)
return self._doc_to_annotated(doc)
[docs]
async def batch(
self,
sentences: Iterable[str] | AsyncIterable[str],
*,
batch_size: int = 128,
) -> AsyncIterator[AnnotatedSentence]:
sentence_stream = (
stream.iterate(sentences)
| pipe.chunks(batch_size)
| pipe.flatmap(self.nlp.pipe)
| pipe.map(self._doc_to_annotated)
)
async with sentence_stream.stream() as streamer:
async for sentence in streamer:
yield sentence
[docs]
class FlairEntityExtractor(EntityExtractor):
def __init__(self, model_name: str = "ner") -> None:
from flair.models import SequenceTagger
self.tagger = SequenceTagger.load(model_name)
@staticmethod
def _sentence_to_annotated(sentence: 'Sentence') -> AnnotatedSentence:
entities = [
Entity(
name=span.tag,
start=span.start_position,
end=span.end_position,
id=f"{span.tag}_{span.start_position}_{span.end_position}",
value=span.text,
)
for span in sentence.get_spans('ner')
]
return AnnotatedSentence(txt=sentence.to_plain_string(), entities=entities, rels=[])
def __call__(self, sentence: str) -> AnnotatedSentence:
from flair.data import Sentence
flair_sentence = Sentence(sentence)
self.tagger.predict(flair_sentence)
return self._sentence_to_annotated(flair_sentence)
[docs]
async def batch(
self,
sentences: Iterable[str] | AsyncIterable[str],
*,
batch_size: int = 128,
) -> AsyncIterator[AnnotatedSentence]:
from flair.data import Sentence
entity_stream = (
stream.iterate(sentences)
| pipe.map(Sentence)
| pipe.chunks(batch_size)
| pipe.flatmap(self.tagger.predict)
| pipe.map(self._sentence_to_annotated)
)
async with entity_stream.stream() as streamer:
async for doc in streamer:
yield doc