import dataclasses
import math
import warnings
from collections import Counter, defaultdict
from collections.abc import Generator, Iterable
from enum import Enum, auto
from functools import cached_property
from itertools import combinations
import pandas as pd
from antlr4 import CommonTokenStream, InputStream
from antlr4.error.Errors import CancellationException
from antlr4.error.ErrorStrategy import BailErrorStrategy
from nltk import CFG, Nonterminal, Production
from architxt.grammar.metagrammarLexer import metagrammarLexer
from architxt.grammar.metagrammarParser import metagrammarParser
from architxt.similarity import jaccard
from architxt.tree import Forest, NodeLabel, NodeType, Tree, TreeOID, has_type
__all__ = ['Group', 'Relation', 'RelationOrientation', 'Schema']
_NODE_TYPE_RANK = {
NodeType.COLL: 1,
NodeType.REL: 2,
NodeType.GROUP: 3,
NodeType.ENT: 4,
}
[docs]
@dataclasses.dataclass(slots=True, frozen=True)
class Group:
name: str
entities: set[str]
def __hash__(self) -> int:
return hash(self.name)
[docs]
class RelationOrientation(Enum):
"""
Specifies the direction of a relationship between two groups.
This enum is used to indicate the source or cardinality orientation of a relationship.
"""
LEFT = auto()
"""The source of the relationship is the left group."""
RIGHT = auto()
"""The source of the relationship is the right group."""
BOTH = auto()
"""The relationship is bidirectional or many-to-many, with no single source."""
[docs]
@dataclasses.dataclass(slots=True, frozen=True)
class Relation:
name: str
left: str
right: str
orientation: RelationOrientation = RelationOrientation.BOTH
def __hash__(self) -> int:
return hash((self.name, self.left, self.right))
[docs]
class Schema(CFG):
_groups: set[Group]
_relations: set[Relation]
def __init__(self, productions: Iterable[Production], groups: set[Group], relations: set[Relation]) -> None:
productions = sorted(productions, key=lambda p: Schema._get_rank(p.lhs()))
root_production = Production(Nonterminal('ROOT'), sorted(prod.lhs() for prod in productions))
super().__init__(Nonterminal('ROOT'), [root_production, *productions])
self._groups = groups
self._relations = relations
@staticmethod
def _get_rank(nt: Nonterminal) -> int:
if isinstance(nt.symbol(), NodeLabel) and nt.symbol().type in _NODE_TYPE_RANK:
return _NODE_TYPE_RANK[nt.symbol().type]
return 0
[docs]
@classmethod
def from_description(
cls,
*,
groups: set[Group] | None = None,
relations: set[Relation] | None = None,
collections: bool = True,
) -> 'Schema':
"""
Create a Schema from a description of groups, relations, and collections.
:param groups: A dictionary mapping groups names to sets of entities.
:param relations: A dictionary mapping relation names to tuples of group names.
:param collections: Whether to generate collection productions.
:return: A Schema object.
"""
productions: set[Production] = set()
if groups:
for group in groups:
group_label = NodeLabel(NodeType.GROUP, group.name)
entity_labels = [Nonterminal(NodeLabel(NodeType.ENT, entity)) for entity in group.entities]
productions.add(Production(Nonterminal(group_label), sorted(entity_labels)))
if relations:
for relation in relations:
relation_label = NodeLabel(NodeType.REL, relation.name)
group_labels = [
Nonterminal(NodeLabel(NodeType.GROUP, relation.left)),
Nonterminal(NodeLabel(NodeType.GROUP, relation.right)),
]
productions.add(Production(Nonterminal(relation_label), group_labels))
if collections:
coll_productions = {
Production(Nonterminal(NodeLabel(NodeType.COLL, prod.lhs().symbol().name)), [prod.lhs()])
for prod in productions
}
productions.update(coll_productions)
return cls(productions, groups or set(), relations or set())
[docs]
@classmethod
def from_forest(cls, forest: Iterable[Tree], *, keep_unlabelled: bool = True, merge_lhs: bool = True) -> 'Schema': # noqa: C901
"""
Create a Schema from a given forest of trees.
:param forest: The input forest from which to derive the schema.
:param keep_unlabelled: Whether to keep uncategorized nodes in the schema.
:param merge_lhs: Whether to merge nodes in the schema.
:return: A CFG-based schema representation.
"""
schema_productions: dict[Nonterminal, set[tuple[Nonterminal, ...]]] = defaultdict(set)
groups: set[Group] = set()
relations_examples: dict[str, dict[str, dict[str, tuple[TreeOID, TreeOID]]]] = defaultdict(
lambda: defaultdict(dict)
)
relations_is_multi: dict[str, dict[str, bool]] = defaultdict(lambda: defaultdict(lambda: False))
for tree in forest:
for prod in tree.productions():
if prod.is_lexical() or prod.lhs().symbol() == 'ROOT':
continue
if has_type(prod, NodeType.COLL):
schema_productions[prod.lhs()] = {(prod.rhs()[0],)}
elif has_type(prod, NodeType.REL) and len(prod) == 2:
rhs = tuple(sorted(prod.rhs()))
schema_productions[prod.lhs()].add(rhs)
elif has_type(prod, NodeType.GROUP):
if merge_lhs:
merged_rhs = set(prod.rhs()).union(*schema_productions[prod.lhs()])
rhs = tuple(sorted(merged_rhs))
schema_productions[prod.lhs()] = {rhs}
else:
rhs = tuple(sorted(set(prod.rhs())))
schema_productions[prod.lhs()].add(rhs)
group = Group(
name=prod.lhs().symbol().name,
entities={ent.symbol().name for entities in schema_productions[prod.lhs()] for ent in entities},
)
groups.add(group)
elif keep_unlabelled:
rhs = tuple(sorted(set(prod.rhs())))
schema_productions[prod.lhs()].add(rhs)
for subtree in tree.subtrees(lambda x: has_type(x, NodeType.REL) and len(x) == 2):
pair = tuple(sorted((subtree[0].oid, subtree[1].oid)))
rel = relations_examples[subtree.label.name]
for child in subtree:
relations_is_multi[subtree.label.name][child.label.name] |= False
if not (existing := rel[child.label.name].get(child.oid)):
rel[child.label.name][child.oid] = pair
elif existing != pair:
relations_is_multi[subtree.label.name][child.label.name] = True
del relations_examples
productions = (Production(lhs, rhs) for lhs, alternatives in schema_productions.items() for rhs in alternatives)
relations = cls._convert_relations(relations_is_multi)
return cls(productions, groups, relations)
@cached_property
def entities(self) -> set[str]:
"""The set of entities in the schema."""
return {entity for group in self.groups for entity in group.entities}
@property
def groups(self) -> set[Group]:
"""The set of groups in the schema."""
return self._groups
@property
def relations(self) -> set[Relation]:
"""The set of relations in the schema."""
return self._relations
@staticmethod
def _convert_relations(
relations_flags: dict[str, dict[str, bool]],
) -> set[Relation]:
"""
Convert relation counts into relation objects.
:param relations_flags: A dict mapping relation-name -> { entity: is_multi_flag, ... }
:return: A set of relations.
"""
relations: set[Relation] = set()
for name, flags in relations_flags.items():
keys = tuple(flags.keys())
if len(keys) != 2:
continue
left, right = keys
if flags[left] == flags[right]:
orientation = RelationOrientation.BOTH
elif flags[left]:
orientation = RelationOrientation.LEFT
else:
orientation = RelationOrientation.RIGHT
relation = Relation(name=name, left=left, right=right, orientation=orientation)
relations.add(relation)
return relations
[docs]
def verify(self) -> bool:
"""
Verify the schema against the meta-grammar.
:returns: True if the schema is valid, False otherwise.
"""
input_text = self.as_cfg()
lexer = metagrammarLexer(InputStream(input_text))
stream = CommonTokenStream(lexer)
parser = metagrammarParser(stream)
parser._errHandler = BailErrorStrategy()
try:
parser.start()
return parser.getNumberOfSyntaxErrors() == 0
except CancellationException:
warnings.warn("Invalid syntax")
except Exception as error:
warnings.warn(f"Verification failed: {error!s}")
return False
@property
def group_overlap(self) -> float:
"""
Get the group overlap ratio as a combined Jaccard index.
The group overlap ratio is computed as the mean of all pairwise Jaccard indices for each pair of groups.
:return: The group overlap ratio as a float value between 0 and 1.
A higher value indicates a higher degree of overlap between groups.
"""
jaccard_indices = [jaccard(group1.entities, group2.entities) for group1, group2 in combinations(self.groups, 2)]
# Combine scores (average of pairwise indices)
return sum(jaccard_indices) / len(jaccard_indices) if jaccard_indices else 0.0
@property
def group_balance_score(self) -> float:
r"""
Get the balance score of attributes across groups.
The balance metric (B) measures the dispersion of attributes (coefficient of variation),
indicating if the schema is well-balanced.
A higher balance metric indicates that attributes are distributed more evenly across groups, while
a lower balance metric suggests that some groups may be too large (wide) or too small (fragmented).
.. math::
B = 1 - \frac{\sigma(A)}{\mu(A)}
Where:
- :math:`A`: The set of attributes counts for all groups.
- :math:`\mu(A)`: The mean number of attributes per group.
- :math:`\sigma(A)`: The standard deviation of attribute counts across groups.
:return: Balance metric (B), a measure of attribute dispersion.
- :math:`B \approx 1`: Attributes are evenly distributed.
- :math:`B \approx 0`: Significant imbalance; some groups are much larger or smaller than others.
"""
if not len(self.groups):
return 1.0
entities_counts = [len(group.entities) for group in self.groups]
mean_attributes = sum(entities_counts) / len(entities_counts)
variance = sum((count - mean_attributes) ** 2 for count in entities_counts) / len(entities_counts)
std_dev = math.sqrt(variance)
variation = std_dev / mean_attributes if mean_attributes else 1.0
return 1 - variation
[docs]
def as_cfg(self) -> str:
"""
Convert the schema to a CFG representation.
:returns: The schema as a list of production rules, each terminated by a semicolon.
"""
return '\n'.join(f"{prod};" for prod in self.productions())
[docs]
def find_collapsible_groups(self) -> set[str]:
"""
Identify all groups eligible for collapsing into attributed relationships.
A group M is collapsible if it participates exactly twice in a 1-n relation
on the 'one' side, i.e. we want to collapse patterns like:
A --(n-1)--> M <--(1-n)-- B
Into a direct n-n edge:
A --[attributed edge]-- B
:return: A set of groups that can be turned into attributed edges.
>>> schema = Schema.from_description(relations={
... Relation(name='R1', left='A', right='M', orientation=RelationOrientation.LEFT),
... Relation(name='R2', left='M', right='B', orientation=RelationOrientation.RIGHT),
... })
>>> schema.find_collapsible_groups()
{'M'}
>>> schema = Schema.from_description(relations={
... Relation(name='R1', left='M', right='B', orientation=RelationOrientation.RIGHT),
... Relation(name='R2', left='M', right='C', orientation=RelationOrientation.RIGHT),
... })
>>> schema.find_collapsible_groups()
{'M'}
>>> schema = Schema.from_description(relations={
... Relation(name='R1', left='A', right='M', orientation=RelationOrientation.BOTH),
... Relation(name='R2', left='M', right='B', orientation=RelationOrientation.RIGHT),
... })
>>> schema.find_collapsible_groups()
set()
>>> schema = Schema.from_description(relations={
... Relation(name='R1', left='A', right='M', orientation=RelationOrientation.LEFT),
... Relation(name='R2', left='M', right='B', orientation=RelationOrientation.RIGHT),
... Relation(name='R2', left='M', right='C', orientation=RelationOrientation.RIGHT),
... })
>>> schema.find_collapsible_groups()
set()
"""
group_count = Counter()
for relation in self.relations:
if relation.orientation == RelationOrientation.LEFT:
group_count[relation.left] += 3
group_count[relation.right] += 1
elif relation.orientation == RelationOrientation.RIGHT:
group_count[relation.left] += 1
group_count[relation.right] += 3
else:
group_count[relation.left] += 3
group_count[relation.right] += 3
return {group for group, count in group_count.items() if count == 2}