From ac821bbe4d7d93b4ff435449e1fb4c79cc79f828 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 12 Jun 2022 00:32:41 +0800 Subject: [PATCH] add cpp serving infer(except PPShiTu) --- deploy/paddleserving/build_server.sh | 74 +++ .../preprocess/general_clas_op.cpp | 206 +++++++++ .../preprocess/general_clas_op.h | 70 +++ .../preprocess/preprocess_op.cpp | 149 ++++++ .../paddleserving/preprocess/preprocess_op.h | 81 ++++ ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + ...ormal_normal_serving_cpp_linux_gpu_cpu.txt | 14 + test_tipc/prepare.sh | 12 +- test_tipc/test_serving_infer.sh | 435 ++++++++++-------- 22 files changed, 1048 insertions(+), 189 deletions(-) create mode 100644 deploy/paddleserving/build_server.sh create mode 100644 deploy/paddleserving/preprocess/general_clas_op.cpp create mode 100644 deploy/paddleserving/preprocess/general_clas_op.h create mode 100644 deploy/paddleserving/preprocess/preprocess_op.cpp create mode 100644 deploy/paddleserving/preprocess/preprocess_op.h create mode 100644 test_tipc/config/MobileNetV3/MobileNetV3_large_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPHGNet/PPHGNet_small_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPHGNet/PPHGNet_tiny_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x0_25_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x0_35_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x0_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x0_75_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x1_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x2_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNet/PPLCNet_x2_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/PPLCNetV2/PPLCNetV2_base_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/ResNet/ResNet50_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt create mode 100644 test_tipc/config/SwinTransformer/SwinTransformer_tiny_patch4_window7_224_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt diff --git a/deploy/paddleserving/build_server.sh b/deploy/paddleserving/build_server.sh new file mode 100644 index 00000000..290656ed --- /dev/null +++ b/deploy/paddleserving/build_server.sh @@ -0,0 +1,74 @@ +#使用镜像: +#registry.baidubce.com/paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 + +#编译Serving Server: + +#client和app可以直接使用release版本 + +#server因为加入了自定义OP,需要重新编译 + +#默认编译时的${PWD}=PaddleClas/deploy/paddleserving/ + +python_name=${1:-'python'} + +apt-get update +apt install -y libcurl4-openssl-dev libbz2-dev +wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && tar xf centos_ssl.tar && rm -rf centos_ssl.tar && mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k && mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k && ln -sf /usr/lib/libcrypto.so.1.0.2k /usr/lib/libcrypto.so.10 && ln -sf /usr/lib/libssl.so.1.0.2k /usr/lib/libssl.so.10 && ln -sf /usr/lib/libcrypto.so.10 /usr/lib/libcrypto.so && ln -sf /usr/lib/libssl.so.10 /usr/lib/libssl.so + +# 安装go依赖 +rm -rf /usr/local/go +wget -qO- https://paddle-ci.cdn.bcebos.com/go1.17.2.linux-amd64.tar.gz | tar -xz -C /usr/local +export GOROOT=/usr/local/go +export GOPATH=/root/gopath +export PATH=$PATH:$GOPATH/bin:$GOROOT/bin +go env -w GO111MODULE=on +go env -w GOPROXY=https://goproxy.cn,direct +go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2 +go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2 +go install github.com/golang/protobuf/protoc-gen-go@v1.4.3 +go install google.golang.org/grpc@v1.33.0 +go env -w GO111MODULE=auto + +# 下载opencv库 +wget https://paddle-qa.bj.bcebos.com/PaddleServing/opencv3.tar.gz && tar -xvf opencv3.tar.gz && rm -rf opencv3.tar.gz +export OPENCV_DIR=$PWD/opencv3 + +# clone Serving +git clone https://github.com/PaddlePaddle/Serving.git -b develop --depth=1 +cd Serving # PaddleClas/deploy/paddleserving/Serving +export Serving_repo_path=$PWD +git submodule update --init --recursive +${python_name} -m pip install -r python/requirements.txt + +# set env +export PYTHON_INCLUDE_DIR=$(${python_name} -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") +export PYTHON_LIBRARIES=$(${python_name} -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))") +export PYTHON_EXECUTABLE=`which ${python_name}` + +export CUDA_PATH='/usr/local/cuda' +export CUDNN_LIBRARY='/usr/local/cuda/lib64/' +export CUDA_CUDART_LIBRARY='/usr/local/cuda/lib64/' +export TENSORRT_LIBRARY_PATH='/usr/local/TensorRT6-cuda10.1-cudnn7/targets/x86_64-linux-gnu/' + +# cp 自定义OP代码 +\cp ../preprocess/general_clas_op.* ${Serving_repo_path}/core/general-server/op +\cp ../preprocess/preprocess_op.* ${Serving_repo_path}/core/predictor/tools/pp_shitu_tools + +# 编译Server, export SERVING_BIN +mkdir server-build-gpu-opencv && cd server-build-gpu-opencv +cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \ + -DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \ + -DCUDNN_LIBRARY=${CUDNN_LIBRARY} \ + -DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \ + -DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \ + -DOPENCV_DIR=${OPENCV_DIR} \ + -DWITH_OPENCV=ON \ + -DSERVER=ON \ + -DWITH_GPU=ON .. +make -j32 + +${python_name} -m pip install python/dist/paddle* +export SERVING_BIN=$PWD/core/general-server/serving +cd ../../ \ No newline at end of file diff --git a/deploy/paddleserving/preprocess/general_clas_op.cpp b/deploy/paddleserving/preprocess/general_clas_op.cpp new file mode 100644 index 00000000..e0ab48fa --- /dev/null +++ b/deploy/paddleserving/preprocess/general_clas_op.cpp @@ -0,0 +1,206 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-server/op/general_clas_op.h" +#include "core/predictor/framework/infer.h" +#include "core/predictor/framework/memory.h" +#include "core/predictor/framework/resource.h" +#include "core/util/include/timer.h" +#include +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace serving { + +using baidu::paddle_serving::Timer; +using baidu::paddle_serving::predictor::MempoolWrapper; +using baidu::paddle_serving::predictor::general_model::Tensor; +using baidu::paddle_serving::predictor::general_model::Response; +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::InferManager; +using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; + +int GeneralClasOp::inference() { + VLOG(2) << "Going to run inference"; + const std::vector pre_node_names = pre_names(); + if (pre_node_names.size() != 1) { + LOG(ERROR) << "This op(" << op_name() + << ") can only have one predecessor op, but received " + << pre_node_names.size(); + return -1; + } + const std::string pre_name = pre_node_names[0]; + + const GeneralBlob *input_blob = get_depend_argument(pre_name); + if (!input_blob) { + LOG(ERROR) << "input_blob is nullptr,error"; + return -1; + } + uint64_t log_id = input_blob->GetLogId(); + VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; + + GeneralBlob *output_blob = mutable_data(); + if (!output_blob) { + LOG(ERROR) << "output_blob is nullptr,error"; + return -1; + } + output_blob->SetLogId(log_id); + + if (!input_blob) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed mutable depended argument, op:" << pre_name; + return -1; + } + + const TensorVector *in = &input_blob->tensor_vector; + TensorVector *out = &output_blob->tensor_vector; + + int batch_size = input_blob->_batch_size; + output_blob->_batch_size = batch_size; + VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size; + + Timer timeline; + int64_t start = timeline.TimeStampUS(); + timeline.Start(); + + // only support string type + + char *total_input_ptr = static_cast(in->at(0).data.data()); + std::string base64str = total_input_ptr; + + cv::Mat img = Base2Mat(base64str); + + // RGB2BGR + cv::cvtColor(img, img, cv::COLOR_BGR2RGB); + + // Resize + cv::Mat resize_img; + resize_op_.Run(img, resize_img, resize_short_size_); + + // CenterCrop + crop_op_.Run(resize_img, crop_size_); + + // Normalize + normalize_op_.Run(&resize_img, mean_, scale_, is_scale_); + + // Permute + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + permute_op_.Run(&resize_img, input.data()); + float maxValue = *max_element(input.begin(), input.end()); + float minValue = *min_element(input.begin(), input.end()); + + TensorVector *real_in = new TensorVector(); + if (!real_in) { + LOG(ERROR) << "real_in is nullptr,error"; + return -1; + } + + std::vector input_shape; + int in_num = 0; + void *databuf_data = NULL; + char *databuf_char = NULL; + size_t databuf_size = 0; + + input_shape = {1, 3, resize_img.rows, resize_img.cols}; + in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, + std::multiplies()); + + databuf_size = in_num * sizeof(float); + databuf_data = MempoolWrapper::instance().malloc(databuf_size); + if (!databuf_data) { + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; + } + + memcpy(databuf_data, input.data(), databuf_size); + databuf_char = reinterpret_cast(databuf_data); + paddle::PaddleBuf paddleBuf(databuf_char, databuf_size); + paddle::PaddleTensor tensor_in; + tensor_in.name = in->at(0).name; + tensor_in.dtype = paddle::PaddleDType::FLOAT32; + tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols}; + tensor_in.lod = in->at(0).lod; + tensor_in.data = paddleBuf; + real_in->push_back(tensor_in); + + if (InferManager::instance().infer(engine_name().c_str(), real_in, out, + batch_size)) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed do infer in fluid model: " << engine_name().c_str(); + return -1; + } + + int64_t end = timeline.TimeStampUS(); + CopyBlobInfo(input_blob, output_blob); + AddBlobInfo(output_blob, start); + AddBlobInfo(output_blob, end); + return 0; +} + +cv::Mat GeneralClasOp::Base2Mat(std::string &base64_data) { + cv::Mat img; + std::string s_mat; + s_mat = base64Decode(base64_data.data(), base64_data.size()); + std::vector base64_img(s_mat.begin(), s_mat.end()); + img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR + return img; +} + +std::string GeneralClasOp::base64Decode(const char *Data, int DataByte) { + const char DecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + }; + + std::string strDecode; + int nValue; + int i = 0; + while (i < DataByte) { + if (*Data != '\r' && *Data != '\n') { + nValue = DecodeTable[*Data++] << 18; + nValue += DecodeTable[*Data++] << 12; + strDecode += (nValue & 0x00FF0000) >> 16; + if (*Data != '=') { + nValue += DecodeTable[*Data++] << 6; + strDecode += (nValue & 0x0000FF00) >> 8; + if (*Data != '=') { + nValue += DecodeTable[*Data++]; + strDecode += nValue & 0x000000FF; + } + } + i += 4; + } else // 回车换行,跳过 + { + Data++; + i++; + } + } + return strDecode; +} +DEFINE_OP(GeneralClasOp); +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/deploy/paddleserving/preprocess/general_clas_op.h b/deploy/paddleserving/preprocess/general_clas_op.h new file mode 100644 index 00000000..69b7a8e0 --- /dev/null +++ b/deploy/paddleserving/preprocess/general_clas_op.h @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "core/general-server/general_model_service.pb.h" +#include "core/general-server/op/general_infer_helper.h" +#include "core/predictor/tools/pp_shitu_tools/preprocess_op.h" +#include "paddle_inference_api.h" // NOLINT +#include +#include + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace serving { + +class GeneralClasOp + : public baidu::paddle_serving::predictor::OpWithChannel { +public: + typedef std::vector TensorVector; + + DECLARE_OP(GeneralClasOp); + + int inference(); + +private: + // clas preprocess + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {0.229f, 0.224f, 0.225f}; + bool is_scale_ = true; + + int resize_short_size_ = 256; + int crop_size_ = 224; + + PaddleClas::ResizeImg resize_op_; + PaddleClas::Normalize normalize_op_; + PaddleClas::Permute permute_op_; + PaddleClas::CenterCropImg crop_op_; + + // read pics + cv::Mat Base2Mat(std::string &base64_data); + std::string base64Decode(const char *Data, int DataByte); +}; + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/deploy/paddleserving/preprocess/preprocess_op.cpp b/deploy/paddleserving/preprocess/preprocess_op.cpp new file mode 100644 index 00000000..9c79342c --- /dev/null +++ b/deploy/paddleserving/preprocess/preprocess_op.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "preprocess_op.h" + +namespace Feature { + +void Permute::Run(const cv::Mat *im, float *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i); + } +} + +void Normalize::Run(cv::Mat *im, const std::vector &mean, + const std::vector &std, float scale) { + (*im).convertTo(*im, CV_32FC3, scale); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean[0]) / std[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean[1]) / std[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean[2]) / std[2]; + } + } +} + +void CenterCropImg::Run(cv::Mat &img, const int crop_size) { + int resize_w = img.cols; + int resize_h = img.rows; + int w_start = int((resize_w - crop_size) / 2); + int h_start = int((resize_h - crop_size) / 2); + cv::Rect rect(w_start, h_start, crop_size, crop_size); + img = img(rect); +} + +void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + int resize_short_size, int size) { + int resize_h = 0; + int resize_w = 0; + if (size > 0) { + resize_h = size; + resize_w = size; + } else { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + if (h < w) { + ratio = float(resize_short_size) / float(h); + } else { + ratio = float(resize_short_size) / float(w); + } + resize_h = round(float(h) * ratio); + resize_w = round(float(w) * ratio); + } + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); +} + +} // namespace Feature + +namespace PaddleClas { +void Permute::Run(const cv::Mat *im, float *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i); + } +} + +void Normalize::Run(cv::Mat *im, const std::vector &mean, + const std::vector &scale, const bool is_scale) { + double e = 1.0; + if (is_scale) { + e /= 255.0; + } + (*im).convertTo(*im, CV_32FC3, e); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean[0]) / scale[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean[1]) / scale[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean[2]) / scale[2]; + } + } +} + +void CenterCropImg::Run(cv::Mat &img, const int crop_size) { + int resize_w = img.cols; + int resize_h = img.rows; + int w_start = int((resize_w - crop_size) / 2); + int h_start = int((resize_h - crop_size) / 2); + cv::Rect rect(w_start, h_start, crop_size, crop_size); + img = img(rect); +} + +void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + int resize_short_size) { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + if (h < w) { + ratio = float(resize_short_size) / float(h); + } else { + ratio = float(resize_short_size) / float(w); + } + + int resize_h = round(float(h) * ratio); + int resize_w = round(float(w) * ratio); + + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); +} + +} // namespace PaddleClas diff --git a/deploy/paddleserving/preprocess/preprocess_op.h b/deploy/paddleserving/preprocess/preprocess_op.h new file mode 100644 index 00000000..0ea9d2e1 --- /dev/null +++ b/deploy/paddleserving/preprocess/preprocess_op.h @@ -0,0 +1,81 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace Feature { + +class Normalize { +public: + virtual void Run(cv::Mat *im, const std::vector &mean, + const std::vector &std, float scale); +}; + +// RGB -> CHW +class Permute { +public: + virtual void Run(const cv::Mat *im, float *data); +}; + +class CenterCropImg { +public: + virtual void Run(cv::Mat &im, const int crop_size = 224); +}; + +class ResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len, + int size = 0); +}; + +} // namespace Feature + +namespace PaddleClas { + +class Normalize { +public: + virtual void Run(cv::Mat *im, const std::vector &mean, + const std::vector &scale, const bool is_scale = true); +}; + +// RGB -> CHW +class Permute { +public: + virtual void Run(const cv::Mat *im, float *data); +}; + +class CenterCropImg { +public: + virtual void Run(cv::Mat &im, const int crop_size = 224); +}; + +class ResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len); +}; + +} // namespace PaddleClas diff --git a/test_tipc/config/MobileNetV3/MobileNetV3_large_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/MobileNetV3/MobileNetV3_large_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..0a3b777e --- /dev/null +++ b/test_tipc/config/MobileNetV3/MobileNetV3_large_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:MobileNetV3_large_x1_0 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/MobileNetV3_large_x1_0_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/MobileNetV3_large_x1_0_serving/ +--serving_client:./deploy/paddleserving/MobileNetV3_large_x1_0_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPHGNet/PPHGNet_small_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPHGNet/PPHGNet_small_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..b576cb90 --- /dev/null +++ b/test_tipc/config/PPHGNet/PPHGNet_small_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPHGNet_small +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPHGNet_small_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPHGNet_small_serving/ +--serving_client:./deploy/paddleserving/PPHGNet_small_client/ +serving_dir:./deploy/paddleserving +web_service:classification_web_service.py +--use_gpu:0|null +pipline:pipeline_http_client.py diff --git a/test_tipc/config/PPHGNet/PPHGNet_tiny_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPHGNet/PPHGNet_tiny_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..6b4f5e41 --- /dev/null +++ b/test_tipc/config/PPHGNet/PPHGNet_tiny_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPHGNet_tiny +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPHGNet_tiny_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPHGNet_tiny_serving/ +--serving_client:./deploy/paddleserving/PPHGNet_tiny_client/ +serving_dir:./deploy/paddleserving +web_service:classification_web_service.py +--use_gpu:0|null +pipline:pipeline_http_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x0_25_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x0_25_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..a36cb9ac --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x0_25_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x0_25 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_25_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x0_25_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x0_25_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x0_25_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x0_35_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x0_35_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..f5673d9b --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x0_35_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x0_35 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_35_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x0_35_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x0_35_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x0_35_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x0_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x0_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..487d270e --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x0_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x0_5 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_5_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x0_5_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x0_5_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x0_5_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x0_75_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x0_75_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..51b2a816 --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x0_75_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x0_75 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_75_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x0_75_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x0_75_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x0_75_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..39aac394 --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x1_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x1_0 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_0_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x1_0_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x1_0_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x1_0_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x1_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x1_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..20cdbd0b --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x1_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x1_5 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_5_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x1_5_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x1_5_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x1_5_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x2_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x2_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..f2a25a3d --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x2_0_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x2_0 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x2_0_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x2_0_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x2_0_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNet/PPLCNet_x2_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNet/PPLCNet_x2_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..e8c340ea --- /dev/null +++ b/test_tipc/config/PPLCNet/PPLCNet_x2_5_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNet_x2_5 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNet_x2_5_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNet_x2_5_serving/ +--serving_client:./deploy/paddleserving/PPLCNet_x2_5_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/PPLCNetV2/PPLCNetV2_base_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..2c355a9d --- /dev/null +++ b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:PPLCNetV2_base +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/PPLCNetV2_base_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/PPLCNetV2_base_serving/ +--serving_client:./deploy/paddleserving/PPLCNetV2_base_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/ResNet/ResNet50_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/ResNet/ResNet50_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..a4344caf --- /dev/null +++ b/test_tipc/config/ResNet/ResNet50_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:ResNet50 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/ResNet50_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/ResNet50_serving/ +--serving_client:./deploy/paddleserving/ResNet50_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..2de32443 --- /dev/null +++ b/test_tipc/config/ResNet/ResNet50_vd_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:ResNet50_vd +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/ResNet50_vd_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/ResNet50_vd_serving/ +--serving_client:./deploy/paddleserving/ResNet50_vd_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/config/SwinTransformer/SwinTransformer_tiny_patch4_window7_224_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/config/SwinTransformer/SwinTransformer_tiny_patch4_window7_224_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt new file mode 100644 index 00000000..5b5a2a1a --- /dev/null +++ b/test_tipc/config/SwinTransformer/SwinTransformer_tiny_patch4_window7_224_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt @@ -0,0 +1,14 @@ +===========================serving_params=========================== +model_name:SwinTransformer_tiny_patch4_window7_224 +python:python3.7 +inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/SwinTransformer_tiny_patch4_window7_224_infer.tar +trans_model:-m paddle_serving_client.convert +--dirname:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--serving_server:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_serving/ +--serving_client:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_client/ +serving_dir:./deploy/paddleserving +web_service:null +--use_gpu:0|null +pipline:test_cpp_serving_client.py diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 5a18c386..8a4ca02a 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -199,10 +199,16 @@ fi if [[ ${MODE} = "serving_infer" ]]; then # prepare serving env + ${python_name} -m pip install paddle_serving_client==0.9.0 + ${python_name} -m pip install paddle-serving-app==0.9.0 python_name=$(func_parser_value "${lines[2]}") - ${python_name} -m pip install install paddle-serving-server-gpu==0.7.0.post102 - ${python_name} -m pip install paddle_serving_client==0.7.0 - ${python_name} -m pip install paddle-serving-app==0.7.0 + if [[ ${FILENAME} =~ "cpp" ]]; then + pushd ./deploy/paddleserving + bash build_server.sh + popd + else + ${python_name} -m pip install install paddle-serving-server-gpu==0.9.0.post102 + fi if [[ ${model_name} =~ "ShiTu" ]]; then cls_inference_model_url=$(func_parser_value "${lines[3]}") cls_tar_name=$(func_get_url_file_name "${cls_inference_model_url}") diff --git a/test_tipc/test_serving_infer.sh b/test_tipc/test_serving_infer.sh index 3eef6b1c..f501f1bc 100644 --- a/test_tipc/test_serving_infer.sh +++ b/test_tipc/test_serving_infer.sh @@ -50,8 +50,13 @@ function func_serving_cls(){ set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}") set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}") - trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}" - eval $trans_model_cmd + for python_ in ${python[*]}; do + if [[ ${python_} =~ "python" ]]; then + trans_model_cmd="${python_} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}" + eval $trans_model_cmd + break + fi + done # modify the alias_name of fetch_var to "outputs" server_fetch_var_line_cmd="sed -i '/fetch_var/,/is_lod_tensor/s/alias_name: .*/alias_name: \"prediction\"/' ${serving_server_value}/serving_server_conf.prototxt" @@ -69,103 +74,124 @@ function func_serving_cls(){ unset https_proxy unset http_proxy - # modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt - set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}" - eval ${set_web_service_feet_var_cmd} - - model_config=21 - serving_server_dir_name=$(func_get_url_file_name "$serving_server_value") - set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml" - eval ${set_model_config_cmd} - - for python in ${python[*]}; do - if [[ ${python} = "cpp" ]]; then - for use_gpu in ${web_use_gpu_list[*]}; do - if [ ${use_gpu} = "null" ]; then - web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293" - eval $web_service_cmd - sleep 5s - _save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log" - pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/" + if [[ ${FILENAME} =~ "cpp" ]]; then + for item in ${python[*]}; do + if [[ ${item} =~ "python" ]]; then + python_=${item} + break + fi + done + set_pipeline_py_feed_var_cmd="sed -i '/feed/,/img}/s/{.*: img}/{${feed_var_name}: img}/' ${pipeline_py}" + echo $PWD - $set_pipeline_py_feed_var_cmd + eval ${set_pipeline_py_feed_var_cmd} + serving_server_dir_name=$(func_get_url_file_name "$serving_server_value") + set_client_config_line_cmd="sed -i '/client/,/serving_server_conf.prototxt/s/.\/.*\/serving_server_conf.prototxt/.\/${serving_server_dir_name}\/serving_server_conf.prototxt/' ${pipeline_py}" + echo $PWD - $set_client_config_line_cmd + eval ${set_client_config_line_cmd} + for use_gpu in ${web_use_gpu_list[*]}; do + if [ ${use_gpu} = "null" ]; then + web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --port 9292 &" + echo $PWD - $web_service_cpp_cmd + eval ${web_service_cpp_cmd} + sleep 5s + _save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_batchsize_1.log" + pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 " + echo ${pipeline_cmd} + eval ${pipeline_cmd} + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}" + ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9 + sleep 5s + else + web_service_cpp_cmd="${python_} -m paddle_serving_server.serve --model ${serving_server_dir_name} --port 9292 --gpu_id=${use_gpu} &" + echo $PWD - $web_service_cpp_cmd + eval ${web_service_cpp_cmd} + sleep 5s + + _save_log_path="${LOG_PATH}/server_infer_cpp_gpu_pipeline_batchsize_1.log" + pipeline_cmd="${python_} test_cpp_serving_client.py > ${_save_log_path} 2>&1 " + + echo $PWD - $pipeline_cmd + eval $pipeline_cmd + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}" + ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9 + sleep 5s + fi + done + else + # python serving + # modify the input_name in "classification_web_service.py" to be consistent with feed_var.name in prototxt + set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}" + eval ${set_web_service_feed_var_cmd} + + model_config=21 + serving_server_dir_name=$(func_get_url_file_name "$serving_server_value") + set_model_config_cmd="sed -i '${model_config}s/model_config: .*/model_config: ${serving_server_dir_name}/' config.yml" + eval ${set_model_config_cmd} + + for use_gpu in ${web_use_gpu_list[*]}; do + if [[ ${use_gpu} = "null" ]]; then + device_type_line=24 + set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml" + eval $set_device_type_cmd + + devices_line=27 + set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml" + eval $set_devices_cmd + + web_service_cmd="${python_} ${web_service_py} &" + eval $web_service_cmd + sleep 5s + for pipeline in ${pipeline_py[*]}; do + _save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log" + pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1 " eval $pipeline_cmd - status_check $last_status "${pipeline_cmd}" "${status_log}" + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" sleep 5s - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 - else - web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293 --gpu_id=0" - eval $web_service_cmd - sleep 5s - _save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log" - pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/" - eval $pipeline_cmd - status_check $last_status "${pipeline_cmd}" "${status_log}" - sleep 5s - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + done + ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + elif [ ${use_gpu} -eq 0 ]; then + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue fi - done - else - # python serving - for use_gpu in ${web_use_gpu_list[*]}; do - if [[ ${use_gpu} = "null" ]]; then - device_type_line=24 - set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml" - eval $set_device_type_cmd - - devices_line=27 - set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml" - eval $set_devices_cmd - - web_service_cmd="${python} ${web_service_py} &" - eval $web_service_cmd - sleep 5s - for pipeline in ${pipeline_py[*]}; do - _save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log" - pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 " - eval $pipeline_cmd - last_status=${PIPESTATUS[0]} - eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" - sleep 5s - done - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 - elif [ ${use_gpu} -eq 0 ]; then - if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then - continue - fi - if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then - continue - fi - if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then - continue - fi - - device_type_line=24 - set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml" - eval $set_device_type_cmd - - devices_line=27 - set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml" - eval $set_devices_cmd - - web_service_cmd="${python} ${web_service_py} & " - eval $web_service_cmd - sleep 5s - for pipeline in ${pipeline_py[*]}; do - _save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log" - pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1" - eval $pipeline_cmd - last_status=${PIPESTATUS[0]} - eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" - sleep 5s - done - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 - else - echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!" + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue fi - done - fi - done + if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then + continue + fi + + device_type_line=24 + set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml" + eval $set_device_type_cmd + + devices_line=27 + set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml" + eval $set_devices_cmd + + web_service_cmd="${python_} ${web_service_py} & " + eval $web_service_cmd + sleep 5s + for pipeline in ${pipeline_py[*]}; do + _save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log" + pipeline_cmd="${python_} ${pipeline} > ${_save_log_path} 2>&1" + eval $pipeline_cmd + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" + sleep 5s + done + ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + else + echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!" + fi + done + fi } @@ -200,6 +226,12 @@ function func_serving_rec(){ pipeline_py=$(func_parser_value "${lines[17]}") IFS='|' + for python_ in ${python[*]}; do + if [[ ${python_} =~ "python" ]]; then + python_interp=${python_} + break + fi + done # pdserving cd ./deploy @@ -208,7 +240,7 @@ function func_serving_rec(){ set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") set_serving_server=$(func_set_params "${cls_serving_server_key}" "${cls_serving_server_value}") set_serving_client=$(func_set_params "${cls_serving_client_key}" "${cls_serving_client_value}") - cls_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}" + cls_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}" eval $cls_trans_model_cmd set_dirname=$(func_set_params "${det_infer_model_dir_key}" "${det_infer_model_dir_value}") @@ -216,7 +248,7 @@ function func_serving_rec(){ set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") set_serving_server=$(func_set_params "${det_serving_server_key}" "${det_serving_server_value}") set_serving_client=$(func_set_params "${det_serving_client_key}" "${det_serving_client_value}") - det_trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}" + det_trans_model_cmd="${python_interp} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}" eval $det_trans_model_cmd # modify the alias_name of fetch_var to "outputs" @@ -235,99 +267,130 @@ function func_serving_rec(){ unset https_proxy unset http_proxy - # modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt - set_web_service_feet_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}" - eval ${set_web_service_feet_var_cmd} - - for python in ${python[*]}; do - if [[ ${python} = "cpp" ]]; then - for use_gpu in ${web_use_gpu_list[*]}; do - if [ ${use_gpu} = "null" ]; then - web_service_cpp_cmd="${python} web_service_py" - - eval $web_service_cmd - sleep 5s - _save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log" - pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/" + if [[ ${FILENAME} =~ "cpp" ]]; then + det_serving_client_dir_name=$(func_get_url_file_name "$det_serving_client_value") + set_det_client_config_line_cmd="sed -i '/MainbodyDetect/,/serving_client_conf.prototxt/s/models\/.*\/serving_client_conf.prototxt/models\/${det_serving_client_dir_name}\/serving_client_conf.prototxt/' ${pipeline_py}" + echo $PWD - $set_det_client_config_line_cmd + eval ${set_det_client_config_line_cmd} + + cls_serving_client_dir_name=$(func_get_url_file_name "$cls_serving_client_value") + set_cls_client_config_line_cmd="sed -i '/ObjectRecognition/,/serving_client_conf.prototxt/s/models\/.*\/serving_client_conf.prototxt/models\/${cls_serving_client_dir_name}\/serving_client_conf.prototxt/' ${pipeline_py}" + echo $PWD - $set_cls_client_config_line_cmd + eval ${set_cls_client_config_line_cmd} + + set_pipeline_py_feed_var_cmd="sed -i '/ObjectRecognition/,/feed={\"x\": batch_imgs}/s/{.*: batch_imgs}/{${feed_var_name}: batch_imgs}/' ${pipeline_py}" + echo $PWD - $set_pipeline_py_feed_var_cmd + eval ${set_pipeline_py_feed_var_cmd} + + for use_gpu in ${web_use_gpu_list[*]}; do + if [ ${use_gpu} = "null" ]; then + + det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value") + web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../models/${det_serving_server_dir_name} --port 9293 >>log_mainbody_detection.txt &" + + cls_serving_server_dir_name=$(func_get_url_file_name "$cls_serving_server_value") + web_service_cpp_cmd2="${python_interp} -m paddle_serving_server.serve --model ../../models/${cls_serving_server_dir_name} --port 9294 >>log_feature_extraction.txt &" + echo $PWD - $web_service_cpp_cmd + eval $web_service_cpp_cmd + echo $PWD - $web_service_cpp_cmd2 + eval $web_service_cpp_cmd2 + sleep 5s + _save_log_path="${LOG_PATH}/server_infer_cpp_cpu_batchsize_1.log" + pipeline_cmd="${python_interp} test_cpp_serving_client.py > ${_save_log_path} 2>&1 " + echo ${pipeline_cmd} + eval ${pipeline_cmd} + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}" + ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9 + sleep 5s + else + det_serving_server_dir_name=$(func_get_url_file_name "$det_serving_server_value") + web_service_cpp_cmd="${python_interp} -m paddle_serving_server.serve --model ../../models/${det_serving_server_dir_name} --port 9293 --gpu_id=${use_gpu} >>log_mainbody_detection.txt &" + + cls_serving_server_dir_name=$(func_get_url_file_name "$cls_serving_server_value") + web_service_cpp_cmd2="${python_interp} -m paddle_serving_server.serve --model ../../models/${cls_serving_server_dir_name} --port 9294 --gpu_id=${use_gpu} >>log_feature_extraction.txt &" + echo $PWD - $web_service_cpp_cmd + eval $web_service_cpp_cmd + echo $PWD - $web_service_cpp_cmd2 + eval $web_service_cpp_cmd2 + sleep 5s + _save_log_path="${LOG_PATH}/server_infer_cpp_gpu_batchsize_1.log" + pipeline_cmd="${python_interp} test_cpp_serving_client.py > ${_save_log_path} 2>&1 " + echo ${pipeline_cmd} + eval ${pipeline_cmd} + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check ${last_status} "${pipeline_cmd}" "${status_log}" "${model_name}" + ps ux | grep -E 'serving_server|pipeline' | awk '{print $2}' | xargs kill -s 9 + sleep 5s + fi + done + else + # modify the input_name in "recognition_web_service.py" to be consistent with feed_var.name in prototxt + set_web_service_feed_var_cmd="sed -i '/preprocess/,/input_imgs}/s/{.*: input_imgs}/{${feed_var_name}: input_imgs}/' ${web_service_py}" + eval ${set_web_service_feed_var_cmd} + # python serving + for use_gpu in ${web_use_gpu_list[*]}; do + if [[ ${use_gpu} = "null" ]]; then + device_type_line=24 + set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml" + eval $set_device_type_cmd + + devices_line=27 + set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml" + eval $set_devices_cmd + + web_service_cmd="${python} ${web_service_py} &" + eval $web_service_cmd + sleep 5s + for pipeline in ${pipeline_py[*]}; do + _save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log" + pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 " eval $pipeline_cmd - status_check $last_status "${pipeline_cmd}" "${status_log}" - sleep 5s - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 - else - web_service_cpp_cmd="${python} web_service_py" - eval $web_service_cmd + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" sleep 5s - _save_log_path="${LOG_PATH}/server_infer_cpp_cpu_pipeline_usemkldnn_False_threads_4_batchsize_1.log" - pipeline_cmd="${python} ocr_cpp_client.py ppocr_det_mobile_2.0_client/ ppocr_rec_mobile_2.0_client/" - eval $pipeline_cmd - status_check $last_status "${pipeline_cmd}" "${status_log}" - sleep 5s - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + done + ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + elif [ ${use_gpu} -eq 0 ]; then + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue fi - done - else - # python serving - for use_gpu in ${web_use_gpu_list[*]}; do - if [[ ${use_gpu} = "null" ]]; then - device_type_line=24 - set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 0/' config.yml" - eval $set_device_type_cmd - - devices_line=27 - set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"\"/' config.yml" - eval $set_devices_cmd - - web_service_cmd="${python} ${web_service_py} &" - eval $web_service_cmd - sleep 5s - for pipeline in ${pipeline_py[*]}; do - _save_log_path="${LOG_PATH}/server_infer_cpu_${pipeline%_client*}_batchsize_1.log" - pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1 " - eval $pipeline_cmd - last_status=${PIPESTATUS[0]} - eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" - sleep 5s - done - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 - elif [ ${use_gpu} -eq 0 ]; then - if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then - continue - fi - if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then - continue - fi - if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then - continue - fi - - device_type_line=24 - set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml" - eval $set_device_type_cmd - - devices_line=27 - set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml" - eval $set_devices_cmd - - web_service_cmd="${python} ${web_service_py} & " - eval $web_service_cmd - sleep 10s - for pipeline in ${pipeline_py[*]}; do - _save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log" - pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1" - eval $pipeline_cmd - last_status=${PIPESTATUS[0]} - eval "cat ${_save_log_path}" - status_check $last_status "${pipeline_cmd}" "${status_log}" - sleep 10s - done - ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 - else - echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!" + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue fi - done - fi - done + if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then + continue + fi + + device_type_line=24 + set_device_type_cmd="sed -i '${device_type_line}s/device_type: .*/device_type: 1/' config.yml" + eval $set_device_type_cmd + + devices_line=27 + set_devices_cmd="sed -i '${devices_line}s/devices: .*/devices: \"${use_gpu}\"/' config.yml" + eval $set_devices_cmd + + web_service_cmd="${python} ${web_service_py} & " + eval $web_service_cmd + sleep 10s + for pipeline in ${pipeline_py[*]}; do + _save_log_path="${LOG_PATH}/server_infer_gpu_${pipeline%_client*}_batchsize_1.log" + pipeline_cmd="${python} ${pipeline} > ${_save_log_path} 2>&1" + eval $pipeline_cmd + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" + sleep 10s + done + ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + else + echo "Does not support hardware [${use_gpu}] other than CPU and GPU Currently!" + fi + done + fi } -- GitLab