diff --git a/dask_cuda/__init__.py b/dask_cuda/__init__.py index 1a95f0a7..a93cb1cb 100644 --- a/dask_cuda/__init__.py +++ b/dask_cuda/__init__.py @@ -5,7 +5,6 @@ import dask import dask.utils -import dask.dataframe as dd import dask.dataframe.shuffle from .explicit_comms.dataframe.shuffle import patch_shuffle_expression from distributed.protocol.cuda import cuda_deserialize, cuda_serialize @@ -18,17 +17,6 @@ from .proxify_device_objects import proxify_decorator, unproxify_decorator -try: - if not dd._dask_expr_enabled(): - raise ValueError( - "Dask-CUDA no longer supports the legacy Dask DataFrame API. " - "Please set the 'dataframe.query-planning' config to `True` " - "or None, or downgrade RAPIDS to <=24.12." - ) -except AttributeError: - pass - - # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True` patch_shuffle_expression() # Monkey patching Dask to make use of proxify and unproxify in compatibility mode diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 2d80cebe..b2a313dc 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -31,6 +31,20 @@ Proxify = Callable[[T], T] +try: + from dask.dataframe import dask_expr + +except ImportError: + # TODO: Remove when pinned to dask>2024.12.1 + import dask_expr + + if not dd._dask_expr_enabled(): + raise ValueError( + "The legacy DataFrame API is not supported in dask_cudf>24.12. " + "Please enable query-planning, or downgrade to dask_cudf<=24.12" + ) + + def get_proxify(worker: Worker) -> Proxify: """Get function to proxify objects""" from dask_cuda.proxify_host_file import ProxifyHostFile @@ -576,7 +590,6 @@ def patch_shuffle_expression() -> None: an `ECShuffle` expression when the 'explicit-comms' config is set to `True`. """ - import dask_expr class ECShuffle(dask_expr._shuffle.TaskShuffle): """Explicit-Comms Shuffle Expression."""