未验证 提交 17963299 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #37 from PaddlePaddle/develop

Sync develop
......@@ -42,7 +42,7 @@ We consider deploying deep learning inference service online to be a user-facing
- Any model trained by [PaddlePaddle](https://github.com/paddlepaddle/paddle) can be directly used or [Model Conversion Interface](./doc/SAVE.md) for online deployment of Paddle Serving.
- Support [Multi-model Pipeline Deployment](./doc/PIPELINE_SERVING.md), and provide the requirements of the REST interface and RPC interface itself, [Pipeline example](./python/examples/pipeline).
- Support the model zoos from the Paddle ecosystem, such as [PaddleDetection](./python/examples/detection), [PaddleOCR](./python/examples/ocr), [PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/tools/recserving/movie_recommender).
- Support the model zoos from the Paddle ecosystem, such as [PaddleDetection](./python/examples/detection), [PaddleOCR](./python/examples/ocr), [PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/recserving/movie_recommender).
- Provide a variety of pre-processing and post-processing to facilitate users in training, deployment and other stages of related code, bridging the gap between AI developers and application developers, please refer to
[Serving Examples](./python/examples/).
......
......@@ -44,7 +44,7 @@ Paddle Serving 旨在帮助深度学习开发者轻易部署在线预测服务
- 任何经过[PaddlePaddle](https://github.com/paddlepaddle/paddle)训练的模型,都可以经过直接保存或是[模型转换接口](./doc/SAVE_CN.md),用于Paddle Serving在线部署。
- 支持[多模型串联服务部署](./doc/PIPELINE_SERVING_CN.md), 同时提供Rest接口和RPC接口以满足您的需求,[Pipeline示例](./python/examples/pipeline)
- 支持Paddle生态的各大模型库, 例如[PaddleDetection](./python/examples/detection)[PaddleOCR](./python/examples/ocr)[PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/tools/recserving/movie_recommender)
- 支持Paddle生态的各大模型库, 例如[PaddleDetection](./python/examples/detection)[PaddleOCR](./python/examples/ocr)[PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/recserving/movie_recommender)
- 提供丰富多彩的前后处理,方便用户在训练、部署等各阶段复用相关代码,弥合AI开发者和应用开发者之间的鸿沟,详情参考[模型示例](./python/examples/)
<p align="center">
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_detection_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
/*
#include "opencv2/imgcodecs/legacy/constants_c.h"
#include "opencv2/imgproc/types_c.h"
*/
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralDetectionOp::inference() {
VLOG(2) << "Going to run inference";
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
<< ") can only have one predecessor op, but received "
<< pre_node_names.size();
return -1;
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
}
output_blob->SetLogId(log_id);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op:" << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector* out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
output_blob->_batch_size = batch_size;
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
std::vector<int> input_shape;
int in_num =0;
void* databuf_data = NULL;
char* databuf_char = NULL;
size_t databuf_size = 0;
std::string* input_ptr = static_cast<std::string*>(in->at(0).data.data());
std::string base64str = input_ptr[0];
float ratio_h{};
float ratio_w{};
cv::Mat img = Base2Mat(base64str);
cv::Mat srcimg;
cv::Mat resize_img;
cv::Mat resize_img_rec;
cv::Mat crop_img;
img.copyTo(srcimg);
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_det, this->scale_det,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
TensorVector* real_in = new TensorVector();
if (!real_in) {
LOG(ERROR) << "real_in is nullptr,error";
return -1;
}
for (int i = 0; i < in->size(); ++i) {
input_shape = {1, 3, resize_img.rows, resize_img.cols};
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());
databuf_size = in_num*sizeof(float);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
memcpy(databuf_data,input.data(),databuf_size);
databuf_char = reinterpret_cast<char*>(databuf_data);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
paddle::PaddleTensor tensor_in;
tensor_in.name = in->at(i).name;
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
tensor_in.lod = in->at(i).lod;
tensor_in.data = paddleBuf;
real_in->push_back(tensor_in);
}
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
if (InferManager::instance().infer(
engine_name().c_str(), real_in, out, batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
std::vector<int> output_shape;
int out_num =0;
void* databuf_data_out = NULL;
char* databuf_char_out = NULL;
size_t databuf_size_out = 0;
//this is special add for PaddleOCR postprecess
int infer_outnum = out->size();
for (int k = 0;k <infer_outnum; ++k) {
int n2 = out->at(k).shape[2];
int n3 = out->at(k).shape[3];
int n = n2 * n3;
float* out_data = static_cast<float*>(out->at(k).data.data());
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
const double threshold = this->det_db_thresh_ * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map,
this->det_db_box_thresh_,
this->det_db_unclip_ratio_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(img, boxes[i]);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
this->resize_op_rec.Run(crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_);
this->normalize_op_.Run(&resize_img_rec, this->mean_rec, this->scale_rec,
this->is_scale_);
std::vector<float> output_rec(1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f);
this->permute_op_.Run(&resize_img_rec, output_rec.data());
// Inference.
output_shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols};
out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
databuf_size_out = out_num*sizeof(float);
databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out);
if (!databuf_data_out) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size_out;
return -1;
}
memcpy(databuf_data_out,output_rec.data(),databuf_size_out);
databuf_char_out = reinterpret_cast<char*>(databuf_data_out);
paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out);
paddle::PaddleTensor tensor_out;
tensor_out.name = "image";
tensor_out.dtype = paddle::PaddleDType::FLOAT32;
tensor_out.shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols};
tensor_out.data = paddleBuf;
out->push_back(tensor_out);
}
}
out->erase(out->begin(),out->begin()+infer_outnum);
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
}
cv::Mat GeneralDetectionOp::Base2Mat(std::string &base64_data)
{
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR);//CV_LOAD_IMAGE_COLOR
return img;
}
std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte)
{
const char DecodeTable[] =
{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte)
{
if (*Data != '\r' && *Data != '\n')
{
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=')
{
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=')
{
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
}
else// 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
}
cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box) {
cv::Mat image;
srcimage.copyTo(image);
std::vector<std::vector<int>> points = box;
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
cv::Mat img_crop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
for (int i = 0; i < points.size(); i++) {
points[i][0] -= left;
points[i][1] -= top;
}
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
pow(points[0][1] - points[1][1], 2)));
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
pow(points[0][1] - points[3][1], 2)));
cv::Point2f pts_std[4];
pts_std[0] = cv::Point2f(0., 0.);
pts_std[1] = cv::Point2f(img_crop_width, 0.);
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
pts_std[3] = cv::Point2f(0.f, img_crop_height);
cv::Point2f pointsf[4];
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M,
cv::Size(img_crop_width, img_crop_height),
cv::BORDER_REPLICATE);
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
cv::transpose(dst_img, srcCopy);
cv::flip(srcCopy, srcCopy, 0);
return srcCopy;
} else {
return dst_img;
}
}
DEFINE_OP(GeneralDetectionOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
\ No newline at end of file
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include <numeric>
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/tools/ocrtools/postprocess_op.h"
#include "core/predictor/tools/ocrtools/preprocess_op.h"
#include "paddle_inference_api.h" // NOLINT
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralDetectionOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralDetectionOp);
int inference();
private:
//config info
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
// pre-process
PaddleOCR::ResizeImgType0 resize_op_;
PaddleOCR::Normalize normalize_op_;
PaddleOCR::Permute permute_op_;
PaddleOCR::CrnnResizeImg resize_op_rec;
bool use_tensorrt_ = false;
bool use_fp16_ = false;
// post-process
PaddleOCR::PostProcessor post_processor_;
//det config info
int max_side_len_ = 960;
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5;
double det_db_unclip_ratio_ = 2.0;
std::vector<float> mean_det = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_det = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
//rec config info
std::vector<std::string> label_list_;
std::vector<float> mean_rec = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_rec = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box);
cv::Mat Base2Mat(std::string &base64_data);
std::string base64Decode(const char* Data, int DataByte);
std::vector<std::vector<std::vector<int>>> boxes;
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -13,8 +13,10 @@
// limitations under the License.
#pragma once
#include <string>
#include <algorithm>
#include <cctype>
#include <fstream>
#include <string>
#include "core/predictor/common/inner_common.h"
#include "core/predictor/common/macros.h"
......@@ -26,6 +28,38 @@ namespace predictor {
namespace butil = base;
#endif
enum class Precision {
kUnk = -1, // unknown type
kFloat32 = 0, // fp32
kInt8, // int8
kHalf, // fp16
kBfloat16, // bf16
};
static std::string PrecisionTypeString(const Precision data_type) {
switch (data_type) {
case Precision::kFloat32:
return "kFloat32";
case Precision::kInt8:
return "kInt8";
case Precision::kHalf:
return "kHalf";
case Precision::kBfloat16:
return "kBloat16";
default:
return "unUnk";
}
}
static std::string ToLower(const std::string& data) {
std::string result = data;
std::transform(
result.begin(), result.end(), result.begin(), [](unsigned char c) {
return tolower(c);
});
return result;
}
class TimerFlow {
public:
static const int MAX_SIZE = 1024;
......
......@@ -4,7 +4,7 @@
This document will use an example of text classification task based on IMDB dataset to show how to build a A/B Test framework using Paddle Serving. The structure relationship between the client and servers in the example is shown in the figure below.
<img src="abtest.png" style="zoom:33%;" />
<img src="abtest.png" style="zoom:25%;" />
Note that: A/B Test is only applicable to RPC mode, not web mode.
......@@ -88,7 +88,7 @@ with open('processed.data') as f:
cnt[tag]['total'] += 1
for tag, data in cnt.items():
print('[{}](total: {}) acc: {}'.format(tag, data['total'], float(data['acc']) / float(data['total'])))
print('[{}]<total: {}> acc: {}'.format(tag, data['total'], float(data['acc']) / float(data['total'])))
```
In the code, the function `client.add_variant(tag, clusters, variant_weight)` is to add a variant with label `tag` and flow weight `variant_weight`. In this example, a BOW variant with label of `bow` and flow weight of `10`, and an LSTM variant with label of `lstm` and a flow weight of `90` are added. The flow on the client side will be distributed to two variants according to the ratio of `10:90`.
......@@ -98,8 +98,8 @@ When making prediction on the client side, if the parameter `need_variant_tag=Tr
### Expected Results
Due to different network conditions, the results of each prediction may be slightly different.
``` python
[lstm](total: 1867) acc: 0.490091055169
[bow](total: 217) acc: 0.73732718894
[lstm]<total: 1867> acc: 0.490091055169
[bow]<total: 217> acc: 0.73732718894
```
<!--
......
......@@ -92,7 +92,7 @@ with open('processed.data') as f:
cnt[tag]['total'] += 1
for tag, data in cnt.items():
print('[{}](total: {}) acc: {}'.format(tag, data['total'], float(data['acc'])/float(data['total']) ))
print('[{}]<total: {}> acc: {}'.format(tag, data['total'], float(data['acc'])/float(data['total']) ))
```
代码中,`client.add_variant(tag, clusters, variant_weight)`是为了添加一个标签为`tag`、流量权重为`variant_weight`的variant。在这个样例中,添加了一个标签为`bow`、流量权重为`10`的BOW variant,以及一个标签为`lstm`、流量权重为`90`的LSTM variant。Client端的流量会根据`10:90`的比例分发到两个variant。
......@@ -101,6 +101,6 @@ Client端做预测时,若指定参数`need_variant_tag=True`,返回值则包
### 预期结果
由于网络情况的不同,可能每次预测的结果略有差异。
``` bash
[lstm](total: 1867) acc: 0.490091055169
[bow](total: 217) acc: 0.73732718894
[lstm]<total: 1867> acc: 0.490091055169
[bow]<total: 217> acc: 0.73732718894
```
......@@ -132,7 +132,7 @@ Please install pre-commit, which automatically reformat the changes to C/C++ and
Please remember to add related unit tests.
- For C/C++ code, please follow [`google-test` Primer](https://github.com/google/googletest/blob/master/googletest/docs/primer.md) .
- For C/C++ code, please follow [`google-test` Primer](https://github.com/google/googletest/blob/master/docs/primer.md) .
- For Python code, please use [Python's standard `unittest` package](http://pythontesting.net/framework/unittest/unittest-introduction/).
......
......@@ -7,11 +7,10 @@
There are two examples on CTR under python / examples, they are criteo_ctr, criteo_ctr_with_cube. The former is to save the entire model during training, including sparse parameters. The latter is to cut out the sparse parameters and save them into two parts, one is the sparse parameter and the other is the dense parameter. Because the scale of sparse parameters is very large in industrial cases, reaching the order of 10 ^ 9. Therefore, it is not practical to start large-scale sparse parameter prediction on one machine. Therefore, we introduced Baidu's industrial-grade product Cube to provide the sparse parameter service for many years to provide distributed sparse parameter services.
The local mode of Cube is different from distributed Cube, which is designed to be convenient for developers to use in experiments and demos.
<!--If there is a demand for distributed sparse parameter service, please continue reading [Distributed Cube User Guide](./Distributed_Cube) after reading this document (still developing).-->
<!--If there is a demand for distributed sparse parameter service, please continue reading [Quantization Storage on Cube Sparse Parameter Indexing](./CUBE_QUANT.md) after reading this document (still developing).-->
This document uses the original model without any compression algorithm. If there is a need for a quantitative model to go online, please read the [Quantization Storage on Cube Sparse Parameter Indexing](./CUBE_QUANT.md)
## Example
in directory python/example/criteo_ctr_with_cube, run
......
......@@ -6,7 +6,7 @@
在python/examples下有两个关于CTR的示例,他们分别是criteo_ctr, criteo_ctr_with_cube。前者是在训练时保存整个模型,包括稀疏参数。后者是将稀疏参数裁剪出来,保存成两个部分,一个是稀疏参数,另一个是稠密参数。由于在工业级的场景中,稀疏参数的规模非常大,达到10^9数量级。因此在一台机器上启动大规模稀疏参数预测是不实际的,因此我们引入百度多年来在稀疏参数索引领域的工业级产品Cube,提供分布式的稀疏参数服务。
<!--单机版Cube是分布式Cube的弱化版本,旨在方便开发者做实验和Demo时使用。如果有分布式稀疏参数服务的需求,请在读完此文档之后,继续阅读 [稀疏参数索引服务Cube使用指南](分布式Cube)(正在建设中)。-->
<!--单机版Cube是分布式Cube的弱化版本,旨在方便开发者做实验和Demo时使用。如果有分布式稀疏参数服务的需求,请在读完此文档之后,继续阅读 [稀疏参数索引服务Cube使用指南](CUBE_LOCAL_CN.md)(正在建设中)。-->
本文档使用的都是未经过任何压缩算法处理的原始模型,如果有量化模型上线需求,请阅读[Cube稀疏参数索引量化存储使用指南](./CUBE_QUANT_CN.md)
......
......@@ -70,7 +70,7 @@ The inference framework of the well-known deep learning platform only supports C
> Model conversion across deep learning platforms
Models trained on other deep learning platforms can be passed《[PaddlePaddle/X2Paddle工具](https://github.com/PaddlePaddle/X2Paddle)》.We convert multiple mainstream CV models to Paddle models. TensorFlow, Caffe, ONNX, PyTorch model conversion is tested.《[An End-to-end Tutorial from Training to Inference Service Deployment](TRAIN_TO_SERVICE.md)
Models trained on other deep learning platforms can be passed《[PaddlePaddle/X2Paddle工具](https://github.com/PaddlePaddle/X2Paddle)》.We convert multiple mainstream CV models to Paddle models. TensorFlow, Caffe, ONNX, PyTorch model conversion is tested.《[AIStudio教程-Paddle Serving服务化部署框架](https://www.paddlepaddle.org.cn/tutorials/projectdetail/1555945)
Because it is impossible to directly view the feed and fetch parameter information in the model file, it is not convenient for users to assemble the parameters. Therefore, Paddle Serving developed a tool to convert the Paddle model into Serving format and generate a prototxt file containing feed and fetch parameter information. The following figure is the generated prototxt file of the uci_housing example. For more conversion methods, refer to the document《[How to save a servable model of Paddle Serving?](SAVE.md)》.
```
......
......@@ -74,7 +74,7 @@ Paddle Serving提供了4种开发语言SDK,包括Python、C++、Java、Golang
其他深度学习平台训练的模型,可以通过《[PaddlePaddle/X2Paddle工具](https://github.com/PaddlePaddle/X2Paddle)》将多个主流的CV模型转为Paddle模型,测试过TensorFlow、Caffe、ONNX、PyTorch模型转换。
以IMDB评论情感分析任务为例通过9步展示,Paddle Serving从模型的训练到部署预测服务的全流程《[端到端完成从训练到部署全流程](TRAIN_TO_SERVICE_CN.md)
以IMDB评论情感分析任务为例通过9步展示,Paddle Serving从模型的训练到部署预测服务的全流程《[AIStudio教程-Paddle Serving服务化部署框架](https://www.paddlepaddle.org.cn/tutorials/projectdetail/1555945)
由于无法直接查看模型文件中feed和fetch参数信息,不方便用户拼装参数。因此,Paddle Serving开发一个工具将Paddle模型转成Serving的格式,生成包含feed和fetch参数信息的prototxt文件。下图是uci_housing示例的生成的prototxt文件,更多转换方法参考文档《[怎样保存用于Paddle Serving的模型](SAVE_CN.md)》。
```
......
......@@ -222,9 +222,7 @@ InvalidArgumentError: Device id must be less than GPU count, but received id is:
#### Q: python编译的GCC版本与serving的版本不匹配
**A:**:1)使用[GPU docker](https://github.com/PaddlePaddle/Serving/blob/develop/doc/RUN_IN_DOCKER.md#gpunvidia-docker)解决环境问题
2)修改anaconda的虚拟环境下安装的python的gcc版本[参考](https://www.jianshu.com/p/c498b3d86f77)
**A:**:1)使用[GPU docker](https://github.com/PaddlePaddle/Serving/blob/develop/doc/RUN_IN_DOCKER.md#gpunvidia-docker)解决环境问题;2)修改anaconda的虚拟环境下安装的python的gcc版本[改变python的GCC编译环境](https://www.jianshu.com/p/c498b3d86f77)
#### Q: paddle-serving是否支持本地离线安装
......
......@@ -78,7 +78,7 @@ https://paddle-serving.bj.bcebos.com/whl/paddle_serving_app-0.0.0-py2-none-any.w
```
## ARM user
for ARM user who uses [PaddleLite](https://github.com/PaddlePaddle/PaddleLite) can download the wheel packages as follows. And ARM user should use the xpu-beta docker [DOCKER IMAGES](./DOCKER_IMAGES.md)
for ARM user who uses [Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite) can download the wheel packages as follows. And ARM user should use the xpu-beta docker [DOCKER IMAGES](./DOCKER_IMAGES.md)
**We only support Python 3.6 for Arm Users.**
### Wheel Package Links
......
......@@ -48,7 +48,7 @@ python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --po
### Nodes with multiple inputs
An example containing multiple input nodes is given in the [MODEL_ENSEMBLE_IN_PADDLE_SERVING](MODEL_ENSEMBLE_IN_PADDLE_SERVING.md). A example graph and the corresponding DAG definition code is as follows.
An example containing multiple input nodes is given in the [MODEL_ENSEMBLE_IN_PADDLE_SERVING](./deprecated/MODEL_ENSEMBLE_IN_PADDLE_SERVING.md). A example graph and the corresponding DAG definition code is as follows.
<center>
<img src='complex_dag.png' width = "480" height = "400" align="middle"/>
......
......@@ -47,7 +47,7 @@ python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --po
### 包含多个输入的节点
[Paddle Serving中的集成预测](MODEL_ENSEMBLE_IN_PADDLE_SERVING_CN.md)文档中给出了一个包含多个输入节点的样例,示意图和代码如下。
[Paddle Serving中的集成预测](./deprecated/MODEL_ENSEMBLE_IN_PADDLE_SERVING_CN.md)文档中给出了一个包含多个输入节点的样例,示意图和代码如下。
<center>
<img src='complex_dag.png' width = "480" height = "400" align="middle"/>
......
## Paddle Serving uses TensorRT
(English|[简体中文]((./TENSOR_RT_CN.md)))
(English|[简体中文](./TENSOR_RT_CN.md))
### Background
......
......@@ -14,7 +14,7 @@
**安装Git工具**: 详情参见[Git官网](https://git-scm.com/downloads)
**安装必要的C++库(可选)**:部分用户可能会在`import paddle`阶段遇见dll无法链接的问题,建议可以[安装Visual Studio社区版本](`https://visualstudio.microsoft.com/`) ,并且安装C++的相关组件。
**安装必要的C++库(可选)**:部分用户可能会在`import paddle`阶段遇见dll无法链接的问题,建议[安装Visual Studio社区版本](https://visualstudio.microsoft.com/) ,并且安装C++的相关组件。
**安装Paddle和Serving**:在Powershell,执行
......
......@@ -115,7 +115,7 @@ Server instance perspective
![Paddle Serving multi-service](../multi-service.png)
Paddle Serving instances can load multiple models at the same time, and each model uses a Service (and its configured workflow) to undertake services. You can refer to [service configuration file in Demo example](../tools/cpp_examples/demo-serving/conf/service.prototxt) to learn how to configure multiple services for the serving instance
Paddle Serving instances can load multiple models at the same time, and each model uses a Service (and its configured workflow) to undertake services. You can refer to [service configuration file in Demo example](../../tools/cpp_examples/demo-serving/conf/service.prototxt) to learn how to configure multiple services for the serving instance
#### 4.2.3 Hierarchical relationship of business scheduling
......@@ -124,7 +124,7 @@ From the client's perspective, a Paddle Serving service can be divided into thre
![Call hierarchy relationship](../multi-variants.png)
One Service corresponds to one inference model, and there is one endpoint under the model. Different versions of the model are implemented through multiple variant concepts under endpoint:
The same model prediction service can configure multiple variants, and each variant has its own downstream IP list. The client code can configure relative weights for each variant to achieve the relationship of adjusting the traffic ratio (refer to the description of variant_weight_list in [Client Configuration](./deprecated/CLIENT_CONFIGURE.md) section 3.2).
The same model prediction service can configure multiple variants, and each variant has its own downstream IP list. The client code can configure relative weights for each variant to achieve the relationship of adjusting the traffic ratio (refer to the description of variant_weight_list in [Client Configuration](../CLIENT_CONFIGURE.md) section 3.2).
![Client-side proxy function](../client-side-proxy.png)
......@@ -141,7 +141,7 @@ No matter how the communication protocol changes, the framework only needs to en
### 5.1 Data Compression Method
Baidu-rpc has built-in data compression methods such as snappy, gzip, zlib, which can be configured in the configuration file (refer to [Client Configuration](./deprecated/CLIENT_CONFIGURE.md) Section 3.1 for an introduction to compress_type)
Baidu-rpc has built-in data compression methods such as snappy, gzip, zlib, which can be configured in the configuration file (refer to [Client Configuration](../CLIENT_CONFIGURE.md) Section 3.1 for an introduction to compress_type)
### 5.2 C ++ SDK API Interface
......
......@@ -37,9 +37,24 @@ using paddle_infer::Tensor;
using paddle_infer::CreatePredictor;
DECLARE_int32(gpuid);
DECLARE_string(precision);
DECLARE_bool(use_calib);
static const int max_batch = 32;
static const int min_subgraph_size = 3;
static PrecisionType precision_type;
PrecisionType GetPrecision(const std::string& precision_data) {
std::string precision_type = predictor::ToLower(precision_data);
if (precision_type == "fp32") {
return PrecisionType::kFloat32;
} else if (precision_type == "int8") {
return PrecisionType::kInt8;
} else if (precision_type == "fp16") {
return PrecisionType::kHalf;
}
return PrecisionType::kFloat32;
}
// Engine Base
class PaddleEngineBase {
......@@ -107,9 +122,9 @@ class PaddleInferenceEngine : public PaddleEngineBase {
if (engine_conf.has_encrypted_model() && engine_conf.encrypted_model()) {
// decrypt model
std::string model_buffer, params_buffer, key_buffer;
predictor::ReadBinaryFile(model_path + "encrypt_model", &model_buffer);
predictor::ReadBinaryFile(model_path + "encrypt_params", &params_buffer);
predictor::ReadBinaryFile(model_path + "key", &key_buffer);
predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer);
predictor::ReadBinaryFile(model_path + "/encrypt_params", &params_buffer);
predictor::ReadBinaryFile(model_path + "/key", &key_buffer);
auto cipher = paddle::MakeCipher("");
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
......@@ -137,6 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase {
// 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid);
}
precision_type = GetPrecision(FLAGS_precision);
if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
......@@ -145,14 +161,24 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.EnableTensorRtEngine(1 << 20,
max_batch,
min_subgraph_size,
Config::Precision::kFloat32,
precision_type,
false,
false);
FLAGS_use_calib);
LOG(INFO) << "create TensorRT predictor";
}
if (engine_conf.has_use_lite() && engine_conf.use_lite()) {
config.EnableLiteEngine(PrecisionType::kFloat32, true);
config.EnableLiteEngine(precision_type, true);
}
if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) ||
(engine_conf.has_use_lite() && !engine_conf.use_lite() &&
engine_conf.has_use_gpu() && !engine_conf.use_gpu())) {
if (precision_type == PrecisionType::kInt8) {
config.EnableMkldnnQuantizer();
} else if (precision_type == PrecisionType::kHalf) {
config.EnableMkldnnBfloat16();
}
}
if (engine_conf.has_use_xpu() && engine_conf.use_xpu()) {
......@@ -171,7 +197,6 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.EnableMemoryOptim();
}
predictor::AutoLock lock(predictor::GlobalCreateMutex::instance());
_predictor = CreatePredictor(config);
if (NULL == _predictor.get()) {
......
......@@ -20,6 +20,8 @@ namespace paddle_serving {
namespace inference {
DEFINE_int32(gpuid, 0, "GPU device id to use");
DEFINE_string(precision, "fp32", "precision to deploy, default is fp32");
DEFINE_bool(use_calib, false, "calibration mode, default is false");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<PaddleInferenceEngine>,
......
......@@ -31,6 +31,7 @@ dirname is the folder path where the model is located. If the parameter is discr
The key is stored in the `key` file, and the encrypted model file and server-side configuration file are stored in the `encrypt_server` directory.
client-side configuration file are stored in the `encrypt_client` directory.
**Notice:** When encryption prediction is used, the model configuration and parameter folder loaded by server and client should be encrypt_server/ and encrypt_client/
## Start Encryption Service
CPU Service
```
......@@ -43,5 +44,5 @@ python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_
## Prediction
```
python test_client.py uci_housing_client/serving_client_conf.prototxt
python test_client.py encrypt_client/serving_client_conf.prototxt
```
......@@ -31,6 +31,8 @@ def serving_encryption():
密钥保存在`key`文件中,加密模型文件以及server端配置文件保存在`encrypt_server`目录下,client端配置文件保存在`encrypt_client`目录下。
**注意:** 当使用加密预测时,服务端和客户端启动加载的模型配置和参数文件夹是encrypt_server/和encrypt_client/
## 启动加密预测服务
CPU预测服务
```
......@@ -43,5 +45,5 @@ python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_
## 预测
```
python test_client.py uci_housing_client/serving_client_conf.prototxt
python test_client.py encrypt_client/
```
......@@ -15,31 +15,6 @@ tar -xzvf ocr_det.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.tar
tar xf test_imgs.tar
```
## C++ OCR Service
### Start Service
Select a startup mode according to CPU / GPU device
After the -- model parameter, the folder path of multiple model files is passed in to start the prediction service of multiple model concatenation.
```
#for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user
python -m paddle_serving_server_gpu.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
```
### Client Prediction
The pre-processing and post-processing is in the C + + server part, the image's Base64 encoded string is passed into the C + + server.
so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed.
for this case, `feed_type` should be 3(which means the data type is string),`shape` should be 1.
By passing in multiple client folder paths, the client can be started for multi model prediction.
```
python ocr_c_client_bytes.py ocr_det_client ocr_rec_client
```
## Web Service
......@@ -123,3 +98,30 @@ python rec_debugger_server.py gpu #for gpu user
```
python rec_web_client.py
```
## C++ OCR Service
**Notice:** If you need to concatenate det model and rec model, and do pre-processing and post-processing in Paddle Serving C++ framework, you need to use the C++ server compiled with WITH_OPENCV option,see the [COMPILE.md](../../../doc/COMPILE.md)
### Start Service
Select a startup mode according to CPU / GPU device
After the -- model parameter, the folder path of multiple model files is passed in to start the prediction service of multiple model concatenation.
```
#for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
```
### Client Prediction
The pre-processing and post-processing is in the C + + server part, the image's Base64 encoded string is passed into the C + + server.
so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed.
for this case, `feed_type` should be 3(which means the data type is string),`shape` should be 1.
By passing in multiple client folder paths, the client can be started for multi model prediction.
```
python ocr_cpp_client.py ocr_det_client ocr_rec_client
```
......@@ -15,31 +15,6 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.t
tar xf test_imgs.tar
```
## C++ OCR Service服务
### 启动服务
根据CPU/GPU设备选择一种启动方式
通过--model后,指定多个模型文件的文件夹路径来启动多模型串联的预测服务。
```
#for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user
python -m paddle_serving_server_gpu.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
```
### 启动客户端
由于需要在C++Server部分进行前后处理,传入C++Server的仅仅是图片的base64编码的字符串,故第一个模型的Client配置需要修改
`ocr_det_client/serving_client_conf.prototxt``feed_var`字段
对于本示例而言,`feed_type`应修改为3(数据类型为string),`shape`为1.
通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。
```
python ocr_c_client_bytes.py ocr_det_client ocr_rec_client
```
## Web Service服务
### 启动服务
......@@ -123,3 +98,29 @@ python rec_debugger_server.py gpu #for gpu user
```
python rec_web_client.py
```
## C++ OCR Service服务
**注意:** 若您需要使用Paddle Serving C++框架串联det模型和rec模型,并进行前后处理,您需要使用开启WITH_OPENCV选项编译的C++ Server,详见[COMPILE.md](../../../doc/COMPILE.md)
### 启动服务
根据CPU/GPU设备选择一种启动方式
通过--model后,指定多个模型文件的文件夹路径来启动多模型串联的预测服务。
```
#for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
```
### 启动客户端
由于需要在C++Server部分进行前后处理,传入C++Server的仅仅是图片的base64编码的字符串,故第一个模型的Client配置需要修改
`ocr_det_client/serving_client_conf.prototxt``feed_var`字段
对于本示例而言,`feed_type`应修改为3(数据类型为string),`shape`为1.
通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。
```
python ocr_cpp_client.py ocr_det_client ocr_rec_client
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
import sys
import numpy as np
import base64
import os
import cv2
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
client = Client()
# TODO:load_client need to load more than one client model.
# this need to figure out some details.
client.load_client_config(sys.argv[1:])
client.connect(["127.0.0.1:9293"])
import paddle
test_img_dir = "imgs/"
def cv2_to_base64(image):
return base64.b64encode(image) #data.tostring()).decode('utf8')
for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data = file.read()
image = cv2_to_base64(image_data)
fetch_map = client.predict(
feed={"image": image}, fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"], batch=True)
#print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
print(fetch_map)
......@@ -27,6 +27,12 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("LocalPredictor")
logger.setLevel(logging.INFO)
precision_map = {
'int8': paddle_infer.PrecisionType.Int8,
'fp32': paddle_infer.PrecisionType.Float32,
'fp16': paddle_infer.PrecisionType.Half,
}
class LocalPredictor(object):
"""
......@@ -56,6 +62,8 @@ class LocalPredictor(object):
use_trt=False,
use_lite=False,
use_xpu=False,
precision="fp32",
use_calib=False,
use_feed_fetch_ops=False):
"""
Load model configs and create the paddle predictor by Paddle Inference API.
......@@ -71,6 +79,8 @@ class LocalPredictor(object):
use_trt: use nvidia TensorRT optimization, False default
use_lite: use Paddle-Lite Engint, False default
use_xpu: run predict on Baidu Kunlun, False default
precision: precision mode, "fp32" default
use_calib: use TensorRT calibration, False default
use_feed_fetch_ops: use feed/fetch ops, False default.
"""
client_config = "{}/serving_server_conf.prototxt".format(model_path)
......@@ -88,9 +98,11 @@ class LocalPredictor(object):
logger.info(
"LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\
gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format(
model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops))
use_trt:{}, use_lite:{}, use_xpu: {}, precision: {}, use_calib: {},\
use_feed_fetch_ops:{}"
.format(model_path, use_gpu, gpu_id, use_profile, thread_num,
mem_optim, ir_optim, use_trt, use_lite, use_xpu, precision,
use_calib, use_feed_fetch_ops))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
......@@ -106,6 +118,9 @@ class LocalPredictor(object):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
precision_type = paddle_infer.PrecisionType.Float32
if precision.lower() in precision_map:
precision_type = precision_map[precision.lower()]
if use_profile:
config.enable_profile()
if mem_optim:
......@@ -121,6 +136,7 @@ class LocalPredictor(object):
config.enable_use_gpu(100, gpu_id)
if use_trt:
config.enable_tensorrt_engine(
precision_mode=precision_type,
workspace_size=1 << 20,
max_batch_size=32,
min_subgraph_size=3,
......@@ -129,7 +145,7 @@ class LocalPredictor(object):
if use_lite:
config.enable_lite_engine(
precision_mode=paddle_infer.PrecisionType.Float32,
precision_mode=precision_type,
zero_copy=True,
passes_filter=[],
ops_filter=[])
......@@ -138,6 +154,11 @@ class LocalPredictor(object):
# 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024)
if not use_gpu and not use_lite:
if precision_type == paddle_infer.PrecisionType.Int8:
config.enable_quantizer()
if precision.lower() == "bf16":
config.enable_mkldnn_bfloat16()
self.predictor = paddle_infer.create_predictor(config)
def predict(self, feed=None, fetch=None, batch=False, log_id=0):
......
......@@ -51,6 +51,16 @@ def serve_args():
"--name", type=str, default="None", help="Default service name")
parser.add_argument(
"--use_mkl", default=False, action="store_true", help="Use MKL")
parser.add_argument(
"--precision",
type=str,
default="fp32",
help="precision mode(fp32, int8, fp16, bf16)")
parser.add_argument(
"--use_calib",
default=False,
action="store_true",
help="Use TensorRT Calibration")
parser.add_argument(
"--mem_optim_off",
default=False,
......@@ -125,17 +135,16 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
if len(model) == 2 and idx == 0:
#Temporary support for OCR model,it will be completely revised later
#If you want to use this, C++ server must compile with WITH_OPENCV option.
if len(model) == 2 and idx == 0 and model[0] == 'ocr_det_model':
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
server = None
if use_multilang:
server = serving.MultiLangServer()
......@@ -148,6 +157,8 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size)
server.set_port(port)
server.set_precision(args.precision)
server.set_use_calib(args.use_calib)
server.use_encryption_model(use_encryption_model)
if args.product_name != None:
server.set_product_name(args.product_name)
......@@ -210,6 +221,8 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
server.use_mkl(use_mkl)
server.set_precision(args.precision)
server.set_use_calib(args.use_calib)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.set_max_body_size(max_body_size)
......@@ -297,7 +310,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
raise ValueError(
"The input of --model should be a dir not file.")
with open(single_model_config + "/key", "wb") as f:
f.write(key)
return True
......@@ -309,7 +323,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
raise ValueError(
"The input of --model should be a dir not file.")
with open(single_model_config + "/key", "rb") as f:
cur_key = f.read()
if key != cur_key:
......@@ -394,7 +409,9 @@ if __name__ == "__main__":
device=args.device,
use_lite=args.use_lite,
use_xpu=args.use_xpu,
ir_optim=args.ir_optim)
ir_optim=args.ir_optim,
precision=args.precision,
use_calib=args.use_calib)
web_service.run_rpc_service()
app_instance = Flask(__name__)
......
......@@ -42,24 +42,37 @@ from concurrent import futures
class Server(object):
def __init__(self):
"""
self.model_toolkit_conf:'list'=[] # The quantity of self.model_toolkit_conf is equal to the InferOp quantity/Engine--OP
self.model_conf:'collections.OrderedDict()' # Save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
self.workflow_fn:'str'="workflow.prototxt" # Only one for one Service/Workflow
self.resource_fn:'str'="resource.prototxt" # Only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
self.infer_service_fn:'str'="infer_service.prototxt" # Only one for one Service,Service--Workflow
self.model_toolkit_fn:'list'=[] # ["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
self.general_model_config_fn:'list'=[] # ["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
self.subdirectory:'list'=[] # The quantity is equal to the InferOp quantity, and name = node.name = engine.name
self.model_config_paths:'collections.OrderedDict()' # Save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
"""
self.server_handle_ = None
self.infer_service_conf = None
self.model_toolkit_conf = []#The quantity is equal to the InferOp quantity,Engine--OP
self.model_toolkit_conf = []
self.resource_conf = None
self.memory_optimization = False
self.ir_optimization = False
self.model_conf = collections.OrderedDict()# save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
self.workflow_fn = "workflow.prototxt"#only one for one Service,Workflow--Op
self.resource_fn = "resource.prototxt"#only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
self.infer_service_fn = "infer_service.prototxt"#only one for one Service,Service--Workflow
self.model_toolkit_fn = []#["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
self.general_model_config_fn = []#["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
self.subdirectory = []#The quantity is equal to the InferOp quantity, and name = node.name = engine.name
self.model_conf = collections.OrderedDict()
self.workflow_fn = "workflow.prototxt"
self.resource_fn = "resource.prototxt"
self.infer_service_fn = "infer_service.prototxt"
self.model_toolkit_fn = []
self.general_model_config_fn = []
self.subdirectory = []
self.cube_config_fn = "cube.conf"
self.workdir = ""
self.max_concurrency = 0
self.num_threads = 2
self.port = 8080
self.precision = "fp32"
self.use_calib = False
self.reload_interval_s = 10
self.max_body_size = 64 * 1024 * 1024
self.module_path = os.path.dirname(paddle_serving_server.__file__)
......@@ -71,12 +84,15 @@ class Server(object):
self.use_trt = False
self.use_lite = False
self.use_xpu = False
self.model_config_paths = collections.OrderedDict() # save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
self.model_config_paths = collections.OrderedDict()
self.product_name = None
self.container_id = None
def get_fetch_list(self,infer_node_idx = -1 ):
fetch_names = [var.alias_name for var in list(self.model_conf.values())[infer_node_idx].fetch_var]
def get_fetch_list(self, infer_node_idx=-1):
fetch_names = [
var.alias_name
for var in list(self.model_conf.values())[infer_node_idx].fetch_var
]
return fetch_names
def set_max_concurrency(self, concurrency):
......@@ -99,6 +115,12 @@ class Server(object):
def set_port(self, port):
self.port = port
def set_precision(self, precision="fp32"):
self.precision = precision
def set_use_calib(self, use_calib=False):
self.use_calib = use_calib
def set_reload_interval(self, interval):
self.reload_interval_s = interval
......@@ -172,6 +194,10 @@ class Server(object):
engine.use_trt = self.use_trt
engine.use_lite = self.use_lite
engine.use_xpu = self.use_xpu
engine.use_gpu = False
if self.device == "gpu":
engine.use_gpu = True
if os.path.exists('{}/__params__'.format(model_config_path)):
engine.combined_model = True
else:
......@@ -195,7 +221,8 @@ class Server(object):
self.workdir = workdir
if self.resource_conf == None:
self.resource_conf = server_sdk.ResourceConf()
for idx, op_general_model_config_fn in enumerate(self.general_model_config_fn):
for idx, op_general_model_config_fn in enumerate(
self.general_model_config_fn):
with open("{}/{}".format(workdir, op_general_model_config_fn),
"w") as fout:
fout.write(str(list(self.model_conf.values())[idx]))
......@@ -212,9 +239,11 @@ class Server(object):
if "quant" in node.name:
self.resource_conf.cube_quant_bits = 8
self.resource_conf.model_toolkit_path.extend([workdir])
self.resource_conf.model_toolkit_file.extend([self.model_toolkit_fn[idx]])
self.resource_conf.model_toolkit_file.extend(
[self.model_toolkit_fn[idx]])
self.resource_conf.general_model_path.extend([workdir])
self.resource_conf.general_model_file.extend([op_general_model_config_fn])
self.resource_conf.general_model_file.extend(
[op_general_model_config_fn])
#TODO:figure out the meaning of product_name and container_id.
if self.product_name != None:
self.resource_conf.auth_product_name = self.product_name
......@@ -237,15 +266,18 @@ class Server(object):
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
raise ValueError(
"The input of --model should be a dir not file.")
if isinstance(model_config_paths_args, list):
# If there is only one model path, use the default infer_op.
# Because there are several infer_op type, we need to find
# it from workflow_conf.
default_engine_types = [
'GeneralInferOp', 'GeneralDistKVInferOp',
'GeneralDistKVQuantInferOp','GeneralDetectionOp',
'GeneralInferOp',
'GeneralDistKVInferOp',
'GeneralDistKVQuantInferOp',
'GeneralDetectionOp',
]
# now only support single-workflow.
# TODO:support multi-workflow
......@@ -258,14 +290,22 @@ class Server(object):
)
f = open("{}/serving_server_conf.prototxt".format(
model_config_paths_args[model_config_paths_list_idx]), 'r')
self.model_conf[node.name] = google.protobuf.text_format.Merge(str(f.read()), m_config.GeneralModelConfig())
self.model_config_paths[node.name] = model_config_paths_args[model_config_paths_list_idx]
self.general_model_config_fn.append(node.name+"/general_model.prototxt")
self.model_toolkit_fn.append(node.name+"/model_toolkit.prototxt")
model_config_paths_args[model_config_paths_list_idx]),
'r')
self.model_conf[
node.name] = google.protobuf.text_format.Merge(
str(f.read()), m_config.GeneralModelConfig())
self.model_config_paths[
node.name] = model_config_paths_args[
model_config_paths_list_idx]
self.general_model_config_fn.append(
node.name + "/general_model.prototxt")
self.model_toolkit_fn.append(node.name +
"/model_toolkit.prototxt")
self.subdirectory.append(node.name)
model_config_paths_list_idx += 1
if model_config_paths_list_idx == len(model_config_paths_args):
if model_config_paths_list_idx == len(
model_config_paths_args):
break
#Right now, this is not useful.
elif isinstance(model_config_paths_args, dict):
......@@ -280,7 +320,8 @@ class Server(object):
self.model_conf[node.name] = google.protobuf.text_format.Merge(
str(f.read()), m_config.GeneralModelConfig())
else:
raise Exception("The type of model_config_paths must be str or list or "
raise Exception(
"The type of model_config_paths must be str or list or "
"dict({op: model_path}), not {}.".format(
type(model_config_paths_args)))
# check config here
......@@ -409,7 +450,7 @@ class Server(object):
resource_fn = "{}/{}".format(workdir, self.resource_fn)
self._write_pb_str(resource_fn, self.resource_conf)
for idx,single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
for idx, single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
......@@ -443,6 +484,8 @@ class Server(object):
"-max_concurrency {} " \
"-num_threads {} " \
"-port {} " \
"-precision {} " \
"-use_calib {} " \
"-reload_interval_s {} " \
"-resource_path {} " \
"-resource_file {} " \
......@@ -456,6 +499,8 @@ class Server(object):
self.max_concurrency,
self.num_threads,
self.port,
self.precision,
self.use_calib,
self.reload_interval_s,
self.workdir,
self.resource_fn,
......@@ -471,6 +516,8 @@ class Server(object):
"-max_concurrency {} " \
"-num_threads {} " \
"-port {} " \
"-precision {} " \
"-use_calib {} " \
"-reload_interval_s {} " \
"-resource_path {} " \
"-resource_file {} " \
......@@ -485,6 +532,8 @@ class Server(object):
self.max_concurrency,
self.num_threads,
self.port,
self.precision,
self.use_calib,
self.reload_interval_s,
self.workdir,
self.resource_fn,
......@@ -498,6 +547,7 @@ class Server(object):
os.system(command)
class MultiLangServer(object):
def __init__(self):
self.bserver_ = Server()
......@@ -532,6 +582,12 @@ class MultiLangServer(object):
def set_port(self, port):
self.gport_ = port
def set_precision(self, precision="fp32"):
self.precision = precision
def set_use_calib(self, use_calib=False):
self.use_calib = use_calib
def set_reload_interval(self, interval):
self.bserver_.set_reload_interval(interval)
......@@ -553,22 +609,23 @@ class MultiLangServer(object):
def set_gpuid(self, gpuid=0):
self.bserver_.set_gpuid(gpuid)
def load_model_config(self, server_config_dir_paths, client_config_path=None):
def load_model_config(self,
server_config_dir_paths,
client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
pass
else:
raise Exception("The type of model_config_paths must be str or list"
", not {}.".format(
type(server_config_dir_paths)))
", not {}.".format(type(server_config_dir_paths)))
for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
raise ValueError(
"The input of --model should be a dir not file.")
self.bserver_.load_model_config(server_config_dir_paths)
if client_config_path is None:
......@@ -576,27 +633,30 @@ class MultiLangServer(object):
if isinstance(server_config_dir_paths, dict):
self.is_multi_model_ = True
client_config_path = []
for server_config_path_items in list(server_config_dir_paths.items()):
client_config_path.append( server_config_path_items[1] )
for server_config_path_items in list(
server_config_dir_paths.items()):
client_config_path.append(server_config_path_items[1])
elif isinstance(server_config_dir_paths, list):
self.is_multi_model_ = False
client_config_path = server_config_dir_paths
else:
raise Exception("The type of model_config_paths must be str or list or "
raise Exception(
"The type of model_config_paths must be str or list or "
"dict({op: model_path}), not {}.".format(
type(server_config_dir_paths)))
if isinstance(client_config_path, str):
client_config_path = [client_config_path]
elif isinstance(client_config_path, list):
pass
else:# dict is not support right now.
raise Exception("The type of client_config_path must be str or list or "
else: # dict is not support right now.
raise Exception(
"The type of client_config_path must be str or list or "
"dict({op: model_path}), not {}.".format(
type(client_config_path)))
if len(client_config_path) != len(server_config_dir_paths):
raise Warning("The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
.format( len(client_config_path), len(server_config_dir_paths) )
)
raise Warning(
"The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
.format(len(client_config_path), len(server_config_dir_paths)))
self.bclient_config_path_list = client_config_path
def prepare_server(self,
......
......@@ -27,6 +27,7 @@ import os
from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
......@@ -36,6 +37,7 @@ def port_is_available(port):
else:
return False
class WebService(object):
def __init__(self, name="default_service"):
self.name = name
......@@ -63,7 +65,9 @@ class WebService(object):
def run_service(self):
self._server.run_server()
def load_model_config(self, server_config_dir_paths, client_config_path=None):
def load_model_config(self,
server_config_dir_paths,
client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
......@@ -73,13 +77,15 @@ class WebService(object):
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
raise ValueError(
"The input of --model should be a dir not file.")
self.server_config_dir_paths = server_config_dir_paths
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
file_path_list = []
for single_model_config in self.server_config_dir_paths:
file_path_list.append( "{}/serving_server_conf.prototxt".format(single_model_config) )
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[0], 'r')
......@@ -109,7 +115,9 @@ class WebService(object):
mem_optim=True,
use_lite=False,
use_xpu=False,
ir_optim=False):
ir_optim=False,
precision="fp32",
use_calib=False):
device = "gpu"
if gpuid == -1:
if use_lite:
......@@ -140,13 +148,16 @@ class WebService(object):
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.set_device(device)
server.set_precision(precision)
server.set_use_calib(use_calib)
if use_lite:
server.set_lite()
if use_xpu:
server.set_xpu()
server.load_model_config(self.server_config_dir_paths)#brpc Server support server_config_dir_paths
server.load_model_config(self.server_config_dir_paths
) #brpc Server support server_config_dir_paths
if gpuid >= 0:
server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device)
......@@ -159,6 +170,8 @@ class WebService(object):
workdir="",
port=9393,
device="gpu",
precision="fp32",
use_calib=False,
use_lite=False,
use_xpu=False,
ir_optim=False,
......@@ -188,7 +201,9 @@ class WebService(object):
mem_optim=mem_optim,
use_lite=use_lite,
use_xpu=use_xpu,
ir_optim=ir_optim))
ir_optim=ir_optim,
precision=precision,
use_calib=use_calib))
else:
for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append(
......@@ -200,7 +215,9 @@ class WebService(object):
mem_optim=mem_optim,
use_lite=use_lite,
use_xpu=use_xpu,
ir_optim=ir_optim))
ir_optim=ir_optim,
precision=precision,
use_calib=use_calib))
def _launch_web_service(self):
gpu_num = len(self.gpus)
......@@ -297,9 +314,13 @@ class WebService(object):
# default self.gpus = [0].
if len(self.gpus) == 0:
self.gpus.append(0)
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=True, gpu_id=self.gpus[0])
self.client.load_model_config(
self.server_config_dir_paths[0],
use_gpu=True,
gpu_id=self.gpus[0])
else:
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=False)
self.client.load_model_config(
self.server_config_dir_paths[0], use_gpu=False)
def run_web_service(self):
print("This API will be deprecated later. Please do not use it")
......
......@@ -238,6 +238,8 @@ class PipelineServer(object):
"devices": "",
"mem_optim": True,
"ir_optim": False,
"precision": "fp32",
"use_calib": False,
},
}
for op in self._used_op:
......@@ -394,6 +396,8 @@ class ServerYamlConfChecker(object):
"devices": "",
"mem_optim": True,
"ir_optim": False,
"precision": "fp32",
"use_calib": False,
}
conf_type = {
"model_config": str,
......@@ -403,6 +407,8 @@ class ServerYamlConfChecker(object):
"devices": str,
"mem_optim": bool,
"ir_optim": bool,
"precision": str,
"use_calib": bool,
}
conf_qualification = {"thread_num": (">=", 1), }
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
......
......@@ -8,11 +8,26 @@ echo "# #"
echo "# #"
echo "# #"
echo "################################################################"
export GOPATH=$HOME/go
export PATH=$PATH:$GOROOT/bin:$GOPATH/bin
export CUDA_INCLUDE_DIRS=/usr/local/cuda-10.2/include
export PYTHONROOT=/usr/local
export PYTHONIOENCODING=utf-8
build_path=/workspace/Serving/
error_words="Fail|DENIED|UNKNOWN|None"
log_dir=${build_path}logs/
data=/root/.cache/serving_data/
dir=`pwd`
RED_COLOR='\E[1;31m'
GREEN_COLOR='\E[1;32m'
YELOW_COLOR='\E[1;33m'
RES='\E[0m'
cuda_version=`cat /usr/local/cuda/version.txt`
if [ $? -ne 0 ]; then
cuda_version=11
fi
go env -w GO111MODULE=on
go env -w GOPROXY=https://goproxy.cn,direct
......@@ -21,27 +36,27 @@ go get -u github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2
go get -u github.com/golang/protobuf/protoc-gen-go@v1.4.3
go get -u google.golang.org/grpc@v1.33.0
build_path=/workspace/Serving/
build_whl_list=(build_gpu_server build_client build_cpu_server build_app)
rpc_model_list=(grpc_impl pipeline_imagenet bert_rpc_gpu bert_rpc_cpu ResNet50_rpc lac_rpc \
cnn_rpc bow_rpc lstm_rpc fit_a_line_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \
criteo_ctr_rpc_cpu criteo_ctr_rpc_gpu ocr_rpc yolov4_rpc_gpu)
http_model_list=(fit_a_line_http lac_http cnn_http bow_http lstm_http ResNet50_http bert_http)
build_whl_list=(build_cpu_server build_gpu_server build_client build_app)
rpc_model_list=(grpc_fit_a_line grpc_yolov4 pipeline_imagenet bert_rpc_gpu bert_rpc_cpu ResNet50_rpc \
lac_rpc cnn_rpc bow_rpc lstm_rpc fit_a_line_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \
criteo_ctr_rpc_cpu criteo_ctr_rpc_gpu ocr_rpc yolov4_rpc_gpu faster_rcnn_hrnetv2p_w18_1x_encrypt)
http_model_list=(fit_a_line_http lac_http cnn_http bow_http lstm_http ResNet50_http bert_http\
pipeline_ocr_cpu_http)
function setproxy(){
function setproxy() {
export http_proxy=${proxy}
export https_proxy=${proxy}
}
function unsetproxy(){
function unsetproxy() {
unset http_proxy
unset https_proxy
}
function kill_server_process(){
kill `ps -ef|grep $1 |awk '{print $2}'`
kill `ps -ef|grep serving |awk '{print $2}'`
function kill_server_process() {
kill `ps -ef | grep serving | awk '{print $2}'` > /dev/null 2>&1
kill `ps -ef | grep python | awk '{print $2}'` > /dev/null 2>&1
echo -e "${GREEN_COLOR}process killed...${RES}"
}
function check() {
......@@ -64,34 +79,93 @@ function check() {
}
function check_result() {
if [ $? -ne 0 ];then
echo -e "\033[4;31;42m$1 model runs failed, please check your pull request or modify test case! \033[0m"
exit 1
if [ $? == 0 ]; then
echo -e "${GREEN_COLOR}$1 execute normally${RES}"
if [ $1 == "server" ]; then
sleep $2
tail ${dir}server_log.txt | tee -a ${log_dir}server_total.txt
fi
if [ $1 == "client" ]; then
tail ${dir}client_log.txt | tee -a ${log_dir}client_total.txt
grep -E "${error_words}" ${dir}client_log.txt > /dev/null
if [ $? == 0 ]; then
echo -e "${RED_COLOR}$1 error command${RES}\n" | tee -a ${log_dir}server_total.txt ${log_dir}client_total.txt
error_log $2
else
echo -e "\033[4;37;42m$1 model runs successfully, congratulations! \033[0m"
echo -e "${GREEN_COLOR}$2${RES}\n" | tee -a ${log_dir}server_total.txt ${log_dir}client_total.txt
fi
fi
else
echo -e "${RED_COLOR}$1 error command${RES}\n" | tee -a ${log_dir}server_total.txt ${log_dir}client_total.txt
tail ${dir}client_log.txt | tee -a ${log_dir}client_total.txt
error_log $2
fi
}
function error_log() {
arg=${1//\//_}
echo "-----------------------------" | tee -a ${log_dir}error_models.txt
arg=${arg%% *}
arr=(${arg//_/ })
if [ ${arr[@]: -1} == 1 -o ${arr[@]: -1} == 2 ]; then
model=${arr[*]:0:${#arr[*]}-3}
deployment=${arr[*]: -3}
else
model=${arr[*]:0:${#arr[*]}-2}
deployment=${arr[*]: -2}
fi
echo "model: ${model// /_}" | tee -a ${log_dir}error_models.txt
echo "deployment: ${deployment// /_}" | tee -a ${log_dir}error_models.txt
echo "py_version: python3.6" | tee -a ${log_dir}error_models.txt
echo "cuda_version: ${cuda_version}" | tee -a ${log_dir}error_models.txt
echo "status: Failed" | tee -a ${log_dir}error_models.txt
echo -e "-----------------------------\n\n" | tee -a ${log_dir}error_models.txt
prefix=${arg//\//_}
for file in ${dir}*
do
cp ${file} ${log_dir}error/${prefix}_${file##*/}
done
}
function check_dir() {
if [ ! -d "$1" ]
then
mkdir -p $1
fi
}
function before_hook(){
function link_data() {
for file in $1*
do
if [ ! -h ${file##*/} ]
then
ln -s ${file} ./${file##*/}
fi
done
}
function before_hook() {
setproxy
unsetproxy
cd ${build_path}/python
pip3.6 install --upgrade pip
pip3.6 install requests
pip3.6 install -r requirements.txt
pip3.6 install numpy==1.16.4
python3.6 -m pip install --upgrade pip
python3.6 -m pip install requests
python3.6 -m pip install -r requirements.txt -i https://mirror.baidu.com/pypi/simple
python3.6 -m pip install numpy==1.16.4
python3.6 -m pip install paddlehub -i https://mirror.baidu.com/pypi/simple
echo "before hook configuration is successful.... "
}
function run_env(){
function run_env() {
setproxy
pip3.6 install --upgrade nltk==3.4
pip3.6 install --upgrade scipy==1.2.1
pip3.6 install --upgrade setuptools==41.0.0
pip3.6 install paddlehub ujson paddlepaddle==2.0.0
python3.6 -m pip install --upgrade nltk==3.4
python3.6 -m pip install --upgrade scipy==1.2.1
python3.6 -m pip install --upgrade setuptools==41.0.0
python3.6 -m pip install paddlehub ujson paddlepaddle==2.0.0
echo "run env configuration is successful.... "
}
function run_gpu_env(){
function run_gpu_env() {
cd ${build_path}
export LD_LIBRARY_PATH=/usr/local/lib64/python3.6/site-packages/paddle/libs/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/workspace/Serving/build_gpu/third_party/install/Paddle/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mklml/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mkldnn/lib/:$LD_LIBRARY_PATH
......@@ -99,14 +173,6 @@ function run_gpu_env(){
echo "run gpu env configuration is successful.... "
}
function run_cpu_env(){
cd ${build_path}
export LD_LIBRARY_PATH=/usr/local/lib64/python3.6/site-packages/paddle/libs/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/workspace/Serving/build_cpu/third_party/install/Paddle/lib/:$LD_LIBRARY_PATH
export SERVING_BIN=${build_path}/build_cpu/core/general-server/serving
echo "run cpu env configuration is successful.... "
}
function build_gpu_server() {
setproxy
cd ${build_path}
......@@ -124,18 +190,21 @@ function build_gpu_server() {
-DSERVER=ON \
-DTENSORRT_ROOT=/usr \
-DWITH_GPU=ON ..
make -j18
make -j18
make install -j18
pip3.6 uninstall paddle-serving-server-gpu -y
pip3.6 install ${build_path}/build/python/dist/*
make -j32
make -j32
make install -j32
python3.6 -m pip uninstall paddle-serving-server-gpu -y
python3.6 -m pip install ${build_path}/build/python/dist/*
cp ${build_path}/build/python/dist/* ../
cp -r ${build_path}/build/ ${build_path}/build_gpu
}
function build_client() {
function build_cpu_server(){
setproxy
cd ${build_path}
if [ -d build_cpu ];then
rm -rf build_cpu
fi
if [ -d build ];then
rm -rf build
fi
......@@ -143,20 +212,20 @@ function build_client() {
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
-DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \
-DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
-DCLIENT=ON ..
make -j18
make -j18
-DWITH_GPU=OFF \
-DSERVER=ON ..
make -j32
make -j32
make install -j32
cp ${build_path}/build/python/dist/* ../
pip3.6 uninstall paddle-serving-client -y
pip3.6 install ${build_path}/build/python/dist/*
python3.6 -m pip uninstall paddle-serving-server -y
python3.6 -m pip install ${build_path}/build/python/dist/*
cp -r ${build_path}/build/ ${build_path}/build_cpu
}
function build_cpu_server(){
function build_client() {
setproxy
cd ${build_path}
if [ -d build_cpu ];then
rm -rf build_cpu
fi
if [ -d build ];then
rm -rf build
fi
......@@ -164,21 +233,18 @@ function build_cpu_server(){
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
-DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \
-DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
-DWITH_GPU=OFF \
-DSERVER=ON ..
make -j18
make -j18
make install -j18
-DCLIENT=ON ..
make -j32
make -j32
cp ${build_path}/build/python/dist/* ../
pip3.6 uninstall paddle-serving-server -y
pip3.6 install ${build_path}/build/python/dist/*
cp -r ${build_path}/build/ ${build_path}/build_cpu
python3.6 -m pip uninstall paddle-serving-client -y
python3.6 -m pip install ${build_path}/build/python/dist/*
}
function build_app() {
setproxy
pip3.6 install paddlehub ujson Pillow
pip3.6 install paddlepaddle==2.0.0
python3.6 -m pip install paddlehub ujson Pillow
python3.6 -m pip install paddlepaddle==2.0.0
cd ${build_path}
if [ -d build ];then
rm -rf build
......@@ -190,458 +256,520 @@ function build_app() {
-DCMAKE_INSTALL_PREFIX=./output -DAPP=ON ..
make
cp ${build_path}/build/python/dist/* ../
pip3.6 uninstall paddle-serving-app -y
pip3.6 install ${build_path}/build/python/dist/*
python3.6 -m pip uninstall paddle-serving-app -y
python3.6 -m pip install ${build_path}/build/python/dist/*
}
function faster_rcnn_hrnetv2p_w18_1x_encrypt() {
dir=${log_dir}rpc_model/faster_rcnn_hrnetv2p_w18_1x/
cd ${build_path}/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x
check_dir ${dir}
data_dir=${data}detection/faster_rcnn_hrnetv2p_w18_1x/
link_data ${data_dir}
python3.6 encrypt.py
unsetproxy
echo -e "${GREEN_COLOR}faster_rcnn_hrnetv2p_w18_1x_ENCRYPTION_GPU_RPC server started${RES}" | tee -a ${log_dir}server_total.txt
python3.6 -m paddle_serving_server.serve --model encrypt_server/ --port 9494 --use_trt --gpu_ids 0 --use_encryption_model > ${dir}server_log.txt 2>&1 &
check_result server 3
echo -e "${GREEN_COLOR}faster_rcnn_hrnetv2p_w18_1x_ENCRYPTION_GPU_RPC client started${RES}" | tee -a ${log_dir}client_total.txt
python3.6 test_encryption.py 000000570688.jpg > ${dir}client_log.txt 2>&1
check_result client "faster_rcnn_hrnetv2p_w18_1x_ENCRYPTION_GPU_RPC server test completed"
kill_server_process
}
function bert_rpc_gpu(){
function pipeline_ocr_cpu_http() {
dir=${log_dir}rpc_model/pipeline_ocr_cpu_http/
check_dir ${dir}
cd ${build_path}/python/examples/pipeline/ocr
data_dir=${data}ocr/
link_data ${data_dir}
unsetproxy
echo -e "${GREEN_COLOR}pipeline_ocr_CPU_HTTP server started${RES}" | tee -a ${log_dir}server_total.txt
$py_version web_service.py > ${dir}server_log.txt 2>&1 &
check_result server 5
echo -e "${GREEN_COLOR}pipeline_ocr_CPU_HTTP client started${RES}" | tee -a ${log_dir}client_total.txt
timeout 15s $py_version pipeline_http_client.py > ${dir}client_log.txt 2>&1
check_result client "pipeline_ocr_CPU_HTTP server test completed"
kill_server_process
}
function bert_rpc_gpu() {
dir=${log_dir}rpc_model/bert_rpc_gpu/
check_dir ${dir}
run_gpu_env
unsetproxy
cd ${build_path}/python/examples/bert
data_dir=${data}bert/
link_data ${data_dir}
sh get_data.sh >/dev/null 2>&1
sed -i 's/9292/8860/g' bert_client.py
sed -i '$aprint(result)' bert_client.py
cp -r /root/.cache/dist_data/serving/bert/bert_seq128_* ./
ls -hlst
python3.6 -m paddle_serving_server_gpu.serve --model bert_seq128_model/ --port 8860 --gpu_ids 0 &
sleep 15
python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8860 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
check_result server 15
nvidia-smi
head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt
head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
check_result client "bert_GPU_RPC server test completed"
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
kill_server_process
}
function bert_rpc_cpu(){
run_cpu_env
function bert_rpc_cpu() {
dir=${log_dir}rpc_model/bert_rpc_cpu/
check_dir ${dir}
unsetproxy
cd ${build_path}/python/examples/bert
data_dir=${data}bert/
link_data ${data_dir}
sed -i 's/8860/8861/g' bert_client.py
python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8861 &
sleep 3
python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8861 > ${dir}server_log.txt 2>&1 &
check_result server 3
cp data-c.txt.1 data-c.txt
head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt
check_result $FUNCNAME
kill_server_process serving
}
function criteo_ctr_with_cube_rpc(){
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/criteo_ctr_with_cube
ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./
sed -i "s/9292/8888/g" test_server.py
sed -i "s/9292/8888/g" test_client.py
wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz >/dev/null 2>&1
tar xf ctr_cube_unittest.tar.gz
mv models/ctr_client_conf ./
mv models/ctr_serving_model_kv ./
mv models/data ./cube/
wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz >/dev/null 2>&1
tar xf cube_app.tar.gz
mv cube_app/cube* ./cube/
sh cube_prepare.sh > haha 2>&1 &
sleep 5
python3.6 test_server.py ctr_serving_model_kv &
sleep 5
python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data
check_result $FUNCNAME
kill `ps -ef|grep cube|awk '{print $2}'`
kill_server_process test_server
head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
check_result client "bert_CPU_RPC server test completed"
kill_server_process
}
function pipeline_imagenet(){
run_gpu_env
function pipeline_imagenet() {
dir=${log_dir}rpc_model/pipeline_imagenet/
check_dir ${dir}
unsetproxy
cd ${build_path}/python/examples/pipeline/imagenet
cp -r /root/.cache/dist_data/serving/imagenet/* ./
ls -a
python3.6 resnet50_web_service.py &
sleep 5
data_dir=${data}imagenet/
link_data ${data_dir}
python3.6 resnet50_web_service.py > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 pipeline_rpc_client.py
python3.6 pipeline_rpc_client.py > ${dir}client_log.txt 2>&1
check_result client "pipeline_imagenet_GPU_RPC server test completed"
nvidia-smi
# check_result $FUNCNAME
kill_server_process resnet50_web_service
kill_server_process
}
function ResNet50_rpc(){
run_gpu_env
function ResNet50_rpc() {
dir=${log_dir}rpc_model/ResNet50_rpc/
check_dir ${dir}
unsetproxy
cd ${build_path}/python/examples/imagenet
cp -r /root/.cache/dist_data/serving/imagenet/* ./
data_dir=${data}imagenet/
link_data ${data_dir}
sed -i 's/9696/8863/g' resnet50_rpc_client.py
python3.6 -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 8863 --gpu_ids 0 &
sleep 5
python3.6 -m paddle_serving_server.serve --model ResNet50_vd_model --port 8863 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt
python3.6 resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
check_result client "ResNet50_GPU_RPC server test completed"
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
kill_server_process
}
function ResNet101_rpc(){
run_gpu_env
function ResNet101_rpc() {
dir=${log_dir}rpc_model/ResNet101_rpc/
check_dir ${dir}
unsetproxy
cd ${build_path}/python/examples/imagenet
sed -i "22cclient.connect(['${host}:8864'])" image_rpc_client.py
python3.6 -m paddle_serving_server_gpu.serve --model ResNet101_vd_model --port 8864 --gpu_ids 0 &
sleep 5
data_dir=${data}imagenet/
link_data ${data_dir}
sed -i "22cclient.connect(['127.0.0.1:8864'])" image_rpc_client.py
python3.6 -m paddle_serving_server.serve --model ResNet101_vd_model --port 8864 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 image_rpc_client.py ResNet101_vd_client_config/serving_client_conf.prototxt
python3.6 image_rpc_client.py ResNet101_vd_client_config/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
check_result client "ResNet101_GPU_RPC server test completed"
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
sleep 5
kill_server_process
}
function cnn_rpc(){
function cnn_rpc() {
dir=${log_dir}rpc_model/cnn_rpc/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/imdb
cp -r /root/.cache/dist_data/serving/imdb/* ./
tar xf imdb_model.tar.gz && tar xf text_classification_data.tar.gz
data_dir=${data}imdb/
link_data ${data_dir}
sed -i 's/9292/8865/g' test_client.py
python3.6 -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 &
sleep 5
head test_data/part-0 | python3.6 test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab
check_result $FUNCNAME
kill_server_process serving
python3.6 -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 > ${dir}server_log.txt 2>&1 &
check_result server 5
head test_data/part-0 | python3.6 test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab > ${dir}client_log.txt 2>&1
check_result client "cnn_CPU_RPC server test completed"
kill_server_process
}
function bow_rpc(){
function bow_rpc() {
dir=${log_dir}rpc_model/bow_rpc/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/imdb
data_dir=${data}imdb/
link_data ${data_dir}
sed -i 's/8865/8866/g' test_client.py
python3.6 -m paddle_serving_server.serve --model imdb_bow_model/ --port 8866 &
sleep 5
head test_data/part-0 | python3.6 test_client.py imdb_bow_client_conf/serving_client_conf.prototxt imdb.vocab
check_result $FUNCNAME
kill_server_process serving
python3.6 -m paddle_serving_server.serve --model imdb_bow_model/ --port 8866 > ${dir}server_log.txt 2>&1 &
check_result server 5
head test_data/part-0 | python3.6 test_client.py imdb_bow_client_conf/serving_client_conf.prototxt imdb.vocab > ${dir}client_log.txt 2>&1
check_result client "bow_CPU_RPC server test completed"
kill_server_process
}
function lstm_rpc(){
function lstm_rpc() {
dir=${log_dir}rpc_model/lstm_rpc/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/imdb
data_dir=${data}imdb/
link_data ${data_dir}
sed -i 's/8866/8867/g' test_client.py
python3.6 -m paddle_serving_server.serve --model imdb_lstm_model/ --port 8867 &
sleep 5
head test_data/part-0 | python3.6 test_client.py imdb_lstm_client_conf/serving_client_conf.prototxt imdb.vocab
check_result $FUNCNAME
kill_server_process serving
python3.6 -m paddle_serving_server.serve --model imdb_lstm_model/ --port 8867 > ${dir}server_log.txt 2>&1 &
check_result server 5
head test_data/part-0 | python3.6 test_client.py imdb_lstm_client_conf/serving_client_conf.prototxt imdb.vocab > ${dir}client_log.txt 2>&1
check_result client "lstm_CPU_RPC server test completed"
kill_server_process
}
function lac_rpc(){
function lac_rpc() {
dir=${log_dir}rpc_model/lac_rpc/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/lac
python3.6 -m paddle_serving_app.package --get_model lac >/dev/null 2>&1
tar xf lac.tar.gz
data_dir=${data}lac/
link_data ${data_dir}
sed -i 's/9292/8868/g' lac_client.py
python3.6 -m paddle_serving_server.serve --model lac_model/ --port 8868 &
sleep 5
echo "我爱北京天安门" | python3.6 lac_client.py lac_client/serving_client_conf.prototxt lac_dict/
check_result $FUNCNAME
kill_server_process serving
python3.6 -m paddle_serving_server.serve --model lac_model/ --port 8868 > ${dir}server_log.txt 2>&1 &
check_result server 5
echo "我爱北京天安门" | python3.6 lac_client.py lac_client/serving_client_conf.prototxt lac_dict/ > ${dir}client_log.txt 2>&1
check_result client "lac_CPU_RPC server test completed"
kill_server_process
}
function fit_a_line_rpc(){
function fit_a_line_rpc() {
dir=${log_dir}rpc_model/fit_a_line_rpc/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/fit_a_line
sh get_data.sh >/dev/null 2>&1
data_dir=${data}fit_a_line/
link_data ${data_dir}
sed -i 's/9393/8869/g' test_client.py
python3.6 -m paddle_serving_server.serve --model uci_housing_model --port 8869 &
sleep 5
python3.6 test_client.py uci_housing_client/serving_client_conf.prototxt
check_result $FUNCNAME
kill_server_process serving
python3.6 -m paddle_serving_server.serve --model uci_housing_model --port 8869 > ${dir}server_log.txt 2>&1 &
check_result server 5
python3.6 test_client.py uci_housing_client/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
check_result client "fit_a_line_CPU_RPC server test completed"
kill_server_process
}
function faster_rcnn_model_rpc(){
function faster_rcnn_model_rpc() {
dir=${log_dir}rpc_model/faster_rcnn_rpc/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/faster_rcnn
cp -r /root/.cache/dist_data/serving/faster_rcnn/faster_rcnn_model.tar.gz ./
tar xf faster_rcnn_model.tar.gz
wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml >/dev/null 2>&1
mv faster_rcnn_model/pddet* ./
cd ${build_path}/python/examples/detection/faster_rcnn_r50_fpn_1x_coco
data_dir=${data}detection/faster_rcnn_r50_fpn_1x_coco/
link_data ${data_dir}
sed -i 's/9494/8870/g' test_client.py
python3.6 -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 8870 --gpu_id 0 --thread 2 &
python3.6 -m paddle_serving_server.serve --model serving_server --port 8870 --gpu_ids 0 --thread 2 --use_trt > ${dir}server_log.txt 2>&1 &
echo "faster rcnn running ..."
nvidia-smi
sleep 5
python3.6 test_client.py pddet_client_conf/serving_client_conf.prototxt infer_cfg.yml 000000570688.jpg
check_result server 10
python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
check_result client "faster_rcnn_GPU_RPC server test completed"
kill_server_process
}
function cascade_rcnn_rpc(){
function cascade_rcnn_rpc() {
dir=${log_dir}rpc_model/cascade_rcnn_rpc/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/cascade_rcnn
cp -r /root/.cache/dist_data/serving/cascade_rcnn/cascade_rcnn_r50_fpx_1x_serving.tar.gz ./
tar xf cascade_rcnn_r50_fpx_1x_serving.tar.gz
data_dir=${data}cascade_rcnn/
link_data ${data_dir}
sed -i "s/9292/8879/g" test_client.py
python3.6 -m paddle_serving_server_gpu.serve --model serving_server --port 8879 --gpu_id 0 --thread 2 &
sleep 5
python3.6 -m paddle_serving_server.serve --model serving_server --port 8879 --gpu_ids 0 --thread 2 > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 test_client.py
python3.6 test_client.py > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
check_result client "cascade_rcnn_GPU_RPC server test completed"
kill_server_process
}
function deeplabv3_rpc() {
dir=${log_dir}rpc_model/deeplabv3_rpc/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/deeplabv3
cp -r /root/.cache/dist_data/serving/deeplabv3/deeplabv3.tar.gz ./
tar xf deeplabv3.tar.gz
data_dir=${data}deeplabv3/
link_data ${data_dir}
sed -i "s/9494/8880/g" deeplabv3_client.py
python3.6 -m paddle_serving_server_gpu.serve --model deeplabv3_server --gpu_ids 0 --port 8880 --thread 2 &
sleep 5
python3.6 -m paddle_serving_server.serve --model deeplabv3_server --gpu_ids 0 --port 8880 --thread 2 > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 deeplabv3_client.py
python3.6 deeplabv3_client.py > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
check_result client "deeplabv3_GPU_RPC server test completed"
kill_server_process
}
function mobilenet_rpc() {
dir=${log_dir}rpc_model/mobilenet_rpc/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/mobilenet
python3.6 -m paddle_serving_app.package --get_model mobilenet_v2_imagenet >/dev/null 2>&1
tar xf mobilenet_v2_imagenet.tar.gz
sed -i "s/9393/8881/g" mobilenet_tutorial.py
python3.6 -m paddle_serving_server_gpu.serve --model mobilenet_v2_imagenet_model --gpu_ids 0 --port 8881 &
sleep 5
python3.6 -m paddle_serving_server.serve --model mobilenet_v2_imagenet_model --gpu_ids 0 --port 8881 > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 mobilenet_tutorial.py
python3.6 mobilenet_tutorial.py > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
check_result client "mobilenet_GPU_RPC server test completed"
kill_server_process
}
function unet_rpc() {
dir=${log_dir}rpc_model/unet_rpc/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/unet_for_image_seg
python3.6 -m paddle_serving_app.package --get_model unet >/dev/null 2>&1
tar xf unet.tar.gz
data_dir=${data}unet_for_image_seg/
link_data ${data_dir}
sed -i "s/9494/8882/g" seg_client.py
python3.6 -m paddle_serving_server_gpu.serve --model unet_model --gpu_ids 0 --port 8882 &
sleep 5
python3.6 -m paddle_serving_server.serve --model unet_model --gpu_ids 0 --port 8882 > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 seg_client.py
python3.6 seg_client.py > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
check_result client "unet_GPU_RPC server test completed"
kill_server_process
}
function resnetv2_rpc() {
dir=${log_dir}rpc_model/resnetv2_rpc/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/resnet_v2_50
cp /root/.cache/dist_data/serving/resnet_v2_50/resnet_v2_50_imagenet.tar.gz ./
tar xf resnet_v2_50_imagenet.tar.gz
data_dir=${data}resnet_v2_50/
link_data ${data_dir}
sed -i 's/9393/8883/g' resnet50_v2_tutorial.py
python3.6 -m paddle_serving_server_gpu.serve --model resnet_v2_50_imagenet_model --gpu_ids 0 --port 8883 &
sleep 10
python3.6 -m paddle_serving_server.serve --model resnet_v2_50_imagenet_model --gpu_ids 0 --port 8883 > ${dir}server_log.txt 2>&1 &
check_result server 10
nvidia-smi
python3.6 resnet50_v2_tutorial.py
python3.6 resnet50_v2_tutorial.py > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
check_result client "resnetv2_GPU_RPC server test completed"
kill_server_process
}
function ocr_rpc() {
dir=${log_dir}rpc_model/ocr_rpc/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/ocr
cp -r /root/.cache/dist_data/serving/ocr/test_imgs ./
data_dir=${data}ocr/
link_data ${data_dir}
python3.6 -m paddle_serving_app.package --get_model ocr_rec >/dev/null 2>&1
tar xf ocr_rec.tar.gz
sed -i 's/9292/8884/g' test_ocr_rec_client.py
python3.6 -m paddle_serving_server.serve --model ocr_rec_model --port 8884 &
sleep 5
python3.6 test_ocr_rec_client.py
# check_result $FUNCNAME
kill_server_process serving
python3.6 -m paddle_serving_server.serve --model ocr_rec_model --port 8884 > ${dir}server_log.txt 2>&1 &
check_result server 5
python3.6 test_ocr_rec_client.py > ${dir}client_log.txt 2>&1
check_result client "ocr_CPU_RPC server test completed"
kill_server_process
}
function criteo_ctr_rpc_cpu() {
dir=${log_dir}rpc_model/criteo_ctr_rpc_cpu/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/criteo_ctr
data_dir=${data}criteo_ctr/
link_data ${data_dir}
sed -i "s/9292/8885/g" test_client.py
ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./
wget https://paddle-serving.bj.bcebos.com/criteo_ctr_example/criteo_ctr_demo_model.tar.gz >/dev/null 2>&1
tar xf criteo_ctr_demo_model.tar.gz
mv models/ctr_client_conf .
mv models/ctr_serving_model .
python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8885 &
sleep 5
python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0
check_result $FUNCNAME
kill_server_process serving
python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8885 > ${dir}server_log.txt 2>&1 &
check_result server 5
python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0 > ${dir}client_log.txt 2>&1
check_result client "criteo_ctr_CPU_RPC server test completed"
kill_server_process
}
function criteo_ctr_rpc_gpu() {
dir=${log_dir}rpc_model/criteo_ctr_rpc_gpu/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/criteo_ctr
data_dir=${data}criteo_ctr/
link_data ${data_dir}
sed -i "s/8885/8886/g" test_client.py
wget https://paddle-serving.bj.bcebos.com/criteo_ctr_example/criteo_ctr_demo_model.tar.gz >/dev/null 2>&1
python3.6 -m paddle_serving_server_gpu.serve --model ctr_serving_model/ --port 8886 --gpu_ids 0 &
sleep 5
python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8886 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
check_result server 5
nvidia-smi
python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/
python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0 > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill `ps -ef|grep ctr|awk '{print $2}'`
kill_server_process serving
check_result client "criteo_ctr_GPU_RPC server test completed"
kill_server_process
}
function yolov4_rpc_gpu() {
dir=${log_dir}rpc_model/yolov4_rpc_gpu/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/yolov4
data_dir=${data}yolov4/
link_data ${data_dir}
sed -i "s/9393/8887/g" test_client.py
cp -r /root/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./
tar xf yolov4.tar.gz
python3.6 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 &
python3.6 -m paddle_serving_server.serve --model yolov4_model --port 8887 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
nvidia-smi
sleep 5
python3.6 test_client.py 000000570688.jpg
check_result server 5
python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
nvidia-smi
# check_result $FUNCNAME
kill_server_process serving
check_result client "yolov4_GPU_RPC server test completed"
kill_server_process
}
function senta_rpc_cpu() {
dir=${log_dir}rpc_model/senta_rpc_cpu/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/senta
data_dir=${data}senta/
link_data ${data_dir}
sed -i "s/9393/8887/g" test_client.py
cp -r /data/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./
tar xf yolov4.tar.gz
python3.6 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 &
python3.6 -m paddle_serving_server.serve --model yolov4_model --port 8887 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
nvidia-smi
sleep 5
python3.6 test_client.py 000000570688.jpg
check_result server 5
python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
nvidia-smi
check_result $FUNCNAME
kill_server_process serving
check_result client "senta_GPU_RPC server test completed"
kill_server_process
}
function fit_a_line_http() {
dir=${log_dir}http_model/fit_a_line_http/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/fit_a_line
sed -i "s/9292/8871/g" test_server.py
python3.6 test_server.py &
sleep 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://${host}:8871/uci/prediction
check_result $FUNCNAME
kill_server_process test_server
sed -i "s/9393/8871/g" test_server.py
python3.6 test_server.py > ${dir}server_log.txt 2>&1 &
check_result server 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://127.0.0.1:8871/uci/prediction > ${dir}client_log.txt 2>&1
check_result client "fit_a_line_CPU_HTTP server test completed"
kill_server_process
}
function lac_http() {
dir=${log_dir}http_model/lac_http/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/lac
python3.6 lac_web_service.py lac_model/ lac_workdir 8872 &
sleep 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' http://${host}:8872/lac/prediction
check_result $FUNCNAME
kill_server_process lac_web_service
python3.6 lac_web_service.py lac_model/ lac_workdir 8872 > ${dir}server_log.txt 2>&1 &
check_result server 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' http://127.0.0.1:8872/lac/prediction > ${dir}client_log.txt 2>&1
check_result client "lac_CPU_HTTP server test completed"
kill_server_process
}
function cnn_http() {
dir=${log_dir}http_model/cnn_http/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/imdb
python3.6 text_classify_service.py imdb_cnn_model/ workdir/ 8873 imdb.vocab &
sleep 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8873/imdb/prediction
check_result $FUNCNAME
kill_server_process text_classify_service
python3.6 text_classify_service.py imdb_cnn_model/ workdir/ 8873 imdb.vocab > ${dir}server_log.txt 2>&1 &
check_result server 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://127.0.0.1:8873/imdb/prediction > ${dir}client_log.txt 2>&1
check_result client "cnn_CPU_HTTP server test completed"
kill_server_process
}
function bow_http() {
dir=${log_dir}http_model/bow_http/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/imdb
python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8874 imdb.vocab &
sleep 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8874/imdb/prediction
check_result $FUNCNAME
kill_server_process text_classify_service
python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8874 imdb.vocab > ${dir}server_log.txt 2>&1 &
check_result server 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://127.0.0.1:8874/imdb/prediction > ${dir}client_log.txt 2>&1
check_result client "bow_CPU_HTTP server test completed"
kill_server_process
}
function lstm_http() {
dir=${log_dir}http_model/lstm_http/
check_dir ${dir}
unsetproxy
run_cpu_env
cd ${build_path}/python/examples/imdb
python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8875 imdb.vocab &
sleep 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8875/imdb/prediction
check_result $FUNCNAME
kill `ps -ef|grep imdb|awk '{print $2}'`
kill_server_process text_classify_service
python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8875 imdb.vocab > ${dir}server_log.txt 2>&1 &
check_result server 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://127.0.0.1:8875/imdb/prediction > ${dir}client_log.txt 2>&1
check_result client "lstm_CPU_HTTP server test completed"
kill_server_process
}
function ResNet50_http() {
dir=${log_dir}http_model/ResNet50_http/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/imagenet
python3.6 resnet50_web_service.py ResNet50_vd_model gpu 8876 &
sleep 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://${host}:8876/image/prediction
check_result $FUNCNAME
kill_server_process resnet50_web_service
python3.6 resnet50_web_service.py ResNet50_vd_model gpu 8876 > ${dir}server_log.txt 2>&1 &
check_result server 10
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://127.0.0.1:8876/image/prediction > ${dir}client_log.txt 2>&1
check_result client "ResNet50_GPU_HTTP server test completed"
kill_server_process
}
function bert_http(){
function bert_http() {
dir=${log_dir}http_model/ResNet50_http/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/bert
cp data-c.txt.1 data-c.txt
cp vocab.txt.1 vocab.txt
export CUDA_VISIBLE_DEVICES=0
python3.6 bert_web_service.py bert_seq128_model/ 8878 &
sleep 5
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}' http://127.0.0.1:8878/bert/prediction
check_result $FUNCNAME
kill_server_process bert_web_service
python3.6 bert_web_service.py bert_seq128_model/ 8878 > ${dir}server_log.txt 2>&1 &
check_result server 5
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}' http://127.0.0.1:8878/bert/prediction > ${dir}client_log.txt 2>&1
check_result client "bert_GPU_HTTP server test completed"
kill_server_process
}
grpc_impl(){
function grpc_fit_a_line() {
dir=${log_dir}rpc_model/grpc_fit_a_line/
check_dir ${dir}
unsetproxy
run_gpu_env
cd ${build_path}/python/examples/grpc_impl_example/fit_a_line
sh get_data.sh >/dev/null 2>&1
python3.6 test_server.py uci_housing_model/ &
sleep 5
echo "sync predict"
python3.6 test_sync_client.py
echo "async predict"
python3.6 test_asyn_client.py
echo "batch predict"
python3.6 test_batch_client.py
echo "timeout predict"
python3.6 test_timeout_client.py
# check_result $FUNCNAME
kill_server_process test_server
}
function build_all_whl(){
data_dir=${data}fit_a_line/
link_data ${data_dir}
python3.6 test_server.py uci_housing_model/ > ${dir}server_log.txt 2>&1 &
check_result server 5
echo "sync predict" > ${dir}client_log.txt 2>&1
python3.6 test_sync_client.py >> ${dir}client_log.txt 2>&1
check_result client "grpc_impl_example_fit_a_line_sync_CPU_gRPC server sync test completed"
echo "async predict" >> ${dir}client_log.txt 2>&1
python3.6 test_asyn_client.py >> ${dir}client_log.txt 2>&1
check_result client "grpc_impl_example_fit_a_line_asyn_CPU_gRPC server asyn test completed"
echo "batch predict" >> ${dir}client_log.txt 2>&1
python3.6 test_batch_client.py >> ${dir}client_log.txt 2>&1
check_result client "grpc_impl_example_fit_a_line_batch_CPU_gRPC server batch test completed"
echo "timeout predict" >> ${dir}client_log.txt 2>&1
python3.6 test_timeout_client.py >> ${dir}client_log.txt 2>&1
check_result client "grpc_impl_example_fit_a_line_timeout_CPU_gRPC server timeout test completed"
kill_server_process
}
function grpc_yolov4() {
dir=${log_dir}rpc_model/grpc_yolov4/
cd ${build_path}/python/examples/grpc_impl_example/yolov4
check_dir ${dir}
data_dir=${data}yolov4/
link_data ${data_dir}
echo -e "${GREEN_COLOR}grpc_impl_example_yolov4_GPU_gRPC server started${RES}"
python3.6 -m paddle_serving_server.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang > ${dir}server_log.txt 2>&1 &
check_result server 5
echo -e "${GREEN_COLOR}grpc_impl_example_yolov4_GPU_gRPC client started${RES}"
python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
check_result client "grpc_yolov4_GPU_GRPC server test completed"
kill_server_process
}
function build_all_whl() {
for whl in ${build_whl_list[@]}
do
echo "===========${whl} begin build==========="
......@@ -651,7 +779,7 @@ function build_all_whl(){
done
}
function run_rpc_models(){
function run_rpc_models() {
for model in ${rpc_model_list[@]}
do
echo "===========${model} run begin==========="
......@@ -661,7 +789,7 @@ function run_rpc_models(){
done
}
function run_http_models(){
function run_http_models() {
for model in ${http_model_list[@]}
do
echo "===========${model} run begin==========="
......@@ -671,7 +799,7 @@ function run_http_models(){
done
}
function end_hook(){
function end_hook() {
cd ${build_path}
kill_server_process
kill `ps -ef|grep python|awk '{print $2}'`
......@@ -679,7 +807,6 @@ function end_hook(){
echo "===========files==========="
ls -hlst
echo "=========== end ==========="
}
function main() {
......@@ -687,10 +814,19 @@ function main() {
build_all_whl
check
run_env
unsetproxy
run_gpu_env
check_dir ${log_dir}rpc_model/
check_dir ${log_dir}http_model/
check_dir ${log_dir}error/
run_rpc_models
# run_http_models
run_http_models
end_hook
if [ -f ${log_dir}error_models.txt ]; then
cat ${log_dir}error_models.txt
echo "error occurred!"
# exit 1
fi
}
main$@
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册