提交 ac821bbe 编写于 作者: H HydrogenSulfate

add cpp serving infer(except PPShiTu)

上级 3fd426aa
#使用镜像:
#registry.baidubce.com/paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82
#编译Serving Server:
#client和app可以直接使用release版本
#server因为加入了自定义OP,需要重新编译
#默认编译时的${PWD}=PaddleClas/deploy/paddleserving/
python_name=${1:-'python'}
apt-get update
apt install -y libcurl4-openssl-dev libbz2-dev
wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && tar xf centos_ssl.tar && rm -rf centos_ssl.tar && mv libcrypto.so.1.0.2k /usr/lib/libcrypto.so.1.0.2k && mv libssl.so.1.0.2k /usr/lib/libssl.so.1.0.2k && ln -sf /usr/lib/libcrypto.so.1.0.2k /usr/lib/libcrypto.so.10 && ln -sf /usr/lib/libssl.so.1.0.2k /usr/lib/libssl.so.10 && ln -sf /usr/lib/libcrypto.so.10 /usr/lib/libcrypto.so && ln -sf /usr/lib/libssl.so.10 /usr/lib/libssl.so
# 安装go依赖
rm -rf /usr/local/go
wget -qO- https://paddle-ci.cdn.bcebos.com/go1.17.2.linux-amd64.tar.gz | tar -xz -C /usr/local
export GOROOT=/usr/local/go
export GOPATH=/root/gopath
export PATH=$PATH:$GOPATH/bin:$GOROOT/bin
go env -w GO111MODULE=on
go env -w GOPROXY=https://goproxy.cn,direct
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2
go install github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2
go install github.com/golang/protobuf/protoc-gen-go@v1.4.3
go install google.golang.org/grpc@v1.33.0
go env -w GO111MODULE=auto
# 下载opencv库
wget https://paddle-qa.bj.bcebos.com/PaddleServing/opencv3.tar.gz && tar -xvf opencv3.tar.gz && rm -rf opencv3.tar.gz
export OPENCV_DIR=$PWD/opencv3
# clone Serving
git clone https://github.com/PaddlePaddle/Serving.git -b develop --depth=1
cd Serving # PaddleClas/deploy/paddleserving/Serving
export Serving_repo_path=$PWD
git submodule update --init --recursive
${python_name} -m pip install -r python/requirements.txt
# set env
export PYTHON_INCLUDE_DIR=$(${python_name} -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())")
export PYTHON_LIBRARIES=$(${python_name} -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))")
export PYTHON_EXECUTABLE=`which ${python_name}`
export CUDA_PATH='/usr/local/cuda'
export CUDNN_LIBRARY='/usr/local/cuda/lib64/'
export CUDA_CUDART_LIBRARY='/usr/local/cuda/lib64/'
export TENSORRT_LIBRARY_PATH='/usr/local/TensorRT6-cuda10.1-cudnn7/targets/x86_64-linux-gnu/'
# cp 自定义OP代码
\cp ../preprocess/general_clas_op.* ${Serving_repo_path}/core/general-server/op
\cp ../preprocess/preprocess_op.* ${Serving_repo_path}/core/predictor/tools/pp_shitu_tools
# 编译Server, export SERVING_BIN
mkdir server-build-gpu-opencv && cd server-build-gpu-opencv
cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \
-DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \
-DCUDNN_LIBRARY=${CUDNN_LIBRARY} \
-DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \
-DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \
-DOPENCV_DIR=${OPENCV_DIR} \
-DWITH_OPENCV=ON \
-DSERVER=ON \
-DWITH_GPU=ON ..
make -j32
${python_name} -m pip install python/dist/paddle*
export SERVING_BIN=$PWD/core/general-server/serving
cd ../../
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_clas_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralClasOp::inference() {
VLOG(2) << "Going to run inference";
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
<< ") can only have one predecessor op, but received "
<< pre_node_names.size();
return -1;
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
}
output_blob->SetLogId(log_id);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op:" << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
output_blob->_batch_size = batch_size;
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
// only support string type
char *total_input_ptr = static_cast<char *>(in->at(0).data.data());
std::string base64str = total_input_ptr;
cv::Mat img = Base2Mat(base64str);
// RGB2BGR
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
// Resize
cv::Mat resize_img;
resize_op_.Run(img, resize_img, resize_short_size_);
// CenterCrop
crop_op_.Run(resize_img, crop_size_);
// Normalize
normalize_op_.Run(&resize_img, mean_, scale_, is_scale_);
// Permute
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
permute_op_.Run(&resize_img, input.data());
float maxValue = *max_element(input.begin(), input.end());
float minValue = *min_element(input.begin(), input.end());
TensorVector *real_in = new TensorVector();
if (!real_in) {
LOG(ERROR) << "real_in is nullptr,error";
return -1;
}
std::vector<int> input_shape;
int in_num = 0;
void *databuf_data = NULL;
char *databuf_char = NULL;
size_t databuf_size = 0;
input_shape = {1, 3, resize_img.rows, resize_img.cols};
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int>());
databuf_size = in_num * sizeof(float);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
memcpy(databuf_data, input.data(), databuf_size);
databuf_char = reinterpret_cast<char *>(databuf_data);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
paddle::PaddleTensor tensor_in;
tensor_in.name = in->at(0).name;
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
tensor_in.lod = in->at(0).lod;
tensor_in.data = paddleBuf;
real_in->push_back(tensor_in);
if (InferManager::instance().infer(engine_name().c_str(), real_in, out,
batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
}
cv::Mat GeneralClasOp::Base2Mat(std::string &base64_data) {
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR
return img;
}
std::string GeneralClasOp::base64Decode(const char *Data, int DataByte) {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte) {
if (*Data != '\r' && *Data != '\n') {
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=') {
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=') {
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
} else // 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
}
DEFINE_OP(GeneralClasOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/tools/pp_shitu_tools/preprocess_op.h"
#include "paddle_inference_api.h" // NOLINT
#include <string>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralClasOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralClasOp);
int inference();
private:
// clas preprocess
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
bool is_scale_ = true;
int resize_short_size_ = 256;
int crop_size_ = 224;
PaddleClas::ResizeImg resize_op_;
PaddleClas::Normalize normalize_op_;
PaddleClas::Permute permute_op_;
PaddleClas::CenterCropImg crop_op_;
// read pics
cv::Mat Base2Mat(std::string &base64_data);
std::string base64Decode(const char *Data, int DataByte);
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <math.h>
#include <numeric>
#include "preprocess_op.h"
namespace Feature {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale) {
(*im).convertTo(*im, CV_32FC3, scale);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size, int size) {
int resize_h = 0;
int resize_w = 0;
if (size > 0) {
resize_h = size;
resize_w = size;
} else {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
resize_h = round(float(h) * ratio);
resize_w = round(float(w) * ratio);
}
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace Feature
namespace PaddleClas {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale) {
double e = 1.0;
if (is_scale) {
e /= 255.0;
}
(*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / scale[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / scale[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / scale[2];
}
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace PaddleClas
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
namespace Feature {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
int size = 0);
};
} // namespace Feature
namespace PaddleClas {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale = true);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
};
} // namespace PaddleClas
===========================serving_params===========================
model_name:MobileNetV3_large_x1_0
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/MobileNetV3_large_x1_0_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/MobileNetV3_large_x1_0_serving/
--serving_client:./deploy/paddleserving/MobileNetV3_large_x1_0_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPHGNet_small
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPHGNet_small_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPHGNet_small_serving/
--serving_client:./deploy/paddleserving/PPHGNet_small_client/
serving_dir:./deploy/paddleserving
web_service:classification_web_service.py
--use_gpu:0|null
pipline:pipeline_http_client.py
===========================serving_params===========================
model_name:PPHGNet_tiny
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPHGNet_tiny_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPHGNet_tiny_serving/
--serving_client:./deploy/paddleserving/PPHGNet_tiny_client/
serving_dir:./deploy/paddleserving
web_service:classification_web_service.py
--use_gpu:0|null
pipline:pipeline_http_client.py
===========================serving_params===========================
model_name:PPLCNet_x0_25
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_25_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_25_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_25_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_25_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNet_x0_35
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_35_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_35_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_35_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_35_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNet_x0_5
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_5_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_5_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_5_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_5_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNet_x0_75
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x0_75_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x0_75_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x0_75_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x0_75_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNet_x1_0
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_0_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x1_0_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x1_0_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x1_0_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNet_x1_5
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x1_5_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x1_5_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x1_5_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x1_5_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNet_x2_0
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x2_0_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x2_0_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x2_0_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNet_x2_5
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNet_x2_5_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNet_x2_5_serving/
--serving_client:./deploy/paddleserving/PPLCNet_x2_5_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:PPLCNetV2_base
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/PPLCNetV2_base_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/PPLCNetV2_base_serving/
--serving_client:./deploy/paddleserving/PPLCNetV2_base_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:ResNet50
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/ResNet50_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/ResNet50_serving/
--serving_client:./deploy/paddleserving/ResNet50_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:ResNet50_vd
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/ResNet50_vd_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/ResNet50_vd_serving/
--serving_client:./deploy/paddleserving/ResNet50_vd_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
===========================serving_params===========================
model_name:SwinTransformer_tiny_patch4_window7_224
python:python3.7
inference_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/SwinTransformer_tiny_patch4_window7_224_infer.tar
trans_model:-m paddle_serving_client.convert
--dirname:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_serving/
--serving_client:./deploy/paddleserving/SwinTransformer_tiny_patch4_window7_224_client/
serving_dir:./deploy/paddleserving
web_service:null
--use_gpu:0|null
pipline:test_cpp_serving_client.py
...@@ -199,10 +199,16 @@ fi ...@@ -199,10 +199,16 @@ fi
if [[ ${MODE} = "serving_infer" ]]; then if [[ ${MODE} = "serving_infer" ]]; then
# prepare serving env # prepare serving env
${python_name} -m pip install paddle_serving_client==0.9.0
${python_name} -m pip install paddle-serving-app==0.9.0
python_name=$(func_parser_value "${lines[2]}") python_name=$(func_parser_value "${lines[2]}")
${python_name} -m pip install install paddle-serving-server-gpu==0.7.0.post102 if [[ ${FILENAME} =~ "cpp" ]]; then
${python_name} -m pip install paddle_serving_client==0.7.0 pushd ./deploy/paddleserving
${python_name} -m pip install paddle-serving-app==0.7.0 bash build_server.sh
popd
else
${python_name} -m pip install install paddle-serving-server-gpu==0.9.0.post102
fi
if [[ ${model_name} =~ "ShiTu" ]]; then if [[ ${model_name} =~ "ShiTu" ]]; then
cls_inference_model_url=$(func_parser_value "${lines[3]}") cls_inference_model_url=$(func_parser_value "${lines[3]}")
cls_tar_name=$(func_get_url_file_name "${cls_inference_model_url}") cls_tar_name=$(func_get_url_file_name "${cls_inference_model_url}")
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册