From cfb3699c19ac1d05ea1d4f6a0c4e5ea181289ed5 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Tue, 7 Jun 2022 17:12:45 +0800 Subject: [PATCH] [TIPC] add serving cpp infer, test=document_fix (#6145) --- deploy/serving/cpp/build_server.sh | 70 +++++ deploy/serving/cpp/preprocess/ppyoloe_op.cpp | 258 ++++++++++++++++ deploy/serving/cpp/preprocess/ppyoloe_op.h | 70 +++++ .../preprocess/serving_client_conf.prototxt | 20 ++ deploy/serving/cpp/preprocess/yolov3_op.cpp | 280 ++++++++++++++++++ deploy/serving/cpp/preprocess/yolov3_op.h | 69 +++++ deploy/serving/cpp/serving_client.py | 118 ++++++++ deploy/serving/python/pipeline_http_client.py | 18 +- deploy/serving/python/web_service.py | 11 +- deploy/third_engine/onnx/infer.py | 8 +- 10 files changed, 898 insertions(+), 24 deletions(-) create mode 100644 deploy/serving/cpp/build_server.sh create mode 100644 deploy/serving/cpp/preprocess/ppyoloe_op.cpp create mode 100644 deploy/serving/cpp/preprocess/ppyoloe_op.h create mode 100644 deploy/serving/cpp/preprocess/serving_client_conf.prototxt create mode 100644 deploy/serving/cpp/preprocess/yolov3_op.cpp create mode 100644 deploy/serving/cpp/preprocess/yolov3_op.h create mode 100644 deploy/serving/cpp/serving_client.py diff --git a/deploy/serving/cpp/build_server.sh b/deploy/serving/cpp/build_server.sh new file mode 100644 index 000000000..28ba46dd6 --- /dev/null +++ b/deploy/serving/cpp/build_server.sh @@ -0,0 +1,70 @@ +#使用镜像: +#registry.baidubce.com/paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 + +#编译Serving Server: + +#client和app可以直接使用release版本 + +#server因为加入了自定义OP,需要重新编译 + +apt-get update +apt install -y libcurl4-openssl-dev libbz2-dev +wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && tar xf centos_ssl.tar && rm -rf centos_ssl.tar && mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k && mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k && ln -sf /usr/lib/libcrypto.so.1.0.2k /usr/lib/libcrypto.so.10 && ln -sf /usr/lib/libssl.so.1.0.2k /usr/lib/libssl.so.10 && ln -sf /usr/lib/libcrypto.so.10 /usr/lib/libcrypto.so && ln -sf /usr/lib/libssl.so.10 /usr/lib/libssl.so + +# 安装go依赖 +rm -rf /usr/local/go +wget -qO- https://paddle-ci.cdn.bcebos.com/go1.17.2.linux-amd64.tar.gz | tar -xz -C /usr/local +export GOROOT=/usr/local/go +export GOPATH=/root/gopath +export PATH=$PATH:$GOPATH/bin:$GOROOT/bin +go env -w GO111MODULE=on +go env -w GOPROXY=https://goproxy.cn,direct +go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2 +go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2 +go install github.com/golang/protobuf/protoc-gen-go@v1.4.3 +go install google.golang.org/grpc@v1.33.0 +go env -w GO111MODULE=auto + +# 下载opencv库 +wget https://paddle-qa.bj.bcebos.com/PaddleServing/opencv3.tar.gz && tar -xvf opencv3.tar.gz && rm -rf opencv3.tar.gz +export OPENCV_DIR=$PWD/opencv3 + +# clone Serving +git clone https://github.com/PaddlePaddle/Serving.git -b develop --depth=1 +cd Serving +export Serving_repo_path=$PWD +git submodule update --init --recursive +python -m pip install -r python/requirements.txt + +# set env +export PYTHON_INCLUDE_DIR=$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") +export PYTHON_LIBRARIES=$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))") +export PYTHON_EXECUTABLE=`which python` + +export CUDA_PATH='/usr/local/cuda' +export CUDNN_LIBRARY='/usr/local/cuda/lib64/' +export CUDA_CUDART_LIBRARY='/usr/local/cuda/lib64/' +export TENSORRT_LIBRARY_PATH='/usr/local/TensorRT6-cuda10.1-cudnn7/targets/x86_64-linux-gnu/' + +# cp 自定义OP代码 +\cp ../deploy/serving/cpp/preprocess/ppyoloe_op.* ${Serving_repo_path}/core/general-server/op +\cp ../deploy/serving/cpp/preprocess/yolov3_op.* ${Serving_repo_path}/core/general-server/op + +# 编译Server, export SERVING_BIN +mkdir server-build-gpu-opencv && cd server-build-gpu-opencv +cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \ + -DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \ + -DCUDNN_LIBRARY=${CUDNN_LIBRARY} \ + -DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \ + -DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \ + -DOPENCV_DIR=${OPENCV_DIR} \ + -DWITH_OPENCV=ON \ + -DSERVER=ON \ + -DWITH_GPU=ON .. +make -j32 + +python -m pip install python/dist/paddle* +export SERVING_BIN=$PWD/core/general-server/serving +cd ../../ diff --git a/deploy/serving/cpp/preprocess/ppyoloe_op.cpp b/deploy/serving/cpp/preprocess/ppyoloe_op.cpp new file mode 100644 index 000000000..cfa937d1c --- /dev/null +++ b/deploy/serving/cpp/preprocess/ppyoloe_op.cpp @@ -0,0 +1,258 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-server/op/ppyoloe_op.h" +#include "core/predictor/framework/infer.h" +#include "core/predictor/framework/memory.h" +#include "core/predictor/framework/resource.h" +#include "core/util/include/timer.h" +#include +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace serving { + +using baidu::paddle_serving::Timer; +using baidu::paddle_serving::predictor::InferManager; +using baidu::paddle_serving::predictor::MempoolWrapper; +using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::general_model::Response; +using baidu::paddle_serving::predictor::general_model::Tensor; + +int PPYOLOEOp::inference() { + VLOG(2) << "Going to run inference"; + const std::vector pre_node_names = pre_names(); + if (pre_node_names.size() != 1) { + LOG(ERROR) << "This op(" << op_name() + << ") can only have one predecessor op, but received " + << pre_node_names.size(); + return -1; + } + const std::string pre_name = pre_node_names[0]; + + const GeneralBlob *input_blob = get_depend_argument(pre_name); + if (!input_blob) { + LOG(ERROR) << "input_blob is nullptr,error"; + return -1; + } + uint64_t log_id = input_blob->GetLogId(); + VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; + + GeneralBlob *output_blob = mutable_data(); + if (!output_blob) { + LOG(ERROR) << "output_blob is nullptr,error"; + return -1; + } + output_blob->SetLogId(log_id); + + if (!input_blob) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed mutable depended argument, op:" << pre_name; + return -1; + } + + const TensorVector *in = &input_blob->tensor_vector; + TensorVector *out = &output_blob->tensor_vector; + + int batch_size = input_blob->_batch_size; + output_blob->_batch_size = batch_size; + VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size; + + Timer timeline; + int64_t start = timeline.TimeStampUS(); + timeline.Start(); + + // only support string type + char *total_input_ptr = static_cast(in->at(0).data.data()); + std::string base64str = total_input_ptr; + + cv::Mat img = Base2Mat(base64str); + cv::cvtColor(img, img, cv::COLOR_BGR2RGB); + + // preprocess + std::vector input(1 * 3 * im_shape_h * im_shape_w, 0.0f); + preprocess_det(img, input.data(), scale_factor_h, scale_factor_w, im_shape_h, + im_shape_w, mean_, scale_, is_scale_); + + // create real_in + TensorVector *real_in = new TensorVector(); + if (!real_in) { + LOG(ERROR) << "real_in is nullptr,error"; + return -1; + } + + int in_num = 0; + size_t databuf_size = 0; + void *databuf_data = NULL; + char *databuf_char = NULL; + + // image + in_num = 1 * 3 * im_shape_h * im_shape_w; + databuf_size = in_num * sizeof(float); + + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + + memcpy(databuf_data, input.data(), databuf_size); + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in; + tensor_in.name = "image"; + tensor_in.dtype = paddle::PaddleDType::FLOAT32; + tensor_in.shape = {1, 3, im_shape_h, im_shape_w}; + tensor_in.lod = in->at(0).lod; + tensor_in.data = paddleBuf; + real_in->push_back(tensor_in); + + // scale_factor + std::vector scale_factor{scale_factor_h, scale_factor_w}; + databuf_size = 2 * sizeof(float); + + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + + memcpy(databuf_data, scale_factor.data(), databuf_size); + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf_2(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in_2; + tensor_in_2.name = "scale_factor"; + tensor_in_2.dtype = paddle::PaddleDType::FLOAT32; + tensor_in_2.shape = {1, 2}; + tensor_in_2.lod = in->at(0).lod; + tensor_in_2.data = paddleBuf_2; + real_in->push_back(tensor_in_2); + + if (InferManager::instance().infer(engine_name().c_str(), real_in, out, + batch_size)) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed do infer in fluid model: " << engine_name().c_str(); + return -1; + } + + int64_t end = timeline.TimeStampUS(); + CopyBlobInfo(input_blob, output_blob); + AddBlobInfo(output_blob, start); + AddBlobInfo(output_blob, end); + return 0; +} + +void PPYOLOEOp::preprocess_det(const cv::Mat &img, float *data, + float &scale_factor_h, float &scale_factor_w, + int im_shape_h, int im_shape_w, + const std::vector &mean, + const std::vector &scale, + const bool is_scale) { + // scale_factor + scale_factor_h = + static_cast(im_shape_h) / static_cast(img.rows); + scale_factor_w = + static_cast(im_shape_w) / static_cast(img.cols); + + // Resize + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(im_shape_w, im_shape_h), 0, 0, 2); + + // Normalize + double e = 1.0; + if (is_scale) { + e /= 255.0; + } + cv::Mat img_fp; + (resize_img).convertTo(img_fp, CV_32FC3, e); + for (int h = 0; h < im_shape_h; h++) { + for (int w = 0; w < im_shape_w; w++) { + img_fp.at(h, w)[0] = + (img_fp.at(h, w)[0] - mean[0]) / scale[0]; + img_fp.at(h, w)[1] = + (img_fp.at(h, w)[1] - mean[1]) / scale[1]; + img_fp.at(h, w)[2] = + (img_fp.at(h, w)[2] - mean[2]) / scale[2]; + } + } + + // Permute + int rh = img_fp.rows; + int rw = img_fp.cols; + int rc = img_fp.channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(img_fp, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), + i); + } +} + +cv::Mat PPYOLOEOp::Base2Mat(std::string &base64_data) { + cv::Mat img; + std::string s_mat; + s_mat = base64Decode(base64_data.data(), base64_data.size()); + std::vector base64_img(s_mat.begin(), s_mat.end()); + img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR + return img; +} + +std::string PPYOLOEOp::base64Decode(const char *Data, int DataByte) { + const char DecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + }; + + std::string strDecode; + int nValue; + int i = 0; + while (i < DataByte) { + if (*Data != '\r' && *Data != '\n') { + nValue = DecodeTable[*Data++] << 18; + nValue += DecodeTable[*Data++] << 12; + strDecode += (nValue & 0x00FF0000) >> 16; + if (*Data != '=') { + nValue += DecodeTable[*Data++] << 6; + strDecode += (nValue & 0x0000FF00) >> 8; + if (*Data != '=') { + nValue += DecodeTable[*Data++]; + strDecode += nValue & 0x000000FF; + } + } + i += 4; + } else // 回车换行,跳过 + { + Data++; + i++; + } + } + return strDecode; +} + +DEFINE_OP(PPYOLOEOp); + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/deploy/serving/cpp/preprocess/ppyoloe_op.h b/deploy/serving/cpp/preprocess/ppyoloe_op.h new file mode 100644 index 000000000..87d81d242 --- /dev/null +++ b/deploy/serving/cpp/preprocess/ppyoloe_op.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "core/general-server/general_model_service.pb.h" +#include "core/general-server/op/general_infer_helper.h" +#include "paddle_inference_api.h" // NOLINT +#include +#include + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace serving { + +class PPYOLOEOp + : public baidu::paddle_serving::predictor::OpWithChannel { +public: + typedef std::vector TensorVector; + + DECLARE_OP(PPYOLOEOp); + + int inference(); + +private: + // ppyoloe, picodet preprocess + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {0.229f, 0.224f, 0.225f}; + bool is_scale_ = true; + int im_shape_h = 640; + int im_shape_w = 640; + float scale_factor_h = 1.0f; + float scale_factor_w = 1.0f; + void preprocess_det(const cv::Mat &img, float *data, float &scale_factor_h, + float &scale_factor_w, int im_shape_h, + int im_shape_w, const std::vector &mean, + const std::vector &scale, + const bool is_scale); + + // read pics + cv::Mat Base2Mat(std::string &base64_data); + std::string base64Decode(const char *Data, int DataByte); +}; + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/deploy/serving/cpp/preprocess/serving_client_conf.prototxt b/deploy/serving/cpp/preprocess/serving_client_conf.prototxt new file mode 100644 index 000000000..fb069003a --- /dev/null +++ b/deploy/serving/cpp/preprocess/serving_client_conf.prototxt @@ -0,0 +1,20 @@ +feed_var { + name: "input" + alias_name: "input" + is_lod_tensor: false + feed_type: 20 + shape: 1 +} +fetch_var { + name: "multiclass_nms3_0.tmp_0" + alias_name: "multiclass_nms3_0.tmp_0" + is_lod_tensor: true + fetch_type: 1 + shape: -1 +} +fetch_var { + name: "multiclass_nms3_0.tmp_2" + alias_name: "multiclass_nms3_0.tmp_2" + is_lod_tensor: false + fetch_type: 2 +} \ No newline at end of file diff --git a/deploy/serving/cpp/preprocess/yolov3_op.cpp b/deploy/serving/cpp/preprocess/yolov3_op.cpp new file mode 100644 index 000000000..34ca313c7 --- /dev/null +++ b/deploy/serving/cpp/preprocess/yolov3_op.cpp @@ -0,0 +1,280 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-server/op/yolov3_op.h" +#include "core/predictor/framework/infer.h" +#include "core/predictor/framework/memory.h" +#include "core/predictor/framework/resource.h" +#include "core/util/include/timer.h" +#include +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace serving { + +using baidu::paddle_serving::Timer; +using baidu::paddle_serving::predictor::InferManager; +using baidu::paddle_serving::predictor::MempoolWrapper; +using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::general_model::Response; +using baidu::paddle_serving::predictor::general_model::Tensor; + +int YOLOv3Op::inference() { + VLOG(2) << "Going to run inference"; + const std::vector pre_node_names = pre_names(); + if (pre_node_names.size() != 1) { + LOG(ERROR) << "This op(" << op_name() + << ") can only have one predecessor op, but received " + << pre_node_names.size(); + return -1; + } + const std::string pre_name = pre_node_names[0]; + + const GeneralBlob *input_blob = get_depend_argument(pre_name); + if (!input_blob) { + LOG(ERROR) << "input_blob is nullptr,error"; + return -1; + } + uint64_t log_id = input_blob->GetLogId(); + VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; + + GeneralBlob *output_blob = mutable_data(); + if (!output_blob) { + LOG(ERROR) << "output_blob is nullptr,error"; + return -1; + } + output_blob->SetLogId(log_id); + + if (!input_blob) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed mutable depended argument, op:" << pre_name; + return -1; + } + + const TensorVector *in = &input_blob->tensor_vector; + TensorVector *out = &output_blob->tensor_vector; + + int batch_size = input_blob->_batch_size; + output_blob->_batch_size = batch_size; + VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size; + + Timer timeline; + int64_t start = timeline.TimeStampUS(); + timeline.Start(); + + // only support string type + char *total_input_ptr = static_cast(in->at(0).data.data()); + std::string base64str = total_input_ptr; + + cv::Mat img = Base2Mat(base64str); + cv::cvtColor(img, img, cv::COLOR_BGR2RGB); + + // preprocess + std::vector input(1 * 3 * im_shape_h * im_shape_w, 0.0f); + preprocess_det(img, input.data(), scale_factor_h, scale_factor_w, im_shape_h, + im_shape_w, mean_, scale_, is_scale_); + + // create real_in + TensorVector *real_in = new TensorVector(); + if (!real_in) { + LOG(ERROR) << "real_in is nullptr,error"; + return -1; + } + + int in_num = 0; + size_t databuf_size = 0; + void *databuf_data = NULL; + char *databuf_char = NULL; + + // im_shape + std::vector im_shape{static_cast(im_shape_h), + static_cast(im_shape_w)}; + databuf_size = 2 * sizeof(float); + + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + + memcpy(databuf_data, im_shape.data(), databuf_size); + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf_0(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in_0; + tensor_in_0.name = "im_shape"; + tensor_in_0.dtype = paddle::PaddleDType::FLOAT32; + tensor_in_0.shape = {1, 2}; + tensor_in_0.lod = in->at(0).lod; + tensor_in_0.data = paddleBuf_0; + real_in->push_back(tensor_in_0); + + // image + in_num = 1 * 3 * im_shape_h * im_shape_w; + databuf_size = in_num * sizeof(float); + + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + + memcpy(databuf_data, input.data(), databuf_size); + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf_1(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in_1; + tensor_in_1.name = "image"; + tensor_in_1.dtype = paddle::PaddleDType::FLOAT32; + tensor_in_1.shape = {1, 3, im_shape_h, im_shape_w}; + tensor_in_1.lod = in->at(0).lod; + tensor_in_1.data = paddleBuf_1; + real_in->push_back(tensor_in_1); + + // scale_factor + std::vector scale_factor{scale_factor_h, scale_factor_w}; + databuf_size = 2 * sizeof(float); + + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + + memcpy(databuf_data, scale_factor.data(), databuf_size); + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf_2(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in_2; + tensor_in_2.name = "scale_factor"; + tensor_in_2.dtype = paddle::PaddleDType::FLOAT32; + tensor_in_2.shape = {1, 2}; + tensor_in_2.lod = in->at(0).lod; + tensor_in_2.data = paddleBuf_2; + real_in->push_back(tensor_in_2); + + if (InferManager::instance().infer(engine_name().c_str(), real_in, out, + batch_size)) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed do infer in fluid model: " << engine_name().c_str(); + return -1; + } + + int64_t end = timeline.TimeStampUS(); + CopyBlobInfo(input_blob, output_blob); + AddBlobInfo(output_blob, start); + AddBlobInfo(output_blob, end); + return 0; +} + +void YOLOv3Op::preprocess_det(const cv::Mat &img, float *data, + float &scale_factor_h, float &scale_factor_w, + int im_shape_h, int im_shape_w, + const std::vector &mean, + const std::vector &scale, + const bool is_scale) { + // scale_factor + scale_factor_h = + static_cast(im_shape_h) / static_cast(img.rows); + scale_factor_w = + static_cast(im_shape_w) / static_cast(img.cols); + + // Resize + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(im_shape_w, im_shape_h), 0, 0, 2); + + // Normalize + double e = 1.0; + if (is_scale) { + e /= 255.0; + } + cv::Mat img_fp; + (resize_img).convertTo(img_fp, CV_32FC3, e); + for (int h = 0; h < im_shape_h; h++) { + for (int w = 0; w < im_shape_w; w++) { + img_fp.at(h, w)[0] = + (img_fp.at(h, w)[0] - mean[0]) / scale[0]; + img_fp.at(h, w)[1] = + (img_fp.at(h, w)[1] - mean[1]) / scale[1]; + img_fp.at(h, w)[2] = + (img_fp.at(h, w)[2] - mean[2]) / scale[2]; + } + } + + // Permute + int rh = img_fp.rows; + int rw = img_fp.cols; + int rc = img_fp.channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(img_fp, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), + i); + } +} + +cv::Mat YOLOv3Op::Base2Mat(std::string &base64_data) { + cv::Mat img; + std::string s_mat; + s_mat = base64Decode(base64_data.data(), base64_data.size()); + std::vector base64_img(s_mat.begin(), s_mat.end()); + img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR + return img; +} + +std::string YOLOv3Op::base64Decode(const char *Data, int DataByte) { + const char DecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + }; + + std::string strDecode; + int nValue; + int i = 0; + while (i < DataByte) { + if (*Data != '\r' && *Data != '\n') { + nValue = DecodeTable[*Data++] << 18; + nValue += DecodeTable[*Data++] << 12; + strDecode += (nValue & 0x00FF0000) >> 16; + if (*Data != '=') { + nValue += DecodeTable[*Data++] << 6; + strDecode += (nValue & 0x0000FF00) >> 8; + if (*Data != '=') { + nValue += DecodeTable[*Data++]; + strDecode += nValue & 0x000000FF; + } + } + i += 4; + } else // 回车换行,跳过 + { + Data++; + i++; + } + } + return strDecode; +} + +DEFINE_OP(YOLOv3Op); + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/deploy/serving/cpp/preprocess/yolov3_op.h b/deploy/serving/cpp/preprocess/yolov3_op.h new file mode 100644 index 000000000..6445ccb88 --- /dev/null +++ b/deploy/serving/cpp/preprocess/yolov3_op.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "core/general-server/general_model_service.pb.h" +#include "core/general-server/op/general_infer_helper.h" +#include "paddle_inference_api.h" // NOLINT +#include +#include + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace serving { + +class YOLOv3Op + : public baidu::paddle_serving::predictor::OpWithChannel { +public: + typedef std::vector TensorVector; + + DECLARE_OP(YOLOv3Op); + + int inference(); + +private: + // yolov3, ppyolo preprocess + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {0.229f, 0.224f, 0.225f}; + bool is_scale_ = true; + int im_shape_h = 608; + int im_shape_w = 608; + float scale_factor_h = 1.0f; + float scale_factor_w = 1.0f; + void preprocess_det(const cv::Mat &img, float *data, float &scale_factor_h, + float &scale_factor_w, int im_shape_h, int im_shape_w, + const std::vector &mean, + const std::vector &scale, const bool is_scale); + + // read pics + cv::Mat Base2Mat(std::string &base64_data); + std::string base64Decode(const char *Data, int DataByte); +}; + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/deploy/serving/cpp/serving_client.py b/deploy/serving/cpp/serving_client.py new file mode 100644 index 000000000..3f2c5b656 --- /dev/null +++ b/deploy/serving/cpp/serving_client.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import glob +import os +from paddle_serving_client import Client +from paddle_serving_client.proto import general_model_config_pb2 as m_config +import google.protobuf.text_format + +import argparse + +parser = argparse.ArgumentParser(description="args for paddleserving") +parser.add_argument( + "--serving_client", type=str, help="the directory of serving_client") +parser.add_argument("--image_dir", type=str) +parser.add_argument("--image_file", type=str) +parser.add_argument( + "--threshold", type=float, default=0.5, help="Threshold of score.") +args = parser.parse_args() + + +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--image_file or --image_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + return [infer_img] + + images = set() + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + print("Found {} inference images in total.".format(len(images))) + + return images + + +def postprocess(fetch_dict, draw_threshold=0.5): + bboxes = fetch_dict["multiclass_nms3_0.tmp_0"] + bboxes_num = fetch_dict["multiclass_nms3_0.tmp_2"] + for bbox in bboxes: + if bbox[0] > -1 and bbox[1] > draw_threshold: + print(f"{int(bbox[0])} {bbox[1]} " + f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}") + return fetch_dict + + +def get_model_vars(client_config_dir): + # read original serving_client_conf.prototxt + client_config_file = os.path.join(client_config_dir, + "serving_client_conf.prototxt") + with open(client_config_file, 'r') as f: + model_var = google.protobuf.text_format.Merge( + str(f.read()), m_config.GeneralModelConfig()) + # modify feed_var to run core/general-server/op/ + [model_var.feed_var.pop() for _ in range(len(model_var.feed_var))] + feed_var = m_config.FeedVar() + feed_var.name = "input" + feed_var.alias_name = "input" + feed_var.is_lod_tensor = False + feed_var.feed_type = 20 + feed_var.shape.extend([1]) + model_var.feed_var.extend([feed_var]) + with open( + os.path.join(client_config_dir, "serving_client_conf_cpp.prototxt"), + "w") as f: + f.write(str(model_var)) + # get feed_vars/fetch_vars + feed_vars = [var.name for var in model_var.feed_var] + fetch_vars = [var.name for var in model_var.fetch_var] + return feed_vars, fetch_vars + + +if __name__ == '__main__': + url = "127.0.0.1:9997" + logid = 10000 + img_list = get_test_images(args.image_dir, args.image_file) + feed_vars, fetch_vars = get_model_vars(args.serving_client) + + client = Client() + client.load_client_config( + os.path.join(args.serving_client, "serving_client_conf_cpp.prototxt")) + client.connect([url]) + + for img_file in img_list: + with open(img_file, 'rb') as file: + image_data = file.read() + image = base64.b64encode(image_data).decode('utf8') + fetch_dict = client.predict( + feed={feed_vars[0]: image}, fetch=fetch_vars) + result = postprocess(fetch_dict, args.threshold) diff --git a/deploy/serving/python/pipeline_http_client.py b/deploy/serving/python/pipeline_http_client.py index 33eb70dcf..9f5dbb8b7 100644 --- a/deploy/serving/python/pipeline_http_client.py +++ b/deploy/serving/python/pipeline_http_client.py @@ -56,19 +56,6 @@ def get_test_images(infer_dir, infer_img): return images -def cv2_to_base64(image): - """cv2_to_base64 - - Convert an numpy array to a base64 object. - - Args: - image: Input array. - - Returns: Base64 output of the input. - """ - return base64.b64encode(image).decode('utf8') - - if __name__ == "__main__": url = "http://127.0.0.1:18093/ppdet/prediction" logid = 10000 @@ -76,9 +63,10 @@ if __name__ == "__main__": for img_file in img_list: with open(img_file, 'rb') as file: - image_data1 = file.read() + image_data = file.read() - image = cv2_to_base64(image_data1) + # base64 encode + image = base64.b64encode(image_data).decode('utf8') data = {"key": ["image_0"], "value": [image], "logid": logid} # send requests diff --git a/deploy/serving/python/web_service.py b/deploy/serving/python/web_service.py index 12075bc87..cc49cab3e 100644 --- a/deploy/serving/python/web_service.py +++ b/deploy/serving/python/web_service.py @@ -207,7 +207,7 @@ class DetectorOp(Op): result = [] for line in bbox: if line[0] > -1 and line[1] > draw_threshold: - result.append(f"{label_list[int(line[0])]} {line[1]} " + result.append(f"{int(line[0])} {line[1]} " f"{line[2]} {line[3]} {line[4]} {line[5]}") return result @@ -222,10 +222,11 @@ def get_model_vars(model_dir, service_config): # rewrite model_config service_config['op']['ppdet']['local_service_conf'][ 'model_config'] = serving_server_dir - f = open( - os.path.join(serving_server_dir, "serving_server_conf.prototxt"), 'r') - model_var = google.protobuf.text_format.Merge( - str(f.read()), m_config.GeneralModelConfig()) + serving_server_conf = os.path.join(serving_server_dir, + "serving_server_conf.prototxt") + with open(serving_server_conf, 'r') as f: + model_var = google.protobuf.text_format.Merge( + str(f.read()), m_config.GeneralModelConfig()) feed_vars = [var.name for var in model_var.feed_var] fetch_vars = [var.name for var in model_var.fetch_var] return feed_vars, fetch_vars diff --git a/deploy/third_engine/onnx/infer.py b/deploy/third_engine/onnx/infer.py index 322768759..b728b3365 100644 --- a/deploy/third_engine/onnx/infer.py +++ b/deploy/third_engine/onnx/infer.py @@ -45,7 +45,7 @@ SUPPORT_MODELS = { } parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument("-c", "--config", type=str, help="infer_cfg.yml") +parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml") parser.add_argument( '--onnx_file', type=str, default="model.onnx", help="onnx model file path") parser.add_argument("--image_dir", type=str) @@ -86,7 +86,7 @@ def get_test_images(infer_dir, infer_img): class PredictConfig(object): """set config of preprocess, postprocess and visualize Args: - model_dir (str): root path of infer_cfg.yml + infer_config (str): path of infer_cfg.yml """ def __init__(self, infer_config): @@ -145,7 +145,7 @@ def predict_image(infer_config, predictor, img_list): bboxes = np.array(outputs[0]) for bbox in bboxes: if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold: - print(f"{infer_config.label_list[int(bbox[0])]} {bbox[1]} " + print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}") @@ -156,6 +156,6 @@ if __name__ == '__main__': # load predictor predictor = InferenceSession(FLAGS.onnx_file) # load infer config - infer_config = PredictConfig(FLAGS.config) + infer_config = PredictConfig(FLAGS.infer_cfg) predict_image(infer_config, predictor, img_list) -- GitLab