From 068031d3458dedfd34829dcbf9e7e04321487863 Mon Sep 17 00:00:00 2001 From: yhna Date: Sun, 13 Oct 2024 01:44:36 +0900 Subject: [PATCH 1/2] Support trt executor for ort --- ts/torch_handler/base_handler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index fa4be5841c..bff585f651 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -105,6 +105,9 @@ def setup_ort_session(model_pt_path, map_location): else ["CPUExecutionProvider"] ) + if "TensorrtExecutionProvider" in ort.get_available_providers(): + providers.insert(0, "TensorrtExecutionProvider") + sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = psutil.cpu_count(logical=True) From 21e8cb6750d1cecdd55960bee45f1de22b5075bd Mon Sep 17 00:00:00 2001 From: yhna940 Date: Mon, 14 Oct 2024 23:23:26 +0900 Subject: [PATCH 2/2] Add map location contraints --- ts/torch_handler/base_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index bff585f651..21fe7182a4 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -105,7 +105,7 @@ def setup_ort_session(model_pt_path, map_location): else ["CPUExecutionProvider"] ) - if "TensorrtExecutionProvider" in ort.get_available_providers(): + if map_location == "cuda" and "TensorrtExecutionProvider" in ort.get_available_providers(): providers.insert(0, "TensorrtExecutionProvider") sess_options = ort.SessionOptions()