From cdeab050c608d937e5ee0e7ae1c0e28c4d457afb Mon Sep 17 00:00:00 2001 From: pierce <48131946+pierce314159@users.noreply.github.com> Date: Thu, 17 Nov 2022 17:42:36 -0500 Subject: [PATCH] Closes #1912: Updates to `search_interval` (#1913) This PR (closes #1912): - Changes `hierarchical` to True by default - Addresses bug on boundaries in `non_overlapping` check - Adds `hierarchical` to `interval_lookup` which is False by default Co-authored-by: Pierce Hayes --- arkouda/alignment.py | 40 +++++++++++++++++++++------------------- tests/alignment_tests.py | 16 ++++++++-------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/arkouda/alignment.py b/arkouda/alignment.py index 71e609fd05..0c7094324d 100644 --- a/arkouda/alignment.py +++ b/arkouda/alignment.py @@ -279,7 +279,7 @@ def in1d_intervals(vals, intervals, symmetric=False): return found -def search_intervals(vals, intervals, tiebreak=None, hierarchical=False): +def search_intervals(vals, intervals, tiebreak=None, hierarchical=True): """ Given an array of query vals and non-overlapping, closed intervals, return the index of the best (see tiebreak) interval containing each query value, @@ -414,27 +414,28 @@ def search_intervals(vals, intervals, tiebreak=None, hierarchical=False): bounds_okay = True break needtocheck &= lo == hi + # check non_overlapping + left = high[0][:-1] + right = low[0][1:] + not_overlapping = True + if (left <= right).any(): + not_overlapping = False + else: + boundary = left != right + for lo, hi in zip(low[1:], high[1:]): + left = hi[:-1] + right = lo[1:] + _ = left <= right + if not (_ | boundary).all(): + not_overlapping = False + break + boundary = boundary | (left != right) else: bounds_okay = all((hi >= lo).all() for hi, lo in zip(high, low)) if not bounds_okay: raise ValueError("Upper bounds must be greater than lower bounds") - left = high[0][:-1] - right = low[0][1:] - not_overlapping = True - if (left < right).any(): - not_overlapping = False - else: - boundary = left != right - for lo, hi in zip(low[1:], high[1:]): - left = hi[:-1] - right = lo[1:] - if not ((left <= right) | boundary).all(): - not_overlapping = False - break - boundary = boundary | (left != right) - perm = coargsort([concatenate((lo, va, hi)) for lo, va, hi in zip(low, vals, high)]) if singleton or (isinstance(vals, Sequence) and hierarchical): @@ -569,13 +570,14 @@ def is_cosorted(arrays): for array in arrays[1:]: left = array[:-1] right = array[1:] - if not ((left <= right) | boundary).all(): + _ = left <= right + if not (_ | boundary).all(): return False boundary = boundary | (left != right) return True -def interval_lookup(keys, values, arguments, fillvalue=-1, tiebreak=None): +def interval_lookup(keys, values, arguments, fillvalue=-1, tiebreak=None, hierarchical=False): """ Apply a function defined over intervals to an array of arguments. @@ -605,7 +607,7 @@ def interval_lookup(keys, values, arguments, fillvalue=-1, tiebreak=None): if isinstance(values, Categorical): codes = interval_lookup(keys, values.codes, arguments, fillvalue=values._NAcode) return Categorical.from_codes(codes, values.categories, NAvalue=values.NAvalue) - idx = search_intervals(arguments, keys, tiebreak=tiebreak) + idx = search_intervals(arguments, keys, tiebreak=tiebreak, hierarchical=hierarchical) arguments_size = arguments.size if isinstance(arguments, pdarray) else arguments[0].size res = zeros(arguments_size, dtype=values.dtype) if fillvalue is not None: diff --git a/tests/alignment_tests.py b/tests/alignment_tests.py index 5d38f7c098..8a4812c8aa 100644 --- a/tests/alignment_tests.py +++ b/tests/alignment_tests.py @@ -36,17 +36,17 @@ def test_multi_array_search_interval(self): ends = (ak.array([4, 14, 24]), ak.array([4, 14, 24])) vals = (ak.array([3, 13, 23]), ak.array([23, 13, 3])) ans = [-1, 1, -1] - self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends)).to_list()) + self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list()) self.assertListEqual(ans, ak.interval_lookup((starts, ends), ak.arange(3), vals).to_list()) vals = (ak.array([23, 13, 3]), ak.array([23, 13, 3])) ans = [2, 1, 0] - self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends)).to_list()) + self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list()) self.assertListEqual(ans, ak.interval_lookup((starts, ends), ak.arange(3), vals).to_list()) vals = (ak.array([23, 13, 33]), ak.array([23, 13, 3])) ans = [2, 1, -1] - self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends)).to_list()) + self.assertListEqual(ans, ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list()) self.assertListEqual(ans, ak.interval_lookup((starts, ends), ak.arange(3), vals).to_list()) # test hierarchical flag @@ -55,11 +55,11 @@ def test_multi_array_search_interval(self): vals = (ak.array([0, 0, 2, 5, 5, 6, 6, 9]), ak.array([0, 20, 1, 5, 15, 0, 12, 30])) self.assertListEqual( - ak.search_intervals(vals, (starts, ends)).to_list(), [0, -1, 0, 0, 1, -1, 1, -1] + ak.search_intervals(vals, (starts, ends), hierarchical=False).to_list(), + [0, -1, 0, 0, 1, -1, 1, -1], ) self.assertListEqual( - ak.search_intervals(vals, (starts, ends), hierarchical=True).to_list(), - [0, 0, 0, 0, 1, 1, 1, -1], + ak.search_intervals(vals, (starts, ends)).to_list(), [0, 0, 0, 0, 1, 1, 1, -1] ) def test_search_interval_nonunique(self): @@ -159,7 +159,7 @@ def test_representative_cases(self): tiebreak_smallest = (y1 - y0) * (x1 - x0) first_answer = [-1, -1, 0, 0, -1, 0, 2, 0, -1, 0, 0, 3, -1] smallest_answer = [-1, -1, 0, 2, -1, 2, 2, 1, -1, 0, 0, 3, -1] - first_result = ak.search_intervals(values, intervals) + first_result = ak.search_intervals(values, intervals, hierarchical=False) self.assertListEqual(first_result.to_list(), first_answer) - smallest_result = ak.search_intervals(values, intervals, tiebreak=tiebreak_smallest) + smallest_result = ak.search_intervals(values, intervals, tiebreak=tiebreak_smallest, hierarchical=False) self.assertListEqual(smallest_result.to_list(), smallest_answer)