Source code for architxt.nlp.entity_resolver

import contextlib
from abc import ABC, abstractmethod
from collections.abc import Iterable
from types import TracebackType

from googletrans import Translator
from scispacy.candidate_generation import CandidateGenerator
from unidecode import unidecode


[docs] class EntityResolver(ABC): @property def name(self) -> str: return self.__class__.__name__ @abstractmethod async def __call__(self, texts: Iterable[str]) -> Iterable[str]: ...
[docs] class ScispacyResolver(EntityResolver): def __init__( self, *, kb_name: str = 'umls', cleanup: bool = False, translate: bool = False, batch_size: int = 8, threshold: float = 0.7, resolve_text: bool = True, ) -> None: """ Resolve entities using the SciSpaCy entity linker. :param kb_name: The name of the knowledge base to use: `umls`, `mesh`, `rxnorm`, `go`, or `hpo`. :param cleanup: True if the resolved text should be uniformized. :param translate: True if the text should be translated if it does not correspond to the model language. :param batch_size: Number of texts to process in parallel (useful for large corpora). :param threshold : The threshold that an entity candidate must reach to be considered. :param resolve_text: True if the resolver should return the canonical name instead of the identifier """ self.translate = translate self.cleanup = cleanup self.batch_size = batch_size self.threshold = threshold self.kb_name = kb_name self.resolve_text = resolve_text self.exit_stack = contextlib.AsyncExitStack() self.candidate_generator = CandidateGenerator(name=self.kb_name) async def __aenter__(self) -> 'ScispacyResolver': if self.translate: translator = Translator(list_operation_max_concurrency=self.batch_size) self.translator = await self.exit_stack.enter_async_context(translator) return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> None: await self.exit_stack.aclose() @property def name(self) -> str: return self.kb_name async def _translate(self, texts: list[str]) -> list[str]: """ Translate texts in batch asynchronously. Use an existing translator if available, otherwise creates a temporary one. """ if not self.translator: async with Translator(list_operation_max_concurrency=self.batch_size) as temp_translator: translations = await temp_translator.translate(texts, dest="en") else: translations = await self.translator.translate(texts, dest="en") return [t.text for t in translations] def _cleanup_string(self, text: str) -> str: """ Cleanup text to uniformize it. :param text: The text document to clean up. :return: The uniformized text. """ if text and self.cleanup: text = unidecode(text.lower()) return text def _resolve(self, mention_texts: list[str]) -> Iterable[str]: """Resolve entity names using SciSpaCy entity linker.""" for mention, candidates in zip(mention_texts, self.candidate_generator(mention_texts, 10), strict=False): best_candidate = None best_candidate_score = 0 for candidate in candidates: if (score := max(candidate.similarities, default=0)) > self.threshold and score > best_candidate_score: best_candidate = candidate best_candidate_score = score if not best_candidate: yield mention elif self.resolve_text: yield self.candidate_generator.kb.cui_to_entity[best_candidate.concept_id].canonical_name else: yield best_candidate.concept_id async def __call__(self, texts: Iterable[str]) -> Iterable[str]: texts = list(texts) if self.translate: texts = await self._translate(texts) return map(self._cleanup_string, self._resolve(texts))