Skip to content

Commit

Permalink
mmrotate sdk module (open-mmlab#450)
Browse files Browse the repository at this point in the history
* support mmrotate

* fix name

* windows default link to cudart_static.lib, which is not compatible with static build && python_api

* python api

* fix ci

* fix type & remove unused meta info

* fix doxygen, add [out] to @param

* fix mmrotate-c-api

* refactor naming

* refactor naming

* fix lint

* fix lint

* move replace_RResize -> get_preprocess

* Update cuda.cmake

On windows, make static lib and python api build success.

* fix ptr

* Use unique ptr to prevent memory leaks

* move unique_ptr

* remove deleter

Co-authored-by: chenxin2 <chenxin2@sensetime.com>
Co-authored-by: cx <cx@ubuntu20.04>
  • Loading branch information
3 people authored May 17, 2022
1 parent 1a8d7ac commit 0ce7c83
Show file tree
Hide file tree
Showing 18 changed files with 631 additions and 6 deletions.
5 changes: 5 additions & 0 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ if (${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.18.0")
cmake_policy(SET CMP0104 OLD)
endif ()

if (MSVC)
# use shared, on windows, python api can't build with static lib.
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
endif ()

# nvcc compiler settings
find_package(CUDA REQUIRED)
#message(STATUS "CUDA VERSION: ${CUDA_VERSION_STRING}")
Expand Down
8 changes: 8 additions & 0 deletions configs/mmrotate/rotated-detection_sdk_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./rotated-detection_static.py', '../_base_/backends/sdk.py']

codebase_config = dict(model_type='sdk')

backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
])
6 changes: 5 additions & 1 deletion csrc/apis/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ project(capis)
include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)

if ("all" IN_LIST MMDEPLOY_CODEBASES)
set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;pose_detector;restorer;model")
set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;"
"pose_detector;restorer;model;rotated_detector")
else ()
set(TASK_LIST "model")
if ("mmcls" IN_LIST MMDEPLOY_CODEBASES)
Expand All @@ -27,6 +28,9 @@ else ()
if ("mmpose" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "pose_detector")
endif ()
if ("mmrotate" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "rotated_detector")
endif()
endif ()

foreach (TASK ${TASK_LIST})
Expand Down
2 changes: 1 addition & 1 deletion csrc/apis/c/detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ MMDEPLOY_API int mmdeploy_detector_create_by_path(const char* model_path, const
* @param[in] mat_count number of images in the batch
* @param[out] results a linear buffer to save detection results of each image. It must be released
* by \ref mmdeploy_detector_release_result
* @param result_count a linear buffer with length being \p mat_count to save the number of
* @param[out] result_count a linear buffer with length being \p mat_count to save the number of
* detection results of each image. And it must be released by \ref
* mmdeploy_detector_release_result
* @return status of inference
Expand Down
142 changes: 142 additions & 0 deletions csrc/apis/c/rotated_detector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "rotated_detector.h"

#include <numeric>

#include "codebase/mmrotate/mmrotate.h"
#include "core/device.h"
#include "core/graph.h"
#include "core/mat.h"
#include "core/utils/formatter.h"
#include "handle.h"

using namespace std;
using namespace mmdeploy;

namespace {

Value& config_template() {
// clang-format off
static Value v{
{
"pipeline", {
{"input", {"image"}},
{"output", {"det"}},
{
"tasks",{
{
{"name", "mmrotate"},
{"type", "Inference"},
{"params", {{"model", "TBD"}}},
{"input", {"image"}},
{"output", {"det"}}
}
}
}
}
}
};
// clang-format on
return v;
}

template <class ModelType>
int mmdeploy_rotated_detector_create_impl(ModelType&& m, const char* device_name, int device_id,
mm_handle_t* handle) {
try {
auto value = config_template();
value["pipeline"]["tasks"][0]["params"]["model"] = std::forward<ModelType>(m);

auto pose_estimator = std::make_unique<Handle>(device_name, device_id, std::move(value));

*handle = pose_estimator.release();
return MM_SUCCESS;

} catch (const std::exception& e) {
MMDEPLOY_ERROR("exception caught: {}", e.what());
} catch (...) {
MMDEPLOY_ERROR("unknown exception caught");
}
return MM_E_FAIL;
}

} // namespace

int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name, int device_id,
mm_handle_t* handle) {
return mmdeploy_rotated_detector_create_impl(*static_cast<Model*>(model), device_name, device_id,
handle);
}

int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name,
int device_id, mm_handle_t* handle) {
return mmdeploy_rotated_detector_create_impl(model_path, device_name, device_id, handle);
}

int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats, int mat_count,
mm_rotated_detect_t** results, int** result_count) {
if (handle == nullptr || mats == nullptr || mat_count == 0 || results == nullptr ||
result_count == nullptr) {
return MM_E_INVALID_ARG;
}

try {
auto detector = static_cast<Handle*>(handle);

Value input{Value::kArray};
for (int i = 0; i < mat_count; ++i) {
mmdeploy::Mat _mat{mats[i].height, mats[i].width, PixelFormat(mats[i].format),
DataType(mats[i].type), mats[i].data, Device{"cpu"}};
input.front().push_back({{"ori_img", _mat}});
}

auto output = detector->Run(std::move(input)).value().front();
auto detector_outputs = from_value<vector<mmrotate::RotatedDetectorOutput>>(output);

vector<int> _result_count;
_result_count.reserve(mat_count);
for (const auto& det_output : detector_outputs) {
_result_count.push_back((int)det_output.detections.size());
}

auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0);

std::unique_ptr<int[]> result_count_data(new int[_result_count.size()]{});
std::copy(_result_count.begin(), _result_count.end(), result_count_data.get());

std::unique_ptr<mm_rotated_detect_t[]> result_data(new mm_rotated_detect_t[total]{});
auto result_ptr = result_data.get();

for (const auto& det_output : detector_outputs) {
for (const auto& detection : det_output.detections) {
result_ptr->label_id = detection.label_id;
result_ptr->score = detection.score;
const auto& rbbox = detection.rbbox;
for (int i = 0; i < 5; i++) {
result_ptr->rbbox[i] = rbbox[i];
}
++result_ptr;
}
}

*result_count = result_count_data.release();
*results = result_data.release();

return MM_SUCCESS;

} catch (const std::exception& e) {
MMDEPLOY_ERROR("exception caught: {}", e.what());
} catch (...) {
MMDEPLOY_ERROR("unknown exception caught");
}
return MM_E_FAIL;
}

void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results,
const int* result_count) {
delete[] results;
delete[] result_count;
}

void mmdeploy_rotated_detector_destroy(mm_handle_t handle) { delete static_cast<Handle*>(handle); }
82 changes: 82 additions & 0 deletions csrc/apis/c/rotated_detector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) OpenMMLab. All rights reserved.

/**
* @file rotated_detector.h
* @brief Interface to MMRotate task
*/

#ifndef MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_
#define MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_

#include "common.h"

#ifdef __cplusplus
extern "C" {
#endif

typedef struct mm_rotated_detect_t {
int label_id;
float score;
float rbbox[5]; // cx, cy, w, h, angle
} mm_rotated_detect_t;

/**
* @brief Create rotated detector's handle
* @param[in] model an instance of mmrotate sdk model created by
* \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h
* @param[in] device_name name of device, such as "cpu", "cuda", etc.
* @param[in] device_id id of device.
* @param[out] handle instance of a rotated detector
* @return status of creating rotated detector's handle
*/
MMDEPLOY_API int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name,
int device_id, mm_handle_t* handle);

/**
* @brief Create rotated detector's handle
* @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter
* @param[in] device_name name of device, such as "cpu", "cuda", etc.
* @param[in] device_id id of device.
* @param[out] handle instance of a rotated detector
* @return status of creating rotated detector's handle
*/
MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path,
const char* device_name, int device_id,
mm_handle_t* handle);

/**
* @brief Apply rotated detector to batch images and get their inference results
* @param[in] handle rotated detector's handle created by \ref
* mmdeploy_rotated_detector_create_by_path
* @param[in] mats a batch of images
* @param[in] mat_count number of images in the batch
* @param[out] results a linear buffer to save detection results of each image. It must be released
* by \ref mmdeploy_rotated_detector_release_result
* @param[out] result_count a linear buffer with length being \p mat_count to save the number of
* detection results of each image. And it must be released by \ref
* mmdeploy_rotated_detector_release_result
* @return status of inference
*/
MMDEPLOY_API int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats,
int mat_count, mm_rotated_detect_t** results,
int** result_count);

/** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply
* @param[in] results rotated detection results buffer
* @param[in] result_count \p results size buffer
*/
MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results,
const int* result_count);

/**
* @brief Destroy rotated detector's handle
* @param[in] handle rotated detector's handle created by \ref
* mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create
*/
MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mm_handle_t handle);

#ifdef __cplusplus
}
#endif

#endif // MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_
1 change: 1 addition & 0 deletions csrc/apis/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mmdeploy_python_add_module(text_detector)
mmdeploy_python_add_module(text_recognizer)
mmdeploy_python_add_module(restorer)
mmdeploy_python_add_module(pose_detector)
mmdeploy_python_add_module(rotated_detector)

pybind11_add_module(${PROJECT_NAME} ${MMDEPLOY_PYTHON_SRCS})

Expand Down
83 changes: 83 additions & 0 deletions csrc/apis/python/rotated_detector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "rotated_detector.h"

#include "common.h"
#include "core/logger.h"

namespace mmdeploy {

class PyRotatedDetector {
public:
PyRotatedDetector(const char *model_path, const char *device_name, int device_id) {
MMDEPLOY_INFO("{}, {}, {}", model_path, device_name, device_id);
auto status =
mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &handle_);
if (status != MM_SUCCESS) {
throw std::runtime_error("failed to create rotated detector");
}
}
py::list Apply(const std::vector<PyImage> &imgs) {
std::vector<mm_mat_t> mats;
mats.reserve(imgs.size());
for (const auto &img : imgs) {
auto mat = GetMat(img);
mats.push_back(mat);
}

mm_rotated_detect_t *rbboxes{};
int *res_count{};
auto status = mmdeploy_rotated_detector_apply(handle_, mats.data(), (int)mats.size(), &rbboxes,
&res_count);
if (status != MM_SUCCESS) {
throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status));
}
auto output = py::list{};
auto result = rbboxes;
auto counts = res_count;
for (int i = 0; i < mats.size(); i++) {
auto _dets = py::array_t<float>({*counts, 6});
auto _labels = py::array_t<int>({*counts});
auto dets = _dets.mutable_data();
auto labels = _labels.mutable_data();
for (int j = 0; j < *counts; j++) {
for (int k = 0; k < 5; k++) {
*dets++ = result->rbbox[k];
}
*dets++ = result->score;
*labels++ = result->label_id;
result++;
}
counts++;
output.append(py::make_tuple(std::move(_dets), std::move(_labels)));
}
mmdeploy_rotated_detector_release_result(rbboxes, res_count);
return output;
}
~PyRotatedDetector() {
mmdeploy_rotated_detector_destroy(handle_);
handle_ = {};
}

private:
mm_handle_t handle_{};
};

static void register_python_rotated_detector(py::module &m) {
py::class_<PyRotatedDetector>(m, "RotatedDetector")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyRotatedDetector>(model_path, device_name, device_id);
}))
.def("__call__", &PyRotatedDetector::Apply);
}

class PythonRotatedDetectorRegisterer {
public:
PythonRotatedDetectorRegisterer() {
gPythonBindings().emplace("rotated_detector", register_python_rotated_detector);
}
};

static PythonRotatedDetectorRegisterer python_rotated_detector_registerer;

} // namespace mmdeploy
1 change: 1 addition & 0 deletions csrc/codebase/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ if ("all" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND CODEBASES "mmocr")
list(APPEND CODEBASES "mmedit")
list(APPEND CODEBASES "mmpose")
list(APPEND CODEBASES "mmrotate")
else ()
set(CODEBASES ${MMDEPLOY_CODEBASES})
endif ()
Expand Down
11 changes: 11 additions & 0 deletions csrc/codebase/mmrotate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.14)
project(mmdeploy_mmrotate)

include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)

file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils)
add_library(mmdeploy::mmrotate ALIAS ${PROJECT_NAME})
Loading

0 comments on commit 0ce7c83

Please sign in to comment.