Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Pytorch to share same memory pool as RMM via cli #1392

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion dask_cuda/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from distributed.utils import import_term

from .cuda_worker import CUDAWorker
from .utils import print_cluster_config
from .utils import CommaSeparatedChoice, print_cluster_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -164,6 +164,17 @@ def cuda():
incompatible with RMM pools and managed memory, trying to enable both will
result in failure.""",
)
@click.option(
"--set-rmm-allocator-for-libs",
"rmm_allocator_external_lib_list",
type=CommaSeparatedChoice(["cupy", "torch"]),
default=None,
show_default=True,
help="""
Set RMM as the allocator for external libraries. Provide a comma-separated
list of libraries to set, e.g., "torch,cupy".
Supported options are: torch, cupy.""",
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
)
@click.option(
"--rmm-release-threshold",
default=None,
Expand Down Expand Up @@ -351,6 +362,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down Expand Up @@ -425,6 +437,7 @@ def worker(
rmm_maximum_pool_size,
rmm_managed_memory,
rmm_async,
rmm_allocator_external_lib_list,
rmm_release_threshold,
rmm_log_directory,
rmm_track_allocations,
Expand Down
2 changes: 2 additions & 0 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -231,6 +232,7 @@ def del_pid_file():
release_threshold=rmm_release_threshold,
log_directory=rmm_log_directory,
track_allocations=rmm_track_allocations,
external_lib_list=rmm_allocator_external_lib_list,
),
PreImport(pre_import),
CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats),
Expand Down
11 changes: 11 additions & 0 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ class LocalCUDACluster(LocalCluster):
The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
incompatible with RMM pools and managed memory. Trying to enable both will
result in an exception.
rmm_allocator_external_lib_list: list or None, default None
List of external libraries for which to set RMM as the allocator.
Supported options are: ``["torch", "cupy"]``. If None, no external
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
libraries will use RMM as their allocator.
rmm_release_threshold: int, str or None, default None
When ``rmm.async is True`` and the pool size grows beyond this value, unused
memory held by the pool will be released at the next synchronization point.
Expand Down Expand Up @@ -231,6 +235,7 @@ def __init__(
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_async=False,
rmm_allocator_external_lib_list=None,
rmm_release_threshold=None,
rmm_log_directory=None,
rmm_track_allocations=False,
Expand Down Expand Up @@ -265,6 +270,9 @@ def __init__(
n_workers = len(CUDA_VISIBLE_DEVICES)
if n_workers < 1:
raise ValueError("Number of workers cannot be less than 1.")

if isinstance(rmm_allocator_external_lib_list, str):
rmm_allocator_external_lib_list = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(rmm_allocator_external_lib_list, str):
rmm_allocator_external_lib_list = []

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So i added a type check here, the reason is that i trust lazy users like me to pass in the same config that they do in cli , example this is how cli looks right now.

dask-cuda-worker "tcp://10.33.227.161:8786" --set-rmm-allocator-for-libs "torch"

With the updated behavior i complain loudly (see example below):

cluster = LocalCUDACluster(rmm_allocator_external_lib_list="torch")
ValueError                                Traceback (most recent call last)
Cell In[2], line 1
----> 1 cluster = LocalCUDACluster(rmm_allocator_external_lib_list="torch")

File ~/dask-cuda/dask_cuda/local_cuda_cluster.py:275, in LocalCUDACluster.__init__(self, CUDA_VISIBLE_DEVICES, n_workers, threads_per_worker, memory_limit, device_memory_limit, enable_cudf_spill, cudf_spill_stats, data, local_directory, shared_filesystem, protocol, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, rmm_pool_size, rmm_maximum_pool_size, rmm_managed_memory, rmm_async, rmm_allocator_external_lib_list, rmm_release_threshold, rmm_log_directory, rmm_track_allocations, jit_unspill, log_spilling, worker_class, pre_import, **kwargs)
    272     raise ValueError("Number of workers cannot be less than 1.")
    274 if rmm_allocator_external_lib_list is not None and not isinstance(rmm_allocator_external_lib_list, list):
--> 275     raise ValueError(
    276         "rmm_allocator_external_lib_list must be a list of strings. "
    277         "Valid examples: ['torch'], ['cupy'], or ['torch', 'cupy']. "
    278         f"Received: {type(rmm_allocator_external_lib_list)} "
    279         f"with value: {rmm_allocator_external_lib_list}"
    280     )
    282 # Set nthreads=1 when parsing mem_limit since it only depends on n_workers
    283 logger = logging.getLogger(__name__)

ValueError: rmm_allocator_external_lib_list must be a list of strings. Valid examples: ['torch'], ['cupy'], or ['torch', 'cupy']. Received: <class 'str'> with value: torch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, but in that case I think the amount of work/code to support a string is relatively similar, instead of raising the exception should we just support passing a comma-separate string list as well then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it here: 2517874.

Let me know if you want me to change anything, thanks for the suggestion, i agree it made sense.

# Set nthreads=1 when parsing mem_limit since it only depends on n_workers
logger = logging.getLogger(__name__)
self.memory_limit = parse_memory_limit(
Expand All @@ -284,6 +292,8 @@ def __init__(
self.rmm_managed_memory = rmm_managed_memory
self.rmm_async = rmm_async
self.rmm_release_threshold = rmm_release_threshold
self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list

if rmm_pool_size is not None or rmm_managed_memory or rmm_async:
try:
import rmm # noqa F401
Expand Down Expand Up @@ -437,6 +447,7 @@ def new_worker_spec(self):
release_threshold=self.rmm_release_threshold,
log_directory=self.rmm_log_directory,
track_allocations=self.rmm_track_allocations,
external_lib_list=self.rmm_allocator_external_lib_list,
),
PreImport(self.pre_import),
CUDFSetup(self.enable_cudf_spill, self.cudf_spill_stats),
Expand Down
12 changes: 11 additions & 1 deletion dask_cuda/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from distributed import WorkerPlugin

from .utils import get_rmm_log_file_name, parse_device_memory_limit
from .utils import (
enable_rmm_memory_for_library,
get_rmm_log_file_name,
parse_device_memory_limit,
)


class CPUAffinity(WorkerPlugin):
Expand Down Expand Up @@ -39,6 +43,7 @@ def __init__(
release_threshold,
log_directory,
track_allocations,
external_lib_list,
):
if initial_pool_size is None and maximum_pool_size is not None:
raise ValueError(
Expand All @@ -61,6 +66,7 @@ def __init__(
self.logging = log_directory is not None
self.log_directory = log_directory
self.rmm_track_allocations = track_allocations
self.external_lib_list = external_lib_list

def setup(self, worker=None):
if self.initial_pool_size is not None:
Expand Down Expand Up @@ -123,6 +129,10 @@ def setup(self, worker=None):
mr = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))

if self.external_lib_list is not None:
for lib in self.external_lib_list:
enable_rmm_memory_for_library(lib)


class PreImport(WorkerPlugin):
def __init__(self, libraries):
Expand Down
73 changes: 72 additions & 1 deletion dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from contextlib import suppress
from functools import singledispatch
from multiprocessing import cpu_count
from typing import Optional
from typing import Callable, Dict, Optional

import click
import numpy as np
import pynvml
import toolz
Expand Down Expand Up @@ -764,3 +765,73 @@ def get_rmm_memory_resource_stack(mr) -> list:
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.allocation_counts["current_bytes"]
return None


def enable_rmm_memory_for_library(lib_name: str) -> None:
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
"""
Enable RMM memory pool support for a specified third-party library.
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved

This function allows the given library to utilize RMM's memory pool if it supports
integration with RMM. The library name is passed as a string argument, and if the
library is compatible, its memory allocator will be configured to use RMM.

Parameters
----------
lib_name : str
The name of the third-party library to enable RMM memory pool support for.
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved

Raises
------
ValueError
If the library name is not supported or does not have RMM integration.
ImportError
If the required library is not installed.
"""

# Mapping of supported libraries to their respective setup functions
setup_functions: Dict[str, Callable[[], None]] = {
"torch": _setup_rmm_for_torch,
"cupy": _setup_rmm_for_cupy,
}

if lib_name not in setup_functions:
supported_libs = ", ".join(setup_functions.keys())
raise ValueError(
f"The library '{lib_name}' is not supported for RMM integration. "
f"Supported libraries are: {supported_libs}."
)

# Call the setup function for the specified library
setup_functions[lib_name]()


def _setup_rmm_for_torch() -> None:
try:
import torch
except ImportError as e:
raise ImportError("PyTorch is not installed.") from e

from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def _setup_rmm_for_cupy() -> None:
try:
import cupy
except ImportError as e:
raise ImportError("CuPy is not installed.") from e

from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)


class CommaSeparatedChoice(click.Choice):
def convert(self, value, param, ctx):
values = [v.strip() for v in value.split(",")]
for v in values:
if v not in self.choices:
choices_str = ", ".join(f"'{c}'" for c in self.choices)
self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})")
return values
Loading