Source code for architxt.similarity

import math
from collections import defaultdict
from collections.abc import Callable, Collection, Iterable
from itertools import combinations

import mlflow
import numpy as np
import numpy.typing as npt
import plotly.figure_factory as ff
from Levenshtein import jaro_winkler
from Levenshtein import ratio as levenshtein_ratio
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
from tqdm.auto import tqdm

from architxt.tree import Forest, NodeType, Tree, has_type

METRIC_FUNC = Callable[[Collection[str], Collection[str]], float]
TREE_CLUSTER = set[tuple[Tree, ...]]


[docs] def jaccard(x: Collection[str], y: Collection[str]) -> float: """ Jaccard similarity. :param x: The first sequence of strings. :param y: The second sequence of strings. :return: The Jaccard similarity as a float between 0 and 1, where 1 means identical sequences. >>> jaccard({"A", "B"}, {"A", "B", "C"}) 0.6666666666666666 >>> jaccard({"apple", "banana", "cherry"}, {"apple", "cherry", "date"}) 0.5 >>> jaccard(set(), set()) 1.0 """ x_set = set(x) y_set = set(y) return len(x_set & y_set) / len(x_set | y_set) if x_set or y_set else 1.0
[docs] def levenshtein(x: Collection[str], y: Collection[str]) -> float: """Levenshtein similarity.""" return levenshtein_ratio(sorted(x), sorted(y))
[docs] def jaro(x: Collection[str], y: Collection[str]) -> float: """Jaro winkler similarity.""" return jaro_winkler(sorted(x), sorted(y))
DEFAULT_METRIC: METRIC_FUNC = jaro # jaccard, levenshtein, jaro
[docs] def similarity(x: Tree, y: Tree, *, metric: METRIC_FUNC = DEFAULT_METRIC) -> float: """ Compute the similarity between two tree objects based on their entity labels and context. The function uses a specified metric (such as Jaccard, Levenshtein, or Jaro-Winkler) to calculate the similarity between the labels of entities in the trees. The similarity is computed as recursive weighted mean for each tree anestor. :param x: The first tree object. :param y: The second tree object. :param metric: A metric function to compute the similarity between the entity labels of the two trees. :return: A similarity score between 0 and 1, where 1 indicates maximum similarity. >>> from architxt.tree import Tree >>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))') >>> similarity(t[0], t[1], metric=jaccard) 0.5555555555555555 """ assert x is not None assert y is not None if x is y or x.label() == y.label(): return 1.0 weight_sum = 0.0 sim_sum = 0.0 distance = 1 while x is not None and y is not None: # Extract the entity labels as sets for faster lookup x_labels = x.entity_labels() y_labels = y.entity_labels() # If no common entity labels, return similarity 0 early if x_labels.isdisjoint(y_labels): return 0.0 # Calculate similarity for current level and accumulate weighted sum weight = 1 / distance weight_sum += weight sim_sum += weight * metric(x_labels, y_labels) # Move to parent nodes x = x.parent() y = y.parent() distance += 1 return min(max(sim_sum / weight_sum, 0), 1) # Need to fix float issues
[docs] def sim(x: Tree, y: Tree, tau: float, metric: METRIC_FUNC = DEFAULT_METRIC) -> bool: """ Determine whether the similarity between two tree objects exceeds a given threshold `tau`. :param x: The first tree object to compare. :param y: The second tree object to compare. :param tau: The threshold value for similarity. :param metric: A callable similarity metric to compute the similarity between the two trees. :return: `True` if the similarity between `x` and `y` is greater than or equal to `tau`, otherwise `False`. >>> from architxt.tree import Tree >>> t = Tree.fromstring('(S (X (ENT::person Alice) (ENT::fruit apple)) (Y (ENT::person Bob) (ENT::animal rabbit)))') >>> sim(t[0], t[1], tau=0.5, metric=jaccard) True """ return similarity(x, y, metric=metric) >= tau
[docs] def compute_dist_matrix(subtrees: Collection[Tree], *, metric: METRIC_FUNC) -> npt.NDArray[np.uint16]: """ Compute the condensed distance matrix for a collection of subtrees. This function computes pairwise distances between all subtrees and stores the results in a condensed distance matrix format (1D array), which is suitable for hierarchical clustering. The computation is sequential. :param subtrees: A list of subtrees for which pairwise distances will be calculated. :param metric: A callable similarity metric to compute the similarity between the two trees. :return: A 1D numpy array containing the condensed distance matrix (only a triangle of the full matrix). """ nb_combinations = math.comb(len(subtrees), 2) distances = ( (1 - similarity(x, y, metric=metric)) if abs(x.height() - y.height()) < 5 else 1.0 for x, y in combinations(subtrees, 2) ) return np.fromiter( tqdm( distances, desc='similarity', total=nb_combinations, leave=False, unit_scale=True, ), count=nb_combinations, dtype=np.float32, )
[docs] def equiv_cluster( trees: Forest, *, tau: float, metric: METRIC_FUNC = DEFAULT_METRIC, _all_subtrees: bool = True, _step: int | None = None, ) -> TREE_CLUSTER: """ Cluster subtrees of a given tree based on their similarity. The clusters are created by applying a distance threshold `tau` to the linkage matrix which is derived from pairwise subtree similarity calculations. Subtrees that are similar enough (based on `tau` and the `metric`) are grouped into clusters. Each cluster is represented as a tuple of subtrees. :param trees: The forest from which to extract and cluster subtrees. :param tau: The similarity threshold for clustering. :param metric: The similarity metric function used to compute the similarity between subtrees. :return: A set of tuples, where each tuple represents a cluster of subtrees that meet the similarity threshold. """ subtrees = ( [ subtree for tree in trees for subtree in tree.subtrees(lambda x: not has_type(x, NodeType.ENT) and not x.has_duplicate_entity()) ] if _all_subtrees else list(trees) ) if len(subtrees) < 2: return set() # Compute distance matrix for all subtrees dist_matrix = compute_dist_matrix(subtrees, metric=metric) # Perform hierarchical clustering based on the distance threshold tau linkage_matrix = hierarchy.linkage(dist_matrix, method='single') clusters = hierarchy.fcluster(linkage_matrix, 1 - tau, criterion='distance') square_dist_matrix = squareform(dist_matrix) if mlflow.active_run() and _step is not None: labels = [st.label() for st in subtrees] fig = ff.create_annotated_heatmap(z=square_dist_matrix, colorscale='Cividis', x=labels, y=labels) mlflow.log_figure(fig, f'similarity/{_step}/heatmap.html') fig = ff.create_dendrogram( linkage_matrix, orientation='left', color_threshold=1 - tau, labels=labels, linkagefun=lambda _: linkage_matrix, ) mlflow.log_figure(fig, f'similarity/{_step}/dendrogram.html') # Group subtrees by cluster ID subtree_clusters = defaultdict(list) for idx, cluster_id in enumerate(clusters): subtree_clusters[cluster_id].append(idx) # Sort clusters based on the center element (the closest subtree to all others) # We determine the center by computing the sum of distances for each subtree to all others in the cluster. # The index of the subtree with the smallest sum of distances is the center. sorted_clusters = set() for cluster_indices in subtree_clusters.values(): sum_distances = np.sum(square_dist_matrix[np.ix_(cluster_indices, cluster_indices)], axis=1) center_index = cluster_indices[np.argmin(sum_distances)] # Sort the cluster based on distance to the center sorted_cluster = sorted(cluster_indices, key=lambda idx: square_dist_matrix[center_index][idx]) # Add the sorted cluster as a tuple to the set (immutable and hashable) sorted_clusters.add(tuple(subtrees[i] for i in sorted_cluster)) return sorted_clusters
[docs] def get_equiv_of( t: Tree, equiv_subtrees: TREE_CLUSTER, *, tau: float, metric: METRIC_FUNC = DEFAULT_METRIC ) -> tuple[Tree, ...]: """ Get the cluster containing the specified tree `t` based on similarity comparisons with the given set of clusters. The clusters are assessed using the provided similarity metric and threshold `tau`. :param t: The tree from which to extract and cluster subtrees. :param equiv_subtrees: The set of equivalent subtrees. :param tau: The similarity threshold for clustering. :param metric: The similarity metric function used to compute the similarity between subtrees. :return: A tuple representing the cluster of subtrees that meet the similarity threshold. """ distance_to_center = {} for cluster in equiv_subtrees: if t in cluster or (cluster_sim := similarity(t, cluster[0], metric=metric)) >= tau: return cluster distance_to_center[cluster] = cluster_sim # Sort equiv subtrees by similarity to the center element (the first one as the cluster are sorted) sorted_equiv_subtrees = sorted(distance_to_center.items(), key=lambda x: x[1], reverse=True) for cluster, _ in sorted_equiv_subtrees: # Early exit: stop checking once we find a matching cluster if t in cluster or any(sim(x, t, tau, metric) for x in cluster): return cluster # Return empty tuple if no similar cluster is found return ()
[docs] def entity_labels(forest: Forest, *, tau: float, metric: METRIC_FUNC | None = DEFAULT_METRIC) -> dict[str, int]: """ Process the given forest to assign labels to entities based on clustering of their ancestor. :param forest: The forest from which to extract and cluster entities. :param tau: The similarity threshold for clustering. :param metric: The similarity metric function used to compute the similarity between subtrees. If None, use the parent label as the equivalent class. :return: A dictionary mapping entities to their respective cluster IDs. """ entity_parents = [ subtree for tree in forest for subtree in tree.subtrees(lambda x: not has_type(x, NodeType.ENT) and x.has_entity_child()) ] equiv_subtrees: Iterable[Iterable[Tree]] if metric is None: equiv_subtrees_map: dict[str, list[Tree]] = defaultdict(list) for subtree in entity_parents: equiv_subtrees_map[subtree.label()].append(subtree) equiv_subtrees = equiv_subtrees_map.values() else: equiv_subtrees = equiv_cluster(entity_parents, tau=tau, metric=metric, _all_subtrees=False) return { f"{child.label().name}${' '.join(child)}": i for i, cluster in enumerate(equiv_subtrees) for subtree in cluster for child in subtree if has_type(child, NodeType.ENT) }