diff --git a/README.md b/README.md
index 8184a3b0dfa7271c96b9b34fd7ed1a1900529a4c..4b63ce9c8aead852a76404f4268fc8131c060729 100644
--- a/README.md
+++ b/README.md
@@ -42,7 +42,7 @@ We consider deploying deep learning inference service online to be a user-facing
- Any model trained by [PaddlePaddle](https://github.com/paddlepaddle/paddle) can be directly used or [Model Conversion Interface](./doc/SAVE.md) for online deployment of Paddle Serving.
- Support [Multi-model Pipeline Deployment](./doc/PIPELINE_SERVING.md), and provide the requirements of the REST interface and RPC interface itself, [Pipeline example](./python/examples/pipeline).
-- Support the model zoos from the Paddle ecosystem, such as [PaddleDetection](./python/examples/detection), [PaddleOCR](./python/examples/ocr), [PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/tools/recserving/movie_recommender).
+- Support the model zoos from the Paddle ecosystem, such as [PaddleDetection](./python/examples/detection), [PaddleOCR](./python/examples/ocr), [PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/recserving/movie_recommender).
- Provide a variety of pre-processing and post-processing to facilitate users in training, deployment and other stages of related code, bridging the gap between AI developers and application developers, please refer to
[Serving Examples](./python/examples/).
diff --git a/README_CN.md b/README_CN.md
index d166d7c0ffb558ae309afb1fec572ad79ab5f679..4ee2c9863dfbc6f4531d0ec00ca92aacc19e769e 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -44,7 +44,7 @@ Paddle Serving 旨在帮助深度学习开发者轻易部署在线预测服务
- 任何经过[PaddlePaddle](https://github.com/paddlepaddle/paddle)训练的模型,都可以经过直接保存或是[模型转换接口](./doc/SAVE_CN.md),用于Paddle Serving在线部署。
- 支持[多模型串联服务部署](./doc/PIPELINE_SERVING_CN.md), 同时提供Rest接口和RPC接口以满足您的需求,[Pipeline示例](./python/examples/pipeline)。
-- 支持Paddle生态的各大模型库, 例如[PaddleDetection](./python/examples/detection),[PaddleOCR](./python/examples/ocr),[PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/tools/recserving/movie_recommender)。
+- 支持Paddle生态的各大模型库, 例如[PaddleDetection](./python/examples/detection),[PaddleOCR](./python/examples/ocr),[PaddleRec](https://github.com/PaddlePaddle/PaddleRec/tree/master/recserving/movie_recommender)。
- 提供丰富多彩的前后处理,方便用户在训练、部署等各阶段复用相关代码,弥合AI开发者和应用开发者之间的鸿沟,详情参考[模型示例](./python/examples/)。
diff --git a/core/general-server/op/general_detection_op.cpp b/core/general-server/op/general_detection_op.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..f02465e0a70ce5ee86f71f8c194df34e545269df
--- /dev/null
+++ b/core/general-server/op/general_detection_op.cpp
@@ -0,0 +1,353 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "core/general-server/op/general_detection_op.h"
+#include
+#include
+#include
+#include
+#include "core/predictor/framework/infer.h"
+#include "core/predictor/framework/memory.h"
+#include "core/predictor/framework/resource.h"
+#include "core/util/include/timer.h"
+
+
+/*
+#include "opencv2/imgcodecs/legacy/constants_c.h"
+#include "opencv2/imgproc/types_c.h"
+*/
+
+namespace baidu {
+namespace paddle_serving {
+namespace serving {
+
+using baidu::paddle_serving::Timer;
+using baidu::paddle_serving::predictor::MempoolWrapper;
+using baidu::paddle_serving::predictor::general_model::Tensor;
+using baidu::paddle_serving::predictor::general_model::Response;
+using baidu::paddle_serving::predictor::general_model::Request;
+using baidu::paddle_serving::predictor::general_model::FetchInst;
+using baidu::paddle_serving::predictor::InferManager;
+using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
+
+int GeneralDetectionOp::inference() {
+ VLOG(2) << "Going to run inference";
+ const std::vector pre_node_names = pre_names();
+ if (pre_node_names.size() != 1) {
+ LOG(ERROR) << "This op(" << op_name()
+ << ") can only have one predecessor op, but received "
+ << pre_node_names.size();
+ return -1;
+ }
+ const std::string pre_name = pre_node_names[0];
+
+ const GeneralBlob *input_blob = get_depend_argument(pre_name);
+ if (!input_blob) {
+ LOG(ERROR) << "input_blob is nullptr,error";
+ return -1;
+ }
+ uint64_t log_id = input_blob->GetLogId();
+ VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
+
+ GeneralBlob *output_blob = mutable_data();
+ if (!output_blob) {
+ LOG(ERROR) << "output_blob is nullptr,error";
+ return -1;
+ }
+ output_blob->SetLogId(log_id);
+
+ if (!input_blob) {
+ LOG(ERROR) << "(logid=" << log_id
+ << ") Failed mutable depended argument, op:" << pre_name;
+ return -1;
+ }
+
+ const TensorVector *in = &input_blob->tensor_vector;
+ TensorVector* out = &output_blob->tensor_vector;
+
+ int batch_size = input_blob->_batch_size;
+ VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
+
+ output_blob->_batch_size = batch_size;
+
+ VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
+
+ std::vector input_shape;
+ int in_num =0;
+ void* databuf_data = NULL;
+ char* databuf_char = NULL;
+ size_t databuf_size = 0;
+
+ std::string* input_ptr = static_cast(in->at(0).data.data());
+ std::string base64str = input_ptr[0];
+ float ratio_h{};
+ float ratio_w{};
+
+
+ cv::Mat img = Base2Mat(base64str);
+ cv::Mat srcimg;
+ cv::Mat resize_img;
+
+ cv::Mat resize_img_rec;
+ cv::Mat crop_img;
+ img.copyTo(srcimg);
+
+ this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
+ this->use_tensorrt_);
+
+ this->normalize_op_.Run(&resize_img, this->mean_det, this->scale_det,
+ this->is_scale_);
+
+ std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
+ this->permute_op_.Run(&resize_img, input.data());
+
+
+ TensorVector* real_in = new TensorVector();
+ if (!real_in) {
+ LOG(ERROR) << "real_in is nullptr,error";
+ return -1;
+ }
+
+ for (int i = 0; i < in->size(); ++i) {
+ input_shape = {1, 3, resize_img.rows, resize_img.cols};
+ in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies());
+ databuf_size = in_num*sizeof(float);
+ databuf_data = MempoolWrapper::instance().malloc(databuf_size);
+ if (!databuf_data) {
+ LOG(ERROR) << "Malloc failed, size: " << databuf_size;
+ return -1;
+ }
+ memcpy(databuf_data,input.data(),databuf_size);
+ databuf_char = reinterpret_cast(databuf_data);
+ paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
+ paddle::PaddleTensor tensor_in;
+ tensor_in.name = in->at(i).name;
+ tensor_in.dtype = paddle::PaddleDType::FLOAT32;
+ tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
+ tensor_in.lod = in->at(i).lod;
+ tensor_in.data = paddleBuf;
+ real_in->push_back(tensor_in);
+ }
+
+ Timer timeline;
+ int64_t start = timeline.TimeStampUS();
+ timeline.Start();
+
+ if (InferManager::instance().infer(
+ engine_name().c_str(), real_in, out, batch_size)) {
+ LOG(ERROR) << "(logid=" << log_id
+ << ") Failed do infer in fluid model: " << engine_name().c_str();
+ return -1;
+ }
+ std::vector output_shape;
+ int out_num =0;
+ void* databuf_data_out = NULL;
+ char* databuf_char_out = NULL;
+ size_t databuf_size_out = 0;
+ //this is special add for PaddleOCR postprecess
+ int infer_outnum = out->size();
+ for (int k = 0;k at(k).shape[2];
+ int n3 = out->at(k).shape[3];
+ int n = n2 * n3;
+
+ float* out_data = static_cast(out->at(k).data.data());
+ std::vector pred(n, 0.0);
+ std::vector cbuf(n, ' ');
+
+ for (int i = 0; i < n; i++) {
+ pred[i] = float(out_data[i]);
+ cbuf[i] = (unsigned char)((out_data[i]) * 255);
+ }
+
+ cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
+ cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
+
+ const double threshold = this->det_db_thresh_ * 255;
+ const double maxvalue = 255;
+ cv::Mat bit_map;
+ cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
+ cv::Mat dilation_map;
+ cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
+ cv::dilate(bit_map, dilation_map, dila_ele);
+ boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map,
+ this->det_db_box_thresh_,
+ this->det_db_unclip_ratio_);
+
+ boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
+
+ for (int i = boxes.size() - 1; i >= 0; i--) {
+ crop_img = GetRotateCropImage(img, boxes[i]);
+
+ float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
+
+ this->resize_op_rec.Run(crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_);
+
+ this->normalize_op_.Run(&resize_img_rec, this->mean_rec, this->scale_rec,
+ this->is_scale_);
+
+ std::vector output_rec(1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f);
+
+ this->permute_op_.Run(&resize_img_rec, output_rec.data());
+
+ // Inference.
+ output_shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols};
+ out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies());
+ databuf_size_out = out_num*sizeof(float);
+ databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out);
+ if (!databuf_data_out) {
+ LOG(ERROR) << "Malloc failed, size: " << databuf_size_out;
+ return -1;
+ }
+ memcpy(databuf_data_out,output_rec.data(),databuf_size_out);
+ databuf_char_out = reinterpret_cast(databuf_data_out);
+ paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out);
+ paddle::PaddleTensor tensor_out;
+ tensor_out.name = "image";
+ tensor_out.dtype = paddle::PaddleDType::FLOAT32;
+ tensor_out.shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols};
+ tensor_out.data = paddleBuf;
+ out->push_back(tensor_out);
+ }
+ }
+ out->erase(out->begin(),out->begin()+infer_outnum);
+
+
+ int64_t end = timeline.TimeStampUS();
+ CopyBlobInfo(input_blob, output_blob);
+ AddBlobInfo(output_blob, start);
+ AddBlobInfo(output_blob, end);
+ return 0;
+}
+
+cv::Mat GeneralDetectionOp::Base2Mat(std::string &base64_data)
+{
+ cv::Mat img;
+ std::string s_mat;
+ s_mat = base64Decode(base64_data.data(), base64_data.size());
+ std::vector base64_img(s_mat.begin(), s_mat.end());
+ img = cv::imdecode(base64_img, cv::IMREAD_COLOR);//CV_LOAD_IMAGE_COLOR
+ return img;
+}
+
+std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte)
+{
+
+ const char DecodeTable[] =
+ {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 62, // '+'
+ 0, 0, 0,
+ 63, // '/'
+ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
+ 0, 0, 0, 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
+ 0, 0, 0, 0, 0, 0,
+ 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
+ 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
+ };
+
+ std::string strDecode;
+ int nValue;
+ int i = 0;
+ while (i < DataByte)
+ {
+ if (*Data != '\r' && *Data != '\n')
+ {
+ nValue = DecodeTable[*Data++] << 18;
+ nValue += DecodeTable[*Data++] << 12;
+ strDecode += (nValue & 0x00FF0000) >> 16;
+ if (*Data != '=')
+ {
+ nValue += DecodeTable[*Data++] << 6;
+ strDecode += (nValue & 0x0000FF00) >> 8;
+ if (*Data != '=')
+ {
+ nValue += DecodeTable[*Data++];
+ strDecode += nValue & 0x000000FF;
+ }
+ }
+ i += 4;
+ }
+ else// 回车换行,跳过
+ {
+ Data++;
+ i++;
+ }
+ }
+ return strDecode;
+}
+
+cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage,
+ std::vector> box) {
+ cv::Mat image;
+ srcimage.copyTo(image);
+ std::vector> points = box;
+
+ int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
+ int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
+ int left = int(*std::min_element(x_collect, x_collect + 4));
+ int right = int(*std::max_element(x_collect, x_collect + 4));
+ int top = int(*std::min_element(y_collect, y_collect + 4));
+ int bottom = int(*std::max_element(y_collect, y_collect + 4));
+
+ cv::Mat img_crop;
+ image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
+
+ for (int i = 0; i < points.size(); i++) {
+ points[i][0] -= left;
+ points[i][1] -= top;
+ }
+
+ int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
+ pow(points[0][1] - points[1][1], 2)));
+ int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
+ pow(points[0][1] - points[3][1], 2)));
+
+ cv::Point2f pts_std[4];
+ pts_std[0] = cv::Point2f(0., 0.);
+ pts_std[1] = cv::Point2f(img_crop_width, 0.);
+ pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
+ pts_std[3] = cv::Point2f(0.f, img_crop_height);
+
+ cv::Point2f pointsf[4];
+ pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
+ pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
+ pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
+ pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
+
+ cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
+
+ cv::Mat dst_img;
+ cv::warpPerspective(img_crop, dst_img, M,
+ cv::Size(img_crop_width, img_crop_height),
+ cv::BORDER_REPLICATE);
+
+ if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
+ cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
+ cv::transpose(dst_img, srcCopy);
+ cv::flip(srcCopy, srcCopy, 0);
+ return srcCopy;
+ } else {
+ return dst_img;
+ }
+}
+
+DEFINE_OP(GeneralDetectionOp);
+
+} // namespace serving
+} // namespace paddle_serving
+} // namespace baidu
\ No newline at end of file
diff --git a/core/general-server/op/general_detection_op.h b/core/general-server/op/general_detection_op.h
new file mode 100755
index 0000000000000000000000000000000000000000..272ed5ff40575d42ac3058ad1824285925fc252c
--- /dev/null
+++ b/core/general-server/op/general_detection_op.h
@@ -0,0 +1,85 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include
+#include
+#include
+#include "core/general-server/general_model_service.pb.h"
+#include "core/general-server/op/general_infer_helper.h"
+#include "core/predictor/tools/ocrtools/postprocess_op.h"
+#include "core/predictor/tools/ocrtools/preprocess_op.h"
+#include "paddle_inference_api.h" // NOLINT
+
+#include "opencv2/core.hpp"
+#include "opencv2/imgcodecs.hpp"
+#include "opencv2/imgproc.hpp"
+
+
+namespace baidu {
+namespace paddle_serving {
+namespace serving {
+
+class GeneralDetectionOp
+ : public baidu::paddle_serving::predictor::OpWithChannel {
+ public:
+ typedef std::vector TensorVector;
+
+ DECLARE_OP(GeneralDetectionOp);
+
+ int inference();
+
+ private:
+ //config info
+ bool use_gpu_ = false;
+ int gpu_id_ = 0;
+ int gpu_mem_ = 4000;
+ int cpu_math_library_num_threads_ = 4;
+ bool use_mkldnn_ = false;
+ // pre-process
+ PaddleOCR::ResizeImgType0 resize_op_;
+ PaddleOCR::Normalize normalize_op_;
+ PaddleOCR::Permute permute_op_;
+ PaddleOCR::CrnnResizeImg resize_op_rec;
+
+ bool use_tensorrt_ = false;
+ bool use_fp16_ = false;
+ // post-process
+ PaddleOCR::PostProcessor post_processor_;
+
+ //det config info
+ int max_side_len_ = 960;
+
+ double det_db_thresh_ = 0.3;
+ double det_db_box_thresh_ = 0.5;
+ double det_db_unclip_ratio_ = 2.0;
+
+ std::vector mean_det = {0.485f, 0.456f, 0.406f};
+ std::vector scale_det = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
+ bool is_scale_ = true;
+
+ //rec config info
+ std::vector label_list_;
+ std::vector mean_rec = {0.5f, 0.5f, 0.5f};
+ std::vector scale_rec = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
+ cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
+ std::vector> box);
+ cv::Mat Base2Mat(std::string &base64_data);
+ std::string base64Decode(const char* Data, int DataByte);
+ std::vector>> boxes;
+};
+
+} // namespace serving
+} // namespace paddle_serving
+} // namespace baidu
diff --git a/core/predictor/common/utils.h b/core/predictor/common/utils.h
index 052f90b166f04a28d0e7aeb427884921abdcab5e..4437bb94f2535281f83440c3dbae311423edf91c 100644
--- a/core/predictor/common/utils.h
+++ b/core/predictor/common/utils.h
@@ -13,8 +13,10 @@
// limitations under the License.
#pragma once
-#include
+#include
+#include
#include
+#include
#include "core/predictor/common/inner_common.h"
#include "core/predictor/common/macros.h"
@@ -26,6 +28,38 @@ namespace predictor {
namespace butil = base;
#endif
+enum class Precision {
+ kUnk = -1, // unknown type
+ kFloat32 = 0, // fp32
+ kInt8, // int8
+ kHalf, // fp16
+ kBfloat16, // bf16
+};
+
+static std::string PrecisionTypeString(const Precision data_type) {
+ switch (data_type) {
+ case Precision::kFloat32:
+ return "kFloat32";
+ case Precision::kInt8:
+ return "kInt8";
+ case Precision::kHalf:
+ return "kHalf";
+ case Precision::kBfloat16:
+ return "kBloat16";
+ default:
+ return "unUnk";
+ }
+}
+
+static std::string ToLower(const std::string& data) {
+ std::string result = data;
+ std::transform(
+ result.begin(), result.end(), result.begin(), [](unsigned char c) {
+ return tolower(c);
+ });
+ return result;
+}
+
class TimerFlow {
public:
static const int MAX_SIZE = 1024;
diff --git a/doc/ABTEST_IN_PADDLE_SERVING.md b/doc/ABTEST_IN_PADDLE_SERVING.md
index 09d13c12583ddfa0f1767cf46a309cac5ef86867..71cd267f76705583fed0ffbb57fda7a1039cbba6 100644
--- a/doc/ABTEST_IN_PADDLE_SERVING.md
+++ b/doc/ABTEST_IN_PADDLE_SERVING.md
@@ -4,7 +4,7 @@
This document will use an example of text classification task based on IMDB dataset to show how to build a A/B Test framework using Paddle Serving. The structure relationship between the client and servers in the example is shown in the figure below.
-
+
Note that: A/B Test is only applicable to RPC mode, not web mode.
@@ -88,7 +88,7 @@ with open('processed.data') as f:
cnt[tag]['total'] += 1
for tag, data in cnt.items():
- print('[{}](total: {}) acc: {}'.format(tag, data['total'], float(data['acc']) / float(data['total'])))
+ print('[{}] acc: {}'.format(tag, data['total'], float(data['acc']) / float(data['total'])))
```
In the code, the function `client.add_variant(tag, clusters, variant_weight)` is to add a variant with label `tag` and flow weight `variant_weight`. In this example, a BOW variant with label of `bow` and flow weight of `10`, and an LSTM variant with label of `lstm` and a flow weight of `90` are added. The flow on the client side will be distributed to two variants according to the ratio of `10:90`.
@@ -98,8 +98,8 @@ When making prediction on the client side, if the parameter `need_variant_tag=Tr
### Expected Results
Due to different network conditions, the results of each prediction may be slightly different.
``` python
-[lstm](total: 1867) acc: 0.490091055169
-[bow](total: 217) acc: 0.73732718894
+[lstm] acc: 0.490091055169
+[bow] acc: 0.73732718894
```
+
This document uses the original model without any compression algorithm. If there is a need for a quantitative model to go online, please read the [Quantization Storage on Cube Sparse Parameter Indexing](./CUBE_QUANT.md)
-
## Example
in directory python/example/criteo_ctr_with_cube, run
diff --git a/doc/CUBE_LOCAL_CN.md b/doc/CUBE_LOCAL_CN.md
index 9191fe8f54d3e9695d4da04adb82d3c3d33567b2..e1f424a3b6f7ef856b53ff811196688ea33870c1 100644
--- a/doc/CUBE_LOCAL_CN.md
+++ b/doc/CUBE_LOCAL_CN.md
@@ -6,7 +6,7 @@
在python/examples下有两个关于CTR的示例,他们分别是criteo_ctr, criteo_ctr_with_cube。前者是在训练时保存整个模型,包括稀疏参数。后者是将稀疏参数裁剪出来,保存成两个部分,一个是稀疏参数,另一个是稠密参数。由于在工业级的场景中,稀疏参数的规模非常大,达到10^9数量级。因此在一台机器上启动大规模稀疏参数预测是不实际的,因此我们引入百度多年来在稀疏参数索引领域的工业级产品Cube,提供分布式的稀疏参数服务。
-
+
本文档使用的都是未经过任何压缩算法处理的原始模型,如果有量化模型上线需求,请阅读[Cube稀疏参数索引量化存储使用指南](./CUBE_QUANT_CN.md)
diff --git a/doc/DESIGN_DOC.md b/doc/DESIGN_DOC.md
index d0c66a97946f700690097e1cb82d589476735edc..e5a4d5295c8be21087e72a5d761ac7a34de199f3 100644
--- a/doc/DESIGN_DOC.md
+++ b/doc/DESIGN_DOC.md
@@ -70,7 +70,7 @@ The inference framework of the well-known deep learning platform only supports C
> Model conversion across deep learning platforms
-Models trained on other deep learning platforms can be passed《[PaddlePaddle/X2Paddle工具](https://github.com/PaddlePaddle/X2Paddle)》.We convert multiple mainstream CV models to Paddle models. TensorFlow, Caffe, ONNX, PyTorch model conversion is tested.《[An End-to-end Tutorial from Training to Inference Service Deployment](TRAIN_TO_SERVICE.md)》
+Models trained on other deep learning platforms can be passed《[PaddlePaddle/X2Paddle工具](https://github.com/PaddlePaddle/X2Paddle)》.We convert multiple mainstream CV models to Paddle models. TensorFlow, Caffe, ONNX, PyTorch model conversion is tested.《[AIStudio教程-Paddle Serving服务化部署框架](https://www.paddlepaddle.org.cn/tutorials/projectdetail/1555945)》
Because it is impossible to directly view the feed and fetch parameter information in the model file, it is not convenient for users to assemble the parameters. Therefore, Paddle Serving developed a tool to convert the Paddle model into Serving format and generate a prototxt file containing feed and fetch parameter information. The following figure is the generated prototxt file of the uci_housing example. For more conversion methods, refer to the document《[How to save a servable model of Paddle Serving?](SAVE.md)》.
```
diff --git a/doc/DESIGN_DOC_CN.md b/doc/DESIGN_DOC_CN.md
index f13a149e5d3bb758a02a2204d8aa24b5a5d0520a..d0c069145b56164def78603ef0e8c3c89171bd7b 100644
--- a/doc/DESIGN_DOC_CN.md
+++ b/doc/DESIGN_DOC_CN.md
@@ -74,7 +74,7 @@ Paddle Serving提供了4种开发语言SDK,包括Python、C++、Java、Golang
其他深度学习平台训练的模型,可以通过《[PaddlePaddle/X2Paddle工具](https://github.com/PaddlePaddle/X2Paddle)》将多个主流的CV模型转为Paddle模型,测试过TensorFlow、Caffe、ONNX、PyTorch模型转换。
-以IMDB评论情感分析任务为例通过9步展示,Paddle Serving从模型的训练到部署预测服务的全流程《[端到端完成从训练到部署全流程](TRAIN_TO_SERVICE_CN.md)》
+以IMDB评论情感分析任务为例通过9步展示,Paddle Serving从模型的训练到部署预测服务的全流程《[AIStudio教程-Paddle Serving服务化部署框架](https://www.paddlepaddle.org.cn/tutorials/projectdetail/1555945)》
由于无法直接查看模型文件中feed和fetch参数信息,不方便用户拼装参数。因此,Paddle Serving开发一个工具将Paddle模型转成Serving的格式,生成包含feed和fetch参数信息的prototxt文件。下图是uci_housing示例的生成的prototxt文件,更多转换方法参考文档《[怎样保存用于Paddle Serving的模型](SAVE_CN.md)》。
```
diff --git a/doc/FAQ.md b/doc/FAQ.md
index cbcf514b938354589f0d5253cad74b295a5f677c..4a20c25588dc1008b86fc52b298735464f0465d7 100644
--- a/doc/FAQ.md
+++ b/doc/FAQ.md
@@ -14,9 +14,9 @@
0-int64
- 1-float32
+ 1-float32
- 2-int32
+ 2-int32
#### Q: paddle-serving是否支持windows和Linux环境下的多线程调用
@@ -222,9 +222,7 @@ InvalidArgumentError: Device id must be less than GPU count, but received id is:
#### Q: python编译的GCC版本与serving的版本不匹配
-**A:**:1)使用[GPU docker](https://github.com/PaddlePaddle/Serving/blob/develop/doc/RUN_IN_DOCKER.md#gpunvidia-docker)解决环境问题
-
- 2)修改anaconda的虚拟环境下安装的python的gcc版本[参考](https://www.jianshu.com/p/c498b3d86f77)
+**A:**:1)使用[GPU docker](https://github.com/PaddlePaddle/Serving/blob/develop/doc/RUN_IN_DOCKER.md#gpunvidia-docker)解决环境问题;2)修改anaconda的虚拟环境下安装的python的gcc版本[改变python的GCC编译环境](https://www.jianshu.com/p/c498b3d86f77)
#### Q: paddle-serving是否支持本地离线安装
diff --git a/doc/LATEST_PACKAGES.md b/doc/LATEST_PACKAGES.md
index d034fafa6e7cc588511beccff11d4beeaa5ba72e..0adb9f89ee5a381038388f6b16cb95e99b7fd328 100644
--- a/doc/LATEST_PACKAGES.md
+++ b/doc/LATEST_PACKAGES.md
@@ -78,7 +78,7 @@ https://paddle-serving.bj.bcebos.com/whl/paddle_serving_app-0.0.0-py2-none-any.w
```
## ARM user
-for ARM user who uses [PaddleLite](https://github.com/PaddlePaddle/PaddleLite) can download the wheel packages as follows. And ARM user should use the xpu-beta docker [DOCKER IMAGES](./DOCKER_IMAGES.md)
+for ARM user who uses [Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite) can download the wheel packages as follows. And ARM user should use the xpu-beta docker [DOCKER IMAGES](./DOCKER_IMAGES.md)
**We only support Python 3.6 for Arm Users.**
### Wheel Package Links
diff --git a/doc/SERVER_DAG.md b/doc/SERVER_DAG.md
index dbf277ccbccc2a06838d65bfbf75e514b4d9a1ed..8441e7d7fa8ea00d2d7accedb3f7cd2baab745ad 100644
--- a/doc/SERVER_DAG.md
+++ b/doc/SERVER_DAG.md
@@ -48,7 +48,7 @@ python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --po
### Nodes with multiple inputs
-An example containing multiple input nodes is given in the [MODEL_ENSEMBLE_IN_PADDLE_SERVING](MODEL_ENSEMBLE_IN_PADDLE_SERVING.md). A example graph and the corresponding DAG definition code is as follows.
+An example containing multiple input nodes is given in the [MODEL_ENSEMBLE_IN_PADDLE_SERVING](./deprecated/MODEL_ENSEMBLE_IN_PADDLE_SERVING.md). A example graph and the corresponding DAG definition code is as follows.
diff --git a/doc/SERVER_DAG_CN.md b/doc/SERVER_DAG_CN.md
index 80d01f0287c5f721f093e96c7bcd1827f0601496..16e53bc6a98af6abd4a114137f2e72593242afcd 100644
--- a/doc/SERVER_DAG_CN.md
+++ b/doc/SERVER_DAG_CN.md
@@ -47,7 +47,7 @@ python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --po
### 包含多个输入的节点
-在[Paddle Serving中的集成预测](MODEL_ENSEMBLE_IN_PADDLE_SERVING_CN.md)文档中给出了一个包含多个输入节点的样例,示意图和代码如下。
+在[Paddle Serving中的集成预测](./deprecated/MODEL_ENSEMBLE_IN_PADDLE_SERVING_CN.md)文档中给出了一个包含多个输入节点的样例,示意图和代码如下。
diff --git a/doc/TENSOR_RT.md b/doc/TENSOR_RT.md
index 6e53a6ff029df6a46080d656a6dc9db95a9633e3..7504646fea750572cde472ebfb6178989b542ec1 100644
--- a/doc/TENSOR_RT.md
+++ b/doc/TENSOR_RT.md
@@ -1,6 +1,6 @@
## Paddle Serving uses TensorRT
-(English|[简体中文]((./TENSOR_RT_CN.md)))
+(English|[简体中文](./TENSOR_RT_CN.md))
### Background
diff --git a/doc/WINDOWS_TUTORIAL_CN.md b/doc/WINDOWS_TUTORIAL_CN.md
index 4184840f4e5646fcd998dfa33b80b8b9210b05d7..143d3b22ff0d2a6c9b35542ac301fd2a906a0962 100644
--- a/doc/WINDOWS_TUTORIAL_CN.md
+++ b/doc/WINDOWS_TUTORIAL_CN.md
@@ -14,7 +14,7 @@
**安装Git工具**: 详情参见[Git官网](https://git-scm.com/downloads)
-**安装必要的C++库(可选)**:部分用户可能会在`import paddle`阶段遇见dll无法链接的问题,建议可以[安装Visual Studio社区版本](`https://visualstudio.microsoft.com/`) ,并且安装C++的相关组件。
+**安装必要的C++库(可选)**:部分用户可能会在`import paddle`阶段遇见dll无法链接的问题,建议[安装Visual Studio社区版本](https://visualstudio.microsoft.com/) ,并且安装C++的相关组件。
**安装Paddle和Serving**:在Powershell,执行
diff --git a/doc/deprecated/DESIGN.md b/doc/deprecated/DESIGN.md
index d14bb0569b5b9e236367d1f4eb61d1774c100511..a6f17c044549908d739e52354d4e3aba92fa0c19 100644
--- a/doc/deprecated/DESIGN.md
+++ b/doc/deprecated/DESIGN.md
@@ -115,7 +115,7 @@ Server instance perspective

-Paddle Serving instances can load multiple models at the same time, and each model uses a Service (and its configured workflow) to undertake services. You can refer to [service configuration file in Demo example](../tools/cpp_examples/demo-serving/conf/service.prototxt) to learn how to configure multiple services for the serving instance
+Paddle Serving instances can load multiple models at the same time, and each model uses a Service (and its configured workflow) to undertake services. You can refer to [service configuration file in Demo example](../../tools/cpp_examples/demo-serving/conf/service.prototxt) to learn how to configure multiple services for the serving instance
#### 4.2.3 Hierarchical relationship of business scheduling
@@ -124,7 +124,7 @@ From the client's perspective, a Paddle Serving service can be divided into thre

One Service corresponds to one inference model, and there is one endpoint under the model. Different versions of the model are implemented through multiple variant concepts under endpoint:
-The same model prediction service can configure multiple variants, and each variant has its own downstream IP list. The client code can configure relative weights for each variant to achieve the relationship of adjusting the traffic ratio (refer to the description of variant_weight_list in [Client Configuration](./deprecated/CLIENT_CONFIGURE.md) section 3.2).
+The same model prediction service can configure multiple variants, and each variant has its own downstream IP list. The client code can configure relative weights for each variant to achieve the relationship of adjusting the traffic ratio (refer to the description of variant_weight_list in [Client Configuration](../CLIENT_CONFIGURE.md) section 3.2).

@@ -141,7 +141,7 @@ No matter how the communication protocol changes, the framework only needs to en
### 5.1 Data Compression Method
-Baidu-rpc has built-in data compression methods such as snappy, gzip, zlib, which can be configured in the configuration file (refer to [Client Configuration](./deprecated/CLIENT_CONFIGURE.md) Section 3.1 for an introduction to compress_type)
+Baidu-rpc has built-in data compression methods such as snappy, gzip, zlib, which can be configured in the configuration file (refer to [Client Configuration](../CLIENT_CONFIGURE.md) Section 3.1 for an introduction to compress_type)
### 5.2 C ++ SDK API Interface
diff --git a/paddle_inference/paddle/include/paddle_engine.h b/paddle_inference/paddle/include/paddle_engine.h
index 599d5e5e5477da72927f76c0189a82721db3c6b4..a2be5257aeedb984a9d30b9946c707fcf3ff824d 100644
--- a/paddle_inference/paddle/include/paddle_engine.h
+++ b/paddle_inference/paddle/include/paddle_engine.h
@@ -37,9 +37,24 @@ using paddle_infer::Tensor;
using paddle_infer::CreatePredictor;
DECLARE_int32(gpuid);
+DECLARE_string(precision);
+DECLARE_bool(use_calib);
static const int max_batch = 32;
static const int min_subgraph_size = 3;
+static PrecisionType precision_type;
+
+PrecisionType GetPrecision(const std::string& precision_data) {
+ std::string precision_type = predictor::ToLower(precision_data);
+ if (precision_type == "fp32") {
+ return PrecisionType::kFloat32;
+ } else if (precision_type == "int8") {
+ return PrecisionType::kInt8;
+ } else if (precision_type == "fp16") {
+ return PrecisionType::kHalf;
+ }
+ return PrecisionType::kFloat32;
+}
// Engine Base
class PaddleEngineBase {
@@ -107,9 +122,9 @@ class PaddleInferenceEngine : public PaddleEngineBase {
if (engine_conf.has_encrypted_model() && engine_conf.encrypted_model()) {
// decrypt model
std::string model_buffer, params_buffer, key_buffer;
- predictor::ReadBinaryFile(model_path + "encrypt_model", &model_buffer);
- predictor::ReadBinaryFile(model_path + "encrypt_params", ¶ms_buffer);
- predictor::ReadBinaryFile(model_path + "key", &key_buffer);
+ predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer);
+ predictor::ReadBinaryFile(model_path + "/encrypt_params", ¶ms_buffer);
+ predictor::ReadBinaryFile(model_path + "/key", &key_buffer);
auto cipher = paddle::MakeCipher("");
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
@@ -137,6 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase {
// 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid);
}
+ precision_type = GetPrecision(FLAGS_precision);
if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
@@ -145,14 +161,24 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.EnableTensorRtEngine(1 << 20,
max_batch,
min_subgraph_size,
- Config::Precision::kFloat32,
+ precision_type,
false,
- false);
+ FLAGS_use_calib);
LOG(INFO) << "create TensorRT predictor";
}
if (engine_conf.has_use_lite() && engine_conf.use_lite()) {
- config.EnableLiteEngine(PrecisionType::kFloat32, true);
+ config.EnableLiteEngine(precision_type, true);
+ }
+
+ if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) ||
+ (engine_conf.has_use_lite() && !engine_conf.use_lite() &&
+ engine_conf.has_use_gpu() && !engine_conf.use_gpu())) {
+ if (precision_type == PrecisionType::kInt8) {
+ config.EnableMkldnnQuantizer();
+ } else if (precision_type == PrecisionType::kHalf) {
+ config.EnableMkldnnBfloat16();
+ }
}
if (engine_conf.has_use_xpu() && engine_conf.use_xpu()) {
@@ -171,7 +197,6 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.EnableMemoryOptim();
}
-
predictor::AutoLock lock(predictor::GlobalCreateMutex::instance());
_predictor = CreatePredictor(config);
if (NULL == _predictor.get()) {
diff --git a/paddle_inference/paddle/src/paddle_engine.cpp b/paddle_inference/paddle/src/paddle_engine.cpp
index 94ed4b9ae92df3c8f407590f9c24f351bf7ec6a3..b6da2a5a0eeb31473e2eba5b1a5b58855dbb03c6 100644
--- a/paddle_inference/paddle/src/paddle_engine.cpp
+++ b/paddle_inference/paddle/src/paddle_engine.cpp
@@ -20,6 +20,8 @@ namespace paddle_serving {
namespace inference {
DEFINE_int32(gpuid, 0, "GPU device id to use");
+DEFINE_string(precision, "fp32", "precision to deploy, default is fp32");
+DEFINE_bool(use_calib, false, "calibration mode, default is false");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine,
diff --git a/python/examples/encryption/README.md b/python/examples/encryption/README.md
index d8a04e29bf56439a24db7dadfdfe3ab5d9626e14..a08b8b84241fb699992d1a718f2bfbf986d8d180 100644
--- a/python/examples/encryption/README.md
+++ b/python/examples/encryption/README.md
@@ -31,6 +31,7 @@ dirname is the folder path where the model is located. If the parameter is discr
The key is stored in the `key` file, and the encrypted model file and server-side configuration file are stored in the `encrypt_server` directory.
client-side configuration file are stored in the `encrypt_client` directory.
+**Notice:** When encryption prediction is used, the model configuration and parameter folder loaded by server and client should be encrypt_server/ and encrypt_client/
## Start Encryption Service
CPU Service
```
@@ -43,5 +44,5 @@ python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_
## Prediction
```
-python test_client.py uci_housing_client/serving_client_conf.prototxt
+python test_client.py encrypt_client/serving_client_conf.prototxt
```
diff --git a/python/examples/encryption/README_CN.md b/python/examples/encryption/README_CN.md
index bb853ff37f914a5e2cfe1c6bbb097d17eb99a29a..f950796ec14dadfd7bf6744d94aba4959c838e7f 100644
--- a/python/examples/encryption/README_CN.md
+++ b/python/examples/encryption/README_CN.md
@@ -31,6 +31,8 @@ def serving_encryption():
密钥保存在`key`文件中,加密模型文件以及server端配置文件保存在`encrypt_server`目录下,client端配置文件保存在`encrypt_client`目录下。
+**注意:** 当使用加密预测时,服务端和客户端启动加载的模型配置和参数文件夹是encrypt_server/和encrypt_client/
+
## 启动加密预测服务
CPU预测服务
```
@@ -43,5 +45,5 @@ python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_
## 预测
```
-python test_client.py uci_housing_client/serving_client_conf.prototxt
+python test_client.py encrypt_client/
```
diff --git a/python/examples/ocr/README.md b/python/examples/ocr/README.md
index ba28075fffe62498f35834f10a9db3f20a445e29..605b4abe7c12cf2a9b0d8d0d02a2fe9c04b76723 100755
--- a/python/examples/ocr/README.md
+++ b/python/examples/ocr/README.md
@@ -15,31 +15,6 @@ tar -xzvf ocr_det.tar.gz
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.tar
tar xf test_imgs.tar
```
-## C++ OCR Service
-
-### Start Service
-Select a startup mode according to CPU / GPU device
-
-After the -- model parameter, the folder path of multiple model files is passed in to start the prediction service of multiple model concatenation.
-```
-#for cpu user
-python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
-#for gpu user
-python -m paddle_serving_server_gpu.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
-```
-
-### Client Prediction
-The pre-processing and post-processing is in the C + + server part, the image's Base64 encoded string is passed into the C + + server.
-
-so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed.
-
-for this case, `feed_type` should be 3(which means the data type is string),`shape` should be 1.
-
-By passing in multiple client folder paths, the client can be started for multi model prediction.
-```
-python ocr_c_client_bytes.py ocr_det_client ocr_rec_client
-```
-
## Web Service
@@ -123,3 +98,30 @@ python rec_debugger_server.py gpu #for gpu user
```
python rec_web_client.py
```
+
+## C++ OCR Service
+
+**Notice:** If you need to concatenate det model and rec model, and do pre-processing and post-processing in Paddle Serving C++ framework, you need to use the C++ server compiled with WITH_OPENCV option,see the [COMPILE.md](../../../doc/COMPILE.md)
+
+### Start Service
+Select a startup mode according to CPU / GPU device
+
+After the -- model parameter, the folder path of multiple model files is passed in to start the prediction service of multiple model concatenation.
+```
+#for cpu user
+python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
+#for gpu user
+python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
+```
+
+### Client Prediction
+The pre-processing and post-processing is in the C + + server part, the image's Base64 encoded string is passed into the C + + server.
+
+so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed.
+
+for this case, `feed_type` should be 3(which means the data type is string),`shape` should be 1.
+
+By passing in multiple client folder paths, the client can be started for multi model prediction.
+```
+python ocr_cpp_client.py ocr_det_client ocr_rec_client
+```
diff --git a/python/examples/ocr/README_CN.md b/python/examples/ocr/README_CN.md
index 895d26010059c3b3e03e4798271a3dbd1ec4c924..ad7ddcee21dd2f514d2ab8f63a732ee93349abac 100755
--- a/python/examples/ocr/README_CN.md
+++ b/python/examples/ocr/README_CN.md
@@ -15,31 +15,6 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/ocr/test_imgs.t
tar xf test_imgs.tar
```
-## C++ OCR Service服务
-
-### 启动服务
-根据CPU/GPU设备选择一种启动方式
-
-通过--model后,指定多个模型文件的文件夹路径来启动多模型串联的预测服务。
-```
-#for cpu user
-python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
-#for gpu user
-python -m paddle_serving_server_gpu.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
-```
-
-### 启动客户端
-由于需要在C++Server部分进行前后处理,传入C++Server的仅仅是图片的base64编码的字符串,故第一个模型的Client配置需要修改
-
-即`ocr_det_client/serving_client_conf.prototxt`中`feed_var`字段
-
-对于本示例而言,`feed_type`应修改为3(数据类型为string),`shape`为1.
-
-通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。
-```
-python ocr_c_client_bytes.py ocr_det_client ocr_rec_client
-```
-
## Web Service服务
### 启动服务
@@ -123,3 +98,29 @@ python rec_debugger_server.py gpu #for gpu user
```
python rec_web_client.py
```
+## C++ OCR Service服务
+
+**注意:** 若您需要使用Paddle Serving C++框架串联det模型和rec模型,并进行前后处理,您需要使用开启WITH_OPENCV选项编译的C++ Server,详见[COMPILE.md](../../../doc/COMPILE.md)
+
+### 启动服务
+根据CPU/GPU设备选择一种启动方式
+
+通过--model后,指定多个模型文件的文件夹路径来启动多模型串联的预测服务。
+```
+#for cpu user
+python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
+#for gpu user
+python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0
+```
+
+### 启动客户端
+由于需要在C++Server部分进行前后处理,传入C++Server的仅仅是图片的base64编码的字符串,故第一个模型的Client配置需要修改
+
+即`ocr_det_client/serving_client_conf.prototxt`中`feed_var`字段
+
+对于本示例而言,`feed_type`应修改为3(数据类型为string),`shape`为1.
+
+通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。
+```
+python ocr_cpp_client.py ocr_det_client ocr_rec_client
+```
diff --git a/python/examples/ocr/ocr_cpp_client.py b/python/examples/ocr/ocr_cpp_client.py
new file mode 100755
index 0000000000000000000000000000000000000000..fa9209aabc4a7e03fe9c69ac85cd496065b1ffc2
--- /dev/null
+++ b/python/examples/ocr/ocr_cpp_client.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# pylint: disable=doc-string-missing
+
+from paddle_serving_client import Client
+import sys
+import numpy as np
+import base64
+import os
+import cv2
+from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
+from paddle_serving_app.reader import Div, Normalize, Transpose
+
+client = Client()
+# TODO:load_client need to load more than one client model.
+# this need to figure out some details.
+client.load_client_config(sys.argv[1:])
+client.connect(["127.0.0.1:9293"])
+
+import paddle
+test_img_dir = "imgs/"
+
+def cv2_to_base64(image):
+ return base64.b64encode(image) #data.tostring()).decode('utf8')
+
+for img_file in os.listdir(test_img_dir):
+ with open(os.path.join(test_img_dir, img_file), 'rb') as file:
+ image_data = file.read()
+ image = cv2_to_base64(image_data)
+ fetch_map = client.predict(
+ feed={"image": image}, fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"], batch=True)
+ #print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
+ print(fetch_map)
diff --git a/python/paddle_serving_app/local_predict.py b/python/paddle_serving_app/local_predict.py
index e50d353261d25c2e2cfa3ed80a55c43e96eaddb7..c31c818eee26837f396ae22a3521a7e14f7320c9 100644
--- a/python/paddle_serving_app/local_predict.py
+++ b/python/paddle_serving_app/local_predict.py
@@ -27,6 +27,12 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("LocalPredictor")
logger.setLevel(logging.INFO)
+precision_map = {
+ 'int8': paddle_infer.PrecisionType.Int8,
+ 'fp32': paddle_infer.PrecisionType.Float32,
+ 'fp16': paddle_infer.PrecisionType.Half,
+}
+
class LocalPredictor(object):
"""
@@ -56,6 +62,8 @@ class LocalPredictor(object):
use_trt=False,
use_lite=False,
use_xpu=False,
+ precision="fp32",
+ use_calib=False,
use_feed_fetch_ops=False):
"""
Load model configs and create the paddle predictor by Paddle Inference API.
@@ -71,6 +79,8 @@ class LocalPredictor(object):
use_trt: use nvidia TensorRT optimization, False default
use_lite: use Paddle-Lite Engint, False default
use_xpu: run predict on Baidu Kunlun, False default
+ precision: precision mode, "fp32" default
+ use_calib: use TensorRT calibration, False default
use_feed_fetch_ops: use feed/fetch ops, False default.
"""
client_config = "{}/serving_server_conf.prototxt".format(model_path)
@@ -88,9 +98,11 @@ class LocalPredictor(object):
logger.info(
"LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\
gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
- use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format(
- model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
- ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops))
+ use_trt:{}, use_lite:{}, use_xpu: {}, precision: {}, use_calib: {},\
+ use_feed_fetch_ops:{}"
+ .format(model_path, use_gpu, gpu_id, use_profile, thread_num,
+ mem_optim, ir_optim, use_trt, use_lite, use_xpu, precision,
+ use_calib, use_feed_fetch_ops))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
@@ -106,6 +118,9 @@ class LocalPredictor(object):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
+ precision_type = paddle_infer.PrecisionType.Float32
+ if precision.lower() in precision_map:
+ precision_type = precision_map[precision.lower()]
if use_profile:
config.enable_profile()
if mem_optim:
@@ -121,6 +136,7 @@ class LocalPredictor(object):
config.enable_use_gpu(100, gpu_id)
if use_trt:
config.enable_tensorrt_engine(
+ precision_mode=precision_type,
workspace_size=1 << 20,
max_batch_size=32,
min_subgraph_size=3,
@@ -129,7 +145,7 @@ class LocalPredictor(object):
if use_lite:
config.enable_lite_engine(
- precision_mode=paddle_infer.PrecisionType.Float32,
+ precision_mode=precision_type,
zero_copy=True,
passes_filter=[],
ops_filter=[])
@@ -138,6 +154,11 @@ class LocalPredictor(object):
# 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024)
+ if not use_gpu and not use_lite:
+ if precision_type == paddle_infer.PrecisionType.Int8:
+ config.enable_quantizer()
+ if precision.lower() == "bf16":
+ config.enable_mkldnn_bfloat16()
self.predictor = paddle_infer.create_predictor(config)
def predict(self, feed=None, fetch=None, batch=False, log_id=0):
diff --git a/python/paddle_serving_server/serve.py b/python/paddle_serving_server/serve.py
index 02d7cec8dc65fc165b51396a624f5cb70a269aeb..79534e9a8cbbb2cc784cdbdfe62690ccceaf7228 100755
--- a/python/paddle_serving_server/serve.py
+++ b/python/paddle_serving_server/serve.py
@@ -51,6 +51,16 @@ def serve_args():
"--name", type=str, default="None", help="Default service name")
parser.add_argument(
"--use_mkl", default=False, action="store_true", help="Use MKL")
+ parser.add_argument(
+ "--precision",
+ type=str,
+ default="fp32",
+ help="precision mode(fp32, int8, fp16, bf16)")
+ parser.add_argument(
+ "--use_calib",
+ default=False,
+ action="store_true",
+ help="Use TensorRT Calibration")
parser.add_argument(
"--mem_optim_off",
default=False,
@@ -109,7 +119,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
if model == "":
print("You must specify your serving model")
exit(-1)
-
+
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
@@ -125,17 +135,16 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
- if len(model) == 2 and idx == 0:
+ #Temporary support for OCR model,it will be completely revised later
+ #If you want to use this, C++ server must compile with WITH_OPENCV option.
+ if len(model) == 2 and idx == 0 and model[0] == 'ocr_det_model':
infer_op_name = "general_detection"
- else:
- infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
-
+
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
-
server = None
if use_multilang:
server = serving.MultiLangServer()
@@ -148,6 +157,8 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size)
server.set_port(port)
+ server.set_precision(args.precision)
+ server.set_use_calib(args.use_calib)
server.use_encryption_model(use_encryption_model)
if args.product_name != None:
server.set_product_name(args.product_name)
@@ -199,7 +210,7 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
-
+
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
@@ -210,6 +221,8 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
server.use_mkl(use_mkl)
+ server.set_precision(args.precision)
+ server.set_use_calib(args.use_calib)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.set_max_body_size(max_body_size)
@@ -297,7 +310,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model:
if os.path.isfile(single_model_config):
- raise ValueError("The input of --model should be a dir not file.")
+ raise ValueError(
+ "The input of --model should be a dir not file.")
with open(single_model_config + "/key", "wb") as f:
f.write(key)
return True
@@ -309,7 +323,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model:
if os.path.isfile(single_model_config):
- raise ValueError("The input of --model should be a dir not file.")
+ raise ValueError(
+ "The input of --model should be a dir not file.")
with open(single_model_config + "/key", "rb") as f:
cur_key = f.read()
if key != cur_key:
@@ -394,7 +409,9 @@ if __name__ == "__main__":
device=args.device,
use_lite=args.use_lite,
use_xpu=args.use_xpu,
- ir_optim=args.ir_optim)
+ ir_optim=args.ir_optim,
+ precision=args.precision,
+ use_calib=args.use_calib)
web_service.run_rpc_service()
app_instance = Flask(__name__)
diff --git a/python/paddle_serving_server/server.py b/python/paddle_serving_server/server.py
index d96253b592f70956591c345606eeb0d01e1e4b43..c08ef838cda60ff147fae8ba1c3470c5f5b8f4d1 100755
--- a/python/paddle_serving_server/server.py
+++ b/python/paddle_serving_server/server.py
@@ -42,24 +42,37 @@ from concurrent import futures
class Server(object):
def __init__(self):
+ """
+ self.model_toolkit_conf:'list'=[] # The quantity of self.model_toolkit_conf is equal to the InferOp quantity/Engine--OP
+ self.model_conf:'collections.OrderedDict()' # Save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
+ self.workflow_fn:'str'="workflow.prototxt" # Only one for one Service/Workflow
+ self.resource_fn:'str'="resource.prototxt" # Only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
+ self.infer_service_fn:'str'="infer_service.prototxt" # Only one for one Service,Service--Workflow
+ self.model_toolkit_fn:'list'=[] # ["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
+ self.general_model_config_fn:'list'=[] # ["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
+ self.subdirectory:'list'=[] # The quantity is equal to the InferOp quantity, and name = node.name = engine.name
+ self.model_config_paths:'collections.OrderedDict()' # Save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
+ """
self.server_handle_ = None
self.infer_service_conf = None
- self.model_toolkit_conf = []#The quantity is equal to the InferOp quantity,Engine--OP
+ self.model_toolkit_conf = []
self.resource_conf = None
self.memory_optimization = False
self.ir_optimization = False
- self.model_conf = collections.OrderedDict()# save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
- self.workflow_fn = "workflow.prototxt"#only one for one Service,Workflow--Op
- self.resource_fn = "resource.prototxt"#only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
- self.infer_service_fn = "infer_service.prototxt"#only one for one Service,Service--Workflow
- self.model_toolkit_fn = []#["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
- self.general_model_config_fn = []#["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
- self.subdirectory = []#The quantity is equal to the InferOp quantity, and name = node.name = engine.name
+ self.model_conf = collections.OrderedDict()
+ self.workflow_fn = "workflow.prototxt"
+ self.resource_fn = "resource.prototxt"
+ self.infer_service_fn = "infer_service.prototxt"
+ self.model_toolkit_fn = []
+ self.general_model_config_fn = []
+ self.subdirectory = []
self.cube_config_fn = "cube.conf"
self.workdir = ""
self.max_concurrency = 0
self.num_threads = 2
self.port = 8080
+ self.precision = "fp32"
+ self.use_calib = False
self.reload_interval_s = 10
self.max_body_size = 64 * 1024 * 1024
self.module_path = os.path.dirname(paddle_serving_server.__file__)
@@ -71,12 +84,15 @@ class Server(object):
self.use_trt = False
self.use_lite = False
self.use_xpu = False
- self.model_config_paths = collections.OrderedDict() # save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
+ self.model_config_paths = collections.OrderedDict()
self.product_name = None
self.container_id = None
- def get_fetch_list(self,infer_node_idx = -1 ):
- fetch_names = [var.alias_name for var in list(self.model_conf.values())[infer_node_idx].fetch_var]
+ def get_fetch_list(self, infer_node_idx=-1):
+ fetch_names = [
+ var.alias_name
+ for var in list(self.model_conf.values())[infer_node_idx].fetch_var
+ ]
return fetch_names
def set_max_concurrency(self, concurrency):
@@ -99,6 +115,12 @@ class Server(object):
def set_port(self, port):
self.port = port
+ def set_precision(self, precision="fp32"):
+ self.precision = precision
+
+ def set_use_calib(self, use_calib=False):
+ self.use_calib = use_calib
+
def set_reload_interval(self, interval):
self.reload_interval_s = interval
@@ -172,6 +194,10 @@ class Server(object):
engine.use_trt = self.use_trt
engine.use_lite = self.use_lite
engine.use_xpu = self.use_xpu
+ engine.use_gpu = False
+ if self.device == "gpu":
+ engine.use_gpu = True
+
if os.path.exists('{}/__params__'.format(model_config_path)):
engine.combined_model = True
else:
@@ -195,9 +221,10 @@ class Server(object):
self.workdir = workdir
if self.resource_conf == None:
self.resource_conf = server_sdk.ResourceConf()
- for idx, op_general_model_config_fn in enumerate(self.general_model_config_fn):
+ for idx, op_general_model_config_fn in enumerate(
+ self.general_model_config_fn):
with open("{}/{}".format(workdir, op_general_model_config_fn),
- "w") as fout:
+ "w") as fout:
fout.write(str(list(self.model_conf.values())[idx]))
for workflow in self.workflow_conf.workflows:
for node in workflow.nodes:
@@ -212,9 +239,11 @@ class Server(object):
if "quant" in node.name:
self.resource_conf.cube_quant_bits = 8
self.resource_conf.model_toolkit_path.extend([workdir])
- self.resource_conf.model_toolkit_file.extend([self.model_toolkit_fn[idx]])
+ self.resource_conf.model_toolkit_file.extend(
+ [self.model_toolkit_fn[idx]])
self.resource_conf.general_model_path.extend([workdir])
- self.resource_conf.general_model_file.extend([op_general_model_config_fn])
+ self.resource_conf.general_model_file.extend(
+ [op_general_model_config_fn])
#TODO:figure out the meaning of product_name and container_id.
if self.product_name != None:
self.resource_conf.auth_product_name = self.product_name
@@ -237,15 +266,18 @@ class Server(object):
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
- raise ValueError("The input of --model should be a dir not file.")
-
+ raise ValueError(
+ "The input of --model should be a dir not file.")
+
if isinstance(model_config_paths_args, list):
# If there is only one model path, use the default infer_op.
# Because there are several infer_op type, we need to find
# it from workflow_conf.
default_engine_types = [
- 'GeneralInferOp', 'GeneralDistKVInferOp',
- 'GeneralDistKVQuantInferOp','GeneralDetectionOp',
+ 'GeneralInferOp',
+ 'GeneralDistKVInferOp',
+ 'GeneralDistKVQuantInferOp',
+ 'GeneralDetectionOp',
]
# now only support single-workflow.
# TODO:support multi-workflow
@@ -256,16 +288,24 @@ class Server(object):
raise Exception(
"You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
)
-
+
f = open("{}/serving_server_conf.prototxt".format(
- model_config_paths_args[model_config_paths_list_idx]), 'r')
- self.model_conf[node.name] = google.protobuf.text_format.Merge(str(f.read()), m_config.GeneralModelConfig())
- self.model_config_paths[node.name] = model_config_paths_args[model_config_paths_list_idx]
- self.general_model_config_fn.append(node.name+"/general_model.prototxt")
- self.model_toolkit_fn.append(node.name+"/model_toolkit.prototxt")
+ model_config_paths_args[model_config_paths_list_idx]),
+ 'r')
+ self.model_conf[
+ node.name] = google.protobuf.text_format.Merge(
+ str(f.read()), m_config.GeneralModelConfig())
+ self.model_config_paths[
+ node.name] = model_config_paths_args[
+ model_config_paths_list_idx]
+ self.general_model_config_fn.append(
+ node.name + "/general_model.prototxt")
+ self.model_toolkit_fn.append(node.name +
+ "/model_toolkit.prototxt")
self.subdirectory.append(node.name)
model_config_paths_list_idx += 1
- if model_config_paths_list_idx == len(model_config_paths_args):
+ if model_config_paths_list_idx == len(
+ model_config_paths_args):
break
#Right now, this is not useful.
elif isinstance(model_config_paths_args, dict):
@@ -278,11 +318,12 @@ class Server(object):
"that the input and output of multiple models are the same.")
f = open("{}/serving_server_conf.prototxt".format(path), 'r')
self.model_conf[node.name] = google.protobuf.text_format.Merge(
- str(f.read()), m_config.GeneralModelConfig())
+ str(f.read()), m_config.GeneralModelConfig())
else:
- raise Exception("The type of model_config_paths must be str or list or "
- "dict({op: model_path}), not {}.".format(
- type(model_config_paths_args)))
+ raise Exception(
+ "The type of model_config_paths must be str or list or "
+ "dict({op: model_path}), not {}.".format(
+ type(model_config_paths_args)))
# check config here
# print config here
@@ -409,7 +450,7 @@ class Server(object):
resource_fn = "{}/{}".format(workdir, self.resource_fn)
self._write_pb_str(resource_fn, self.resource_conf)
- for idx,single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
+ for idx, single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
@@ -443,6 +484,8 @@ class Server(object):
"-max_concurrency {} " \
"-num_threads {} " \
"-port {} " \
+ "-precision {} " \
+ "-use_calib {} " \
"-reload_interval_s {} " \
"-resource_path {} " \
"-resource_file {} " \
@@ -456,6 +499,8 @@ class Server(object):
self.max_concurrency,
self.num_threads,
self.port,
+ self.precision,
+ self.use_calib,
self.reload_interval_s,
self.workdir,
self.resource_fn,
@@ -471,6 +516,8 @@ class Server(object):
"-max_concurrency {} " \
"-num_threads {} " \
"-port {} " \
+ "-precision {} " \
+ "-use_calib {} " \
"-reload_interval_s {} " \
"-resource_path {} " \
"-resource_file {} " \
@@ -485,6 +532,8 @@ class Server(object):
self.max_concurrency,
self.num_threads,
self.port,
+ self.precision,
+ self.use_calib,
self.reload_interval_s,
self.workdir,
self.resource_fn,
@@ -498,6 +547,7 @@ class Server(object):
os.system(command)
+
class MultiLangServer(object):
def __init__(self):
self.bserver_ = Server()
@@ -532,6 +582,12 @@ class MultiLangServer(object):
def set_port(self, port):
self.gport_ = port
+ def set_precision(self, precision="fp32"):
+ self.precision = precision
+
+ def set_use_calib(self, use_calib=False):
+ self.use_calib = use_calib
+
def set_reload_interval(self, interval):
self.bserver_.set_reload_interval(interval)
@@ -553,22 +609,23 @@ class MultiLangServer(object):
def set_gpuid(self, gpuid=0):
self.bserver_.set_gpuid(gpuid)
- def load_model_config(self, server_config_dir_paths, client_config_path=None):
+ def load_model_config(self,
+ server_config_dir_paths,
+ client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
pass
else:
raise Exception("The type of model_config_paths must be str or list"
- ", not {}.".format(
- type(server_config_dir_paths)))
-
+ ", not {}.".format(type(server_config_dir_paths)))
for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
- raise ValueError("The input of --model should be a dir not file.")
+ raise ValueError(
+ "The input of --model should be a dir not file.")
self.bserver_.load_model_config(server_config_dir_paths)
if client_config_path is None:
@@ -576,27 +633,30 @@ class MultiLangServer(object):
if isinstance(server_config_dir_paths, dict):
self.is_multi_model_ = True
client_config_path = []
- for server_config_path_items in list(server_config_dir_paths.items()):
- client_config_path.append( server_config_path_items[1] )
+ for server_config_path_items in list(
+ server_config_dir_paths.items()):
+ client_config_path.append(server_config_path_items[1])
elif isinstance(server_config_dir_paths, list):
self.is_multi_model_ = False
client_config_path = server_config_dir_paths
else:
- raise Exception("The type of model_config_paths must be str or list or "
- "dict({op: model_path}), not {}.".format(
- type(server_config_dir_paths)))
+ raise Exception(
+ "The type of model_config_paths must be str or list or "
+ "dict({op: model_path}), not {}.".format(
+ type(server_config_dir_paths)))
if isinstance(client_config_path, str):
client_config_path = [client_config_path]
elif isinstance(client_config_path, list):
pass
- else:# dict is not support right now.
- raise Exception("The type of client_config_path must be str or list or "
- "dict({op: model_path}), not {}.".format(
- type(client_config_path)))
+ else: # dict is not support right now.
+ raise Exception(
+ "The type of client_config_path must be str or list or "
+ "dict({op: model_path}), not {}.".format(
+ type(client_config_path)))
if len(client_config_path) != len(server_config_dir_paths):
- raise Warning("The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
- .format( len(client_config_path), len(server_config_dir_paths) )
- )
+ raise Warning(
+ "The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
+ .format(len(client_config_path), len(server_config_dir_paths)))
self.bclient_config_path_list = client_config_path
def prepare_server(self,
diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py
index fc059c9d4d689a5b4e5bf065582bc3543e252735..a7d117061d00406a1fb4f6583f01ecf366fbba67 100755
--- a/python/paddle_serving_server/web_service.py
+++ b/python/paddle_serving_server/web_service.py
@@ -27,6 +27,7 @@ import os
from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op
+
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
@@ -36,6 +37,7 @@ def port_is_available(port):
else:
return False
+
class WebService(object):
def __init__(self, name="default_service"):
self.name = name
@@ -63,7 +65,9 @@ class WebService(object):
def run_service(self):
self._server.run_server()
- def load_model_config(self, server_config_dir_paths, client_config_path=None):
+ def load_model_config(self,
+ server_config_dir_paths,
+ client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
@@ -73,14 +77,16 @@ class WebService(object):
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
- raise ValueError("The input of --model should be a dir not file.")
+ raise ValueError(
+ "The input of --model should be a dir not file.")
self.server_config_dir_paths = server_config_dir_paths
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
file_path_list = []
for single_model_config in self.server_config_dir_paths:
- file_path_list.append( "{}/serving_server_conf.prototxt".format(single_model_config) )
-
+ file_path_list.append("{}/serving_server_conf.prototxt".format(
+ single_model_config))
+
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
@@ -109,7 +115,9 @@ class WebService(object):
mem_optim=True,
use_lite=False,
use_xpu=False,
- ir_optim=False):
+ ir_optim=False,
+ precision="fp32",
+ use_calib=False):
device = "gpu"
if gpuid == -1:
if use_lite:
@@ -130,7 +138,7 @@ class WebService(object):
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
-
+
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
@@ -140,13 +148,16 @@ class WebService(object):
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.set_device(device)
+ server.set_precision(precision)
+ server.set_use_calib(use_calib)
if use_lite:
server.set_lite()
if use_xpu:
server.set_xpu()
- server.load_model_config(self.server_config_dir_paths)#brpc Server support server_config_dir_paths
+ server.load_model_config(self.server_config_dir_paths
+ ) #brpc Server support server_config_dir_paths
if gpuid >= 0:
server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device)
@@ -159,6 +170,8 @@ class WebService(object):
workdir="",
port=9393,
device="gpu",
+ precision="fp32",
+ use_calib=False,
use_lite=False,
use_xpu=False,
ir_optim=False,
@@ -188,7 +201,9 @@ class WebService(object):
mem_optim=mem_optim,
use_lite=use_lite,
use_xpu=use_xpu,
- ir_optim=ir_optim))
+ ir_optim=ir_optim,
+ precision=precision,
+ use_calib=use_calib))
else:
for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append(
@@ -200,7 +215,9 @@ class WebService(object):
mem_optim=mem_optim,
use_lite=use_lite,
use_xpu=use_xpu,
- ir_optim=ir_optim))
+ ir_optim=ir_optim,
+ precision=precision,
+ use_calib=use_calib))
def _launch_web_service(self):
gpu_num = len(self.gpus)
@@ -297,9 +314,13 @@ class WebService(object):
# default self.gpus = [0].
if len(self.gpus) == 0:
self.gpus.append(0)
- self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=True, gpu_id=self.gpus[0])
+ self.client.load_model_config(
+ self.server_config_dir_paths[0],
+ use_gpu=True,
+ gpu_id=self.gpus[0])
else:
- self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=False)
+ self.client.load_model_config(
+ self.server_config_dir_paths[0], use_gpu=False)
def run_web_service(self):
print("This API will be deprecated later. Please do not use it")
diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py
index 9043540792730db6c9349243277a63a0565e01c1..7ea8858d2b47c1c20226c4f53805f3aa2fd75643 100644
--- a/python/pipeline/pipeline_server.py
+++ b/python/pipeline/pipeline_server.py
@@ -238,6 +238,8 @@ class PipelineServer(object):
"devices": "",
"mem_optim": True,
"ir_optim": False,
+ "precision": "fp32",
+ "use_calib": False,
},
}
for op in self._used_op:
@@ -394,6 +396,8 @@ class ServerYamlConfChecker(object):
"devices": "",
"mem_optim": True,
"ir_optim": False,
+ "precision": "fp32",
+ "use_calib": False,
}
conf_type = {
"model_config": str,
@@ -403,6 +407,8 @@ class ServerYamlConfChecker(object):
"devices": str,
"mem_optim": bool,
"ir_optim": bool,
+ "precision": str,
+ "use_calib": bool,
}
conf_qualification = {"thread_num": (">=", 1), }
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
diff --git a/tools/scripts/ipipe_py3.sh b/tools/scripts/ipipe_py3.sh
index 156b32d4acdc5fa63acd6e2c1d467dccb680d36b..1f34f69525a9475279f55b270e3f61d2d2feff14 100644
--- a/tools/scripts/ipipe_py3.sh
+++ b/tools/scripts/ipipe_py3.sh
@@ -8,11 +8,26 @@ echo "# #"
echo "# #"
echo "# #"
echo "################################################################"
-
export GOPATH=$HOME/go
export PATH=$PATH:$GOROOT/bin:$GOPATH/bin
export CUDA_INCLUDE_DIRS=/usr/local/cuda-10.2/include
export PYTHONROOT=/usr/local
+export PYTHONIOENCODING=utf-8
+
+build_path=/workspace/Serving/
+error_words="Fail|DENIED|UNKNOWN|None"
+log_dir=${build_path}logs/
+data=/root/.cache/serving_data/
+dir=`pwd`
+RED_COLOR='\E[1;31m'
+GREEN_COLOR='\E[1;32m'
+YELOW_COLOR='\E[1;33m'
+RES='\E[0m'
+cuda_version=`cat /usr/local/cuda/version.txt`
+
+if [ $? -ne 0 ]; then
+ cuda_version=11
+fi
go env -w GO111MODULE=on
go env -w GOPROXY=https://goproxy.cn,direct
@@ -21,90 +36,141 @@ go get -u github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2
go get -u github.com/golang/protobuf/protoc-gen-go@v1.4.3
go get -u google.golang.org/grpc@v1.33.0
-build_path=/workspace/Serving/
-build_whl_list=(build_gpu_server build_client build_cpu_server build_app)
-rpc_model_list=(grpc_impl pipeline_imagenet bert_rpc_gpu bert_rpc_cpu ResNet50_rpc lac_rpc \
-cnn_rpc bow_rpc lstm_rpc fit_a_line_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \
-criteo_ctr_rpc_cpu criteo_ctr_rpc_gpu ocr_rpc yolov4_rpc_gpu)
-http_model_list=(fit_a_line_http lac_http cnn_http bow_http lstm_http ResNet50_http bert_http)
-
+build_whl_list=(build_cpu_server build_gpu_server build_client build_app)
+rpc_model_list=(grpc_fit_a_line grpc_yolov4 pipeline_imagenet bert_rpc_gpu bert_rpc_cpu ResNet50_rpc \
+lac_rpc cnn_rpc bow_rpc lstm_rpc fit_a_line_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \
+criteo_ctr_rpc_cpu criteo_ctr_rpc_gpu ocr_rpc yolov4_rpc_gpu faster_rcnn_hrnetv2p_w18_1x_encrypt)
+http_model_list=(fit_a_line_http lac_http cnn_http bow_http lstm_http ResNet50_http bert_http\
+pipeline_ocr_cpu_http)
-function setproxy(){
- export http_proxy=${proxy}
- export https_proxy=${proxy}
+function setproxy() {
+ export http_proxy=${proxy}
+ export https_proxy=${proxy}
}
-function unsetproxy(){
- unset http_proxy
- unset https_proxy
+function unsetproxy() {
+ unset http_proxy
+ unset https_proxy
}
-function kill_server_process(){
- kill `ps -ef|grep $1 |awk '{print $2}'`
- kill `ps -ef|grep serving |awk '{print $2}'`
+function kill_server_process() {
+ kill `ps -ef | grep serving | awk '{print $2}'` > /dev/null 2>&1
+ kill `ps -ef | grep python | awk '{print $2}'` > /dev/null 2>&1
+ echo -e "${GREEN_COLOR}process killed...${RES}"
}
function check() {
cd ${build_path}
if [ ! -f paddle_serving_app* ]; then
- echo "paddle_serving_app is compiled failed, please check your pull request"
- exit 1
+ echo "paddle_serving_app is compiled failed, please check your pull request"
+ exit 1
elif [ ! -f paddle_serving_server-* ]; then
- echo "paddle_serving_server-cpu is compiled failed, please check your pull request"
- exit 1
+ echo "paddle_serving_server-cpu is compiled failed, please check your pull request"
+ exit 1
elif [ ! -f paddle_serving_server_* ]; then
- echo "paddle_serving_server_gpu is compiled failed, please check your pull request"
- exit 1
+ echo "paddle_serving_server_gpu is compiled failed, please check your pull request"
+ exit 1
elif [ ! -f paddle_serving_client* ]; then
- echo "paddle_serving_server_client is compiled failed, please check your pull request"
- exit 1
+ echo "paddle_serving_server_client is compiled failed, please check your pull request"
+ exit 1
else
- echo "paddle serving build passed"
+ echo "paddle serving build passed"
fi
}
function check_result() {
- if [ $? -ne 0 ];then
- echo -e "\033[4;31;42m$1 model runs failed, please check your pull request or modify test case! \033[0m"
- exit 1
+ if [ $? == 0 ]; then
+ echo -e "${GREEN_COLOR}$1 execute normally${RES}"
+ if [ $1 == "server" ]; then
+ sleep $2
+ tail ${dir}server_log.txt | tee -a ${log_dir}server_total.txt
+ fi
+ if [ $1 == "client" ]; then
+ tail ${dir}client_log.txt | tee -a ${log_dir}client_total.txt
+ grep -E "${error_words}" ${dir}client_log.txt > /dev/null
+ if [ $? == 0 ]; then
+ echo -e "${RED_COLOR}$1 error command${RES}\n" | tee -a ${log_dir}server_total.txt ${log_dir}client_total.txt
+ error_log $2
+ else
+ echo -e "${GREEN_COLOR}$2${RES}\n" | tee -a ${log_dir}server_total.txt ${log_dir}client_total.txt
+ fi
+ fi
else
- echo -e "\033[4;37;42m$1 model runs successfully, congratulations! \033[0m"
+ echo -e "${RED_COLOR}$1 error command${RES}\n" | tee -a ${log_dir}server_total.txt ${log_dir}client_total.txt
+ tail ${dir}client_log.txt | tee -a ${log_dir}client_total.txt
+ error_log $2
fi
}
-function before_hook(){
- setproxy
- cd ${build_path}/python
- pip3.6 install --upgrade pip
- pip3.6 install requests
- pip3.6 install -r requirements.txt
- pip3.6 install numpy==1.16.4
- echo "before hook configuration is successful.... "
+function error_log() {
+ arg=${1//\//_}
+ echo "-----------------------------" | tee -a ${log_dir}error_models.txt
+ arg=${arg%% *}
+ arr=(${arg//_/ })
+ if [ ${arr[@]: -1} == 1 -o ${arr[@]: -1} == 2 ]; then
+ model=${arr[*]:0:${#arr[*]}-3}
+ deployment=${arr[*]: -3}
+ else
+ model=${arr[*]:0:${#arr[*]}-2}
+ deployment=${arr[*]: -2}
+ fi
+ echo "model: ${model// /_}" | tee -a ${log_dir}error_models.txt
+ echo "deployment: ${deployment// /_}" | tee -a ${log_dir}error_models.txt
+ echo "py_version: python3.6" | tee -a ${log_dir}error_models.txt
+ echo "cuda_version: ${cuda_version}" | tee -a ${log_dir}error_models.txt
+ echo "status: Failed" | tee -a ${log_dir}error_models.txt
+ echo -e "-----------------------------\n\n" | tee -a ${log_dir}error_models.txt
+ prefix=${arg//\//_}
+ for file in ${dir}*
+ do
+ cp ${file} ${log_dir}error/${prefix}_${file##*/}
+ done
+}
+
+function check_dir() {
+ if [ ! -d "$1" ]
+ then
+ mkdir -p $1
+ fi
}
-function run_env(){
- setproxy
- pip3.6 install --upgrade nltk==3.4
- pip3.6 install --upgrade scipy==1.2.1
- pip3.6 install --upgrade setuptools==41.0.0
- pip3.6 install paddlehub ujson paddlepaddle==2.0.0
- echo "run env configuration is successful.... "
+function link_data() {
+ for file in $1*
+ do
+ if [ ! -h ${file##*/} ]
+ then
+ ln -s ${file} ./${file##*/}
+ fi
+ done
}
-function run_gpu_env(){
- cd ${build_path}
- export LD_LIBRARY_PATH=/usr/local/lib64/python3.6/site-packages/paddle/libs/:$LD_LIBRARY_PATH
- export LD_LIBRARY_PATH=/workspace/Serving/build_gpu/third_party/install/Paddle/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mklml/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mkldnn/lib/:$LD_LIBRARY_PATH
- export SERVING_BIN=${build_path}/build_gpu/core/general-server/serving
- echo "run gpu env configuration is successful.... "
+function before_hook() {
+ setproxy
+ unsetproxy
+ cd ${build_path}/python
+ python3.6 -m pip install --upgrade pip
+ python3.6 -m pip install requests
+ python3.6 -m pip install -r requirements.txt -i https://mirror.baidu.com/pypi/simple
+ python3.6 -m pip install numpy==1.16.4
+ python3.6 -m pip install paddlehub -i https://mirror.baidu.com/pypi/simple
+ echo "before hook configuration is successful.... "
}
-function run_cpu_env(){
- cd ${build_path}
- export LD_LIBRARY_PATH=/usr/local/lib64/python3.6/site-packages/paddle/libs/:$LD_LIBRARY_PATH
- export LD_LIBRARY_PATH=/workspace/Serving/build_cpu/third_party/install/Paddle/lib/:$LD_LIBRARY_PATH
- export SERVING_BIN=${build_path}/build_cpu/core/general-server/serving
- echo "run cpu env configuration is successful.... "
+function run_env() {
+ setproxy
+ python3.6 -m pip install --upgrade nltk==3.4
+ python3.6 -m pip install --upgrade scipy==1.2.1
+ python3.6 -m pip install --upgrade setuptools==41.0.0
+ python3.6 -m pip install paddlehub ujson paddlepaddle==2.0.0
+ echo "run env configuration is successful.... "
+}
+
+function run_gpu_env() {
+ cd ${build_path}
+ export LD_LIBRARY_PATH=/usr/local/lib64/python3.6/site-packages/paddle/libs/:$LD_LIBRARY_PATH
+ export LD_LIBRARY_PATH=/workspace/Serving/build_gpu/third_party/install/Paddle/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mklml/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mkldnn/lib/:$LD_LIBRARY_PATH
+ export SERVING_BIN=${build_path}/build_gpu/core/general-server/serving
+ echo "run gpu env configuration is successful.... "
}
function build_gpu_server() {
@@ -124,573 +190,643 @@ function build_gpu_server() {
-DSERVER=ON \
-DTENSORRT_ROOT=/usr \
-DWITH_GPU=ON ..
- make -j18
- make -j18
- make install -j18
- pip3.6 uninstall paddle-serving-server-gpu -y
- pip3.6 install ${build_path}/build/python/dist/*
+ make -j32
+ make -j32
+ make install -j32
+ python3.6 -m pip uninstall paddle-serving-server-gpu -y
+ python3.6 -m pip install ${build_path}/build/python/dist/*
cp ${build_path}/build/python/dist/* ../
cp -r ${build_path}/build/ ${build_path}/build_gpu
}
-function build_client() {
- setproxy
- cd ${build_path}
- if [ -d build ];then
- rm -rf build
- fi
- mkdir build && cd build
- cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
- -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \
- -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
- -DCLIENT=ON ..
- make -j18
- make -j18
- cp ${build_path}/build/python/dist/* ../
- pip3.6 uninstall paddle-serving-client -y
- pip3.6 install ${build_path}/build/python/dist/*
+function build_cpu_server(){
+ setproxy
+ cd ${build_path}
+ if [ -d build_cpu ];then
+ rm -rf build_cpu
+ fi
+ if [ -d build ];then
+ rm -rf build
+ fi
+ mkdir build && cd build
+ cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
+ -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \
+ -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
+ -DWITH_GPU=OFF \
+ -DSERVER=ON ..
+ make -j32
+ make -j32
+ make install -j32
+ cp ${build_path}/build/python/dist/* ../
+ python3.6 -m pip uninstall paddle-serving-server -y
+ python3.6 -m pip install ${build_path}/build/python/dist/*
+ cp -r ${build_path}/build/ ${build_path}/build_cpu
}
-function build_cpu_server(){
- setproxy
- cd ${build_path}
- if [ -d build_cpu ];then
- rm -rf build_cpu
- fi
- if [ -d build ];then
- rm -rf build
- fi
- mkdir build && cd build
- cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
- -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \
- -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
- -DWITH_GPU=OFF \
- -DSERVER=ON ..
- make -j18
- make -j18
- make install -j18
- cp ${build_path}/build/python/dist/* ../
- pip3.6 uninstall paddle-serving-server -y
- pip3.6 install ${build_path}/build/python/dist/*
- cp -r ${build_path}/build/ ${build_path}/build_cpu
+function build_client() {
+ setproxy
+ cd ${build_path}
+ if [ -d build ];then
+ rm -rf build
+ fi
+ mkdir build && cd build
+ cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
+ -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \
+ -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
+ -DCLIENT=ON ..
+ make -j32
+ make -j32
+ cp ${build_path}/build/python/dist/* ../
+ python3.6 -m pip uninstall paddle-serving-client -y
+ python3.6 -m pip install ${build_path}/build/python/dist/*
}
function build_app() {
- setproxy
- pip3.6 install paddlehub ujson Pillow
- pip3.6 install paddlepaddle==2.0.0
- cd ${build_path}
- if [ -d build ];then
- rm -rf build
- fi
- mkdir build && cd build
- cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
- -DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython3.6.so \
- -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
- -DCMAKE_INSTALL_PREFIX=./output -DAPP=ON ..
- make
- cp ${build_path}/build/python/dist/* ../
- pip3.6 uninstall paddle-serving-app -y
- pip3.6 install ${build_path}/build/python/dist/*
-}
-
-function bert_rpc_gpu(){
- run_gpu_env
- unsetproxy
- cd ${build_path}/python/examples/bert
- sh get_data.sh >/dev/null 2>&1
- sed -i 's/9292/8860/g' bert_client.py
- sed -i '$aprint(result)' bert_client.py
- cp -r /root/.cache/dist_data/serving/bert/bert_seq128_* ./
- ls -hlst
- python3.6 -m paddle_serving_server_gpu.serve --model bert_seq128_model/ --port 8860 --gpu_ids 0 &
- sleep 15
- nvidia-smi
- head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function bert_rpc_cpu(){
- run_cpu_env
- unsetproxy
- cd ${build_path}/python/examples/bert
- sed -i 's/8860/8861/g' bert_client.py
- python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8861 &
- sleep 3
- cp data-c.txt.1 data-c.txt
- head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function criteo_ctr_with_cube_rpc(){
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/criteo_ctr_with_cube
- ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./
- sed -i "s/9292/8888/g" test_server.py
- sed -i "s/9292/8888/g" test_client.py
- wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz >/dev/null 2>&1
- tar xf ctr_cube_unittest.tar.gz
- mv models/ctr_client_conf ./
- mv models/ctr_serving_model_kv ./
- mv models/data ./cube/
- wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz >/dev/null 2>&1
- tar xf cube_app.tar.gz
- mv cube_app/cube* ./cube/
- sh cube_prepare.sh > haha 2>&1 &
- sleep 5
- python3.6 test_server.py ctr_serving_model_kv &
- sleep 5
- python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data
- check_result $FUNCNAME
- kill `ps -ef|grep cube|awk '{print $2}'`
- kill_server_process test_server
-}
-
-function pipeline_imagenet(){
- run_gpu_env
- unsetproxy
- cd ${build_path}/python/examples/pipeline/imagenet
- cp -r /root/.cache/dist_data/serving/imagenet/* ./
- ls -a
- python3.6 resnet50_web_service.py &
- sleep 5
- nvidia-smi
- python3.6 pipeline_rpc_client.py
- nvidia-smi
- # check_result $FUNCNAME
- kill_server_process resnet50_web_service
-}
-
-function ResNet50_rpc(){
- run_gpu_env
- unsetproxy
- cd ${build_path}/python/examples/imagenet
- cp -r /root/.cache/dist_data/serving/imagenet/* ./
- sed -i 's/9696/8863/g' resnet50_rpc_client.py
- python3.6 -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 8863 --gpu_ids 0 &
- sleep 5
- nvidia-smi
- python3.6 resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function ResNet101_rpc(){
- run_gpu_env
- unsetproxy
- cd ${build_path}/python/examples/imagenet
- sed -i "22cclient.connect(['${host}:8864'])" image_rpc_client.py
- python3.6 -m paddle_serving_server_gpu.serve --model ResNet101_vd_model --port 8864 --gpu_ids 0 &
- sleep 5
- nvidia-smi
- python3.6 image_rpc_client.py ResNet101_vd_client_config/serving_client_conf.prototxt
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
- sleep 5
-}
-
-function cnn_rpc(){
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/imdb
- cp -r /root/.cache/dist_data/serving/imdb/* ./
- tar xf imdb_model.tar.gz && tar xf text_classification_data.tar.gz
- sed -i 's/9292/8865/g' test_client.py
- python3.6 -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 &
- sleep 5
- head test_data/part-0 | python3.6 test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function bow_rpc(){
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/imdb
- sed -i 's/8865/8866/g' test_client.py
- python3.6 -m paddle_serving_server.serve --model imdb_bow_model/ --port 8866 &
- sleep 5
- head test_data/part-0 | python3.6 test_client.py imdb_bow_client_conf/serving_client_conf.prototxt imdb.vocab
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function lstm_rpc(){
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/imdb
- sed -i 's/8866/8867/g' test_client.py
- python3.6 -m paddle_serving_server.serve --model imdb_lstm_model/ --port 8867 &
- sleep 5
- head test_data/part-0 | python3.6 test_client.py imdb_lstm_client_conf/serving_client_conf.prototxt imdb.vocab
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function lac_rpc(){
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/lac
- python3.6 -m paddle_serving_app.package --get_model lac >/dev/null 2>&1
- tar xf lac.tar.gz
- sed -i 's/9292/8868/g' lac_client.py
- python3.6 -m paddle_serving_server.serve --model lac_model/ --port 8868 &
- sleep 5
- echo "我爱北京天安门" | python3.6 lac_client.py lac_client/serving_client_conf.prototxt lac_dict/
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function fit_a_line_rpc(){
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/fit_a_line
- sh get_data.sh >/dev/null 2>&1
- sed -i 's/9393/8869/g' test_client.py
- python3.6 -m paddle_serving_server.serve --model uci_housing_model --port 8869 &
- sleep 5
- python3.6 test_client.py uci_housing_client/serving_client_conf.prototxt
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function faster_rcnn_model_rpc(){
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/faster_rcnn
- cp -r /root/.cache/dist_data/serving/faster_rcnn/faster_rcnn_model.tar.gz ./
- tar xf faster_rcnn_model.tar.gz
- wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml >/dev/null 2>&1
- mv faster_rcnn_model/pddet* ./
- sed -i 's/9494/8870/g' test_client.py
- python3.6 -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 8870 --gpu_id 0 --thread 2 &
- echo "faster rcnn running ..."
- nvidia-smi
- sleep 5
- python3.6 test_client.py pddet_client_conf/serving_client_conf.prototxt infer_cfg.yml 000000570688.jpg
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
-}
-
-function cascade_rcnn_rpc(){
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/cascade_rcnn
- cp -r /root/.cache/dist_data/serving/cascade_rcnn/cascade_rcnn_r50_fpx_1x_serving.tar.gz ./
- tar xf cascade_rcnn_r50_fpx_1x_serving.tar.gz
- sed -i "s/9292/8879/g" test_client.py
- python3.6 -m paddle_serving_server_gpu.serve --model serving_server --port 8879 --gpu_id 0 --thread 2 &
- sleep 5
- nvidia-smi
- python3.6 test_client.py
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
+ setproxy
+ python3.6 -m pip install paddlehub ujson Pillow
+ python3.6 -m pip install paddlepaddle==2.0.0
+ cd ${build_path}
+ if [ -d build ];then
+ rm -rf build
+ fi
+ mkdir build && cd build
+ cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \
+ -DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython3.6.so \
+ -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \
+ -DCMAKE_INSTALL_PREFIX=./output -DAPP=ON ..
+ make
+ cp ${build_path}/build/python/dist/* ../
+ python3.6 -m pip uninstall paddle-serving-app -y
+ python3.6 -m pip install ${build_path}/build/python/dist/*
+}
+
+function faster_rcnn_hrnetv2p_w18_1x_encrypt() {
+ dir=${log_dir}rpc_model/faster_rcnn_hrnetv2p_w18_1x/
+ cd ${build_path}/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x
+ check_dir ${dir}
+ data_dir=${data}detection/faster_rcnn_hrnetv2p_w18_1x/
+ link_data ${data_dir}
+ python3.6 encrypt.py
+ unsetproxy
+ echo -e "${GREEN_COLOR}faster_rcnn_hrnetv2p_w18_1x_ENCRYPTION_GPU_RPC server started${RES}" | tee -a ${log_dir}server_total.txt
+ python3.6 -m paddle_serving_server.serve --model encrypt_server/ --port 9494 --use_trt --gpu_ids 0 --use_encryption_model > ${dir}server_log.txt 2>&1 &
+ check_result server 3
+ echo -e "${GREEN_COLOR}faster_rcnn_hrnetv2p_w18_1x_ENCRYPTION_GPU_RPC client started${RES}" | tee -a ${log_dir}client_total.txt
+ python3.6 test_encryption.py 000000570688.jpg > ${dir}client_log.txt 2>&1
+ check_result client "faster_rcnn_hrnetv2p_w18_1x_ENCRYPTION_GPU_RPC server test completed"
+ kill_server_process
+}
+
+function pipeline_ocr_cpu_http() {
+ dir=${log_dir}rpc_model/pipeline_ocr_cpu_http/
+ check_dir ${dir}
+ cd ${build_path}/python/examples/pipeline/ocr
+ data_dir=${data}ocr/
+ link_data ${data_dir}
+ unsetproxy
+ echo -e "${GREEN_COLOR}pipeline_ocr_CPU_HTTP server started${RES}" | tee -a ${log_dir}server_total.txt
+ $py_version web_service.py > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ echo -e "${GREEN_COLOR}pipeline_ocr_CPU_HTTP client started${RES}" | tee -a ${log_dir}client_total.txt
+ timeout 15s $py_version pipeline_http_client.py > ${dir}client_log.txt 2>&1
+ check_result client "pipeline_ocr_CPU_HTTP server test completed"
+ kill_server_process
+}
+
+function bert_rpc_gpu() {
+ dir=${log_dir}rpc_model/bert_rpc_gpu/
+ check_dir ${dir}
+ run_gpu_env
+ unsetproxy
+ cd ${build_path}/python/examples/bert
+ data_dir=${data}bert/
+ link_data ${data_dir}
+ sh get_data.sh >/dev/null 2>&1
+ sed -i 's/9292/8860/g' bert_client.py
+ sed -i '$aprint(result)' bert_client.py
+ ls -hlst
+ python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8860 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
+ check_result server 15
+ nvidia-smi
+ head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
+ check_result client "bert_GPU_RPC server test completed"
+ nvidia-smi
+ kill_server_process
+}
+
+function bert_rpc_cpu() {
+ dir=${log_dir}rpc_model/bert_rpc_cpu/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/bert
+ data_dir=${data}bert/
+ link_data ${data_dir}
+ sed -i 's/8860/8861/g' bert_client.py
+ python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8861 > ${dir}server_log.txt 2>&1 &
+ check_result server 3
+ cp data-c.txt.1 data-c.txt
+ head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
+ check_result client "bert_CPU_RPC server test completed"
+ kill_server_process
+}
+
+function pipeline_imagenet() {
+ dir=${log_dir}rpc_model/pipeline_imagenet/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/pipeline/imagenet
+ data_dir=${data}imagenet/
+ link_data ${data_dir}
+ python3.6 resnet50_web_service.py > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 pipeline_rpc_client.py > ${dir}client_log.txt 2>&1
+ check_result client "pipeline_imagenet_GPU_RPC server test completed"
+ nvidia-smi
+ kill_server_process
+}
+
+function ResNet50_rpc() {
+ dir=${log_dir}rpc_model/ResNet50_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imagenet
+ data_dir=${data}imagenet/
+ link_data ${data_dir}
+ sed -i 's/9696/8863/g' resnet50_rpc_client.py
+ python3.6 -m paddle_serving_server.serve --model ResNet50_vd_model --port 8863 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
+ check_result client "ResNet50_GPU_RPC server test completed"
+ nvidia-smi
+ kill_server_process
+}
+
+function ResNet101_rpc() {
+ dir=${log_dir}rpc_model/ResNet101_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imagenet
+ data_dir=${data}imagenet/
+ link_data ${data_dir}
+ sed -i "22cclient.connect(['127.0.0.1:8864'])" image_rpc_client.py
+ python3.6 -m paddle_serving_server.serve --model ResNet101_vd_model --port 8864 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 image_rpc_client.py ResNet101_vd_client_config/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
+ check_result client "ResNet101_GPU_RPC server test completed"
+ nvidia-smi
+ kill_server_process
+}
+
+function cnn_rpc() {
+ dir=${log_dir}rpc_model/cnn_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imdb
+ data_dir=${data}imdb/
+ link_data ${data_dir}
+ sed -i 's/9292/8865/g' test_client.py
+ python3.6 -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ head test_data/part-0 | python3.6 test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab > ${dir}client_log.txt 2>&1
+ check_result client "cnn_CPU_RPC server test completed"
+ kill_server_process
+}
+
+function bow_rpc() {
+ dir=${log_dir}rpc_model/bow_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imdb
+ data_dir=${data}imdb/
+ link_data ${data_dir}
+ sed -i 's/8865/8866/g' test_client.py
+ python3.6 -m paddle_serving_server.serve --model imdb_bow_model/ --port 8866 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ head test_data/part-0 | python3.6 test_client.py imdb_bow_client_conf/serving_client_conf.prototxt imdb.vocab > ${dir}client_log.txt 2>&1
+ check_result client "bow_CPU_RPC server test completed"
+ kill_server_process
+}
+
+function lstm_rpc() {
+ dir=${log_dir}rpc_model/lstm_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imdb
+ data_dir=${data}imdb/
+ link_data ${data_dir}
+ sed -i 's/8866/8867/g' test_client.py
+ python3.6 -m paddle_serving_server.serve --model imdb_lstm_model/ --port 8867 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ head test_data/part-0 | python3.6 test_client.py imdb_lstm_client_conf/serving_client_conf.prototxt imdb.vocab > ${dir}client_log.txt 2>&1
+ check_result client "lstm_CPU_RPC server test completed"
+ kill_server_process
+}
+
+function lac_rpc() {
+ dir=${log_dir}rpc_model/lac_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/lac
+ data_dir=${data}lac/
+ link_data ${data_dir}
+ sed -i 's/9292/8868/g' lac_client.py
+ python3.6 -m paddle_serving_server.serve --model lac_model/ --port 8868 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ echo "我爱北京天安门" | python3.6 lac_client.py lac_client/serving_client_conf.prototxt lac_dict/ > ${dir}client_log.txt 2>&1
+ check_result client "lac_CPU_RPC server test completed"
+ kill_server_process
+}
+
+function fit_a_line_rpc() {
+ dir=${log_dir}rpc_model/fit_a_line_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/fit_a_line
+ data_dir=${data}fit_a_line/
+ link_data ${data_dir}
+ sed -i 's/9393/8869/g' test_client.py
+ python3.6 -m paddle_serving_server.serve --model uci_housing_model --port 8869 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ python3.6 test_client.py uci_housing_client/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1
+ check_result client "fit_a_line_CPU_RPC server test completed"
+ kill_server_process
+}
+
+function faster_rcnn_model_rpc() {
+ dir=${log_dir}rpc_model/faster_rcnn_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/detection/faster_rcnn_r50_fpn_1x_coco
+ data_dir=${data}detection/faster_rcnn_r50_fpn_1x_coco/
+ link_data ${data_dir}
+ sed -i 's/9494/8870/g' test_client.py
+ python3.6 -m paddle_serving_server.serve --model serving_server --port 8870 --gpu_ids 0 --thread 2 --use_trt > ${dir}server_log.txt 2>&1 &
+ echo "faster rcnn running ..."
+ nvidia-smi
+ check_result server 10
+ python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "faster_rcnn_GPU_RPC server test completed"
+ kill_server_process
+}
+
+function cascade_rcnn_rpc() {
+ dir=${log_dir}rpc_model/cascade_rcnn_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/cascade_rcnn
+ data_dir=${data}cascade_rcnn/
+ link_data ${data_dir}
+ sed -i "s/9292/8879/g" test_client.py
+ python3.6 -m paddle_serving_server.serve --model serving_server --port 8879 --gpu_ids 0 --thread 2 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 test_client.py > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "cascade_rcnn_GPU_RPC server test completed"
+ kill_server_process
}
function deeplabv3_rpc() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/deeplabv3
- cp -r /root/.cache/dist_data/serving/deeplabv3/deeplabv3.tar.gz ./
- tar xf deeplabv3.tar.gz
- sed -i "s/9494/8880/g" deeplabv3_client.py
- python3.6 -m paddle_serving_server_gpu.serve --model deeplabv3_server --gpu_ids 0 --port 8880 --thread 2 &
- sleep 5
- nvidia-smi
- python3.6 deeplabv3_client.py
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/deeplabv3_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/deeplabv3
+ data_dir=${data}deeplabv3/
+ link_data ${data_dir}
+ sed -i "s/9494/8880/g" deeplabv3_client.py
+ python3.6 -m paddle_serving_server.serve --model deeplabv3_server --gpu_ids 0 --port 8880 --thread 2 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 deeplabv3_client.py > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "deeplabv3_GPU_RPC server test completed"
+ kill_server_process
}
function mobilenet_rpc() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/mobilenet
- python3.6 -m paddle_serving_app.package --get_model mobilenet_v2_imagenet >/dev/null 2>&1
- tar xf mobilenet_v2_imagenet.tar.gz
- sed -i "s/9393/8881/g" mobilenet_tutorial.py
- python3.6 -m paddle_serving_server_gpu.serve --model mobilenet_v2_imagenet_model --gpu_ids 0 --port 8881 &
- sleep 5
- nvidia-smi
- python3.6 mobilenet_tutorial.py
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/mobilenet_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/mobilenet
+ python3.6 -m paddle_serving_app.package --get_model mobilenet_v2_imagenet >/dev/null 2>&1
+ tar xf mobilenet_v2_imagenet.tar.gz
+ sed -i "s/9393/8881/g" mobilenet_tutorial.py
+ python3.6 -m paddle_serving_server.serve --model mobilenet_v2_imagenet_model --gpu_ids 0 --port 8881 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 mobilenet_tutorial.py > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "mobilenet_GPU_RPC server test completed"
+ kill_server_process
}
function unet_rpc() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/unet_for_image_seg
- python3.6 -m paddle_serving_app.package --get_model unet >/dev/null 2>&1
- tar xf unet.tar.gz
- sed -i "s/9494/8882/g" seg_client.py
- python3.6 -m paddle_serving_server_gpu.serve --model unet_model --gpu_ids 0 --port 8882 &
- sleep 5
- nvidia-smi
- python3.6 seg_client.py
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/unet_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/unet_for_image_seg
+ data_dir=${data}unet_for_image_seg/
+ link_data ${data_dir}
+ sed -i "s/9494/8882/g" seg_client.py
+ python3.6 -m paddle_serving_server.serve --model unet_model --gpu_ids 0 --port 8882 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 seg_client.py > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "unet_GPU_RPC server test completed"
+ kill_server_process
}
function resnetv2_rpc() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/resnet_v2_50
- cp /root/.cache/dist_data/serving/resnet_v2_50/resnet_v2_50_imagenet.tar.gz ./
- tar xf resnet_v2_50_imagenet.tar.gz
- sed -i 's/9393/8883/g' resnet50_v2_tutorial.py
- python3.6 -m paddle_serving_server_gpu.serve --model resnet_v2_50_imagenet_model --gpu_ids 0 --port 8883 &
- sleep 10
- nvidia-smi
- python3.6 resnet50_v2_tutorial.py
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/resnetv2_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/resnet_v2_50
+ data_dir=${data}resnet_v2_50/
+ link_data ${data_dir}
+ sed -i 's/9393/8883/g' resnet50_v2_tutorial.py
+ python3.6 -m paddle_serving_server.serve --model resnet_v2_50_imagenet_model --gpu_ids 0 --port 8883 > ${dir}server_log.txt 2>&1 &
+ check_result server 10
+ nvidia-smi
+ python3.6 resnet50_v2_tutorial.py > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "resnetv2_GPU_RPC server test completed"
+ kill_server_process
}
function ocr_rpc() {
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/ocr
- cp -r /root/.cache/dist_data/serving/ocr/test_imgs ./
- python3.6 -m paddle_serving_app.package --get_model ocr_rec >/dev/null 2>&1
- tar xf ocr_rec.tar.gz
- sed -i 's/9292/8884/g' test_ocr_rec_client.py
- python3.6 -m paddle_serving_server.serve --model ocr_rec_model --port 8884 &
- sleep 5
- python3.6 test_ocr_rec_client.py
- # check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/ocr_rpc/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/ocr
+ data_dir=${data}ocr/
+ link_data ${data_dir}
+ python3.6 -m paddle_serving_app.package --get_model ocr_rec >/dev/null 2>&1
+ tar xf ocr_rec.tar.gz
+ sed -i 's/9292/8884/g' test_ocr_rec_client.py
+ python3.6 -m paddle_serving_server.serve --model ocr_rec_model --port 8884 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ python3.6 test_ocr_rec_client.py > ${dir}client_log.txt 2>&1
+ check_result client "ocr_CPU_RPC server test completed"
+ kill_server_process
}
function criteo_ctr_rpc_cpu() {
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/criteo_ctr
- sed -i "s/9292/8885/g" test_client.py
- ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./
- wget https://paddle-serving.bj.bcebos.com/criteo_ctr_example/criteo_ctr_demo_model.tar.gz >/dev/null 2>&1
- tar xf criteo_ctr_demo_model.tar.gz
- mv models/ctr_client_conf .
- mv models/ctr_serving_model .
- python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8885 &
- sleep 5
- python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0
- check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/criteo_ctr_rpc_cpu/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/criteo_ctr
+ data_dir=${data}criteo_ctr/
+ link_data ${data_dir}
+ sed -i "s/9292/8885/g" test_client.py
+ python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8885 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0 > ${dir}client_log.txt 2>&1
+ check_result client "criteo_ctr_CPU_RPC server test completed"
+ kill_server_process
}
function criteo_ctr_rpc_gpu() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/criteo_ctr
- sed -i "s/8885/8886/g" test_client.py
- wget https://paddle-serving.bj.bcebos.com/criteo_ctr_example/criteo_ctr_demo_model.tar.gz >/dev/null 2>&1
- python3.6 -m paddle_serving_server_gpu.serve --model ctr_serving_model/ --port 8886 --gpu_ids 0 &
- sleep 5
- nvidia-smi
- python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/
- nvidia-smi
- check_result $FUNCNAME
- kill `ps -ef|grep ctr|awk '{print $2}'`
- kill_server_process serving
+ dir=${log_dir}rpc_model/criteo_ctr_rpc_gpu/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/criteo_ctr
+ data_dir=${data}criteo_ctr/
+ link_data ${data_dir}
+ sed -i "s/8885/8886/g" test_client.py
+ python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8886 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ nvidia-smi
+ python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0 > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "criteo_ctr_GPU_RPC server test completed"
+ kill_server_process
}
function yolov4_rpc_gpu() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/yolov4
- sed -i "s/9393/8887/g" test_client.py
- cp -r /root/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./
- tar xf yolov4.tar.gz
- python3.6 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 &
- nvidia-smi
- sleep 5
- python3.6 test_client.py 000000570688.jpg
- nvidia-smi
- # check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/yolov4_rpc_gpu/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/yolov4
+ data_dir=${data}yolov4/
+ link_data ${data_dir}
+ sed -i "s/9393/8887/g" test_client.py
+ python3.6 -m paddle_serving_server.serve --model yolov4_model --port 8887 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
+ nvidia-smi
+ check_result server 5
+ python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "yolov4_GPU_RPC server test completed"
+ kill_server_process
}
function senta_rpc_cpu() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/senta
- sed -i "s/9393/8887/g" test_client.py
- cp -r /data/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./
- tar xf yolov4.tar.gz
- python3.6 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 &
- nvidia-smi
- sleep 5
- python3.6 test_client.py 000000570688.jpg
- nvidia-smi
- check_result $FUNCNAME
- kill_server_process serving
+ dir=${log_dir}rpc_model/senta_rpc_cpu/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/senta
+ data_dir=${data}senta/
+ link_data ${data_dir}
+ sed -i "s/9393/8887/g" test_client.py
+ python3.6 -m paddle_serving_server.serve --model yolov4_model --port 8887 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
+ nvidia-smi
+ check_result server 5
+ python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
+ nvidia-smi
+ check_result client "senta_GPU_RPC server test completed"
+ kill_server_process
}
function fit_a_line_http() {
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/fit_a_line
- sed -i "s/9292/8871/g" test_server.py
- python3.6 test_server.py &
- sleep 10
- curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://${host}:8871/uci/prediction
- check_result $FUNCNAME
- kill_server_process test_server
+ dir=${log_dir}http_model/fit_a_line_http/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/fit_a_line
+ sed -i "s/9393/8871/g" test_server.py
+ python3.6 test_server.py > ${dir}server_log.txt 2>&1 &
+ check_result server 10
+ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://127.0.0.1:8871/uci/prediction > ${dir}client_log.txt 2>&1
+ check_result client "fit_a_line_CPU_HTTP server test completed"
+ kill_server_process
}
function lac_http() {
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/lac
- python3.6 lac_web_service.py lac_model/ lac_workdir 8872 &
- sleep 10
- curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' http://${host}:8872/lac/prediction
- check_result $FUNCNAME
- kill_server_process lac_web_service
+ dir=${log_dir}http_model/lac_http/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/lac
+ python3.6 lac_web_service.py lac_model/ lac_workdir 8872 > ${dir}server_log.txt 2>&1 &
+ check_result server 10
+ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' http://127.0.0.1:8872/lac/prediction > ${dir}client_log.txt 2>&1
+ check_result client "lac_CPU_HTTP server test completed"
+ kill_server_process
}
function cnn_http() {
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/imdb
- python3.6 text_classify_service.py imdb_cnn_model/ workdir/ 8873 imdb.vocab &
- sleep 10
- curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8873/imdb/prediction
- check_result $FUNCNAME
- kill_server_process text_classify_service
+ dir=${log_dir}http_model/cnn_http/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imdb
+ python3.6 text_classify_service.py imdb_cnn_model/ workdir/ 8873 imdb.vocab > ${dir}server_log.txt 2>&1 &
+ check_result server 10
+ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://127.0.0.1:8873/imdb/prediction > ${dir}client_log.txt 2>&1
+ check_result client "cnn_CPU_HTTP server test completed"
+ kill_server_process
}
function bow_http() {
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/imdb
- python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8874 imdb.vocab &
- sleep 10
- curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8874/imdb/prediction
- check_result $FUNCNAME
- kill_server_process text_classify_service
+ dir=${log_dir}http_model/bow_http/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imdb
+ python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8874 imdb.vocab > ${dir}server_log.txt 2>&1 &
+ check_result server 10
+ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://127.0.0.1:8874/imdb/prediction > ${dir}client_log.txt 2>&1
+ check_result client "bow_CPU_HTTP server test completed"
+ kill_server_process
}
function lstm_http() {
- unsetproxy
- run_cpu_env
- cd ${build_path}/python/examples/imdb
- python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8875 imdb.vocab &
- sleep 10
- curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8875/imdb/prediction
- check_result $FUNCNAME
- kill `ps -ef|grep imdb|awk '{print $2}'`
- kill_server_process text_classify_service
+ dir=${log_dir}http_model/lstm_http/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imdb
+ python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8875 imdb.vocab > ${dir}server_log.txt 2>&1 &
+ check_result server 10
+ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://127.0.0.1:8875/imdb/prediction > ${dir}client_log.txt 2>&1
+ check_result client "lstm_CPU_HTTP server test completed"
+ kill_server_process
}
function ResNet50_http() {
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/imagenet
- python3.6 resnet50_web_service.py ResNet50_vd_model gpu 8876 &
- sleep 10
- curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://${host}:8876/image/prediction
- check_result $FUNCNAME
- kill_server_process resnet50_web_service
-}
-
-function bert_http(){
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/bert
- cp data-c.txt.1 data-c.txt
- cp vocab.txt.1 vocab.txt
- export CUDA_VISIBLE_DEVICES=0
- python3.6 bert_web_service.py bert_seq128_model/ 8878 &
- sleep 5
- curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}' http://127.0.0.1:8878/bert/prediction
- check_result $FUNCNAME
- kill_server_process bert_web_service
-}
-
-grpc_impl(){
- unsetproxy
- run_gpu_env
- cd ${build_path}/python/examples/grpc_impl_example/fit_a_line
- sh get_data.sh >/dev/null 2>&1
- python3.6 test_server.py uci_housing_model/ &
- sleep 5
- echo "sync predict"
- python3.6 test_sync_client.py
- echo "async predict"
- python3.6 test_asyn_client.py
- echo "batch predict"
- python3.6 test_batch_client.py
- echo "timeout predict"
- python3.6 test_timeout_client.py
-# check_result $FUNCNAME
- kill_server_process test_server
-}
-
-function build_all_whl(){
- for whl in ${build_whl_list[@]}
- do
- echo "===========${whl} begin build==========="
- $whl
- sleep 3
- echo "===========${whl} build over ==========="
- done
-}
-
-function run_rpc_models(){
- for model in ${rpc_model_list[@]}
- do
- echo "===========${model} run begin==========="
- $model
- sleep 3
- echo "===========${model} run end ==========="
- done
-}
-
-function run_http_models(){
- for model in ${http_model_list[@]}
- do
- echo "===========${model} run begin==========="
- $model
- sleep 3
- echo "===========${model} run end ==========="
- done
-}
-
-function end_hook(){
- cd ${build_path}
- kill_server_process
- kill `ps -ef|grep python|awk '{print $2}'`
- sleep 5
- echo "===========files==========="
- ls -hlst
- echo "=========== end ==========="
-
+ dir=${log_dir}http_model/ResNet50_http/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/imagenet
+ python3.6 resnet50_web_service.py ResNet50_vd_model gpu 8876 > ${dir}server_log.txt 2>&1 &
+ check_result server 10
+ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://127.0.0.1:8876/image/prediction > ${dir}client_log.txt 2>&1
+ check_result client "ResNet50_GPU_HTTP server test completed"
+ kill_server_process
+}
+
+function bert_http() {
+ dir=${log_dir}http_model/ResNet50_http/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/bert
+ cp data-c.txt.1 data-c.txt
+ cp vocab.txt.1 vocab.txt
+ export CUDA_VISIBLE_DEVICES=0
+ python3.6 bert_web_service.py bert_seq128_model/ 8878 > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}' http://127.0.0.1:8878/bert/prediction > ${dir}client_log.txt 2>&1
+ check_result client "bert_GPU_HTTP server test completed"
+ kill_server_process
+}
+
+function grpc_fit_a_line() {
+ dir=${log_dir}rpc_model/grpc_fit_a_line/
+ check_dir ${dir}
+ unsetproxy
+ cd ${build_path}/python/examples/grpc_impl_example/fit_a_line
+ data_dir=${data}fit_a_line/
+ link_data ${data_dir}
+ python3.6 test_server.py uci_housing_model/ > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ echo "sync predict" > ${dir}client_log.txt 2>&1
+ python3.6 test_sync_client.py >> ${dir}client_log.txt 2>&1
+ check_result client "grpc_impl_example_fit_a_line_sync_CPU_gRPC server sync test completed"
+ echo "async predict" >> ${dir}client_log.txt 2>&1
+ python3.6 test_asyn_client.py >> ${dir}client_log.txt 2>&1
+ check_result client "grpc_impl_example_fit_a_line_asyn_CPU_gRPC server asyn test completed"
+ echo "batch predict" >> ${dir}client_log.txt 2>&1
+ python3.6 test_batch_client.py >> ${dir}client_log.txt 2>&1
+ check_result client "grpc_impl_example_fit_a_line_batch_CPU_gRPC server batch test completed"
+ echo "timeout predict" >> ${dir}client_log.txt 2>&1
+ python3.6 test_timeout_client.py >> ${dir}client_log.txt 2>&1
+ check_result client "grpc_impl_example_fit_a_line_timeout_CPU_gRPC server timeout test completed"
+ kill_server_process
+}
+
+function grpc_yolov4() {
+ dir=${log_dir}rpc_model/grpc_yolov4/
+ cd ${build_path}/python/examples/grpc_impl_example/yolov4
+ check_dir ${dir}
+ data_dir=${data}yolov4/
+ link_data ${data_dir}
+ echo -e "${GREEN_COLOR}grpc_impl_example_yolov4_GPU_gRPC server started${RES}"
+ python3.6 -m paddle_serving_server.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang > ${dir}server_log.txt 2>&1 &
+ check_result server 5
+ echo -e "${GREEN_COLOR}grpc_impl_example_yolov4_GPU_gRPC client started${RES}"
+ python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1
+ check_result client "grpc_yolov4_GPU_GRPC server test completed"
+ kill_server_process
+}
+
+function build_all_whl() {
+ for whl in ${build_whl_list[@]}
+ do
+ echo "===========${whl} begin build==========="
+ $whl
+ sleep 3
+ echo "===========${whl} build over ==========="
+ done
+}
+
+function run_rpc_models() {
+ for model in ${rpc_model_list[@]}
+ do
+ echo "===========${model} run begin==========="
+ $model
+ sleep 3
+ echo "===========${model} run end ==========="
+ done
+}
+
+function run_http_models() {
+ for model in ${http_model_list[@]}
+ do
+ echo "===========${model} run begin==========="
+ $model
+ sleep 3
+ echo "===========${model} run end ==========="
+ done
+}
+
+function end_hook() {
+ cd ${build_path}
+ kill_server_process
+ kill `ps -ef|grep python|awk '{print $2}'`
+ sleep 5
+ echo "===========files==========="
+ ls -hlst
+ echo "=========== end ==========="
}
function main() {
- before_hook
- build_all_whl
- check
- run_env
- run_rpc_models
-# run_http_models
- end_hook
+ before_hook
+ build_all_whl
+ check
+ run_env
+ unsetproxy
+ run_gpu_env
+ check_dir ${log_dir}rpc_model/
+ check_dir ${log_dir}http_model/
+ check_dir ${log_dir}error/
+ run_rpc_models
+ run_http_models
+ end_hook
+ if [ -f ${log_dir}error_models.txt ]; then
+ cat ${log_dir}error_models.txt
+ echo "error occurred!"
+ # exit 1
+ fi
}
-
main$@