Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: CUDA error: named symbol not found #1298

Open
mikekgfb opened this issue Oct 14, 2024 · 3 comments
Open

RuntimeError: CUDA error: named symbol not found #1298

mikekgfb opened this issue Oct 14, 2024 · 3 comments
Labels
Compile / AOTI Issues related to AOT Inductor and torch compile

Comments

@mikekgfb
Copy link
Contributor

🐛 Describe the bug

python torchchat.py generate stories110M --quant torchchat/quant_config/cuda.json --prompt "It was a dark and stormy night, and"

Using device=cuda Tesla T4
Loading model...
Time to load model: 0.73 seconds
Quantizing the model with: {'executor': {'accelerator': 'cuda'}, 'precision': {'dtype': 'bf16'}, 'linear:int4': {'groupsize': 256}}
Time to quantize model: 0.35 seconds
Traceback (most recent call last):
File "/content/torchchat-1/torchchat.py", line 88, in
generate_main(args)
File "/content/torchchat-1/torchchat/generate.py", line 1210, in main
gen = Generator(
File "/content/torchchat-1/torchchat/generate.py", line 290, in init
self.model = initialize_model(self.builder_args, self.quantize, self.tokenizer)
File "/content/torchchat-1/torchchat/cli/builder.py", line 574, in initialize_model
quantize_model(
File "/content/torchchat-1/torchchat/utils/quantize.py", line 114, in quantize_model
quantize
(model, int4_weight_only(q_kwargs["groupsize"]))
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 462, in quantize

_replace_with_custom_fn_if_matches_filter(
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 202, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 202, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 202, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
[Previous line repeated 2 more times]
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 198, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model)
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 392, in insert_subclass
lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad)
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 553, in apply_int4_weight_only_quant
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq)
File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 286, in from_hp_to_intx
layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type)
File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 1033, in from_plain
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/utils.py", line 322, in pack_tinygemm_scales_and_zeros
torch.cat(
RuntimeError: CUDA error: named symbol not found
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Versions

/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 2
On-line CPU(s) list: 0,1
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family: 6
Model: 79
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 1
Stepping: 0
BogoMIPS: 4399.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32 KiB (1 instance)
L1i cache: 32 KiB (1 instance)
L2 cache: 256 KiB (1 instance)
L3 cache: 55 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable; SMT Host state unknown
Vulnerability Meltdown: Vulnerable
Vulnerability Mmio stale data: Vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Vulnerable (Syscall hardening enabled)
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.13.0
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241002+cu121
[pip3] torchao==0.5.0
[pip3] torchaudio==2.4.1+cu121
[pip3] torchsummary==1.5.1
[pip3] torchtune==0.3.0.dev20240928+cu121
[pip3] torchvision==0.20.0.dev20241002+cu121
[conda] Could not collect

@byjlw byjlw added the Compile / AOTI Issues related to AOT Inductor and torch compile label Oct 15, 2024
@desertfire
Copy link
Contributor

The stack dump suggests this is a torchao issue. For RuntimeError: CUDA error: named symbol not found, does it print out what exact symbol is missing? Can you share your installed CUDA version?

@mikekgfb
Copy link
Contributor Author

mikekgfb commented Oct 17, 2024

@desertfire:

The stack dump suggests this is a torchao issue. For RuntimeError: CUDA error: named symbol not found, does it print out what exact symbol is missing? Can you share your installed CUDA version?

NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2

This ran on Google colab. Detailed trace-back/repro: https://colab.research.google.com/drive/1PRneJBaS5TlJaIgc4Lwv2muiePp6T9Ss?usp=sharing

$nvidia-smi
Thu Oct 17 19:52:30 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

$ python torchchat.py generate stories110M --quant torchchat/quant_config/cuda.json --prompt "It was a dark and stormy night, and"
2024-10-17 19:52:38.809413: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-17 19:52:39.031641: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-17 19:52:39.096690: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-17 19:52:39.466437: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-10-17 19:52:41.795530: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Downloading builder script: 100% 5.67k/5.67k [00:00<00:00, 18.9MB/s]
Downloading https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt...
Downloading https://github.com/karpathy/llama2.c/raw/master/tokenizer.model...
Moving model to /root/.torchchat/model-cache/stories110M.
Using device=cuda Tesla T4
Loading model...
Time to load model: 0.40 seconds
Quantizing the model with: {'executor': {'accelerator': 'cuda'}, 'precision': {'dtype': 'bf16'}, 'linear:int4': {'groupsize': 256}}
Time to quantize model: 0.13 seconds
Traceback (most recent call last):
  File "/content/torchchat-1/torchchat.py", line 88, in <module>
    generate_main(args)
  File "/content/torchchat-1/torchchat/generate.py", line 1215, in main
    gen = Generator(
  File "/content/torchchat-1/torchchat/generate.py", line 290, in __init__
    self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
  File "/content/torchchat-1/torchchat/cli/builder.py", line 574, in _initialize_model
    quantize_model(
  File "/content/torchchat-1/torchchat/utils/quantize.py", line 114, in quantize_model
    quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 462, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 202, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 202, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 202, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  [Previous line repeated 2 more times]
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 198, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 392, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad)
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/quant_api.py", line 553, in apply_int4_weight_only_quant
    return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq)
  File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 286, in from_hp_to_intx
    layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type)
  File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 1033, in from_plain
    scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/utils.py", line 322, in pack_tinygemm_scales_and_zeros
    torch.cat(
RuntimeError: CUDA error: named symbol not found
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

mikekgfb added a commit to mikekgfb/torchchat-1 that referenced this issue Nov 5, 2024
Address pytorch#1298 which causes models to fail on T4 (and other pre-9.0 arch level GPUs) by selecting an alternate dtype when possible, and issue an error otherwise
@mikekgfb
Copy link
Contributor Author

mikekgfb commented Nov 6, 2024

The stack dump suggests this is a torchao issue. For RuntimeError: CUDA error: named symbol not found, does it print out what exact symbol is missing? Can you share your installed CUDA version?

It's a torchao issue (see also: pytorch/ao#1110), or more generally a question of expectations whether we should expect BF16 to work on platforms that don't have HW support (see pytorch/issues/124996). On pre-9.0 architecture, PT emulates BF16 (at least for some/most of the kernels), but AO does not for their linear:int4 kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Compile / AOTI Issues related to AOT Inductor and torch compile
Projects
None yet
Development

No branches or pull requests

3 participants