diff --git a/dask_cuda/cli.py b/dask_cuda/cli.py index a8c6d972..8d335ec3 100644 --- a/dask_cuda/cli.py +++ b/dask_cuda/cli.py @@ -164,6 +164,14 @@ def cuda(): incompatible with RMM pools and managed memory, trying to enable both will result in failure.""", ) +@click.option( + "--set-rmm-allocator-for-libs", + default=None, + show_default=True, + help=""" + Set RMM as the allocator for external libraries. Provide a comma-separated + list of libraries to set, e.g., "torch,cupy". Supported options are: torch, cupy.""", +) @click.option( "--rmm-release-threshold", default=None, @@ -351,6 +359,7 @@ def worker( rmm_maximum_pool_size, rmm_managed_memory, rmm_async, + rmm_allocator_external_lib_list, rmm_release_threshold, rmm_log_directory, rmm_track_allocations, @@ -425,6 +434,7 @@ def worker( rmm_maximum_pool_size, rmm_managed_memory, rmm_async, + rmm_allocator_external_lib_list, rmm_release_threshold, rmm_log_directory, rmm_track_allocations, diff --git a/dask_cuda/cuda_worker.py b/dask_cuda/cuda_worker.py index 3e03ed29..91128eb1 100644 --- a/dask_cuda/cuda_worker.py +++ b/dask_cuda/cuda_worker.py @@ -47,6 +47,7 @@ def __init__( rmm_maximum_pool_size=None, rmm_managed_memory=False, rmm_async=False, + rmm_allocator_external_lib_list=None, rmm_release_threshold=None, rmm_log_directory=None, rmm_track_allocations=False, @@ -202,6 +203,9 @@ def del_pid_file(): "processes set `CUDF_SPILL=on` as well. To disable this warning " "set `DASK_CUDF_SPILL_WARNING=False`." ) + + if rmm_allocator_external_lib_list is not None: + rmm_allocator_external_lib_list = [s.strip() for s in rmm_allocator_external_lib_list.split(',')] self.nannies = [ Nanny( @@ -231,6 +235,7 @@ def del_pid_file(): release_threshold=rmm_release_threshold, log_directory=rmm_log_directory, track_allocations=rmm_track_allocations, + external_lib_list=rmm_allocator_external_lib_list, ), PreImport(pre_import), CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats), diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index c037223b..ed95ba3b 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -143,6 +143,10 @@ class LocalCUDACluster(LocalCluster): The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also incompatible with RMM pools and managed memory. Trying to enable both will result in an exception. + rmm_allocator_external_lib_list: str or list or None, default None + Set RMM as the allocator for external libraries. Can be a comma-separated + string (like "torch,cupy"). + rmm_release_threshold: int, str or None, default None When ``rmm.async is True`` and the pool size grows beyond this value, unused memory held by the pool will be released at the next synchronization point. @@ -231,6 +235,7 @@ def __init__( rmm_maximum_pool_size=None, rmm_managed_memory=False, rmm_async=False, + rmm_allocator_external_lib_list=None, rmm_release_threshold=None, rmm_log_directory=None, rmm_track_allocations=False, @@ -284,6 +289,11 @@ def __init__( self.rmm_managed_memory = rmm_managed_memory self.rmm_async = rmm_async self.rmm_release_threshold = rmm_release_threshold + if rmm_allocator_external_lib_list is not None: + rmm_allocator_external_lib_list = [s.strip() for s in + rmm_allocator_external_lib_list.split(',')] + self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list + if rmm_pool_size is not None or rmm_managed_memory or rmm_async: try: import rmm # noqa F401 @@ -437,6 +447,7 @@ def new_worker_spec(self): release_threshold=self.rmm_release_threshold, log_directory=self.rmm_log_directory, track_allocations=self.rmm_track_allocations, + external_lib_list=self.rmm_allocator_external_lib_list ), PreImport(self.pre_import), CUDFSetup(self.enable_cudf_spill, self.cudf_spill_stats), diff --git a/dask_cuda/plugins.py b/dask_cuda/plugins.py index 122f93ff..4e3726e4 100644 --- a/dask_cuda/plugins.py +++ b/dask_cuda/plugins.py @@ -3,7 +3,7 @@ from distributed import WorkerPlugin -from .utils import get_rmm_log_file_name, parse_device_memory_limit +from .utils import get_rmm_log_file_name, parse_device_memory_limit, enable_rmm_memory_for_library class CPUAffinity(WorkerPlugin): @@ -39,6 +39,7 @@ def __init__( release_threshold, log_directory, track_allocations, + external_lib_list, ): if initial_pool_size is None and maximum_pool_size is not None: raise ValueError( @@ -61,6 +62,8 @@ def __init__( self.logging = log_directory is not None self.log_directory = log_directory self.rmm_track_allocations = track_allocations + self.external_lib_list = external_lib_list + def setup(self, worker=None): if self.initial_pool_size is not None: @@ -122,6 +125,10 @@ def setup(self, worker=None): mr = rmm.mr.get_current_device_resource() rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr)) + + if self.external_lib_list is not None: + for lib in self.external_lib_list: + enable_rmm_memory_for_library(lib) class PreImport(WorkerPlugin): diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index ff4dbbae..2f92d94b 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -7,7 +7,7 @@ from contextlib import suppress from functools import singledispatch from multiprocessing import cpu_count -from typing import Optional +from typing import Optional, Callable, Dict import numpy as np import pynvml @@ -764,3 +764,60 @@ def get_rmm_memory_resource_stack(mr) -> list: if isinstance(mr, rmm.mr.StatisticsResourceAdaptor): return mr.allocation_counts["current_bytes"] return None + + +def enable_rmm_memory_for_library(lib_name: str) -> None: + """ + Enable RMM memory pool support for a specified third-party library. + + This function allows the given library to utilize RMM's memory pool if it supports + integration with RMM. The library name is passed as a string argument, and if the + library is compatible, its memory allocator will be configured to use RMM. + + Parameters + ---------- + lib_name : str + The name of the third-party library to enable RMM memory pool support for. + + Raises + ------ + ValueError + If the library name is not supported or does not have RMM integration. + ImportError + If the required library is not installed. + """ + + # Mapping of supported libraries to their respective setup functions + setup_functions: Dict[str, Callable[[], None]] = { + "torch": _setup_rmm_for_torch, + "cupy": _setup_rmm_for_cupy, + } + + if lib_name not in setup_functions: + supported_libs = ', '.join(setup_functions.keys()) + raise ValueError( + f"The library '{lib_name}' is not supported for RMM integration. " + f"Supported libraries are: {supported_libs}." + ) + + # Call the setup function for the specified library + setup_functions[lib_name]() + +def _setup_rmm_for_torch() -> None: + try: + import torch + except ImportError as e: + raise ImportError("PyTorch is not installed.") from e + + from rmm.allocators.torch import rmm_torch_allocator + + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + +def _setup_rmm_for_cupy() -> None: + try: + import cupy + except ImportError as e: + raise ImportError("CuPy is not installed.") from e + + from rmm.allocators.cupy import rmm_cupy_allocator + cupy.cuda.set_allocator(rmm_cupy_allocator)