diff --git a/cmake/external/prometheus.cmake b/cmake/external/prometheus.cmake index 9abefcc1fb8301ba310d78a94050042114269369..2c7cc2be66a925d4a63b484b3b46cd4af2440916 100644 --- a/cmake/external/prometheus.cmake +++ b/cmake/external/prometheus.cmake @@ -32,13 +32,18 @@ ExternalProject_Add( CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} + -DCMAKE_INSTALL_PREFIX:PATH=${PROMETHEUS_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR=${PROMETHEUS_INSTALL_DIR}/lib + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS=OFF -DENABLE_PUSH=OFF -DENABLE_COMPRESSION=OFF -DENABLE_TESTING=OFF - -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} - -DCMAKE_INSTALL_PREFIX:PATH=${PROMETHEUS_INSTALL_DIR} - -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} BUILD_BYPRODUCTS ${PROMETHEUS_LIBRARIES} ) diff --git a/core/general-server/op/CMakeLists.txt b/core/general-server/op/CMakeLists.txt index 1b631fd7a749e2e6f2f4c5b347ada6a1509842cd..0d0b5e0a1e7fe494c0a9fa438151c4656530633f 100644 --- a/core/general-server/op/CMakeLists.txt +++ b/core/general-server/op/CMakeLists.txt @@ -1,9 +1,13 @@ FILE(GLOB op_srcs ${CMAKE_CURRENT_LIST_DIR}/*.cpp ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/quant.cpp) if(WITH_OPENCV) FILE(GLOB ocrtools_srcs ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/ocrtools/*.cpp) + FILE(GLOB ppshitu_tools_srcs ${CMAKE_CURRENT_LIST_DIR}/../../predictor/tools/pp_shitu_tools/*.cpp) + LIST(APPEND op_srcs ${ppshitu_tools_srcs}) LIST(APPEND op_srcs ${ocrtools_srcs}) else() set (EXCLUDE_DIR "general_detection_op.cpp") + set (EXCLUDE_DIR "general_picodet_op.cpp") + set (EXCLUDE_DIR "general_feature_extract_op.cpp") foreach (TMP_PATH ${op_srcs}) string (FIND ${TMP_PATH} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND) if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1) diff --git a/core/general-server/op/general_feature_extract_op.cpp b/core/general-server/op/general_feature_extract_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..988cb6477be6305ad0e3593ac4f8c1f305035462 --- /dev/null +++ b/core/general-server/op/general_feature_extract_op.cpp @@ -0,0 +1,108 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-server/op/general_feature_extract_op.h" +#include +#include +#include +#include +#include "core/predictor/framework/infer.h" +#include "core/predictor/framework/memory.h" +#include "core/predictor/framework/resource.h" +#include "core/util/include/timer.h" + +namespace baidu { +namespace paddle_serving { +namespace serving { + +using baidu::paddle_serving::Timer; +using baidu::paddle_serving::predictor::MempoolWrapper; +using baidu::paddle_serving::predictor::general_model::Tensor; +using baidu::paddle_serving::predictor::general_model::Response; +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::InferManager; +using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; + +int GeneralFeatureExtractOp::inference() { + VLOG(2) << "Going to run inference"; + const std::vector pre_node_names = pre_names(); + if (pre_node_names.size() != 1) { + LOG(ERROR) << "This op(" << op_name() + << ") can only have one predecessor op, but received " + << pre_node_names.size(); + return -1; + } + const std::string pre_name = pre_node_names[0]; + + const GeneralBlob *input_blob = get_depend_argument(pre_name); + if (!input_blob) { + LOG(ERROR) << "input_blob is nullptr,error"; + return -1; + } + + uint64_t log_id = input_blob->GetLogId(); + VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; + + GeneralBlob *output_blob = mutable_data(); + if (!output_blob) { + LOG(ERROR) << "output_blob is nullptr,error"; + return -1; + } + output_blob->SetLogId(log_id); + + if (!input_blob) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed mutable depended argument, op:" << pre_name; + return -1; + } + + const TensorVector *in = &input_blob->tensor_vector; + TensorVector *out = &output_blob->tensor_vector; + + int batch_size = input_blob->_batch_size; + output_blob->_batch_size = batch_size; + VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size; + + Timer timeline; + int64_t start = timeline.TimeStampUS(); + timeline.Start(); + + paddle::PaddleTensor boxes = in->at(1); + TensorVector* real_in = new TensorVector(); + if (!real_in) { + LOG(ERROR) << "real_in is nullptr, error"; + return -1; + } + + real_in->push_back(in->at(0)); + if (InferManager::instance().infer( + engine_name().c_str(), real_in, out, batch_size)) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed do infer in fluid model: " << engine_name().c_str(); + return -1; + } + out->push_back(boxes); + + int64_t end = timeline.TimeStampUS(); + CopyBlobInfo(input_blob, output_blob); + AddBlobInfo(output_blob, start); + AddBlobInfo(output_blob, end); + return 0; +} + +DEFINE_OP(GeneralFeatureExtractOp); + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/core/general-server/op/general_feature_extract_op.h b/core/general-server/op/general_feature_extract_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d2919d4a415447b3bc1e1cc5b4f6199d94c1873f --- /dev/null +++ b/core/general-server/op/general_feature_extract_op.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "core/general-server/general_model_service.pb.h" +#include "core/general-server/op/general_infer_helper.h" +#include "paddle_inference_api.h" // NOLINT + +namespace baidu { +namespace paddle_serving { +namespace serving { + +class GeneralFeatureExtractOp + : public baidu::paddle_serving::predictor::OpWithChannel { + public: + typedef std::vector TensorVector; + DECLARE_OP(GeneralFeatureExtractOp); + + int inference(); +}; + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/core/general-server/op/general_picodet_op.cpp b/core/general-server/op/general_picodet_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ee9c1b3723b647b9caa7caa94de0a613b101edd7 --- /dev/null +++ b/core/general-server/op/general_picodet_op.cpp @@ -0,0 +1,371 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-server/op/general_picodet_op.h" + +#include +#include +#include +#include + +#include "core/predictor/framework/infer.h" +#include "core/predictor/framework/memory.h" +#include "core/predictor/framework/resource.h" +#include "core/util/include/timer.h" + +namespace baidu { +namespace paddle_serving { +namespace serving { + +using baidu::paddle_serving::Timer; +using baidu::paddle_serving::predictor::MempoolWrapper; +using baidu::paddle_serving::predictor::general_model::Tensor; +using baidu::paddle_serving::predictor::general_model::Response; +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::InferManager; +using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; + +int GeneralPicodetOp::inference() { + VLOG(2) << "Going to run inference"; + const std::vector pre_node_names = pre_names(); + if (pre_node_names.size() != 1) { + LOG(ERROR) << "This op(" << op_name() + << ") can only have one predecessor op, but received " + << pre_node_names.size(); + return -1; + } + const std::string pre_name = pre_node_names[0]; + + const GeneralBlob* input_blob = get_depend_argument(pre_name); + if (!input_blob) { + LOG(ERROR) << "input_blob is nullptr,error"; + return -1; + } + uint64_t log_id = input_blob->GetLogId(); + VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; + + GeneralBlob* output_blob = mutable_data(); + if (!output_blob) { + LOG(ERROR) << "output_blob is nullptr,error"; + return -1; + } + output_blob->SetLogId(log_id); + + if (!input_blob) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed mutable depended argument, op:" << pre_name; + return -1; + } + + const TensorVector* in = &input_blob->tensor_vector; + TensorVector* out = &output_blob->tensor_vector; + int batch_size = input_blob->_batch_size; + VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size; + output_blob->_batch_size = batch_size; + + //get image shape + float* data = (float*)in->at(0).data.data(); + int height = data[0]; + int width = data[1]; + VLOG(2) << "image width: " << width; + VLOG(2) << "image height: " << height; + + ///////////////////det preprocess begin///////////////////////// + //show raw image + unsigned char* img_data = static_cast(in->at(1).data.data()); + cv::Mat origin(height, width, CV_8UC3, img_data); + // cv::imwrite("/workspace/origin_image.jpg", origin); + + cv::Mat origin_img = origin.clone(); + cv::cvtColor(origin, origin, cv::COLOR_BGR2RGB); + InitInfo_Run(&origin, &imgblob); + Resize_Run(&origin, &imgblob); + NormalizeImage_Run(&origin, &imgblob); + Permute_Run(&origin, &imgblob); + ///////////////////det preprocess end///////////////////////// + + Timer timeline; + int64_t start = timeline.TimeStampUS(); + timeline.Start(); + + //generate real_in + TensorVector* real_in = new TensorVector(); + if (!real_in) { + LOG(ERROR) << "real_in is nullptr, error"; + return -1; + } + + //generate im_shape + int in_num = 2; + size_t databuf_size = in_num * sizeof(float); + void *databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + float* databuf_float = reinterpret_cast(databuf_data); + *databuf_float = imgblob.im_shape_[0]; + databuf_float++; + *databuf_float = imgblob.im_shape_[1]; + + char* databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in; + tensor_in.name = "im_shape"; + tensor_in.dtype = paddle::PaddleDType::FLOAT32; + tensor_in.shape = {1, 2}; + tensor_in.lod = in->at(0).lod; + tensor_in.data = paddleBuf; + real_in->push_back(tensor_in); + + //generate scale_factor + databuf_size = in_num * sizeof(float); + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + databuf_float = reinterpret_cast(databuf_data); + *databuf_float = imgblob.scale_factor_[0]; + databuf_float++; + *databuf_float = imgblob.scale_factor_[1]; + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf_2(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in_2; + tensor_in_2.name = "scale_factor"; + tensor_in_2.dtype = paddle::PaddleDType::FLOAT32; + tensor_in_2.shape = {1, 2}; + tensor_in_2.lod = in->at(0).lod; + tensor_in_2.data = paddleBuf_2; + real_in->push_back(tensor_in_2); + + //genarate image + in_num = imgblob.im_data_.size(); + databuf_size = in_num * sizeof(float); + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + memcpy(databuf_data, imgblob.im_data_.data(), databuf_size); + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf_3(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in_3; + tensor_in_3.name = "image"; + tensor_in_3.dtype = paddle::PaddleDType::FLOAT32; + tensor_in_3.shape = {1, 3, imgblob.in_net_shape_[0], imgblob.in_net_shape_[1]}; + tensor_in_3.lod = in->at(0).lod; + tensor_in_3.data = paddleBuf_3; + real_in->push_back(tensor_in_3); + + if (InferManager::instance().infer( + engine_name().c_str(), real_in, out, batch_size)) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed do infer in fluid model: " << engine_name().c_str(); + return -1; + } + + ///////////////////det postprocess begin///////////////////////// + //get output_data_ + std::vector output_data_; + int infer_outnum = out->size(); + paddle::PaddleTensor element = out->at(0); + std::vector element_shape = element.shape; + //get data len + int total_num = 1; + for(auto value_shape: element_shape) + { + total_num *= value_shape; + } + output_data_.resize(total_num); + + float* data_out = (float*)element.data.data(); + for(int j=0; j < total_num; j++) + { + output_data_[j] = data_out[j]; + } + + //det postprocess + //1) get detect result + if(output_data_.size() > max_detect_results * 6){ + output_data_.resize(max_detect_results * 6); + } + std::vector result; + int detect_num = output_data_.size() / 6; + for(int m = 0; m < detect_num; m++) + { + // Class id + int class_id = static_cast(round(output_data_[0 + m * 6])); + // Confidence score + float score = output_data_[1 + m * 6]; + // Box coordinate + int xmin = (output_data_[2 + m * 6]); + int ymin = (output_data_[3 + m * 6]); + int xmax = (output_data_[4 + m * 6]); + int ymax = (output_data_[5 + m * 6]); + + ObjectResult result_item; + result_item.rect = {xmin, ymin, xmax, ymax}; + result_item.class_id = class_id; + result_item.confidence = score; + result.push_back(result_item); + } + + //2) add the whole image + ObjectResult result_whole_img = { + {0, 0, width - 1, height - 1}, 0, 1.0}; + result.push_back(result_whole_img); + + + //3) crop image and do preprocess. concanate the data + cv::Mat srcimg; + cv::cvtColor(origin_img, srcimg, cv::COLOR_BGR2RGB); + std::vector all_data; + for (int j = 0; j < result.size(); ++j) { + int w = result[j].rect[2] - result[j].rect[0]; + int h = result[j].rect[3] - result[j].rect[1]; + cv::Rect rect(result[j].rect[0], result[j].rect[1], w, h); + cv::Mat crop_img = srcimg(rect); + cv::Mat resize_img; + resize_op_.Run(crop_img, resize_img, resize_short_, resize_size_); + normalize_op_.Run(&resize_img, mean_, std_, scale_); + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + permute_op_.Run(&resize_img, input.data()); + for(int m = 0; m < input.size(); m++) + { + all_data.push_back(input[m]); + } + } + ///////////////////det postprocess begin///////////////////////// + + //generate new Tensors; + //"x" + int out_num = all_data.size(); + int databuf_size_out = out_num * sizeof(float); + void *databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out); + if (!databuf_data_out) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size_out; + return -1; + } + memcpy(databuf_data_out, all_data.data(), databuf_size_out); + char *databuf_char_out = reinterpret_cast(databuf_data_out); + paddle::PaddleBuf paddleBuf_out(databuf_char_out, databuf_size_out); + paddle::PaddleTensor tensor_out; + + tensor_out.name = "x"; + tensor_out.dtype = paddle::PaddleDType::FLOAT32; + tensor_out.shape = {result.size(), 3, 224, 224}; + tensor_out.data = paddleBuf_out; + tensor_out.lod = in->at(0).lod; + out->push_back(tensor_out); + + //"boxes" + int box_size_out = result.size() * 6 * sizeof(float); + void *box_data_out = MempoolWrapper::instance().malloc(box_size_out); + if (!box_data_out) { + LOG(ERROR) << "Malloc failed, size: " << box_data_out; + return -1; + } + memcpy(box_data_out, out->at(0).data.data(), box_size_out - 6 * sizeof(float)); + float *box_float_out = reinterpret_cast(box_data_out); + box_float_out += (result.size() - 1) * 6; + box_float_out[0] = 0.0; + box_float_out[1] = 1.0; + box_float_out[2] = 0.0; + box_float_out[3] = 0.0; + box_float_out[4] = width - 1; + box_float_out[5] = height - 1; + char *box_char_out = reinterpret_cast(box_data_out); + paddle::PaddleBuf paddleBuf_out_2(box_char_out, box_size_out); + paddle::PaddleTensor tensor_out_2; + + tensor_out_2.name = "boxes"; + tensor_out_2.dtype = paddle::PaddleDType::FLOAT32; + tensor_out_2.shape = {result.size(), 6}; + tensor_out_2.data = paddleBuf_out_2; + tensor_out_2.lod = in->at(0).lod; + out->push_back(tensor_out_2); + out->erase(out->begin(), out->begin() + infer_outnum); + + int64_t end = timeline.TimeStampUS(); + CopyBlobInfo(input_blob, output_blob); + AddBlobInfo(output_blob, start); + AddBlobInfo(output_blob, end); + return 0; +} + +DEFINE_OP(GeneralPicodetOp); + +void GeneralPicodetOp::Postprocess(const std::vector mats, + std::vector *result, + std::vector bbox_num, + bool is_rbox, + std::vector output_data_, + std::vector out_bbox_num_data_){ + result->clear(); + int start_idx = 0; + for (int im_id = 0; im_id < mats.size(); im_id++) { + cv::Mat raw_mat = mats[im_id]; + int rh = 1; + int rw = 1; + for (int j = start_idx; j < start_idx + bbox_num[im_id]; j++) { + if (is_rbox) { + // Class id + score + 8 parameters + // Class id + int class_id = static_cast(round(output_data_[0 + j * 10])); + // Confidence score + float score = output_data_[1 + j * 10]; + int x1 = (output_data_[2 + j * 10] * rw); + int y1 = (output_data_[3 + j * 10] * rh); + int x2 = (output_data_[4 + j * 10] * rw); + int y2 = (output_data_[5 + j * 10] * rh); + int x3 = (output_data_[6 + j * 10] * rw); + int y3 = (output_data_[7 + j * 10] * rh); + int x4 = (output_data_[8 + j * 10] * rw); + int y4 = (output_data_[9 + j * 10] * rh); + + ObjectResult result_item; + result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4}; + result_item.class_id = class_id; + result_item.confidence = score; + result->push_back(result_item); + } else { + // Class id + int class_id = static_cast(round(output_data_[0 + j * 6])); + // Confidence score + float score = output_data_[1 + j * 6]; + + //xmin, ymin, xmax, ymax + int xmin = (output_data_[2 + j * 6] * rw); + int ymin = (output_data_[3 + j * 6] * rh); + int xmax = (output_data_[4 + j * 6] * rw); + int ymax = (output_data_[5 + j * 6] * rh); + + //get width; get height + int wd = xmax - xmin; //width + int hd = ymax - ymin; //height + + ObjectResult result_item; + result_item.rect = {xmin, ymin, xmax, ymax}; + result_item.class_id = class_id; + result_item.confidence = score; + result->push_back(result_item); + } + } + start_idx += bbox_num[im_id]; + } +} +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/core/general-server/op/general_picodet_op.h b/core/general-server/op/general_picodet_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1f5bcca8e645b9b3d12b3ddf512d1d6ebcbf0469 --- /dev/null +++ b/core/general-server/op/general_picodet_op.h @@ -0,0 +1,183 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "core/general-server/general_model_service.pb.h" +#include "core/general-server/op/general_infer_helper.h" + +#include "core/predictor/tools/pp_shitu_tools/preprocess_op.h" +#include "paddle_inference_api.h" // NOLINT + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" + +namespace baidu { +namespace paddle_serving { +namespace serving { + +struct ObjectResult { + // Rectangle coordinates of detected object: left, right, top, down + std::vector rect; + // Class id of detected object + int class_id; + // Confidence of detected object + float confidence; +}; + +class ImageBlob { +public: + // image width and height + std::vector im_shape_; + // Buffer for image data after preprocessing + std::vector im_data_; + // in net data shape(after pad) + std::vector in_net_shape_; + // Scale factor for image size to origin image size + std::vector scale_factor_; +}; + +class GeneralPicodetOp + : public baidu::paddle_serving::predictor::OpWithChannel { + public: + typedef std::vector TensorVector; + DECLARE_OP(GeneralPicodetOp); + int inference(); //op to do inference + + private: + // rec preprocess + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector std_ = {0.229f, 0.224f, 0.225f}; + float scale_ = 0.00392157; + int resize_size_ = 224; + int resize_short_ = 224; + + Feature::ResizeImg resize_op_; + Feature::Normalize normalize_op_; + Feature::Permute permute_op_; + + private: + // det pre-process + ImageBlob imgblob; + + //resize + int interp_ = 2; + bool keep_ratio_ = false; + std::vector target_size_ = {640,640}; + std::vector in_net_shape_; + + void InitInfo_Run(cv::Mat *im, ImageBlob *data) { + data->im_shape_ = {static_cast(im->rows), + static_cast(im->cols)}; + data->scale_factor_ = {1., 1.}; + data->in_net_shape_ = {static_cast(im->rows), + static_cast(im->cols)}; + } + + void NormalizeImage_Run(cv::Mat *im, ImageBlob *data) { + double e = 1.0; + e /= 255.0; + (*im).convertTo(*im, CV_32FC3, e); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean_[0]) / std_[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean_[1]) / std_[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean_[2]) / std_[2]; + } + } + VLOG(2) << "enter NormalizeImage_Run run"; + VLOG(2) << data->im_shape_[0]; + VLOG(2) << data->im_shape_[1]; + VLOG(2) << data->scale_factor_[0]; + VLOG(2) << data->scale_factor_[1]; + } + + void Resize_Run(cv::Mat *im, ImageBlob *data) { + auto resize_scale = GenerateScale(*im); + data->im_shape_ = {static_cast(im->cols * resize_scale.first), + static_cast(im->rows * resize_scale.second)}; + data->in_net_shape_ = {static_cast(im->cols * resize_scale.first), + static_cast(im->rows * resize_scale.second)}; + cv::resize(*im, *im, cv::Size(), resize_scale.first, resize_scale.second, + interp_); + data->im_shape_ = { + static_cast(im->rows), static_cast(im->cols), + }; + data->scale_factor_ = { + resize_scale.second, resize_scale.first, + }; + VLOG(2) << "enter resize run"; + VLOG(2) << data->im_shape_[0]; + VLOG(2) << data->im_shape_[1]; + VLOG(2) << data->scale_factor_[0]; + VLOG(2) << data->scale_factor_[1]; + } + + std::pair GenerateScale(const cv::Mat &im) { + std::pair resize_scale; + int origin_w = im.cols; + int origin_h = im.rows; + + if (keep_ratio_) { + int im_size_max = std::max(origin_w, origin_h); + int im_size_min = std::min(origin_w, origin_h); + int target_size_max = + *std::max_element(target_size_.begin(), target_size_.end()); + int target_size_min = + *std::min_element(target_size_.begin(), target_size_.end()); + double scale_min = + static_cast(target_size_min) / static_cast(im_size_min); + double scale_max = + static_cast(target_size_max) / static_cast(im_size_max); + double scale_ratio = std::min(scale_min, scale_max); + resize_scale = {scale_ratio, scale_ratio}; + } else { + resize_scale.first = + static_cast(target_size_[1]) / static_cast(origin_w); + resize_scale.second = + static_cast(target_size_[0]) / static_cast(origin_h); + } + return resize_scale; + } + + void Permute_Run(cv::Mat *im, ImageBlob *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + (data->im_data_).resize(rc * rh * rw); + float *base = (data->im_data_).data(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i); + } + } + + //det process + int max_detect_results = 5; + void Postprocess(const std::vector mats, + std::vector *result, + std::vector bbox_num, + bool is_rbox, + std::vector output_data_, + std::vector out_bbox_num_data_); +}; // GeneralPicodetOp +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/core/predictor/tools/pp_shitu_tools/preprocess_op.cpp b/core/predictor/tools/pp_shitu_tools/preprocess_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81ee7891aebbf1dcb814665eb30fa5c046d9f799 --- /dev/null +++ b/core/predictor/tools/pp_shitu_tools/preprocess_op.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "preprocess_op.h" + +namespace Feature { + void Permute::Run(const cv::Mat *im, float *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i); + } + } + + void Normalize::Run(cv::Mat *im, const std::vector &mean, + const std::vector &std, float scale) { + (*im).convertTo(*im, CV_32FC3, scale); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean[0]) / std[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean[1]) / std[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean[2]) / std[2]; + } + } + } + + void CenterCropImg::Run(cv::Mat &img, const int crop_size) { + int resize_w = img.cols; + int resize_h = img.rows; + int w_start = int((resize_w - crop_size) / 2); + int h_start = int((resize_h - crop_size) / 2); + cv::Rect rect(w_start, h_start, crop_size, crop_size); + img = img(rect); + } + + void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + int resize_short_size, int size) { + int resize_h = 0; + int resize_w = 0; + if (size > 0) { + resize_h = size; + resize_w = size; + } else { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + if (h < w) { + ratio = float(resize_short_size) / float(h); + } else { + ratio = float(resize_short_size) / float(w); + } + resize_h = round(float(h) * ratio); + resize_w = round(float(w) * ratio); + } + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); + } + +} // namespace Feature diff --git a/core/predictor/tools/pp_shitu_tools/preprocess_op.h b/core/predictor/tools/pp_shitu_tools/preprocess_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5cdc60cae7f89abc41356cce44b6dbcbe7d2f063 --- /dev/null +++ b/core/predictor/tools/pp_shitu_tools/preprocess_op.h @@ -0,0 +1,55 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace std; + +namespace Feature { + class Normalize { + public: + virtual void Run(cv::Mat *im, const std::vector &mean, + const std::vector &std, float scale); + }; + +// RGB -> CHW + class Permute { + public: + virtual void Run(const cv::Mat *im, float *data); + }; + + class CenterCropImg { + public: + virtual void Run(cv::Mat &im, const int crop_size = 224); + }; + + class ResizeImg { + public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, + int size = 0); + }; +} // namespace Feature diff --git a/doc/images/wechat_group_1.jpeg b/doc/images/wechat_group_1.jpeg index 22f934ebe204527782bad9b98be56c17c9ec4417..18978f92fd6912e508d51ad2d02faa4b9c34bfde 100644 Binary files a/doc/images/wechat_group_1.jpeg and b/doc/images/wechat_group_1.jpeg differ diff --git a/examples/C++/PaddleClas/pp_shitu/README.md b/examples/C++/PaddleClas/pp_shitu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..19363c5b3fbdc534f0483dac77b189e4071e5d0e --- /dev/null +++ b/examples/C++/PaddleClas/pp_shitu/README.md @@ -0,0 +1,24 @@ +# PP-Shitu + +## Get Model +``` +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/serving/pp_shitu.tar.gz +tar -xzvf pp_shitu.tar.gz +``` + +## Get test images and index +``` +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar +tar -xvf drink_dataset_v1.0.tar +``` + +## RPC Service +### Start Service +``` +sh run_cpp_serving.sh +``` + +### Client Prediction +``` +python3 test_cpp_serving_pipeline.py ./drint_dataset_v1.0/test_images/nongfu_spring.jpeg +``` diff --git a/examples/C++/PaddleClas/pp_shitu/README_CN.md b/examples/C++/PaddleClas/pp_shitu/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..fd051be9448726287b580474bd10afbaafb10053 --- /dev/null +++ b/examples/C++/PaddleClas/pp_shitu/README_CN.md @@ -0,0 +1,24 @@ +# PP-Shitu + +## 获取模型 +``` +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/serving/pp_shitu.tar.gz +tar -xzvf pp_shitu.tar.gz +``` + +## 获取测试图像和index +``` +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar +tar -xvf drink_dataset_v1.0.tar +``` + +## RPC 服务 +### 启动服务端 +``` +sh run_cpp_serving.sh +``` + +### 客户端预测 +``` +python3 test_cpp_serving_pipeline.py ./drint_dataset_v1.0/test_images/nongfu_spring.jpeg +``` diff --git a/examples/C++/PaddleClas/pp_shitu/run_cpp_serving.sh b/examples/C++/PaddleClas/pp_shitu/run_cpp_serving.sh new file mode 100644 index 0000000000000000000000000000000000000000..0ad4c82d018bb8c39fbd48a68fce0c6a31544e40 --- /dev/null +++ b/examples/C++/PaddleClas/pp_shitu/run_cpp_serving.sh @@ -0,0 +1,4 @@ +rm -rf log +rm -rf workdir* +export GLOG_v=3 +nohup python3 -m paddle_serving_server.serve --model picodet_PPLCNet_x2_5_mainbody_lite_v2.0_serving general_PPLCNet_x2_5_lite_v2.0_serving --op GeneralPicodetOp GeneralFeatureExtractOp --port 9400 & diff --git a/examples/C++/PaddleClas/pp_shitu/test_cpp_serving_pipeline.py b/examples/C++/PaddleClas/pp_shitu/test_cpp_serving_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d560d5ea56cc1c60f074a7e69d7e9202aff895 --- /dev/null +++ b/examples/C++/PaddleClas/pp_shitu/test_cpp_serving_pipeline.py @@ -0,0 +1,133 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np + +from paddle_serving_client import Client +from paddle_serving_app.reader import * +import cv2 +import faiss +import os +import pickle + +rec_nms_thresold = 0.05 +rec_score_thres = 0.5 +feature_normalize = True +return_k = 1 +index_dir = "./drink_dataset_v1.0/index" + +def init_index(index_dir): + assert os.path.exists(os.path.join( + index_dir, "vector.index")), "vector.index not found ..." + assert os.path.exists(os.path.join( + index_dir, "id_map.pkl")), "id_map.pkl not found ... " + + searcher = faiss.read_index( + os.path.join(index_dir, "vector.index")) + + with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: + id_map = pickle.load(fd) + return searcher, id_map + +#get box +def nms_to_rec_results(results, thresh=0.1): + filtered_results = [] + + x1 = np.array([r["bbox"][0] for r in results]).astype("float32") + y1 = np.array([r["bbox"][1] for r in results]).astype("float32") + x2 = np.array([r["bbox"][2] for r in results]).astype("float32") + y2 = np.array([r["bbox"][3] for r in results]).astype("float32") + scores = np.array([r["rec_scores"] for r in results]) + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + while order.size > 0: + i = order[0] + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + filtered_results.append(results[i]) + return filtered_results + +def postprocess(fetch_dict, + feature_normalize, + det_boxes, + searcher, + id_map, + return_k, + rec_score_thres, + rec_nms_thresold): + batch_features = fetch_dict["features"] + + #do feature norm + if feature_normalize: + feas_norm = np.sqrt( + np.sum(np.square(batch_features), axis=1, keepdims=True)) + batch_features = np.divide(batch_features, feas_norm) + + scores, docs = searcher.search(batch_features, return_k) + + results = [] + for i in range(scores.shape[0]): + pred = {} + if scores[i][0] >= rec_score_thres: + pred["bbox"] = [int(x) for x in det_boxes[i,2:]] + pred["rec_docs"] = id_map[docs[i][0]].split()[1] + pred["rec_scores"] = scores[i][0] + results.append(pred) + + #do nms + results = nms_to_rec_results(results, rec_nms_thresold) + return results + +#do client +if __name__=="__main__": + client = Client() + client.load_client_config(["picodet_PPLCNet_x2_5_mainbody_lite_v2.0_client", "general_PPLCNet_x2_5_lite_v2.0_client"]) + client.connect(['127.0.0.1:9400']) + + im = cv2.imread(sys.argv[1]) + im_shape = np.array(im.shape[:2]).reshape(-1) + fetch_map = client.predict( + feed={ + "image": im, + "im_shape": im_shape + }, + fetch=["features", "boxes"], + batch=False) + + #add retrieval procedure + det_boxes = fetch_map["boxes"] + searcher, id_map = init_index(index_dir) + results = postprocess(fetch_map, feature_normalize, det_boxes, searcher, id_map, return_k, rec_score_thres, rec_nms_thresold) + print(results) + + + + + + + + + +