Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only set CUDA_DEVICE_MAX_CONNECTIONS=1 for Hopper/cc9.0 runs #1236

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ ENV BUILD_DATE=${BUILD_DATE}
ENV XLA_FLAGS=""
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_NVLS_ENABLE=0

COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB}
Expand Down
7 changes: 6 additions & 1 deletion .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,12 @@ pushd ${MAXTEXT_DIR}

export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

local_arch=$(local_cuda_arch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't aware that local_cuda_arch is installed in /usr/local/bin.

if [[ "${local_arch}" == "9.0" ]]; then
echo "Setting CUDA_DEVICE_MAX_CONNECTIONS=1 for cc${local_arch} devices"
export CUDA_DEVICE_MAX_CONNECTIONS=1
fi

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb

| Environment Variable | Value | Explanation |
| -------------------- | ----- | ----------- |
| `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
| `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. |

There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).
Expand Down
Loading