diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index fa4be5841c..21fe7182a4 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -105,6 +105,9 @@ def setup_ort_session(model_pt_path, map_location): else ["CPUExecutionProvider"] ) + if map_location == "cuda" and "TensorrtExecutionProvider" in ort.get_available_providers(): + providers.insert(0, "TensorrtExecutionProvider") + sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = psutil.cpu_count(logical=True)