From 84fbe90c0893215448bfdc8fa7fc44048df60cf0 Mon Sep 17 00:00:00 2001 From: jaketrookman <114928862+jaketrookman@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:16:23 -0400 Subject: [PATCH] Closes #2472, #2800: Update IndexingMsg to use `&` instead of `mod` for setting bigint pdarrays (#2793) * convert from mod to %, except bool * swtich back to the correct bitop, &, and optimizie using bit shift masking, and add testing * Updating if block, add indexing to test, add testing to old implementation, add remove value binops forall loop * making requested changes: fixing tests, moving variable definitions, and adding forall loops * update tests --------- Co-authored-by: jaketrookman --- PROTO_tests/tests/indexing_test.py | 5 +++ src/IndexingMsg.chpl | 68 ++++++++++++++++++++++++------ tests/indexing_test.py | 5 +++ 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/PROTO_tests/tests/indexing_test.py b/PROTO_tests/tests/indexing_test.py index bd6f37fe10..4135a30b39 100644 --- a/PROTO_tests/tests/indexing_test.py +++ b/PROTO_tests/tests/indexing_test.py @@ -103,3 +103,8 @@ def test_bigint_indexing_preserves_max_bits(self): a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=max_bits) assert max_bits == a[ak.arange(10)].max_bits assert max_bits == a[:].max_bits + + def test_handling_bigint_max_bits(self): + a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=3) + a[:] = ak.arange(2**200 - 1, 2**200 + 11) + assert [7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2] == a.to_list() diff --git a/src/IndexingMsg.chpl b/src/IndexingMsg.chpl index 24e08d35e8..c83ef4fcf5 100644 --- a/src/IndexingMsg.chpl +++ b/src/IndexingMsg.chpl @@ -608,7 +608,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = valueArg.getBigIntValue(); if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[idx] = val; } @@ -616,7 +619,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = valueArg.getIntValue():bigint; if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[idx] = val; } @@ -624,7 +630,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = valueArg.getUIntValue():bigint; if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[idx] = val; } @@ -632,7 +641,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = valueArg.getBoolValue():bigint; if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[idx] = val; } @@ -1150,7 +1162,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = value.getBigIntValue(); if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[slice] = val; } @@ -1158,7 +1173,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = value.getIntValue():bigint; if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[slice] = val; } @@ -1166,7 +1184,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = value.getUIntValue():bigint; if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[slice] = val; } @@ -1174,7 +1195,10 @@ module IndexingMsg var e = toSymEntry(gEnt,bigint); var val = value.getBoolValue():bigint; if e.max_bits != -1 { - mod(val, val, e.max_bits); + var max_size = 1:bigint; + max_size <<= e.max_bits; + max_size -= 1; + val &= max_size; } e.a[slice] = val; } @@ -1309,7 +1333,12 @@ module IndexingMsg var x = toSymEntry(gX,bigint); var y = toSymEntry(gY,bigint); if x.max_bits != -1 { - mod(y.a, y.a, x.max_bits); + var max_size = 1:bigint; + max_size <<= x.max_bits; + max_size -= 1; + forall y in y.a with (var local_max_size = max_size) { + y &= local_max_size; + } } x.a[slice] = y.a; } @@ -1318,7 +1347,12 @@ module IndexingMsg var y = toSymEntry(gY,int); var ya = y.a:bigint; if x.max_bits != -1 { - mod(ya, ya, x.max_bits); + var max_size = 1:bigint; + max_size <<= x.max_bits; + max_size -= 1; + forall y in ya with (var local_max_size = max_size) { + y &= local_max_size; + } } x.a[slice] = ya; } @@ -1327,7 +1361,12 @@ module IndexingMsg var y = toSymEntry(gY,uint); var ya = y.a:bigint; if x.max_bits != -1 { - mod(ya, ya, x.max_bits); + var max_size = 1:bigint; + max_size <<= x.max_bits; + max_size -= 1; + forall y in ya with (var local_max_size = max_size) { + y &= local_max_size; + } } x.a[slice] = ya; } @@ -1337,7 +1376,12 @@ module IndexingMsg // TODO change once we can cast directly from bool to bigint var ya = y.a:int:bigint; if x.max_bits != -1 { - mod(ya, ya, x.max_bits); + var max_size = 1:bigint; + max_size <<= x.max_bits; + max_size -= 1; + forall y in ya with (var local_max_size = max_size) { + y &= local_max_size; + } } x.a[slice] = ya; } diff --git a/tests/indexing_test.py b/tests/indexing_test.py index 177948f566..685ddafe78 100644 --- a/tests/indexing_test.py +++ b/tests/indexing_test.py @@ -89,3 +89,8 @@ def test_bigint_indexing_preserves_max_bits(self): a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=max_bits) self.assertEqual(max_bits, a[ak.arange(10)].max_bits) self.assertEqual(max_bits, a[:].max_bits) + + def test_handling_bigint_max_bits(self): + a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=3) + a[:] = ak.arange(2**200 - 1, 2**200 + 11) + self.assertListEqual([7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2], a.to_list()) \ No newline at end of file