Skip to content

Commit

Permalink
POC pytorch in dask-cuda
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com>
  • Loading branch information
VibhuJawa committed Oct 8, 2024
1 parent fe16796 commit 12e0051
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 2 deletions.
10 changes: 10 additions & 0 deletions dask_cuda/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def cuda():
incompatible with RMM pools and managed memory, trying to enable both will
result in failure.""",
)
@click.option(
"--set-rmm-allocator-for-libs",
default=None,
show_default=True,
help="""
Set RMM as the allocator for external libraries. Provide a comma-separated
list of libraries to set, e.g., "torch,cupy". Supported options are: torch, cupy.""",
)
@click.option(
"--rmm-release-threshold",
default=None,
Expand Down Expand Up @@ -351,6 +359,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down Expand Up @@ -425,6 +434,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down
5 changes: 5 additions & 0 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -202,6 +203,9 @@ def del_pid_file():
"processes set `CUDF_SPILL=on` as well. To disable this warning "
"set `DASK_CUDF_SPILL_WARNING=False`."
)

if rmm_allocator_external_lib_list is not None:
rmm_allocator_external_lib_list = [s.strip() for s in rmm_allocator_external_lib_list.split(',')]

self.nannies = [
Nanny(
Expand Down Expand Up @@ -231,6 +235,7 @@ def del_pid_file():
release_threshold=rmm_release_threshold,
log_directory=rmm_log_directory,
track_allocations=rmm_track_allocations,
external_lib_list=rmm_allocator_external_lib_list,
),
PreImport(pre_import),
CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats),
Expand Down
11 changes: 11 additions & 0 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ class LocalCUDACluster(LocalCluster):
The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
incompatible with RMM pools and managed memory. Trying to enable both will
result in an exception.
rmm_allocator_external_lib_list: str or list or None, default None
Set RMM as the allocator for external libraries. Can be a comma-separated
string (like "torch,cupy").
rmm_release_threshold: int, str or None, default None
When ``rmm.async is True`` and the pool size grows beyond this value, unused
memory held by the pool will be released at the next synchronization point.
Expand Down Expand Up @@ -231,6 +235,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -284,6 +289,11 @@ def __init__(
self.rmm_managed_memory = rmm_managed_memory
self.rmm_async = rmm_async
self.rmm_release_threshold = rmm_release_threshold
if rmm_allocator_external_lib_list is not None:
rmm_allocator_external_lib_list = [s.strip() for s in
rmm_allocator_external_lib_list.split(',')]
self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list

if rmm_pool_size is not None or rmm_managed_memory or rmm_async:
try:
import rmm # noqa F401
Expand Down Expand Up @@ -437,6 +447,7 @@ def new_worker_spec(self):
release_threshold=self.rmm_release_threshold,
log_directory=self.rmm_log_directory,
track_allocations=self.rmm_track_allocations,
external_lib_list=self.rmm_allocator_external_lib_list
),
PreImport(self.pre_import),
CUDFSetup(self.enable_cudf_spill, self.cudf_spill_stats),
Expand Down
9 changes: 8 additions & 1 deletion dask_cuda/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from distributed import WorkerPlugin

from .utils import get_rmm_log_file_name, parse_device_memory_limit
from .utils import get_rmm_log_file_name, parse_device_memory_limit, enable_rmm_memory_for_library


class CPUAffinity(WorkerPlugin):
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
release_threshold,
log_directory,
track_allocations,
external_lib_list,
):
if initial_pool_size is None and maximum_pool_size is not None:
raise ValueError(
Expand All @@ -61,6 +62,8 @@ def __init__(
self.logging = log_directory is not None
self.log_directory = log_directory
self.rmm_track_allocations = track_allocations
self.external_lib_list = external_lib_list


def setup(self, worker=None):
if self.initial_pool_size is not None:
Expand Down Expand Up @@ -122,6 +125,10 @@ def setup(self, worker=None):

mr = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))

if self.external_lib_list is not None:
for lib in self.external_lib_list:
enable_rmm_memory_for_library(lib)


class PreImport(WorkerPlugin):
Expand Down
59 changes: 58 additions & 1 deletion dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import suppress
from functools import singledispatch
from multiprocessing import cpu_count
from typing import Optional
from typing import Optional, Callable, Dict

import numpy as np
import pynvml
Expand Down Expand Up @@ -764,3 +764,60 @@ def get_rmm_memory_resource_stack(mr) -> list:
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.allocation_counts["current_bytes"]
return None


def enable_rmm_memory_for_library(lib_name: str) -> None:
"""
Enable RMM memory pool support for a specified third-party library.
This function allows the given library to utilize RMM's memory pool if it supports
integration with RMM. The library name is passed as a string argument, and if the
library is compatible, its memory allocator will be configured to use RMM.
Parameters
----------
lib_name : str
The name of the third-party library to enable RMM memory pool support for.
Raises
------
ValueError
If the library name is not supported or does not have RMM integration.
ImportError
If the required library is not installed.
"""

# Mapping of supported libraries to their respective setup functions
setup_functions: Dict[str, Callable[[], None]] = {
"torch": _setup_rmm_for_torch,
"cupy": _setup_rmm_for_cupy,
}

if lib_name not in setup_functions:
supported_libs = ', '.join(setup_functions.keys())
raise ValueError(
f"The library '{lib_name}' is not supported for RMM integration. "
f"Supported libraries are: {supported_libs}."
)

# Call the setup function for the specified library
setup_functions[lib_name]()

def _setup_rmm_for_torch() -> None:
try:
import torch
except ImportError as e:
raise ImportError("PyTorch is not installed.") from e

from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

def _setup_rmm_for_cupy() -> None:
try:
import cupy
except ImportError as e:
raise ImportError("CuPy is not installed.") from e

from rmm.allocators.cupy import rmm_cupy_allocator
cupy.cuda.set_allocator(rmm_cupy_allocator)

0 comments on commit 12e0051

Please sign in to comment.