From ccff1f93672a354d0d5bb3b5efc38f3bba31e38f Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Fri, 9 Oct 2020 06:18:59 +0000 Subject: [PATCH] add clas cpp inference --- .clang_format.hook | 15 ++ .pre-commit-config.yaml | 70 ++++---- deploy/cpp_infer/CMakeLists.txt | 201 +++++++++++++++++++++++ deploy/cpp_infer/include/cls.h | 86 ++++++++++ deploy/cpp_infer/include/config.h | 82 +++++++++ deploy/cpp_infer/include/preprocess_op.h | 58 +++++++ deploy/cpp_infer/include/utility.h | 46 ++++++ deploy/cpp_infer/src/cls.cpp | 101 ++++++++++++ deploy/cpp_infer/src/config.cpp | 64 ++++++++ deploy/cpp_infer/src/main.cpp | 68 ++++++++ deploy/cpp_infer/src/preprocess_op.cpp | 92 +++++++++++ deploy/cpp_infer/src/utility.cpp | 39 +++++ deploy/cpp_infer/tools/build.sh | 20 +++ deploy/cpp_infer/tools/config.txt | 12 ++ deploy/cpp_infer/tools/run.sh | 2 + 15 files changed, 921 insertions(+), 35 deletions(-) create mode 100644 .clang_format.hook create mode 100755 deploy/cpp_infer/CMakeLists.txt create mode 100644 deploy/cpp_infer/include/cls.h create mode 100644 deploy/cpp_infer/include/config.h create mode 100644 deploy/cpp_infer/include/preprocess_op.h create mode 100644 deploy/cpp_infer/include/utility.h create mode 100644 deploy/cpp_infer/src/cls.cpp create mode 100755 deploy/cpp_infer/src/config.cpp create mode 100644 deploy/cpp_infer/src/main.cpp create mode 100644 deploy/cpp_infer/src/preprocess_op.cpp create mode 100644 deploy/cpp_infer/src/utility.cpp create mode 100755 deploy/cpp_infer/tools/build.sh create mode 100755 deploy/cpp_infer/tools/config.txt create mode 100755 deploy/cpp_infer/tools/run.sh diff --git a/.clang_format.hook b/.clang_format.hook new file mode 100644 index 00000000..1d928216 --- /dev/null +++ b/.clang_format.hook @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +readonly VERSION="3.8" + +version=$(clang-format -version) + +if ! [[ $version == *"$VERSION"* ]]; then + echo "clang-format version check failed." + echo "a version contains '$VERSION' is needed, but get '$version'" + echo "you can install the right version, and make an soft-link to '\$PATH' env" + exit -1 +fi + +clang-format $@ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15d8e516..1584bc76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,35 +1,35 @@ -- repo: https://github.com/PaddlePaddle/mirrors-yapf.git - sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 - hooks: - - id: yapf - files: \.py$ - -- repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v1.5 - hooks: - - id: autopep8 - -- repo: https://github.com/Lucas-C/pre-commit-hooks - sha: v1.0.1 - hooks: - - id: forbid-crlf - files: \.(md|yml)$ - - id: remove-crlf - files: \.(md|yml)$ - - id: forbid-tabs - files: \.(md|yml)$ - - id: remove-tabs - files: \.(md|yml)$ - -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.5.0 - hooks: - - id: check-yaml - - id: check-merge-conflict - - id: detect-private-key - files: (?!.*paddle)^.*$ - - id: end-of-file-fixer - files: \.(md|yml)$ - - id: trailing-whitespace - files: \.(md|yml)$ - - id: check-case-conflict +- repo: https://github.com/PaddlePaddle/mirrors-yapf.git + sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + hooks: + - id: yapf + files: \.py$ +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: a11d9314b22d8f8c7556443875b731ef05965464 + hooks: + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + files: (?!.*paddle)^.*$ + - id: end-of-file-fixer + files: \.md$ + - id: trailing-whitespace + files: \.md$ +- repo: https://github.com/Lucas-C/pre-commit-hooks + sha: v1.0.1 + hooks: + - id: forbid-crlf + files: \.md$ + - id: remove-crlf + files: \.md$ + - id: forbid-tabs + files: \.md$ + - id: remove-tabs + files: \.md$ +- repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat + entry: bash .clang_format.hook -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ diff --git a/deploy/cpp_infer/CMakeLists.txt b/deploy/cpp_infer/CMakeLists.txt new file mode 100755 index 00000000..3f51d654 --- /dev/null +++ b/deploy/cpp_infer/CMakeLists.txt @@ -0,0 +1,201 @@ +project(clas_system CXX C) + +option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) +option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) +option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) +option(WITH_TENSORRT "Compile demo with TensorRT." OFF) + +SET(PADDLE_LIB "" CACHE PATH "Location of libraries") +SET(OPENCV_DIR "" CACHE PATH "Location of libraries") +SET(CUDA_LIB "" CACHE PATH "Location of libraries") +SET(CUDNN_LIB "" CACHE PATH "Location of libraries") +SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT") + +set(DEMO_NAME "clas_system") + + +macro(safe_set_static_flag) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/MD") + endforeach(flag_var) +endmacro() + +if (WITH_MKL) + ADD_DEFINITIONS(-DUSE_MKL) +endif() + +if(NOT DEFINED PADDLE_LIB) + message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") +endif() + +if(NOT DEFINED OPENCV_DIR) + message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv") +endif() + + +if (WIN32) + include_directories("${PADDLE_LIB}/paddle/fluid/inference") + include_directories("${PADDLE_LIB}/paddle/include") + link_directories("${PADDLE_LIB}/paddle/fluid/inference") + find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH) + +else () + find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/share/OpenCV NO_DEFAULT_PATH) + include_directories("${PADDLE_LIB}/paddle/include") + link_directories("${PADDLE_LIB}/paddle/lib") +endif () +include_directories(${OpenCV_INCLUDE_DIRS}) + +if (WIN32) + add_definitions("/DGOOGLE_GLOG_DLL_DECL=") + set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd") + set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT") + if (WITH_STATIC_LIB) + safe_set_static_flag() + add_definitions(-DSTATIC_LIB) + endif() +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o3 -std=c++11") + set(CMAKE_STATIC_LIBRARY_PREFIX "") +endif() +message("flags" ${CMAKE_CXX_FLAGS}) + + +if (WITH_GPU) + if (NOT DEFINED CUDA_LIB OR ${CUDA_LIB} STREQUAL "") + message(FATAL_ERROR "please set CUDA_LIB with -DCUDA_LIB=/path/cuda-8.0/lib64") + endif() + if (NOT WIN32) + if (NOT DEFINED CUDNN_LIB) + message(FATAL_ERROR "please set CUDNN_LIB with -DCUDNN_LIB=/path/cudnn_v7.4/cuda/lib64") + endif() + endif(NOT WIN32) +endif() + +include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") +include_directories("${PADDLE_LIB}/third_party/install/glog/include") +include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") +include_directories("${PADDLE_LIB}/third_party/install/zlib/include") +include_directories("${PADDLE_LIB}/third_party/boost") +include_directories("${PADDLE_LIB}/third_party/eigen3") + +include_directories("${CMAKE_SOURCE_DIR}/") + +if (NOT WIN32) + if (WITH_TENSORRT AND WITH_GPU) + include_directories("${TENSORRT_DIR}/include") + link_directories("${TENSORRT_DIR}/lib") + endif() +endif(NOT WIN32) + +link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") + +link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") +link_directories("${PADDLE_LIB}/third_party/install/glog/lib") +link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") +link_directories("${PADDLE_LIB}/paddle/lib") + + +if(WITH_MKL) + include_directories("${PADDLE_LIB}/third_party/install/mklml/include") + if (WIN32) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.lib + ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md.lib) + else () + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + execute_process(COMMAND cp -r ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} /usr/lib) + endif () + set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") + if(EXISTS ${MKLDNN_PATH}) + include_directories("${MKLDNN_PATH}/include") + if (WIN32) + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) + else () + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) + endif () + endif() +else() + set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) +endif() + +# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a +if(WITH_STATIC_LIB) + set(DEPS + ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) +else() + set(DEPS + ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +if (NOT WIN32) + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags protobuf z xxhash + ) + if(EXISTS "${PADDLE_LIB}/third_party/install/snappystream/lib") + set(DEPS ${DEPS} snappystream) + endif() + if (EXISTS "${PADDLE_LIB}/third_party/install/snappy/lib") + set(DEPS ${DEPS} snappy) + endif() +else() + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags_static libprotobuf xxhash) + set(DEPS ${DEPS} libcmt shlwapi) + if (EXISTS "${PADDLE_LIB}/third_party/install/snappy/lib") + set(DEPS ${DEPS} snappy) + endif() + if(EXISTS "${PADDLE_LIB}/third_party/install/snappystream/lib") + set(DEPS ${DEPS} snappystream) + endif() +endif(NOT WIN32) + + +if(WITH_GPU) + if(NOT WIN32) + if (WITH_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() + set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX}) + else() + set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDNN_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() +endif() + + +if (NOT WIN32) + set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread") + set(DEPS ${DEPS} ${EXTERNAL_LIB}) +endif() + +set(DEPS ${DEPS} ${OpenCV_LIBS}) + +AUX_SOURCE_DIRECTORY(./src SRCS) +add_executable(${DEMO_NAME} ${SRCS}) + +target_link_libraries(${DEMO_NAME} ${DEPS}) + +if (WIN32 AND WITH_MKL) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.dll ./mklml.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md.dll ./libiomp5md.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.dll ./mkldnn.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.dll ./release/mklml.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md.dll ./release/libiomp5md.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.dll ./release/mkldnn.dll + ) +endif() diff --git a/deploy/cpp_infer/include/cls.h b/deploy/cpp_infer/include/cls.h new file mode 100644 index 00000000..3a65f729 --- /dev/null +++ b/deploy/cpp_infer/include/cls.h @@ -0,0 +1,86 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace PaddleClas { + +class Classifier { +public: + explicit Classifier(const std::string &model_dir, const bool &use_gpu, + const int &gpu_id, const int &gpu_mem, + const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const bool &use_zero_copy_run, + const int &resize_short_size, const int &crop_size) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + this->use_zero_copy_run_ = use_zero_copy_run; + + this->resize_short_size_ = resize_short_size; + this->crop_size_ = crop_size; + + LoadModel(model_dir); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir); + + // Run predictor + void Run(cv::Mat &img); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + bool use_zero_copy_run_ = false; + + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + bool is_scale_ = true; + + int resize_short_size_ = 256; + int crop_size_ = 224; + + // pre-process + ResizeImg resize_op_; + Normalize normalize_op_; + Permute permute_op_; + CenterCropImg crop_op_; +}; + +} // namespace PaddleClas \ No newline at end of file diff --git a/deploy/cpp_infer/include/config.h b/deploy/cpp_infer/include/config.h new file mode 100644 index 00000000..57788866 --- /dev/null +++ b/deploy/cpp_infer/include/config.h @@ -0,0 +1,82 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "include/utility.h" + +namespace PaddleClas { + +class Config { +public: + explicit Config(const std::string &config_file) { + config_map_ = LoadConfig(config_file); + + this->use_gpu = bool(stoi(config_map_["use_gpu"])); + + this->gpu_id = stoi(config_map_["gpu_id"]); + + this->gpu_mem = stoi(config_map_["gpu_mem"]); + + this->cpu_math_library_num_threads = + stoi(config_map_["cpu_math_library_num_threads"]); + + this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"])); + + this->use_zero_copy_run = bool(stoi(config_map_["use_zero_copy_run"])); + + this->cls_model_dir.assign(config_map_["cls_model_dir"]); + + this->resize_short_size = stoi(config_map_["resize_short_size"]); + + this->crop_size = stoi(config_map_["crop_size"]); + } + + bool use_gpu = false; + + int gpu_id = 0; + + int gpu_mem = 4000; + + int cpu_math_library_num_threads = 1; + + bool use_mkldnn = false; + + bool use_zero_copy_run = false; + + std::string cls_model_dir; + + int resize_short_size = 256; + int crop_size = 224; + + void PrintConfigInfo(); + +private: + // Load configuration + std::map LoadConfig(const std::string &config_file); + + std::vector split(const std::string &str, + const std::string &delim); + + std::map config_map_; +}; + +} // namespace PaddleClas diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h new file mode 100644 index 00000000..adc73d64 --- /dev/null +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -0,0 +1,58 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace std; +using namespace paddle; + +namespace PaddleClas { + +class Normalize { +public: + virtual void Run(cv::Mat *im, const std::vector &mean, + const std::vector &scale, const bool is_scale = true); +}; + +// RGB -> CHW +class Permute { +public: + virtual void Run(const cv::Mat *im, float *data); +}; + +class CenterCropImg { +public: + virtual void Run(cv::Mat &im, const int crop_size = 224); +}; + +class ResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, + float &ratio_h, float &ratio_w); +}; + +} // namespace PaddleClas \ No newline at end of file diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h new file mode 100644 index 00000000..8dc15247 --- /dev/null +++ b/deploy/cpp_infer/include/utility.h @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" + +namespace PaddleClas { + +class Utility { +public: + static std::vector ReadDict(const std::string &path); + + // template + // inline static size_t argmax(ForwardIterator first, ForwardIterator last) + // { + // return std::distance(first, std::max_element(first, last)); + // } +}; + +} // namespace PaddleClas \ No newline at end of file diff --git a/deploy/cpp_infer/src/cls.cpp b/deploy/cpp_infer/src/cls.cpp new file mode 100644 index 00000000..60f92e25 --- /dev/null +++ b/deploy/cpp_infer/src/cls.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleClas { + +void Classifier::LoadModel(const std::string &model_dir) { + AnalysisConfig config; + config.SetModel(model_dir + "/model", model_dir + "/params"); + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + // cache 10 different shapes for mkldnn to avoid memory leak + config.SetMkldnnCacheCapacity(10); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + // false for zero copy tensor + // true for commom tensor + config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = CreatePaddlePredictor(config); +} + +void Classifier::Run(cv::Mat &img) { + float ratio_h{}; + float ratio_w{}; + + cv::Mat srcimg; + cv::Mat resize_img; + img.copyTo(srcimg); + + this->resize_op_.Run(img, resize_img, this->resize_short_size_, ratio_h, + ratio_w); + this->crop_op_.Run(resize_img, this->crop_size_); + + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + this->permute_op_.Run(&resize_img, input.data()); + + // Inference. + if (this->use_zero_copy_run_) { + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputTensor(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + input_t->copy_from_cpu(input.data()); + this->predictor_->ZeroCopyRun(); + } else { + paddle::PaddleTensor input_t; + input_t.shape = {1, 3, resize_img.rows, resize_img.cols}; + input_t.data = + paddle::PaddleBuf(input.data(), input.size() * sizeof(float)); + input_t.dtype = PaddleDType::FLOAT32; + std::vector outputs; + this->predictor_->Run({input_t}, &outputs, 1); + } + + std::vector out_data; + auto output_names = this->predictor_->GetOutputNames(); + auto output_t = this->predictor_->GetOutputTensor(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + + out_data.resize(out_num); + output_t->copy_to_cpu(out_data.data()); + + int maxPosition = + max_element(out_data.begin(), out_data.end()) - out_data.begin(); + std::cout << "result: " << std::endl; + std::cout << "\tclass: " << maxPosition << std::endl; + std::cout << std::fixed << std::setprecision(10) + << "\tscore: " << double(out_data[maxPosition]) << std::endl; +} + +} // namespace PaddleClas diff --git a/deploy/cpp_infer/src/config.cpp b/deploy/cpp_infer/src/config.cpp new file mode 100755 index 00000000..e8995520 --- /dev/null +++ b/deploy/cpp_infer/src/config.cpp @@ -0,0 +1,64 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleClas { + +std::vector Config::split(const std::string &str, + const std::string &delim) { + std::vector res; + if ("" == str) + return res; + char *strs = new char[str.length() + 1]; + std::strcpy(strs, str.c_str()); + + char *d = new char[delim.length() + 1]; + std::strcpy(d, delim.c_str()); + + char *p = std::strtok(strs, d); + while (p) { + std::string s = p; + res.push_back(s); + p = std::strtok(NULL, d); + } + + return res; +} + +std::map +Config::LoadConfig(const std::string &config_path) { + auto config = Utility::ReadDict(config_path); + + std::map dict; + for (int i = 0; i < config.size(); i++) { + // pass for empty line or comment + if (config[i].size() <= 1 || config[i][0] == '#') { + continue; + } + std::vector res = split(config[i], " "); + dict[res[0]] = res[1]; + } + return dict; +} + +void Config::PrintConfigInfo() { + std::cout << "=======Paddle Class inference config======" << std::endl; + for (auto iter = config_map_.begin(); iter != config_map_.end(); iter++) { + std::cout << iter->first << " : " << iter->second << std::endl; + } + std::cout << "=======End of Paddle Class inference config======" << std::endl; +} + +} // namespace PaddleClas \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp new file mode 100644 index 00000000..01397ef9 --- /dev/null +++ b/deploy/cpp_infer/src/main.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +using namespace std; +using namespace cv; +using namespace PaddleClas; + +int main(int argc, char **argv) { + if (argc < 3) { + std::cerr << "[ERROR] usage: " << argv[0] + << " configure_filepath image_path\n"; + exit(1); + } + + Config config(argv[1]); + + config.PrintConfigInfo(); + + std::string img_path(argv[2]); + + cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); + cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB); + + Classifier classifier(config.cls_model_dir, config.use_gpu, config.gpu_id, + config.gpu_mem, config.cpu_math_library_num_threads, + config.use_mkldnn, config.use_zero_copy_run, + config.resize_short_size, config.crop_size); + + auto start = std::chrono::system_clock::now(); + classifier.Run(srcimg); + auto end = std::chrono::system_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + std::cout << "Cost " + << double(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den + << " s" << std::endl; + + return 0; +} diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp new file mode 100644 index 00000000..a7c41e49 --- /dev/null +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace PaddleClas { + +void Permute::Run(const cv::Mat *im, float *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i); + } +} + +void Normalize::Run(cv::Mat *im, const std::vector &mean, + const std::vector &scale, const bool is_scale) { + double e = 1.0; + if (is_scale) { + e /= 255.0; + } + (*im).convertTo(*im, CV_32FC3, e); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean[0]) * scale[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean[1]) * scale[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean[2]) * scale[2]; + } + } +} + +void CenterCropImg::Run(cv::Mat &img, const int crop_size) { + int resize_w = img.cols; + int resize_h = img.rows; + int w_start = int((resize_w - crop_size) / 2); + int h_start = int((resize_h - crop_size) / 2); + cv::Rect rect(w_start, h_start, crop_size, crop_size); + img = img(rect); +} + +void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, + float &ratio_h, float &ratio_w) { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + if (h < w) { + ratio = float(max_size_len) / float(h); + } else { + ratio = float(max_size_len) / float(w); + } + + int resize_h = int(float(h) * ratio); + int resize_w = int(float(w) * ratio); + + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); + + ratio_h = float(resize_h) / float(h); + ratio_w = float(resize_w) / float(w); +} + +} // namespace PaddleClas \ No newline at end of file diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp new file mode 100644 index 00000000..e6b572a5 --- /dev/null +++ b/deploy/cpp_infer/src/utility.cpp @@ -0,0 +1,39 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +namespace PaddleClas { + +std::vector Utility::ReadDict(const std::string &path) { + std::ifstream in(path); + std::string line; + std::vector m_vec; + if (in) { + while (getline(in, line)) { + m_vec.push_back(line); + } + } else { + std::cout << "no such label file: " << path << ", exit the program..." + << std::endl; + exit(1); + } + return m_vec; +} + +} // namespace PaddleClas \ No newline at end of file diff --git a/deploy/cpp_infer/tools/build.sh b/deploy/cpp_infer/tools/build.sh new file mode 100755 index 00000000..0de61f04 --- /dev/null +++ b/deploy/cpp_infer/tools/build.sh @@ -0,0 +1,20 @@ +OPENCV_DIR=/PaddleClas/PaddleOCR/opencv-3.4.7/opencv3/ +LIB_DIR=/PaddleClas/PaddleOCR/fluid_inference/ +CUDA_LIB_DIR=/usr/local/cuda/lib64 +CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/ + +BUILD_DIR=build +rm -rf ${BUILD_DIR} +mkdir ${BUILD_DIR} +cd ${BUILD_DIR} +cmake .. \ + -DPADDLE_LIB=${LIB_DIR} \ + -DWITH_MKL=ON \ + -DWITH_GPU=OFF \ + -DWITH_STATIC_LIB=OFF \ + -DUSE_TENSORRT=OFF \ + -DOPENCV_DIR=${OPENCV_DIR} \ + -DCUDNN_LIB=${CUDNN_LIB_DIR} \ + -DCUDA_LIB=${CUDA_LIB_DIR} \ + +make -j diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt new file mode 100755 index 00000000..e203b5d8 --- /dev/null +++ b/deploy/cpp_infer/tools/config.txt @@ -0,0 +1,12 @@ +# model load config +use_gpu 0 +gpu_id 0 +gpu_mem 4000 +cpu_math_library_num_threads 10 +use_mkldnn 1 +use_zero_copy_run 1 + +# cls config +cls_model_dir ./inference/ +resize_short_size 256 +crop_size 224 diff --git a/deploy/cpp_infer/tools/run.sh b/deploy/cpp_infer/tools/run.sh new file mode 100755 index 00000000..2c4b5434 --- /dev/null +++ b/deploy/cpp_infer/tools/run.sh @@ -0,0 +1,2 @@ + +./build/clas_system ./tools/config.txt ./ILSVRC2012_val_00000001.JPEG -- GitLab