Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes unit tests in the new ingest pipeline #28

Merged
merged 4 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ name: Unit tests

on:
pull_request:
branches: [ main ]
branches:
- main
- 'feature/new-ingest-process'
workflow_dispatch:

jobs:
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ name: Ruff

on:
pull_request:
branches: [ main ]
branches:
- main
- 'feature/new-ingest-process'
workflow_dispatch:

jobs:
Expand Down
19 changes: 15 additions & 4 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Generic, Hashable, Iterator, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, Iterator, TypeVar

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -556,14 +556,20 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa


def to_hierarchical_clusters(
probabilities: pa.Table, dtype: pa.DataType = pa.int32
probabilities: pa.Table,
proc_func: Callable[[pa.Table, pa.DataType], pa.Table] = component_to_hierarchy,
dtype: pa.DataType = pa.int32,
timeout: int = 300,
) -> pa.Table:
"""
Converts a table of pairwise probabilities into a table of hierarchical clusters.

Args:
probabilities: Arrow table with columns ['component', 'left', 'right',
'probability']
proc_func: Function to process each component
dtype: Arrow data type for parent/child columns
timeout: Maximum seconds to wait for each component to process

Returns:
Arrow table with columns ['parent', 'child', 'probability']
Expand Down Expand Up @@ -598,14 +604,19 @@ def to_hierarchical_clusters(
results = []
with ProcessPoolExecutor(max_workers=n_cores) as executor:
futures = [
executor.submit(component_to_hierarchy, component_table, dtype)
executor.submit(proc_func, component_table, dtype)
for component_table in component_tables
]

for future in futures:
try:
result = future.result()
result = future.result(timeout=timeout)
results.append(result)
except TimeoutError:
logic_logger.error(
f"Component processing timed out after {timeout} seconds"
)
wpfl-dbt marked this conversation as resolved.
Show resolved Hide resolved
continue
except Exception as e:
logic_logger.error(f"Error processing component: {str(e)}")
continue
Expand Down
84 changes: 60 additions & 24 deletions test/client/test_hierarchy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from functools import lru_cache
from itertools import chain
from typing import Any
from typing import Any, Iterator
from unittest.mock import patch

import pyarrow as pa
Expand Down Expand Up @@ -31,6 +33,23 @@ def _combine_strings(*n: str) -> str:
return "".join(sorted(letters))


@contextmanager
def parallel_pool_for_tests(
max_workers: int = 2, timeout: int = 30
) -> Iterator[ThreadPoolExecutor]:
"""Context manager for safe parallel execution in tests using threads.

Args:
max_workers: Maximum number of worker threads
timeout: Maximum seconds to wait for each task
"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
try:
yield executor
finally:
executor.shutdown(wait=False, cancel_futures=True)


@pytest.mark.parametrize(
("parameters"),
[
Expand Down Expand Up @@ -241,11 +260,16 @@ def test_component_to_hierarchy(
"probability": [90, 85, 80],
},
{
("abc", "a", 90),
("abc", "b", 90),
("abc", "c", 85),
("ab", "a", 90),
("ab", "b", 90),
("bc", "b", 85),
("bc", "c", 85),
("abc", "ab", 85),
("abc", "bc", 85),
("cd", "c", 80),
("cd", "d", 80),
("abcd", "abc", 80),
("abcd", "d", 80),
("abcd", "cd", 80),
},
),
# Multiple components test case
Expand All @@ -257,12 +281,18 @@ def test_component_to_hierarchy(
"probability": [90, 85, 95, 92],
},
{
("abc", "a", 90),
("abc", "b", 90),
("abc", "c", 85),
("xyz", "x", 95),
("xyz", "y", 95),
("xyz", "z", 92),
("xy", "x", 95),
("xy", "y", 95),
("yz", "y", 92),
("yz", "z", 92),
("xyz", "xy", 92),
("xyz", "yz", 92),
("ab", "a", 90),
("ab", "b", 90),
("bc", "b", 85),
("bc", "c", 85),
("abc", "ab", 85),
("abc", "bc", 85),
},
),
],
Expand Down Expand Up @@ -304,19 +334,25 @@ def test_hierarchical_clusters(input_data, expected_hierarchy):
)
)

with patch(
"matchbox.common.results.combine_integers", side_effect=_combine_strings
# Run and compare
with (
patch(
"matchbox.common.results.ProcessPoolExecutor",
lambda *args, **kwargs: parallel_pool_for_tests(timeout=30),
),
patch("matchbox.common.results.combine_integers", side_effect=_combine_strings),
):
result = to_hierarchical_clusters(probabilities, dtype=pa.string)

# Sort result the same way as expected for comparison
result = result.sort_by(
[
("probability", "descending"),
("parent", "ascending"),
("child", "ascending"),
]
result = to_hierarchical_clusters(
probabilities, dtype=pa.string, proc_func=component_to_hierarchy
)

assert result.schema == expected.schema
assert result.equals(expected)
result = result.sort_by(
[
("probability", "descending"),
("parent", "ascending"),
("child", "ascending"),
]
)

assert result.schema == expected.schema
assert result.equals(expected)
Loading