From afc27f4c8a7bdc73bc777f96d20b03c525b28a83 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Thu, 9 Jan 2025 13:14:15 -0600 Subject: [PATCH] Clean up imports for `dask>2024.12.1` support (#1424) Follow up to https://github.com/rapidsai/dask-cuda/pull/1417 Cleans up some imports (some of which don't work for `dask>2024.12.1`). Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Mads R. B. Kristensen (https://github.com/madsbk) - Peter Andreas Entschev (https://github.com/pentschev) URL: https://github.com/rapidsai/dask-cuda/pull/1424 --- dask_cuda/__init__.py | 10 ------ dask_cuda/explicit_comms/dataframe/shuffle.py | 19 +++++++++-- dask_cuda/proxy_object.py | 33 ++++++++----------- dask_cuda/tests/test_proxy.py | 16 ++++----- 4 files changed, 37 insertions(+), 41 deletions(-) diff --git a/dask_cuda/__init__.py b/dask_cuda/__init__.py index d07634f2d..a93cb1cb5 100644 --- a/dask_cuda/__init__.py +++ b/dask_cuda/__init__.py @@ -5,10 +5,8 @@ import dask import dask.utils -import dask.dataframe.core import dask.dataframe.shuffle from .explicit_comms.dataframe.shuffle import patch_shuffle_expression -from dask.dataframe import DASK_EXPR_ENABLED from distributed.protocol.cuda import cuda_deserialize, cuda_serialize from distributed.protocol.serialize import dask_deserialize, dask_serialize @@ -19,14 +17,6 @@ from .proxify_device_objects import proxify_decorator, unproxify_decorator -if not DASK_EXPR_ENABLED: - raise ValueError( - "Dask-CUDA no longer supports the legacy Dask DataFrame API. " - "Please set the 'dataframe.query-planning' config to `True` " - "or None, or downgrade RAPIDS to <=24.12." - ) - - # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True` patch_shuffle_expression() # Monkey patching Dask to make use of proxify and unproxify in compatibility mode diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 600da07d5..b2a313dcc 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -18,7 +18,7 @@ from dask.base import tokenize from dask.dataframe import DataFrame, Series from dask.dataframe.core import _concat as dd_concat -from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch +from dask.dataframe.dispatch import group_split_dispatch, hash_object_dispatch from distributed import wait from distributed.protocol import nested_deserialize, to_serialize from distributed.worker import Worker @@ -31,6 +31,20 @@ Proxify = Callable[[T], T] +try: + from dask.dataframe import dask_expr + +except ImportError: + # TODO: Remove when pinned to dask>2024.12.1 + import dask_expr + + if not dd._dask_expr_enabled(): + raise ValueError( + "The legacy DataFrame API is not supported in dask_cudf>24.12. " + "Please enable query-planning, or downgrade to dask_cudf<=24.12" + ) + + def get_proxify(worker: Worker) -> Proxify: """Get function to proxify objects""" from dask_cuda.proxify_host_file import ProxifyHostFile @@ -576,7 +590,6 @@ def patch_shuffle_expression() -> None: an `ECShuffle` expression when the 'explicit-comms' config is set to `True`. """ - import dask_expr class ECShuffle(dask_expr._shuffle.TaskShuffle): """Explicit-Comms Shuffle Expression.""" @@ -585,7 +598,7 @@ def _layer(self): # Execute an explicit-comms shuffle if not hasattr(self, "_ec_shuffled"): on = self.partitioning_index - df = dask_expr._collection.new_collection(self.frame) + df = dask_expr.new_collection(self.frame) self._ec_shuffled = shuffle( df, [on] if isinstance(on, str) else on, diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index ddb7f3292..b42af7b1c 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -12,7 +12,8 @@ import dask import dask.array.core -import dask.dataframe.methods +import dask.dataframe.backends +import dask.dataframe.dispatch import dask.dataframe.utils import dask.utils import distributed.protocol @@ -22,16 +23,6 @@ from dask_cuda.disk_io import disk_read -try: - from dask.dataframe.backends import concat_pandas -except ImportError: - from dask.dataframe.methods import concat_pandas - -try: - from dask.dataframe.dispatch import make_meta_dispatch as make_meta_dispatch -except ImportError: - from dask.dataframe.utils import make_meta as make_meta_dispatch - from .disk_io import SpillToDiskFile from .is_device_object import is_device_object @@ -893,10 +884,12 @@ def obj_pxy_dask_deserialize(header, frames): return subclass(pxy) -@dask.dataframe.core.get_parallel_type.register(ProxyObject) +@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject) def get_parallel_type_proxy_object(obj: ProxyObject): # Notice, `get_parallel_type()` needs a instance not a type object - return dask.dataframe.core.get_parallel_type(obj.__class__.__new__(obj.__class__)) + return dask.dataframe.dispatch.get_parallel_type( + obj.__class__.__new__(obj.__class__) + ) def unproxify_input_wrapper(func): @@ -913,24 +906,24 @@ def wrapper(*args, **kwargs): # Register dispatch of ProxyObject on all known dispatch objects for dispatch in ( - dask.dataframe.core.hash_object_dispatch, - make_meta_dispatch, + dask.dataframe.dispatch.hash_object_dispatch, + dask.dataframe.dispatch.make_meta_dispatch, dask.dataframe.utils.make_scalar, - dask.dataframe.core.group_split_dispatch, + dask.dataframe.dispatch.group_split_dispatch, dask.array.core.tensordot_lookup, dask.array.core.einsum_lookup, dask.array.core.concatenate_lookup, ): dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch)) -dask.dataframe.methods.concat_dispatch.register( - ProxyObject, unproxify_input_wrapper(dask.dataframe.methods.concat) +dask.dataframe.dispatch.concat_dispatch.register( + ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat) ) # We overwrite the Dask dispatch of Pandas objects in order to # deserialize all ProxyObjects before concatenating -dask.dataframe.methods.concat_dispatch.register( +dask.dataframe.dispatch.concat_dispatch.register( (pandas.DataFrame, pandas.Series, pandas.Index), - unproxify_input_wrapper(concat_pandas), + unproxify_input_wrapper(dask.dataframe.backends.concat_pandas), ) diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index 90b84e90d..c4c6600ed 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -504,27 +504,27 @@ def test_pandas(): df1 = pandas.DataFrame({"a": range(10)}) df2 = pandas.DataFrame({"a": range(10)}) - res = dask.dataframe.methods.concat([df1, df2]) - got = dask.dataframe.methods.concat([df1, df2]) + res = dask.dataframe.dispatch.concat([df1, df2]) + got = dask.dataframe.dispatch.concat([df1, df2]) assert_frame_equal(res, got) - got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2]) + got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2]) assert_frame_equal(res, got) - got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)]) + got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)]) assert_frame_equal(res, got) df1 = pandas.Series(range(10)) df2 = pandas.Series(range(10)) - res = dask.dataframe.methods.concat([df1, df2]) - got = dask.dataframe.methods.concat([df1, df2]) + res = dask.dataframe.dispatch.concat([df1, df2]) + got = dask.dataframe.dispatch.concat([df1, df2]) assert all(res == got) - got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2]) + got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2]) assert all(res == got) - got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)]) + got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)]) assert all(res == got)