Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function to compile queries from matchbox #40

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- uses: astral-sh/ruff-action@v3
with:
args: format --check
args: 'format --check'

ruff-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- uses: astral-sh/ruff-action@v3
with:
args: check
args: 'check'
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ typing = [
[tool.uv]
default-groups = ["dev", "typing"]
package = true
upgrade-package = ["ruff"]

[tool.ruff]
exclude = [
Expand Down
8 changes: 2 additions & 6 deletions src/matchbox/client/clean/steps/clean_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def clean_punctuation(column: str) -> str:
return rf"""
trim(
regexp_replace(
lower({
punctuation_to_spaces(
periods_to_nothing(column)
)
}),
lower({punctuation_to_spaces(periods_to_nothing(column))}),
'\s+',
' ',
'g'
Expand Down Expand Up @@ -218,7 +214,7 @@ def regex_remove_list_of_strings(column: str, list_of_strings: List[str]) -> str
'',
'g'
),
'\s{2,}',
'\s{(2,)}',
' ',
'g'
)
Expand Down
2 changes: 1 addition & 1 deletion src/matchbox/client/helpers/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def comparison(sql_condition: str, dialect: str = "postgres") -> str:
node[0], (exp.Connector, exp.Predicate, exp.Condition, exp.Identifier)
):
raise ParseError(
"Must be valid WHERE clause statements. " f"Found {type(node[0])}"
f"Must be valid WHERE clause statements. Found {type(node[0])}"
)

left = False
Expand Down
2 changes: 1 addition & 1 deletion src/matchbox/client/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def check_probabilities(cls, value: pa.Table | DataFrame) -> pa.Table:
optional_fields = {"id"}

if table_fields - optional_fields != expected_fields:
raise ValueError(f"Expected {expected_fields}. \n" f"Found {table_fields}.")
raise ValueError(f"Expected {expected_fields}. \nFound {table_fields}.")

# If a process produces floats, we multiply by 100 and coerce to uint8
if pa.types.is_floating(value["probability"].type):
Expand Down
55 changes: 55 additions & 0 deletions src/matchbox/server/postgresql/benchmark/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from sqlalchemy.orm import Session

from matchbox.common.db import get_schema_table_names
from matchbox.server.postgresql.db import MBDB
from matchbox.server.postgresql.orm import (
Resolutions,
Sources,
)
from matchbox.server.postgresql.utils.query import (
_resolve_cluster_hierarchy,
)


def compile_query_sql(point_of_truth: str, dataset_name: str) -> str:
"""Compiles a the SQL for query() based on a single point of truth and dataset.

Args:
point_of_truth (string): The name of the resolution to use, like "linker_1"
dataset_name (string): The name of the dataset to retrieve, like "dbt.companies"

Returns:
A compiled PostgreSQL query, including semicolon, ready to run on Matchbox
"""
engine = MBDB.get_engine()

source_schema, source_table = get_schema_table_names(dataset_name)

with Session(engine) as session:
point_of_truth_resolution = (
session.query(Resolutions)
.filter(Resolutions.name == point_of_truth)
.first()
)
dataset_id = (
session.query(Resolutions.resolution_id)
.join(Sources, Sources.resolution_id == Resolutions.resolution_id)
.filter(
Sources.schema == source_schema,
Sources.table == source_table,
)
.scalar()
)

id_query = _resolve_cluster_hierarchy(
dataset_id=dataset_id,
resolution=point_of_truth_resolution,
threshold=None,
engine=engine,
)

compiled_stmt = id_query.compile(
dialect=engine.dialect, compile_kwargs={"literal_binds": True}
)

return str(compiled_stmt) + ";"
3 changes: 1 addition & 2 deletions src/matchbox/server/postgresql/utils/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,7 @@ def insert_results(
with engine.connect() as conn:
try:
logic_logger.info(
f"[{resolution.name}] Inserting {clusters.shape[0]:,} results "
"objects"
f"[{resolution.name}] Inserting {clusters.shape[0]:,} results objects"
)

batch_ingest(
Expand Down
2 changes: 1 addition & 1 deletion test/client/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_process(
def test_comparisons():
comparison_name_id = comparison(
sql_condition=(
"l.company_name = r.company_name" " and l.data_hub_id = r.data_hub_id"
"l.company_name = r.company_name and l.data_hub_id = r.data_hub_id"
)
)

Expand Down
2 changes: 1 addition & 1 deletion test/fixtures/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _db_add_dedupe_models_and_data(
model = make_model(
model_name=deduper_name,
description=(
f"Dedupe of {fx_data.source} " f"with {fx_deduper.name} method."
f"Dedupe of {fx_data.source} with {fx_deduper.name} method."
),
model_class=fx_deduper.cls,
model_settings=deduper_settings,
Expand Down
70 changes: 69 additions & 1 deletion test/server/test_postgresql.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from typing import Iterable
from typing import Any, Iterable

import pandas as pd
import pyarrow as pa
import pytest
from sqlalchemy import text

from matchbox.common.db import Source
from matchbox.server.postgresql import MatchboxPostgres
from matchbox.server.postgresql.benchmark.generate_tables import generate_all_tables
from matchbox.server.postgresql.benchmark.init_schema import create_tables, empty_schema
from matchbox.server.postgresql.benchmark.query import compile_query_sql
from matchbox.server.postgresql.db import MBDB
from matchbox.server.postgresql.utils.insert import HashIDMap

from ..fixtures.db import SetupDatabaseCallable


def test_benchmark_init_schema():
schema = MBDB.MatchboxBase.metadata.schema
Expand Down Expand Up @@ -96,3 +102,65 @@ def test_hash_id_map():
with pytest.raises(ValueError) as exc_info:
hash_map.get_hashes(pa.array([999], type=pa.uint64()))
assert "not found in lookup table" in str(exc_info.value)


@pytest.mark.parametrize(
("parameters"),
[
# Test case 1: CDMS/CRN linker, CRN dataset
{
"point_of_truth": "deterministic_naive_test.cdms_naive_test.crn",
"source_index": 0, # CRN
"unique_ids": 1_000,
"unique_pks": 3_000,
},
# Test case 2: CDMS/CRN linker, CDMS dataset
{
"point_of_truth": "deterministic_naive_test.cdms_naive_test.crn",
"source_index": 2, # CDMS
"unique_ids": 1_000,
"unique_pks": 2_000,
},
# Test case 3: CRN/DUNS linker, CRN dataset
{
"point_of_truth": "deterministic_naive_test.crn_naive_test.duns",
"source_index": 0, # CRN
"unique_ids": 1_000,
"unique_pks": 3_000,
},
# Test case 4: CRN/DUNS linker, DUNS dataset
{
"point_of_truth": "deterministic_naive_test.crn_naive_test.duns",
"source_index": 1, # DUNS
"unique_ids": 500,
"unique_pks": 500,
},
],
ids=["cdms-crn_crn", "cdms-crn_cdms", "crn-duns_crn", "crn-duns_duns"],
)
def test_benchmark_query_generation(
setup_database: SetupDatabaseCallable,
matchbox_postgres: MatchboxPostgres,
warehouse_data: list[Source],
parameters: dict[str, Any],
):
setup_database(matchbox_postgres, warehouse_data, "link")

engine = MBDB.get_engine()
point_of_truth = parameters["point_of_truth"]
idx = parameters["source_index"]
dataset_name = f"{warehouse_data[idx].db_schema}.{warehouse_data[idx].db_table}"

sql_query = compile_query_sql(
point_of_truth=point_of_truth, dataset_name=dataset_name
)

assert isinstance(sql_query, str)

with engine.connect() as conn:
res = conn.execute(text(sql_query)).all()

df = pd.DataFrame(res, columns=["id", "pk"])

assert df.id.nunique() == parameters["unique_ids"]
assert df.pk.nunique() == parameters["unique_pks"]
42 changes: 21 additions & 21 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading