Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement connected components with disjoint sets #31

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 53 additions & 129 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, Iterator, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, TypeVar

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -388,130 +388,50 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
)


class UnionFindWithDiff(Generic[T]):
"""A UnionFind data structure with diff capabilities."""
class DisjointSet(Generic[T]):
"""
Disjoint set forest with "path compression" and "union by rank" heuristics.

This follows implementation from Cormen, Thomas H., et al. Introduction to
algorithms. MIT press, 2022
"""

def __init__(self):
self.parent: dict[T, T] = {}
self.rank: dict[T, int] = {}
self._shadow_parent: dict[T, T] = {}
self._shadow_rank: dict[T, int] = {}
self._pending_pairs: list[tuple[T, T]] = []

def make_set(self, x: T) -> None:
if x not in self.parent:
self.parent[x] = x
self.rank[x] = 0

def find(self, x: T, parent_dict: dict[T, T] | None = None) -> T:
if parent_dict is None:
parent_dict = self.parent

if x not in parent_dict:
self.make_set(x)
if parent_dict is self._shadow_parent:
self._shadow_parent[x] = x
self._shadow_rank[x] = 0

# TODO: Instead of being a `while`, could this be an `if`?
while parent_dict[x] != x:
parent_dict[x] = parent_dict[parent_dict[x]]
x = parent_dict[x]
return x
def _make_set(self, x: T) -> None:
self.parent[x] = x
self.rank[x] = 0

def union(self, x: T, y: T) -> None:
root_x = self.find(x)
root_y = self.find(y)

if root_x != root_y:
self._pending_pairs.append((x, y))
self._link(self._find(x), self._find(y))

if self.rank[root_x] < self.rank[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
def _link(self, x: T, y: T) -> None:
if self.rank[x] > self.rank[y]:
self.parent[y] = x
else:
self.parent[x] = y
if self.rank[x] == self.rank[y]:
self.rank[y] += 1

def get_component(self, x: T, parent_dict: dict[T, T] | None = None) -> set[T]:
if parent_dict is None:
parent_dict = self.parent
def _find(self, x: T) -> T:
if x not in self.parent:
self._make_set(x)
return x

root = self.find(x, parent_dict)
return {y for y in parent_dict if self.find(y, parent_dict) == root}
if x != self.parent[x]:
self.parent[x] = self._find(self.parent[x])

def get_components(self, parent_dict: dict[T, T] | None = None) -> list[set[T]]:
if parent_dict is None:
parent_dict = self.parent
return self.parent[x]

def get_components(self) -> list[set[T]]:
components = defaultdict(set)
for x in parent_dict:
root = self.find(x, parent_dict)
for x in self.parent:
root = self._find(x)
components[root].add(x)
return list(components.values())

def diff(self) -> Iterator[tuple[set[T], set[T]]]:
"""
Returns differences including all pairwise merges that occurred since last diff,
excluding cases where old_comp == new_comp.
"""
# Get current state before processing pairs
current_components = self.get_components()
reported_pairs = set()

# Process pending pairs
for x, y in self._pending_pairs:
# Find the final component containing the pair
final_component = next(
comp for comp in current_components if x in comp and y in comp
)

# Only report if the pair forms a proper subset of the final component
pair_component = {x, y}
if (
pair_component != final_component
and frozenset((frozenset(pair_component), frozenset(final_component)))
not in reported_pairs
):
reported_pairs.add(
frozenset((frozenset(pair_component), frozenset(final_component)))
)
yield (pair_component, final_component)

self._pending_pairs.clear()

# Handle initial state
if not self._shadow_parent:
self._shadow_parent = self.parent.copy()
self._shadow_rank = self.rank.copy()
return

# Get old components
old_components = self.get_components(self._shadow_parent)

# Report changes between old and new states
for old_comp in old_components:
if len(old_comp) > 1: # Only consider non-singleton old components
sample_elem = next(iter(old_comp))
new_comp = next(
comp for comp in current_components if sample_elem in comp
)

# Only yield if the components are different and this pair
# hasn't been reported
if (
old_comp != new_comp
and frozenset((frozenset(old_comp), frozenset(new_comp)))
not in reported_pairs
):
reported_pairs.add(
frozenset((frozenset(old_comp), frozenset(new_comp)))
)
yield (old_comp, new_comp)

# Update shadow copy
self._shadow_parent = self.parent.copy()
self._shadow_rank = self.rank.copy()


def component_to_hierarchy(
table: pa.Table, dtype: pa.DataType = pa.int32, salt: int = 1
Expand All @@ -527,36 +447,42 @@ def component_to_hierarchy(
Returns:
Arrow Table with columns ['parent', 'child', 'probability']
"""
hierarchy: list[tuple[int, int, float]] = []
uf = UnionFindWithDiff[int]()
im = IntMap(salt=salt)
probs = pc.unique(table["probability"])
probs = np.sort(pc.unique(table["probability"]).to_numpy())[::-1]

djs = DisjointSet[int]() # implements connected components
im = IntMap(salt=salt) # generates IDs for new clusters
current_roots: dict[int, set[int]] = defaultdict(set) # tracks ultimate parents
hierarchy: list[tuple[int, int, float]] = [] # the output of this function

for threshold in probs:
# Get current probability rows
mask = pc.equal(table["probability"], threshold)
current_probs = table.filter(mask)

# Add rows to union-find
for row in zip(
# Add new pairwise relationships at this threshold
for left, right in zip(
current_probs["left"].to_numpy(),
current_probs["right"].to_numpy(),
strict=True,
):
left, right = row
uf.union(left, right)
djs.union(left, right)
parent = im.index(left, right)
hierarchy.extend([(parent, left, threshold), (parent, right, threshold)])
current_roots[left].add(parent)
current_roots[right].add(parent)

for children in djs.get_components():
if len(children) <= 2:
continue # Skip pairs already handled by pairwise probabilities

parent = im.index(*children)
prev_roots: set[int] = set()
for child in children:
prev_roots.update(current_roots[child])
current_roots[child] = {parent}

# Process union-find diffs
for old_comp, new_comp in uf.diff():
if len(old_comp) > 1:
parent = im.index(*new_comp)
child = im.index(*old_comp)
hierarchy.extend([(parent, child, threshold)])
else:
parent = im.index(*new_comp)
hierarchy.extend([(parent, old_comp.pop(), threshold)])
for r in prev_roots:
hierarchy.append((parent, r, threshold))

parents, children, probs = zip(*hierarchy, strict=True)
return pa.table(
Expand Down Expand Up @@ -597,9 +523,7 @@ def to_hierarchical_clusters(
TimeRemainingColumn(),
]

probabilities = probabilities.sort_by(
[("component", "ascending"), ("probability", "descending")]
)
probabilities = probabilities.sort_by([("component", "ascending")])
components = pc.unique(probabilities["component"])
n_cores = multiprocessing.cpu_count()
n_components = len(components)
Expand Down
37 changes: 37 additions & 0 deletions test/common/test_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from matchbox.common.results import DisjointSet


class TestDisjointSet:
wpfl-dbt marked this conversation as resolved.
Show resolved Hide resolved
def test_disjoint_set_empty(self):
dsj = DisjointSet()

assert dsj.get_components() == []

def test_disjoint_set_same(self):
dsj = DisjointSet()
dsj.union(1, 1)

assert dsj.get_components() == [{1}]

def test_disjoint_set_redundant(self):
dsj = DisjointSet()
dsj.union(1, 2)

assert dsj.get_components() == [{1, 2}]

dsj.union(2, 1)

assert dsj.get_components() == [{1, 2}]

def test_disjoint_set_union(self):
dsj = DisjointSet()
dsj.union(1, 2)
dsj.union(3, 4)
dsj.union(5, 6)

assert sorted(dsj.get_components()) == [{1, 2}, {3, 4}, {5, 6}]

dsj.union(2, 3)
dsj.union(4, 5)

assert dsj.get_components() == [{1, 2, 3, 4, 5, 6}]
Loading