Skip to content

Commit

Permalink
Clean up imports for dask>2024.12.1 support (#1424)
Browse files Browse the repository at this point in the history
Follow up to #1417

Cleans up some imports (some of which don't work for `dask>2024.12.1`).

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #1424
  • Loading branch information
rjzamora authored Jan 9, 2025
1 parent c19a1a7 commit afc27f4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 41 deletions.
10 changes: 0 additions & 10 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

import dask
import dask.utils
import dask.dataframe.core
import dask.dataframe.shuffle
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
from dask.dataframe import DASK_EXPR_ENABLED
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import dask_deserialize, dask_serialize

Expand All @@ -19,14 +17,6 @@
from .proxify_device_objects import proxify_decorator, unproxify_decorator


if not DASK_EXPR_ENABLED:
raise ValueError(
"Dask-CUDA no longer supports the legacy Dask DataFrame API. "
"Please set the 'dataframe.query-planning' config to `True` "
"or None, or downgrade RAPIDS to <=24.12."
)


# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
patch_shuffle_expression()
# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
Expand Down
19 changes: 16 additions & 3 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dask.base import tokenize
from dask.dataframe import DataFrame, Series
from dask.dataframe.core import _concat as dd_concat
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
from dask.dataframe.dispatch import group_split_dispatch, hash_object_dispatch
from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize
from distributed.worker import Worker
Expand All @@ -31,6 +31,20 @@
Proxify = Callable[[T], T]


try:
from dask.dataframe import dask_expr

except ImportError:
# TODO: Remove when pinned to dask>2024.12.1
import dask_expr

if not dd._dask_expr_enabled():
raise ValueError(
"The legacy DataFrame API is not supported in dask_cudf>24.12. "
"Please enable query-planning, or downgrade to dask_cudf<=24.12"
)


def get_proxify(worker: Worker) -> Proxify:
"""Get function to proxify objects"""
from dask_cuda.proxify_host_file import ProxifyHostFile
Expand Down Expand Up @@ -576,7 +590,6 @@ def patch_shuffle_expression() -> None:
an `ECShuffle` expression when the 'explicit-comms'
config is set to `True`.
"""
import dask_expr

class ECShuffle(dask_expr._shuffle.TaskShuffle):
"""Explicit-Comms Shuffle Expression."""
Expand All @@ -585,7 +598,7 @@ def _layer(self):
# Execute an explicit-comms shuffle
if not hasattr(self, "_ec_shuffled"):
on = self.partitioning_index
df = dask_expr._collection.new_collection(self.frame)
df = dask_expr.new_collection(self.frame)
self._ec_shuffled = shuffle(
df,
[on] if isinstance(on, str) else on,
Expand Down
33 changes: 13 additions & 20 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import dask
import dask.array.core
import dask.dataframe.methods
import dask.dataframe.backends
import dask.dataframe.dispatch
import dask.dataframe.utils
import dask.utils
import distributed.protocol
Expand All @@ -22,16 +23,6 @@

from dask_cuda.disk_io import disk_read

try:
from dask.dataframe.backends import concat_pandas
except ImportError:
from dask.dataframe.methods import concat_pandas

try:
from dask.dataframe.dispatch import make_meta_dispatch as make_meta_dispatch
except ImportError:
from dask.dataframe.utils import make_meta as make_meta_dispatch

from .disk_io import SpillToDiskFile
from .is_device_object import is_device_object

Expand Down Expand Up @@ -893,10 +884,12 @@ def obj_pxy_dask_deserialize(header, frames):
return subclass(pxy)


@dask.dataframe.core.get_parallel_type.register(ProxyObject)
@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
def get_parallel_type_proxy_object(obj: ProxyObject):
# Notice, `get_parallel_type()` needs a instance not a type object
return dask.dataframe.core.get_parallel_type(obj.__class__.__new__(obj.__class__))
return dask.dataframe.dispatch.get_parallel_type(
obj.__class__.__new__(obj.__class__)
)


def unproxify_input_wrapper(func):
Expand All @@ -913,24 +906,24 @@ def wrapper(*args, **kwargs):

# Register dispatch of ProxyObject on all known dispatch objects
for dispatch in (
dask.dataframe.core.hash_object_dispatch,
make_meta_dispatch,
dask.dataframe.dispatch.hash_object_dispatch,
dask.dataframe.dispatch.make_meta_dispatch,
dask.dataframe.utils.make_scalar,
dask.dataframe.core.group_split_dispatch,
dask.dataframe.dispatch.group_split_dispatch,
dask.array.core.tensordot_lookup,
dask.array.core.einsum_lookup,
dask.array.core.concatenate_lookup,
):
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))

dask.dataframe.methods.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.methods.concat)
dask.dataframe.dispatch.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
)


# We overwrite the Dask dispatch of Pandas objects in order to
# deserialize all ProxyObjects before concatenating
dask.dataframe.methods.concat_dispatch.register(
dask.dataframe.dispatch.concat_dispatch.register(
(pandas.DataFrame, pandas.Series, pandas.Index),
unproxify_input_wrapper(concat_pandas),
unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
)
16 changes: 8 additions & 8 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,27 +504,27 @@ def test_pandas():
df1 = pandas.DataFrame({"a": range(10)})
df2 = pandas.DataFrame({"a": range(10)})

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert_frame_equal(res, got)

df1 = pandas.Series(range(10))
df2 = pandas.Series(range(10))

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert all(res == got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert all(res == got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert all(res == got)


Expand Down

0 comments on commit afc27f4

Please sign in to comment.