Source code for architxt.database.export.sql

import base64
from collections import defaultdict
from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING

from sqlalchemy import (
    BLOB,
    Column,
    Connection,
    Date,
    DateTime,
    ForeignKey,
    MetaData,
    String,
    Table,
    Uuid,
    insert,
)
from tqdm.auto import tqdm

from architxt.schema import Group, Relation, RelationOrientation, Schema
from architxt.tree import Forest, NodeType, Tree, TreeOID, has_type

if TYPE_CHECKING:
    from architxt.tree import _TypedSubTree, _TypedTree

__all__ = ['export_sql']

PKColumnFactory = Callable[[str], str]


[docs] def default_pk_factory( table_name: str, ) -> str: """ Generate the ID column for the given table. :param table_name: The table name to generate ID for. :return: The name of the ID column for the table. """ return f'architxt_{table_name}ID'
[docs] def export_sql( forest: Forest, conn: Connection, *, pk_factory: PKColumnFactory = default_pk_factory, ) -> None: """ Export the forest to the relational database. :param conn: Connection to the relational database. :param forest: Forest to export. :param pk_factory: A column name factory for the groups primary keys. """ schema = Schema.from_forest(forest, keep_unlabelled=False) create_schema(conn, schema, pk_factory) for tree in tqdm(forest, desc="Exporting relational database"): export_tree(tree, conn, schema, pk_factory) conn.commit()
[docs] def create_schema( conn: Connection, schema: Schema, pk_factory: PKColumnFactory, ) -> Schema: """ Create the schema for the relational database. :param conn: Connection to the graph. :param schema: The schema to build. :param pk_factory: A column name factory for the groups primary keys. """ metadata = MetaData() database_schema: dict[str, Table] = {} for group in schema.groups: database_schema[group.name] = create_table_for_group(group, metadata, pk_factory) for rel in schema.relations: if rel.orientation == RelationOrientation.BOTH: create_table_for_relation(database_schema, rel, metadata, pk_factory) else: add_foreign_keys_to_table(database_schema, rel, pk_factory) metadata.create_all(conn) return schema
[docs] def create_table_for_group(group: Group, metadata: MetaData, pk_factory: PKColumnFactory) -> Table: """ Create a table for the given group. :param group: The group to create a table for. :param metadata: SQLAlchemy metadata to attach the table to. :param pk_factory: A column name factory for the groups primary keys. :return: SQLAlchemy Table object. """ return Table( group.name, metadata, Column(pk_factory(group.name), Uuid, primary_key=True), *(Column(entity, String) for entity in group.entities), )
[docs] def add_foreign_keys_to_table( database_schema: dict[str, Table], relation: Relation, pk_factory: PKColumnFactory, ) -> None: """ Add foreign key constraints to the database schema. :param database_schema: The dictionary of tables in the database schema. :param relation: The relation to build as a foreign key. :param pk_factory: A column name factory for the groups primary keys. """ left = database_schema[relation.left.replace(' ', '')] right = database_schema[relation.right.replace(' ', '')] source, target = (left, right) if relation.orientation == RelationOrientation.LEFT else (right, left) column_name = relation.name if source.name == target.name else pk_factory(target.name) target_column_name = target.primary_key.columns.keys()[0] database_schema[source.name].append_column(Column(column_name, ForeignKey(f"{target.name}.{target_column_name}")))
[docs] def create_table_for_relation( database_schema: dict[str, Table], relation: Relation, metadata: MetaData, pk_factory: PKColumnFactory, ) -> None: """ Create a table for the given relation. :param database_schema: The dictionary of tables in the database schema. :param relation: The relation to build the table for. :param metadata: SQLAlchemy metadata to attach the table to. :param pk_factory: A column name factory for the groups primary keys. """ left = database_schema[relation.left.replace(" ", "")] right = database_schema[relation.right.replace(" ", "")] left_key = pk_factory(left.name) right_key = pk_factory(right.name) database_schema[relation.name] = Table(relation.name, metadata) database_schema[relation.name].append_column( Column(left_key, ForeignKey(f"{left.name}.{left_key}"), primary_key=True) ) database_schema[relation.name].append_column( Column(right_key, ForeignKey(f"{right.name}.{right_key}"), primary_key=True) )
[docs] def export_tree( tree: Tree, conn: Connection, schema: Schema, pk_factory: PKColumnFactory, ) -> None: """ Export the tree to the relational database. :param tree: Tree to export. :param conn: Connection to the relational database. :param schema: The schema. :param pk_factory: A column name factory for the groups primary key. """ data_to_export: dict[str, dict[TreeOID, dict[str, str]]] = {} for subtree in tree.subtrees(): if has_type(subtree, NodeType.GROUP): export_group(subtree, data_to_export, pk_factory) elif has_type(subtree, NodeType.REL): export_relation(subtree, data_to_export, schema, pk_factory) export_data(data_to_export, conn)
[docs] def export_relation( tree: '_TypedTree', data: dict[str, dict[TreeOID, dict[str, str]]], schema: Schema, pk_factory: PKColumnFactory, ) -> None: """ Export the relation to the relational database. :param tree: Relation to export. :param data: Data to export. :param schema: The schema. :param pk_factory: A column name factory for the groups primary key. """ relation = next(rel for rel in schema.relations if rel.name == tree.label.name) if relation.orientation == RelationOrientation.BOTH: relation_data: dict[str, str] = {} for child in tree: column_name = pk_factory(child.label.name) relation_data[column_name] = data[child.label.name][child.oid][column_name] data[relation.name] = {tree.oid: relation_data} else: left: _TypedSubTree | None = None right: _TypedSubTree | None = None for child in tree: if not has_type(child, NodeType.GROUP): continue if child.label.name == relation.left: left = child elif child.label.name == relation.right: right = child if not left or not right: return if relation.orientation == RelationOrientation.RIGHT: left, right = right, left column_name = relation.name if right.label.name == left.label.name else pk_factory(right.label.name) data[left.label.name][left.oid][column_name] = data[right.label.name][right.oid][column_name]
[docs] def export_group( group: '_TypedTree', data: dict[str, dict[TreeOID, dict[str, str]]], pk_factory: PKColumnFactory, ) -> None: """ Export the group to the relational database. :param group: Group to export. :param data: Data to export. :param pk_factory: A column name factory for the groups primary key. """ group_name = group.label.name group_data = get_data_from_group(group) group_data[pk_factory(group_name)] = str(group.oid) if group_name not in data: data[group_name] = {} data[group_name][group.oid] = group_data
[docs] def get_data_from_group(group: Tree) -> dict[str, str]: """ Get data from the relational database. :param group: Group to get data from. :return: Data from the group. """ result: dict[str, str] = {} for entity in group: if entity.label.name is None or 'type' not in entity.metadata: continue if entity.metadata and isinstance(entity.metadata['type'], Date) and isinstance(entity[0], str): entity[0] = datetime.strptime(entity[0], '%Y-%m-%d').date() elif entity.metadata and isinstance(entity.metadata['type'], DateTime) and isinstance(entity[0], str): entity[0] = datetime.strptime(entity[0], '%Y-%m-%d %H:%M:%S') elif entity.metadata and isinstance(entity.metadata['type'], BLOB) and isinstance(entity[0], str): entity[0] = base64.b64decode(entity[0]) result[entity.label.name] = entity[0] return result
[docs] def export_data( data: dict[str, dict[TreeOID, dict[str, str]]], conn: Connection, ) -> None: """ Export the data to the relational database. :param data: Data to export. :param conn: Connection to the relational database. """ if not data: return data_to_export: dict[str, dict[TreeOID, dict[str, str]]] = defaultdict(dict) table_to_insert: dict[str, list[dict[str, str]]] = defaultdict(list) for table, dict_info in data.items(): for oid, info in dict_info.items(): has_foreign_key = False for name, x in info.items(): if isinstance(x, dict) and 'primary_key_insert' not in x: has_foreign_key = True elif isinstance(x, dict): data[table][oid][name] = x['primary_key_insert'] if has_foreign_key: data_to_export[table][oid] = info else: table_to_insert[table].append(info) export_table_to_insert(table_to_insert, conn) export_data(data_to_export, conn)
[docs] def export_table_to_insert( table_to_insert: dict[str, list[dict[str, str]]], conn: Connection, ) -> None: """ Export the table to the graph. :param table_to_insert: Tables to insert. :param conn: Connection to the graph. """ for table in table_to_insert: for row in table_to_insert[table]: info = row database_table = Table(table, MetaData(), autoload_with=conn) primary_keys = [col.name for col in database_table.primary_key.columns] query = ( database_table.select() .with_only_columns(*[getattr(database_table.c, key) for key in primary_keys]) .where(*[getattr(database_table.c, key) == value for key, value in info.items() if key in primary_keys]) ) result = conn.execute(query).fetchone() if not result: insert_command = insert(database_table).values(info) result_insert = conn.execute(insert_command) inserted_id = result_insert.inserted_primary_key[0] else: inserted_id = result[0] if inserted_id: info['primary_key_insert'] = inserted_id