Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up imports for dask>2024.12.1 support #1424

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

import dask
import dask.utils
import dask.dataframe.core
import dask.dataframe.shuffle
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
from dask.dataframe import DASK_EXPR_ENABLED
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import dask_deserialize, dask_serialize

Expand All @@ -19,14 +17,6 @@
from .proxify_device_objects import proxify_decorator, unproxify_decorator


if not DASK_EXPR_ENABLED:
raise ValueError(
"Dask-CUDA no longer supports the legacy Dask DataFrame API. "
"Please set the 'dataframe.query-planning' config to `True` "
"or None, or downgrade RAPIDS to <=24.12."
)
Comment on lines -22 to -27
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this logic into explicit_comms.dataframe.shuffle since that's really the only place in Dask-CUDA where dask-expr matters.



# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
patch_shuffle_expression()
# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
Expand Down
19 changes: 16 additions & 3 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dask.base import tokenize
from dask.dataframe import DataFrame, Series
from dask.dataframe.core import _concat as dd_concat
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
from dask.dataframe.dispatch import group_split_dispatch, hash_object_dispatch
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All dispatch functions have been centralized in the dispatch module for a long time now. Many changes in this PR are just using that preferred module.

from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize
from distributed.worker import Worker
Expand All @@ -31,6 +31,20 @@
Proxify = Callable[[T], T]


try:
from dask.dataframe import dask_expr

except ImportError:
# TODO: Remove when pinned to dask>2024.12.1
import dask_expr

if not dd._dask_expr_enabled():
raise ValueError(
"The legacy DataFrame API is not supported in dask_cudf>24.12. "
"Please enable query-planning, or downgrade to dask_cudf<=24.12"
)


def get_proxify(worker: Worker) -> Proxify:
"""Get function to proxify objects"""
from dask_cuda.proxify_host_file import ProxifyHostFile
Expand Down Expand Up @@ -576,7 +590,6 @@ def patch_shuffle_expression() -> None:
an `ECShuffle` expression when the 'explicit-comms'
config is set to `True`.
"""
import dask_expr

class ECShuffle(dask_expr._shuffle.TaskShuffle):
"""Explicit-Comms Shuffle Expression."""
Expand All @@ -585,7 +598,7 @@ def _layer(self):
# Execute an explicit-comms shuffle
if not hasattr(self, "_ec_shuffled"):
on = self.partitioning_index
df = dask_expr._collection.new_collection(self.frame)
df = dask_expr.new_collection(self.frame)
self._ec_shuffled = shuffle(
df,
[on] if isinstance(on, str) else on,
Expand Down
33 changes: 13 additions & 20 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import dask
import dask.array.core
import dask.dataframe.methods
import dask.dataframe.backends
import dask.dataframe.dispatch
import dask.dataframe.utils
import dask.utils
import distributed.protocol
Expand All @@ -22,16 +23,6 @@

from dask_cuda.disk_io import disk_read

try:
from dask.dataframe.backends import concat_pandas
except ImportError:
from dask.dataframe.methods import concat_pandas

try:
from dask.dataframe.dispatch import make_meta_dispatch as make_meta_dispatch
except ImportError:
from dask.dataframe.utils import make_meta as make_meta_dispatch

from .disk_io import SpillToDiskFile
from .is_device_object import is_device_object

Expand Down Expand Up @@ -893,10 +884,12 @@ def obj_pxy_dask_deserialize(header, frames):
return subclass(pxy)


@dask.dataframe.core.get_parallel_type.register(ProxyObject)
@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
def get_parallel_type_proxy_object(obj: ProxyObject):
# Notice, `get_parallel_type()` needs a instance not a type object
return dask.dataframe.core.get_parallel_type(obj.__class__.__new__(obj.__class__))
return dask.dataframe.dispatch.get_parallel_type(
obj.__class__.__new__(obj.__class__)
)


def unproxify_input_wrapper(func):
Expand All @@ -913,24 +906,24 @@ def wrapper(*args, **kwargs):

# Register dispatch of ProxyObject on all known dispatch objects
for dispatch in (
dask.dataframe.core.hash_object_dispatch,
make_meta_dispatch,
dask.dataframe.dispatch.hash_object_dispatch,
dask.dataframe.dispatch.make_meta_dispatch,
dask.dataframe.utils.make_scalar,
dask.dataframe.core.group_split_dispatch,
dask.dataframe.dispatch.group_split_dispatch,
dask.array.core.tensordot_lookup,
dask.array.core.einsum_lookup,
dask.array.core.concatenate_lookup,
):
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))

dask.dataframe.methods.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.methods.concat)
dask.dataframe.dispatch.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
)


# We overwrite the Dask dispatch of Pandas objects in order to
# deserialize all ProxyObjects before concatenating
dask.dataframe.methods.concat_dispatch.register(
dask.dataframe.dispatch.concat_dispatch.register(
(pandas.DataFrame, pandas.Series, pandas.Index),
unproxify_input_wrapper(concat_pandas),
unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
)
16 changes: 8 additions & 8 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,27 +504,27 @@ def test_pandas():
df1 = pandas.DataFrame({"a": range(10)})
df2 = pandas.DataFrame({"a": range(10)})

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert_frame_equal(res, got)

df1 = pandas.Series(range(10))
df2 = pandas.Series(range(10))

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert all(res == got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert all(res == got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert all(res == got)


Expand Down
Loading