Source code for architxt.utils

from __future__ import annotations

import sys
from random import randrange
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import more_itertools

if TYPE_CHECKING:
    from collections.abc import Generator, Iterable, Sequence

__all__ = ['BATCH_SIZE', 'ExceptionGroup', 'get_commit_batch_size', 'update_url_queries', 'windowed_shuffle']

BATCH_SIZE = 1024
T = TypeVar('T')


if sys.version_info < (3, 11):

    class ExceptionGroup(BaseException):
        def __init__(self, message: str, exceptions: Sequence[BaseException]) -> None:
            message += '\n'.join(f'  ({i}) {exc!r}' for i, exc in enumerate(exceptions, 1))
            super().__init__(message)

else:
    from builtins import ExceptionGroup


[docs] def update_url_queries(url: str, **p: Any) -> str: """ Update query parameters in a URL. Merges existing query parameters with provided keyword arguments. If a parameter already exists, it will be overwritten. >>> update_url_queries('https://example.com?foo=1', bar='2') 'https://example.com?foo=1&bar=2' >>> update_url_queries('https://example.com?foo=1', foo='overwritten') 'https://example.com?foo=overwritten' :param url: The URL to update. :param p: Query parameters to add or update. :return: The URL with updated query parameters. """ u = urlparse(url) q = dict(parse_qsl(u.query)) q.update(p) return urlunparse(u._replace(query=urlencode(q)))
[docs] def get_commit_batch_size(commit: bool | int) -> int: """ Derive the batch size for commit operations. :param commit: Commit mode. - If True or False, returns the default BATCH_SIZE. - If a positive integer, returns that value as the batch size. :return: The batch size to use for chunked operations. :raises ValueError: If commit is a non-positive integer. """ if isinstance(commit, bool): batch_size = BATCH_SIZE elif commit > 0: batch_size = commit else: msg = 'Commit should be a boolean or a positive integer' raise ValueError(msg) return batch_size
[docs] def windowed_shuffle(iterable: Iterable[T], window_size: int = 10) -> Generator[T, None, None]: """ Shuffle an :py:class:`~Iterable` by yielding items in a randomized order using a sliding window buffer. :param iterable: Iterable to shuffle. :param window_size: Size of the sliding window buffer. :yield: Shuffled items. :raise ValueError: If window_size is <= 1. """ if window_size <= 1: msg = "window_size must be > 1" raise ValueError(msg) it = iter(iterable) buf = list(more_itertools.take(window_size, it)) for item in it: idx = randrange(len(buf)) yield buf.pop(idx) buf.append(item) while buf: idx = randrange(len(buf)) yield buf.pop(idx)