Source code for architxt.nlp

import asyncio
import hashlib
import tarfile
import zipfile
from collections.abc import AsyncGenerator, Iterable, Sequence
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, BinaryIO

import mlflow
from mlflow.data.code_dataset_source import CodeDatasetSource
from mlflow.data.meta_dataset import MetaDataset
from platformdirs import user_cache_path
from rich.console import Console
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn

from architxt.bucket.zodb import ZODBTreeBucket
from architxt.nlp.brat import load_brat_dataset
from architxt.nlp.entity_resolver import EntityResolver, ScispacyResolver
from architxt.nlp.parser import Parser
from architxt.tree import Tree
from architxt.utils import BATCH_SIZE

if TYPE_CHECKING:
    from architxt.nlp.model import AnnotatedSentence

__all__ = ['raw_load_corpus']

console = Console()

CACHE_DIR = user_cache_path('architxt')


async def _get_cache_key(
    archive_file: BytesIO | BinaryIO,
    *,
    entities_filter: set[str] | None = None,
    relations_filter: set[str] | None = None,
    entities_mapping: dict[str, str] | None = None,
    relations_mapping: dict[str, str] | None = None,
    language: str,
    resolver: EntityResolver | None = None,
) -> str:
    """Generate a cache key based on the archive file's content and settings."""
    cursor = archive_file.tell()
    file_hash = await asyncio.to_thread(hashlib.file_digest, archive_file, hashlib.md5)
    archive_file.seek(cursor)

    file_hash.update(language.encode())

    if entities_filter:
        file_hash.update('$E'.join(sorted(entities_filter)).encode())
    if relations_filter:
        file_hash.update('$R'.join(sorted(relations_filter)).encode())
    if entities_mapping:
        file_hash.update('$EM'.join(sorted(f'{key}={value}' for key, value in entities_mapping.items())).encode())
    if relations_mapping:
        file_hash.update('$RM'.join(sorted(f'{key}={value}' for key, value in relations_mapping.items())).encode())
    if resolver:
        file_hash.update(resolver.name.encode())

    return file_hash.hexdigest()


[docs] def open_archive(archive_file: BytesIO | BinaryIO) -> zipfile.ZipFile | tarfile.TarFile: cursor = archive_file.tell() signature = archive_file.read(4) archive_file.seek(cursor) if signature.startswith(b'PK\x03\x04'): # ZIP file signature return zipfile.ZipFile(archive_file) if signature.startswith(b'\x1f\x8b'): # GZIP signature (tar.gz) return tarfile.TarFile.open(fileobj=archive_file) msg = "Unsupported file format" raise ValueError(msg)
async def _load_or_cache_corpus( # noqa: C901 archive_file: str | Path | BytesIO | BinaryIO, queue: asyncio.Queue[Tree], progress: Progress, *, entities_filter: set[str] | None = None, relations_filter: set[str] | None = None, entities_mapping: dict[str, str] | None = None, relations_mapping: dict[str, str] | None = None, parser: Parser, language: str, name: str | None = None, resolver: EntityResolver | None = None, cache: bool = True, sample: int | None = None, ) -> None: """ Load the corpus from the disk or cache. :param archive_file: A path or an in-memory file object of the corpus archive. :param entities_filter: A set of entity types to exclude from the output. If None, no filtering is applied. :param relations_filter: A set of relation types to exclude from the output. If None, no filtering is applied. :param entities_mapping: A dictionary mapping entities names to new values. If None, no mapping is applied. :param relations_mapping: A dictionary mapping relation names to new values. If None, no mapping is applied. :param parser: The NLP parser to use. :param language: The language to use for parsing. :param name: The corpus name. :param resolver: An optional entity resolver to use. :param cache: Whether to cache the computed forest or not. :returns: A list of parsed trees representing the enriched corpus. """ should_close = False corpus_cache_path: Path | None = None if isinstance(archive_file, str | Path): archive_file = Path(archive_file).open('rb') # noqa: SIM115 should_close = True try: key = await _get_cache_key( archive_file, entities_filter=entities_filter, entities_mapping=entities_mapping, relations_filter=relations_filter, relations_mapping=relations_mapping, language=language, resolver=resolver, ) if cache: directory = CACHE_DIR / 'corpus_cache' directory.mkdir(parents=True, exist_ok=True) corpus_cache_path = directory / key if mlflow.active_run(): mlflow.log_input( MetaDataset( CodeDatasetSource( { 'entities_filter': sorted(entities_filter or []), 'relations_filter': sorted(relations_filter or []), 'entities_mapping': entities_mapping, 'relations_mapping': relations_mapping, 'cache_file': str(corpus_cache_path.absolute()) if corpus_cache_path else None, } ), name=name or archive_file.name, digest=key, ) ) with ZODBTreeBucket(storage_path=corpus_cache_path) as forest: count = 0 if cache and len(forest): # Attempt to load from cache if available for tree in progress.track( forest, description=f'[green]Loading corpus {archive_file.name} from cache...[/]', total=sample, ): await queue.put(tree.copy()) count += 1 if sample and count >= sample: break else: # Load data from disk with ( open_archive(archive_file) as corpus, TemporaryDirectory() as tmp_dir, ): # Extract archive contents to a temporary directory await asyncio.to_thread(corpus.extractall, path=tmp_dir) # Parse sentences and enrich the forest sentences: Iterable[AnnotatedSentence] = load_brat_dataset( Path(tmp_dir), entities_filter=entities_filter, relations_filter=relations_filter, entities_mapping=entities_mapping, relations_mapping=relations_mapping, ) sentences = progress.track( sentences, description=f'[yellow]Loading corpus {archive_file.name} from disk...[/]', ) batch = [] async for _, tree in parser.parse_batch(sentences, language=language, resolver=resolver): batch.append(tree) if not sample or count < sample: await queue.put(tree.copy()) count += 1 if len(batch) >= BATCH_SIZE: await asyncio.to_thread(forest.update, batch) batch.clear() if batch: await asyncio.to_thread(forest.update, batch) except Exception as e: console.print(f'[red]Error while processing corpus:[/] {e}') raise finally: if should_close: archive_file.close()
[docs] async def raw_load_corpus( corpus_archives: Sequence[str | Path | BytesIO | BinaryIO], languages: Sequence[str], *, parser: Parser, entities_filter: set[str] | None = None, relations_filter: set[str] | None = None, entities_mapping: dict[str, str] | None = None, relations_mapping: dict[str, str] | None = None, resolver_name: str | None = None, cache: bool = True, sample: int | None = None, batch_size: int = BATCH_SIZE, ) -> AsyncGenerator[Tree, None]: """ Asynchronously loads a set of corpus from disk or in-memory archives, parses it, and returns the enriched forest. This function handles both local and in-memory corpus archives, processes the data based on the specified filters and mappings, and uses the provided CoreNLP server for parsing. Optionally, caching can be enabled to avoid repeated computations. The resulting forest is not a valid database instance it needs to be passed to the automatic structuration algorithm first. :param corpus_archives: A list of corpus archive sources, which can be: - Paths to files on disk, or - In-memory file-like objects. The list can include both local and in-memory sources, and its size should match the length of `languages`. :param languages: A list of languages corresponding to each corpus archive. The number of languages must match the number of archives. :param parser: The parser to use to parse the sentences. :param entities_filter: A set of entity types to exclude from the output. If py:`None`, no filtering is applied. :param relations_filter: A set of relation types to exclude from the output. If py:`None`, no filtering is applied. :param entities_mapping: A dictionary mapping entities names to new values. If py:`None`, no mapping is applied. :param relations_mapping: A dictionary mapping relation names to new values. If py:`None`, no mapping is applied. :param resolver_name: The name of the entity resolver to use. If py:`None`, no entity resolution is performed. :param cache: A boolean flag indicating whether to cache the computed forest for faster future access. :param sample: The number of examples to take in each corpus. :param batch_size: The number of sentences to process in each batch. This parameter is used to control the memory usage. :returns: A forest containing the parsed and enriched trees. """ with ( parser as parser_ctx, Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), TimeElapsedColumn(), console=console, ) as progress, ): resolver_ctx = ( ScispacyResolver(cleanup=True, translate=True, kb_name=resolver_name) if resolver_name else nullcontext() ) async with resolver_ctx as resolver: queue: asyncio.Queue[Tree] = asyncio.Queue(batch_size) pending = { asyncio.create_task( _load_or_cache_corpus( corpus, queue, progress, entities_filter=entities_filter, relations_filter=relations_filter, entities_mapping=entities_mapping, relations_mapping=relations_mapping, parser=parser_ctx, language=language, resolver=resolver, cache=cache, sample=sample, ) ) for corpus, language in zip(corpus_archives, languages, strict=True) } try: while pending or not queue.empty(): while not queue.empty(): yield await queue.get() if pending: _, pending = await asyncio.wait(pending, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) finally: for task in pending: task.cancel()