未验证 提交 9bb5cf29 编写于 作者: X xiaoting 提交者: GitHub

update serving_cpp (#5523)

* update serving_cpp

* update prepare for serving and onnx

* update status check

* update serving prepare

* fix serving doc

* fix serving doc

* fix typo

* update paddle2onnx doc

* mv download_bin to server_cpp

* fix onnx doc

* Update prepare.sh

* modified cpp serving preprocess

* update prepare.sh

* sleep longer for serving

* add prototxt
上级 ecb1660b
......@@ -30,6 +30,7 @@
#include <include/preprocess_op.h>
namespace MobileNetV3 {
void Permute::Run(const cv::Mat *im, float *data) {
......
from paddle_serving_server import Server
server = Server()
server.download_bin()
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_clas_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralClasOp::inference() {
VLOG(2) << "Going to run inference";
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
<< ") can only have one predecessor op, but received "
<< pre_node_names.size();
return -1;
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
}
output_blob->SetLogId(log_id);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op:" << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
output_blob->_batch_size = batch_size;
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
// only support string type
char* total_input_ptr = static_cast<char*>(in->at(0).data.data());
std::string base64str = total_input_ptr;
cv::Mat img = Base2Mat(base64str);
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
cv::Mat resize_img;
// preprocess
resize_op_.Run(img, resize_img, resize_short_size_);
crop_op_.Run(resize_img, crop_size_);
normalize_op_.Run(&resize_img, mean_, scale_,
is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
permute_op_.Run(&resize_img, input.data());
TensorVector* real_in = new TensorVector();
if (!real_in) {
LOG(ERROR) << "real_in is nullptr,error";
return -1;
}
std::vector<int> input_shape;
int in_num = 0;
void* databuf_data = NULL;
char* databuf_char = NULL;
size_t databuf_size = 0;
input_shape = {1, 3, resize_img.rows, resize_img.cols};
in_num = std::accumulate(
input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());
databuf_size = in_num * sizeof(float);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
memcpy(databuf_data, input.data(), databuf_size);
databuf_char = reinterpret_cast<char*>(databuf_data);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
paddle::PaddleTensor tensor_in;
tensor_in.name = in->at(0).name;
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
tensor_in.lod = in->at(0).lod;
tensor_in.data = paddleBuf;
real_in->push_back(tensor_in);
if (InferManager::instance().infer(
engine_name().c_str(), real_in, out, batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
}
cv::Mat GeneralClasOp::Base2Mat(std::string& base64_data) {
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR
return img;
}
std::string GeneralClasOp::base64Decode(const char* Data, int DataByte) {
const char
DecodeTable[] =
{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, // 'a'-'z'
};
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte) {
if (*Data != '\r' && *Data != '\n') {
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=') {
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=') {
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
} else // 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
}
DEFINE_OP(GeneralClasOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "paddle_inference_api.h" // NOLINT
#include "core/predictor/tools/pp_shitu_tools/preprocess_op.h"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralClasOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralClasOp);
int inference();
private:
// clas preprocess
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
int resize_short_size_ = 256;
int crop_size_ = 224;
MobileNetV3::ResizeImg resize_op_;
MobileNetV3::Normalize normalize_op_;
MobileNetV3::Permute permute_op_;
MobileNetV3::CenterCropImg crop_op_;
// read pics
cv::Mat Base2Mat(std::string &base64_data);
std::string base64Decode(const char* Data, int DataByte);
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <math.h>
#include <numeric>
#include "preprocess_op.h"
namespace Feature {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale) {
(*im).convertTo(*im, CV_32FC3, scale);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size, int size) {
int resize_h = 0;
int resize_w = 0;
if (size > 0) {
resize_h = size;
resize_w = size;
} else {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
resize_h = round(float(h) * ratio);
resize_w = round(float(w) * ratio);
}
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace Feature
namespace MobileNetV3 {
void Permute::Run(const cv::Mat *im, float *data) {
int rh = im->rows;
int rw = im->cols;
int rc = im->channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale) {
double e = 1.0;
if (is_scale) {
e /= 255.0;
}
(*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
}
}
}
void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
} // namespace MobileNetV3
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
using namespace std;
namespace Feature {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float scale);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
int size = 0);
};
} // namespace Feature
namespace MobileNetV3 {
class Normalize {
public:
virtual void Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale = true);
};
// RGB -> CHW
class Permute {
public:
virtual void Run(const cv::Mat *im, float *data);
};
class CenterCropImg {
public:
virtual void Run(cv::Mat &im, const int crop_size = 224);
};
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
};
} // namespace MobileNetV3
\ No newline at end of file
feed_var {
name: "input"
alias_name: "input"
is_lod_tensor: false
feed_type: 20
shape: 1
}
fetch_var {
name: "softmax_1.tmp_0"
alias_name: "softmax_1.tmp_0"
is_lod_tensor: false
fetch_type: 1
shape: 1000
}
import numpy as np
import PIL
from PIL import Image
def get_new_size(img_size, resize_size):
if isinstance(resize_size, int) or len(
resize_size) == 1: # specified size only for the smallest edge
w, h = img_size
short, long = (w, h) if w <= h else (h, w)
requested_new_short = resize_size if isinstance(
resize_size, int) else resize_size[0]
new_short, new_long = requested_new_short, int(requested_new_short *
long / short)
new_w, new_h = (new_short, new_long) if w <= h else (new_long,
new_short)
else: # specified both h and w
new_w, new_h = resize_size[1], resize_size[0]
return (new_w, new_h)
class ResizeImage(object):
""" resize image """
def __init__(self, resize_size=None, interpolation=Image.BILINEAR):
self.resize_size = resize_size
self.interpolation = interpolation
def __call__(self, img):
size = get_new_size(img.size, self.resize_size)
img = img.resize(size, self.interpolation)
return img
class CenterCropImage(object):
""" crop image """
def __init__(self, size):
if type(size) is int:
self.size = (size, size)
else:
self.size = size # (h, w)
def __call__(self, img):
return center_crop(img, self.size)
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, img):
if isinstance(img, Image.Image):
img = np.array(img)
img = (img * self.scale - self.mean) / self.std
return img
class ToCHW(object):
def __init__(self):
pass
def __call__(self, img):
img = img.transpose((2, 0, 1))
return img
def center_crop(img, size, is_color=True):
if isinstance(img, Image.Image):
img = np.array(img)
if isinstance(size, (list, tuple)):
size = size[0]
h, w = img.shape[:2]
h_start = (h - size) // 2
w_start = (w - size) // 2
h_end, w_end = h_start + size, w_start + size
if is_color:
img = img[h_start:h_end, w_start:w_end, :]
else:
img = img[h_start:h_end, w_start:w_end]
return img
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import base64
from PIL import Image
import io
import os
from preprocess_ops import ResizeImage, CenterCropImage, NormalizeImage, ToCHW, Compose
from paddle_serving_client import Client
import argparse
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument("--image_dir", type=str, default="../../lite_data/test")
args = parser.parse_args()
url = "127.0.0.1:9997"
logid = 10000
img_path = args.image_dir
client = Client()
client.load_client_config(
"serving_client/serving_client_conf.prototxt")
client.connect([url])
def cv2_to_base64(image):
return base64.b64encode(image).decode(
'utf8')
def preprocess(img_file):
with open(img_file, 'rb') as file:
image_data = file.read()
image = cv2_to_base64(image_data)
feed = {"input": image}
fetch = ["softmax_1.tmp_0"]
return feed, fetch
def postprocess(fetch_map):
score_list = fetch_map["softmax_1.tmp_0"]
fetch_dict = {"class_id": [], "prob": []}
for score in score_list:
score = score.tolist()
max_score = max(score)
fetch_dict["class_id"].append(score.index(max_score))
fetch_dict["prob"].append(max_score)
fetch_dict["class_id"] = str(fetch_dict["class_id"])
fetch_dict["prob"] = str(fetch_dict["prob"])
return fetch_dict
for img_file in os.listdir(img_path):
res_list = []
feed, fetch = preprocess(os.path.join(img_path, img_file))
fetch_map = client.predict(feed=feed, fetch=fetch)
result = postprocess(fetch_map)
print(result)
......@@ -5,9 +5,10 @@ trans_model:-m paddle_serving_client.convert
--dirname:./inference/mobilenet_v3_small_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--serving_server:./deploy/serving_python/serving_server/
--serving_client:./deploy/serving_python/serving_client/
serving_dir:./deploy/serving_python
--model:./deploy/serving_python/serving_server/
--port:9993
--serving_server:./deploy/serving_cpp/serving_server/
--serving_client:./deploy/serving_cpp/serving_client/
serving_dir:./deploy/serving_cpp
--model:serving_server
--op:GeneralClasOp
--port:9997
cpp_client:serving_client.py
\ No newline at end of file
......@@ -104,6 +104,13 @@ elif [ ${MODE} = "cpp_infer" ];then
bash tools/build.sh
elif [ ${MODE} = "paddle2onnx_infer" ];then
# install paddle2onnx
python_name_list=$(func_parser_value "${lines[2]}")
IFS='|'
array=(${python_name_list})
python_name=${array[0]}
${python_name} -m pip install paddle2onnx
${python_name} -m pip install onnxruntime==1.9.0
# get data
tar -xf ./test_images/lite_data.tar
# get model
......
......@@ -58,14 +58,15 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker}"
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
# python inference
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu_value}")
set_model_dir=$(func_set_params "${model_key}" "${save_file_value}")
set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
infer_model_cmd="${python} ${inference_py} ${set_img_dir} ${set_model_dir} > ${_save_log_path} 2>&1 "
eval $infer_model_cmd
status_check $last_status "${infer_model_cmd}" "${status_log}"
last_status=${PIPESTATUS[0]}
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
}
......
......@@ -24,16 +24,20 @@ serving_server_value=$(func_parser_value "${lines[7]}")
serving_client_key=$(func_parser_key "${lines[8]}")
serving_client_value=$(func_parser_value "${lines[8]}")
serving_dir_value=$(func_parser_value "${lines[9]}")
run_model_path_key=$(func_parser_value "${lines[10]}")
run_model_path_value=$(func_parser_value "${lines[11]}")
port_key=$(func_parser_value "${lines[12]}")
port_value=$(func_parser_key "${lines[13]}")
cpp_client_value=$(func_parser_value "${lines[14]}")
run_model_path_key=$(func_parser_key "${lines[10]}")
run_model_path_value=$(func_parser_value "${lines[10]}")
op_key=$(func_parser_key "${lines[11]}")
op_value=$(func_parser_value "${lines[11]}")
port_key=$(func_parser_key "${lines[12]}")
port_value=$(func_parser_value "${lines[12]}")
cpp_client_value=$(func_parser_value "${lines[13]}")
LOG_PATH="./log/${model_name}/${MODE}"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results_serving.log"
status_log="../../log/${model_name}/${MODE}/results_serving_infer_cpp.log"
function func_serving(){
IFS='|'
......@@ -50,25 +54,29 @@ function func_serving(){
python_list=(${python_list})
python=${python_list[0]}
trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval $trans_model_cmd}
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
cp "deploy/serving_cpp/preprocess/serving_client_conf.prototxt" ${serving_client_value}
cd ${serving_dir_value}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
echo $PWD
unset https_proxy
unset http_proxy
web_service_cmd="${python} ${web_service_py} &"
eval $web_service_cmd
sleep 2s
_save_log_path="../../log/${model_name}/${MODE}/server_infer_gpu_batchsize_1.log"
_save_log_path="../../log/${model_name}/${MODE}/servering_infer_cpp_gpu_batchsize_1.log"
# phrase 2: run server
cpp_server_cmd="${python} -m paddle_serving_server.serve ${run_model_path_key} ${run_model_path_value} ${port_key} ${port_value} > ${_save_log_path} 2>&1 "
cpp_server_cmd="${python} -m paddle_serving_server.serve ${run_model_path_key} ${run_model_path_value} ${op_key} ${op_value} ${port_key} ${port_value} > serving_log.log & "
eval $cpp_server_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${cpp_server_cmd}" "${status_log}" "${model_name}"
sleep 5s
clinet_cmd="${python} ${cpp_client_value} > ${_save_log_path} 2>&1 "
eval $clinet_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${clinet_cmd}" "${status_log}" "${model_name}"
# eval "cat ${_save_log_path}"
cd ../../
status_check $last_status "${cpp_server_cmd}" "${status_log}"
ps ux | grep -i 'paddle_serving_server' | awk '{print $2}' | xargs kill -s 9
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
}
......
......@@ -32,7 +32,7 @@ image_dir_value=$(func_parser_value "${lines[12]}")
LOG_PATH="./log/${model_name}/${MODE}"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results_serving.log"
status_log="../../log/${model_name}/${MODE}/serving_infer_python_gpu_batchsize_1.log"
function func_serving(){
IFS='|'
......@@ -48,23 +48,28 @@ function func_serving(){
python_list=(${python_list})
python=${python_list[0]}
trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval $trans_model_cmd}
eval ${trans_model_cmd}
last_status=${PIPESTATUS[0]}
cd ${serving_dir_value}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
echo $PWD
unset https_proxy
unset http_proxy
_save_log_path="../../log/${model_name}/${MODE}/serving_infer_python_gpu_batchsize_1.log"
web_service_cmd="${python} ${web_service_py} &"
eval $web_service_cmd
sleep 2s
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
sleep 5s
_save_log_path="../../log/${model_name}/${MODE}/server_infer_gpu_batchsize_1.log"
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
pipeline_cmd="${python} ${pipeline_py} ${set_image_dir} > ${_save_log_path} 2>&1 "
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
eval "cat ${_save_log_path}"
cd ../../
status_check $last_status "${pipeline_cmd}" "${status_log}"
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
}
......
......@@ -46,13 +46,13 @@ Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算
- 安装 Paddle2ONNX
```
python3 -m pip install paddle2onnx
python -m pip install paddle2onnx
```
- 安装 ONNXRuntime
```
# 建议安装 1.9.0 版本,可根据环境更换版本号
python3 -m pip install onnxruntime==1.9.0
python -m pip install onnxruntime==1.9.0
```
<a name="2.2"></a>
......
......@@ -7,12 +7,12 @@
- [2.1 命令解析](#2.1)
- [2.2 配置文件和运行命令映射解析](#2.2)
- [3. 基本训练推理功能测试开发](#3)
- [2.1 准备待测试的命令](#3.1)
- [2.2 准备数据与环境](#3.2)
- [2.3 准备开发所需脚本](#3.3)
- [2.4 填写配置文件](#3.4)
- [2.5 验证配置正确性](#3.5)
- [2.6 撰写说明文档](#3.6)
- [3.1 准备待测试的命令](#3.1)
- [3.2 准备数据与环境](#3.2)
- [3.3 准备开发所需脚本](#3.3)
- [3.4 填写配置文件](#3.4)
- [3.5 验证配置正确性](#3.5)
- [3.6 撰写说明文档](#3.6)
- [4. FAQ](#4)
<a name="1"></a>
......@@ -229,7 +229,8 @@ Run failed with command - paddle2onnx --model_dir=./inference/mobilenet_v3_small
以mobilenet_v3_small的`Linux GPU/CPU 离线量化训练推理功能测试` 为例,命令如下所示。
```bash
bash test_tipc/test_paddle2onnx.sh ./test_tipc/configs/mobilenet_v3_small/paddle2onnx_infer_python.txt paddle2onnx
bash test_tipc/prepare.sh ./test_tipc/configs/mobilenet_v3_small/paddle2onnx_infer_python.txt paddle2onnx_infer
bash test_tipc/test_paddle2onnx.sh ./test_tipc/configs/mobilenet_v3_small/paddle2onnx_infer_python.txt paddle2onnx_infer
```
输出结果如下,表示命令运行成功。
......@@ -237,7 +238,7 @@ bash test_tipc/test_paddle2onnx.sh ./test_tipc/configs/mobilenet_v3_small/paddle
```bash
Run successfully with command - paddle2onnx --model_dir=./inference/mobilenet_v3_small_infer/ --model_filename=inference.pdmodel --params_filename=inference.pdiparams --save_file=./inference/mobilenet_v3_small_onnx/model.onnx --opset_version=10 --enable_onnx_checker=True!
Run successfully with command - python3.7 deploy/onnx_python/infer.py --img_path=./lite_data/test/demo.jpg --onnx_file=./inference/mobilenet_v3_small_onnx/model.onnx > ./log/mobilenet_v3_small//paddle2onnx_infer_cpu.log 2>&1 !
Run successfully with command - python3.7 deploy/onnx_python/infer.py --img_path=./lite_data/test/demo.jpg --onnx_file=./inference/mobilenet_v3_small_onnx/model.onnx > ./log/mobilenet_v3_small/paddle2onnx_infer_cpu.log 2>&1 !
```
**【核验】**
......
......@@ -45,7 +45,7 @@ Paddle Serving 的 C++ 服务的客户端启动命令一般由 PYTHON 程序编
python run_script
```
例如:
- 对于通过argparse传参的场景来说,`python3 resnet50_client.py`
- 对于通过argparse传参的场景来说,`python3.7 resnet50_client.py`
- `python`:替换为 `python3.7`
- `run_script`:替换为 `resnet50_client.py`
......@@ -147,7 +147,16 @@ python3.7 serving_client.py
相关文档可以参考[论文复现赛指南3.2章节](../../../docs/lwfx/ArticleReproduction_CV.md),代码可以参考`基于ImageNet准备小数据集的脚本`:[prepare.py](https://github.com/littletomatodonkey/AlexNet-Prod/blob/tipc/pipeline/Step2/prepare.py)。
2. 环境:安装好PaddlePaddle即可进行离线量化训练推理测试开发
2. 环境:安装好 PaddlePaddle 和 PaddleServing 即可进行服务化部署测试开发
为了将模型预处理放在C++端,需自行开发自定义op并重新编译 PaddleServing,参考如下步骤将自定义op放在Serving repo目录下:
```
cp deploy/serving_cpp/preprocess/general_clas_op.* {Serving_repo_path}/core/general-server/op
cp deploy/serving_cpp/preprocess/preprocess_op.* {Serving_repo_path}/core/predictor/tools/pp_shitu_tools
```
参考[编译文档](https://github.com/PaddlePaddle/Serving/blob/v0.8.3/doc/Compile_CN.md)重新编译Serving,并设置SERVING_BIN环境变量。
**【注意事项】**
......@@ -205,7 +214,11 @@ Run failed with command - python3.7 serving_client.py > ../../log/mobilenet_v3_s
**【实战】**
以mobilenet_v3_small的`Linux GPU/CPU 离线量化训练推理功能测试` 为例,命令如下所示。
以mobilenet_v3_small的`Linux GPU/CPU 服务化部署测试` 为例,命令如下所示。
```bash
bash test_tipc/prepare.sh test_tipc/configs/mobilenet_v3_small/serving_infer_cpp.txt serving_infer
```
```bash
bash test_tipc/test_serving_infer_cpp.sh test_tipc/configs/mobilenet_v3_small/serving_infer_cpp.txt serving_infer
......@@ -265,8 +278,9 @@ test_tipc
<a name="4"></a>
## 4. FAQ
如果访问不成功,可能设置了代理影响的,可以用下面命令取消代理设置。
```
```bash
unset http_proxy
unset https_proxy
```
......@@ -41,7 +41,11 @@ Paddle Serving服务化部署主要包括以下步骤:
<a name="2.1"></a>
### 2.1 准备测试数据
准备测试数据及对应的数据标签,用于后续[推理预测阶段](#2.7)
为方便快速验证推理预测过程,需要准备一个小数据集(训练集和验证集各8~16张图像即可,压缩后数据大小建议在`20M`以内,确保基础训练推理总时间不超过十分钟),放在`lite_data`文件夹下。
相关文档可以参考[论文复现赛指南3.2章节](../../../docs/lwfx/ArticleReproduction_CV.md),代码可以参考`基于ImageNet准备小数据集的脚本`:[prepare.py](https://github.com/littletomatodonkey/AlexNet-Prod/blob/tipc/pipeline/Step2/prepare.py)。
本教程以`./images/demo.jpg`作为测试用例。
......@@ -95,9 +99,9 @@ cd models/tutorials/tipc/serving_python
<a name="2.3"></a>
### 2.3 准备服务化部署模型
#### 2.3.1 下载MobilenetV3 inference模型
#### 2.3.1 准备MobilenetV3 inference模型
参考[MobilenetV3](../../mobilenetv3_prod/Step6/README.md#2),下载inference模型
参考[MobilenetV3](../../mobilenetv3_prod/Step6/README.md#2),确保 inference 模型在当前目录下。
#### 2.3.2 准备服务化部署模型
......@@ -106,7 +110,7 @@ cd models/tutorials/tipc/serving_python
为了便于模型服务化部署,需要将静态图模型(模型结构文件:\*.pdmodel和模型参数文件:\*.pdiparams)使用paddle_serving_client.convert按如下命令转换为服务化部署模型:
```bash
python3 -m paddle_serving_client.convert --dirname {静态图模型路径} --model_filename {模型结构文件} --params_filename {模型参数文件} --serving_server {转换后的服务器端模型和配置文件存储路径} --serving_client {转换后的客户端模型和配置文件存储路径}
python -m paddle_serving_client.convert --dirname {静态图模型路径} --model_filename {模型结构文件} --params_filename {模型参数文件} --serving_server {转换后的服务器端模型和配置文件存储路径} --serving_client {转换后的客户端模型和配置文件存储路径}
```
上面命令中 "转换后的服务器端模型和配置文件" 将用于后续服务化部署。其中`paddle_serving_client.convert`命令是`paddle_serving_client` whl包内置的转换函数,无需修改。
......@@ -115,7 +119,7 @@ python3 -m paddle_serving_client.convert --dirname {静态图模型路径} --mod
针对MobileNetV3网络,将inference模型转换为服务化部署模型的示例命令如下,转换完后在本地生成**serving_server**和**serving_client**两个文件夹。本教程后续主要使用serving_server文件夹中的模型。
```bash
python3 -m paddle_serving_client.convert \
python -m paddle_serving_client.convert \
--dirname ./mobilenet_v3_small_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
......@@ -312,7 +316,7 @@ img_path = "./images/demo.jpg"
当完成服务化部署引擎初始化、数据预处理和预测结果后处理开发,则可以按如下命令启动模型预测服务:
```bash
python3 web_service.py &
python web_service.py &
```
**【实战】**
......@@ -332,7 +336,7 @@ python3 web_service.py &
客户端访问服务的命令如下:
```bash
python3 pipeline_http_client.py
python pipeline_http_client.py
```
访问成功的界面如下图:
......
......@@ -7,12 +7,12 @@
- [2.1 命令解析](#2.1)
- [2.2 配置文件和运行命令映射解析](#2.2)
- [3. 基本训练推理功能测试开发](#3)
- [2.1 准备待测试的命令](#3.1)
- [2.2 准备数据与环境](#3.2)
- [2.3 准备开发所需脚本](#3.3)
- [2.4 填写配置文件](#3.4)
- [2.5 验证配置正确性](#3.5)
- [2.6 撰写说明文档](#3.6)
- [3.1 准备待测试的命令](#3.1)
- [3.2 准备数据与环境](#3.2)
- [3.3 准备开发所需脚本](#3.3)
- [3.4 填写配置文件](#3.4)
- [3.5 验证配置正确性](#3.5)
- [3.6 撰写说明文档](#3.6)
- [4. FAQ](#4)
<a name="1"></a>
......@@ -167,7 +167,7 @@ python3.7 pipeline_http_client.py --image_dir=../../lite_data/test/
相关文档可以参考[论文复现赛指南3.2章节](../../../docs/lwfx/ArticleReproduction_CV.md),代码可以参考`基于ImageNet准备小数据集的脚本`:[prepare.py](https://github.com/littletomatodonkey/AlexNet-Prod/blob/tipc/pipeline/Step2/prepare.py)。
2. 环境:安装好PaddlePaddle即可进行离线量化训练推理测试开发
2. 环境:安装好PaddlePaddle即可进行python服务化部署测试开发
**【注意事项】**
......@@ -226,7 +226,7 @@ Run failed with command - python3.7 pipeline_http_client.py > ../../log/mobilene
**【实战】**
以mobilenet_v3_small的`Linux GPU/CPU 离线量化训练推理功能测试` 为例,命令如下所示。
以mobilenet_v3_small的`Linux GPU/CPU python服务化部署功能测试` 为例,命令如下所示。
```bash
bash test_tipc/test_serving_infer_python.sh test_tipc/configs/mobilenet_v3_small/serving_infer_python.txt serving_infer
......@@ -251,7 +251,7 @@ Run successfully with command - python3.7 pipeline_http_client.py > ../../log/mo
撰写TIPC功能总览和测试流程说明文档,分别为
1. TIPC功能总览文档:test_tipc/README.md
2. Linux GPU/CPU 离线量化训练推理功能测试说明文档:test_tipc/docs/test_serving_infer_python.md
2. Linux GPU/CPU python服务化部署功能测试说明文档:test_tipc/docs/test_serving_infer_python.md
2个文档模板分别位于下述位置,可以直接拷贝到自己的repo中,根据自己的模型进行修改。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册