Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Pytorch to share same memory pool as RMM via cli #1392

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion dask_cuda/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from distributed.utils import import_term

from .cuda_worker import CUDAWorker
from .utils import print_cluster_config
from .utils import CommaSeparatedChoice, print_cluster_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -164,6 +164,16 @@ def cuda():
incompatible with RMM pools and managed memory, trying to enable both will
result in failure.""",
)
@click.option(
"--set-rmm-allocator-for-libs",
"rmm_allocator_external_lib_list",
type=CommaSeparatedChoice(["cupy", "torch"]),
default=None,
show_default=True,
help="""
Set RMM as the allocator for external libraries. Provide a comma-separated
list of libraries to set, e.g., "torch,cupy".""",
)
@click.option(
"--rmm-release-threshold",
default=None,
Expand Down Expand Up @@ -351,6 +361,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down Expand Up @@ -425,6 +436,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down
2 changes: 2 additions & 0 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -231,6 +232,7 @@ def del_pid_file():
release_threshold=rmm_release_threshold,
log_directory=rmm_log_directory,
track_allocations=rmm_track_allocations,
external_lib_list=rmm_allocator_external_lib_list,
),
PreImport(pre_import),
CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats),
Expand Down
22 changes: 22 additions & 0 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ class LocalCUDACluster(LocalCluster):
The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
incompatible with RMM pools and managed memory. Trying to enable both will
result in an exception.
rmm_allocator_external_lib_list: str, list or None, default None
List of external libraries for which to set RMM as the allocator.
Supported options are: ``["torch", "cupy"]``. Can be a comma-separated string
(like ``"torch,cupy"``) or a list of strings (like ``["torch", "cupy"]``).
If ``None``, no external libraries will use RMM as their allocator.
rmm_release_threshold: int, str or None, default None
When ``rmm.async is True`` and the pool size grows beyond this value, unused
memory held by the pool will be released at the next synchronization point.
Expand Down Expand Up @@ -231,6 +236,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -265,6 +271,19 @@ def __init__(
n_workers = len(CUDA_VISIBLE_DEVICES)
if n_workers < 1:
raise ValueError("Number of workers cannot be less than 1.")

if rmm_allocator_external_lib_list is not None:
if isinstance(rmm_allocator_external_lib_list, str):
rmm_allocator_external_lib_list = [
v.strip() for v in rmm_allocator_external_lib_list.split(",")
]
elif not isinstance(rmm_allocator_external_lib_list, list):
raise ValueError(
"rmm_allocator_external_lib_list must be either a comma-separated "
"string or a list of strings. Examples: 'torch,cupy' "
"or ['torch', 'cupy']"
)

# Set nthreads=1 when parsing mem_limit since it only depends on n_workers
logger = logging.getLogger(__name__)
self.memory_limit = parse_memory_limit(
Expand All @@ -284,6 +303,8 @@ def __init__(
self.rmm_managed_memory = rmm_managed_memory
self.rmm_async = rmm_async
self.rmm_release_threshold = rmm_release_threshold
self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list

if rmm_pool_size is not None or rmm_managed_memory or rmm_async:
try:
import rmm # noqa F401
Expand Down Expand Up @@ -437,6 +458,7 @@ def new_worker_spec(self):
release_threshold=self.rmm_release_threshold,
log_directory=self.rmm_log_directory,
track_allocations=self.rmm_track_allocations,
external_lib_list=self.rmm_allocator_external_lib_list,
),
PreImport(self.pre_import),
CUDFSetup(self.enable_cudf_spill, self.cudf_spill_stats),
Expand Down
67 changes: 67 additions & 0 deletions dask_cuda/plugins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import os
from typing import Callable, Dict

from distributed import WorkerPlugin

Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
release_threshold,
log_directory,
track_allocations,
external_lib_list,
):
if initial_pool_size is None and maximum_pool_size is not None:
raise ValueError(
Expand All @@ -61,6 +63,7 @@ def __init__(
self.logging = log_directory is not None
self.log_directory = log_directory
self.rmm_track_allocations = track_allocations
self.external_lib_list = external_lib_list

def setup(self, worker=None):
if self.initial_pool_size is not None:
Expand Down Expand Up @@ -123,6 +126,70 @@ def setup(self, worker=None):
mr = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))

if self.external_lib_list is not None:
for lib in self.external_lib_list:
enable_rmm_memory_for_library(lib)


def enable_rmm_memory_for_library(lib_name: str) -> None:
"""Enable RMM memory pool support for a specified third-party library.

This function allows the given library to utilize RMM's memory pool if it supports
integration with RMM. The library name is passed as a string argument, and if the
library is compatible, its memory allocator will be configured to use RMM.

Parameters
----------
lib_name : str
The name of the third-party library to enable RMM memory pool support for.
Supported libraries are "cupy" and "torch".

Raises
------
ValueError
If the library name is not supported or does not have RMM integration.
ImportError
If the required library is not installed.
"""

# Mapping of supported libraries to their respective setup functions
setup_functions: Dict[str, Callable[[], None]] = {
"torch": _setup_rmm_for_torch,
"cupy": _setup_rmm_for_cupy,
}

if lib_name not in setup_functions:
supported_libs = ", ".join(setup_functions.keys())
raise ValueError(
f"The library '{lib_name}' is not supported for RMM integration. "
f"Supported libraries are: {supported_libs}."
)

# Call the setup function for the specified library
setup_functions[lib_name]()


def _setup_rmm_for_torch() -> None:
try:
import torch
except ImportError as e:
raise ImportError("PyTorch is not installed.") from e

from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def _setup_rmm_for_cupy() -> None:
try:
import cupy
except ImportError as e:
raise ImportError("CuPy is not installed.") from e

from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)


class PreImport(WorkerPlugin):
def __init__(self, libraries):
Expand Down
11 changes: 11 additions & 0 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multiprocessing import cpu_count
from typing import Optional

import click
import numpy as np
import pynvml
import toolz
Expand Down Expand Up @@ -764,3 +765,13 @@ def get_rmm_memory_resource_stack(mr) -> list:
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.allocation_counts["current_bytes"]
return None


class CommaSeparatedChoice(click.Choice):
def convert(self, value, param, ctx):
values = [v.strip() for v in value.split(",")]
for v in values:
if v not in self.choices:
choices_str = ", ".join(f"'{c}'" for c in self.choices)
self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})")
return values
Loading