diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index f4420e234..9de87e017 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -85,7 +85,6 @@ ENV BUILD_DATE=${BUILD_DATE} ENV XLA_FLAGS="" ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true" ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false" -ENV CUDA_DEVICE_MAX_CONNECTIONS=1 ENV NCCL_NVLS_ENABLE=0 COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB} diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 90e7c2488..3b64b9e7b 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -221,7 +221,12 @@ pushd ${MAXTEXT_DIR} export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN} export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION} -export CUDA_DEVICE_MAX_CONNECTIONS=1 + +local_arch=$(local_cuda_arch) +if [[ "${local_arch}" == "9.0" ]]; then + echo "Setting CUDA_DEVICE_MAX_CONNECTIONS=1 for cc${local_arch} devices" + export CUDA_DEVICE_MAX_CONNECTIONS=1 +fi export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false diff --git a/README.md b/README.md index 4438f7efc..e4803248c 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,6 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb | Environment Variable | Value | Explanation | | -------------------- | ----- | ----------- | -| `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches | | `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. | There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).