diff --git a/core/general-server/op/general_detection_op.cpp b/core/general-server/op/general_detection_op.cpp index f02465e0a70ce5ee86f71f8c194df34e545269df..7c33ec8efa8c6e89a7a778def6342415d19ffa94 100755 --- a/core/general-server/op/general_detection_op.cpp +++ b/core/general-server/op/general_detection_op.cpp @@ -22,7 +22,6 @@ #include "core/predictor/framework/resource.h" #include "core/util/include/timer.h" - /* #include "opencv2/imgcodecs/legacy/constants_c.h" #include "opencv2/imgproc/types_c.h" @@ -52,18 +51,18 @@ int GeneralDetectionOp::inference() { } const std::string pre_name = pre_node_names[0]; - const GeneralBlob *input_blob = get_depend_argument(pre_name); + const GeneralBlob* input_blob = get_depend_argument(pre_name); if (!input_blob) { LOG(ERROR) << "input_blob is nullptr,error"; - return -1; + return -1; } uint64_t log_id = input_blob->GetLogId(); VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; - GeneralBlob *output_blob = mutable_data(); + GeneralBlob* output_blob = mutable_data(); if (!output_blob) { LOG(ERROR) << "output_blob is nullptr,error"; - return -1; + return -1; } output_blob->SetLogId(log_id); @@ -73,7 +72,7 @@ int GeneralDetectionOp::inference() { return -1; } - const TensorVector *in = &input_blob->tensor_vector; + const TensorVector* in = &input_blob->tensor_vector; TensorVector* out = &output_blob->tensor_vector; int batch_size = input_blob->_batch_size; @@ -81,38 +80,39 @@ int GeneralDetectionOp::inference() { output_blob->_batch_size = batch_size; - VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size; - std::vector input_shape; - int in_num =0; + int in_num = 0; void* databuf_data = NULL; char* databuf_char = NULL; size_t databuf_size = 0; + // now only support single string + char* total_input_ptr = static_cast(in->at(0).data.data()); + std::string base64str = total_input_ptr; - std::string* input_ptr = static_cast(in->at(0).data.data()); - std::string base64str = input_ptr[0]; float ratio_h{}; float ratio_w{}; - cv::Mat img = Base2Mat(base64str); cv::Mat srcimg; cv::Mat resize_img; - + cv::Mat resize_img_rec; cv::Mat crop_img; img.copyTo(srcimg); - this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w, + this->resize_op_.Run(img, + resize_img, + this->max_side_len_, + ratio_h, + ratio_w, this->use_tensorrt_); - this->normalize_op_.Run(&resize_img, this->mean_det, this->scale_det, - this->is_scale_); + this->normalize_op_.Run( + &resize_img, this->mean_det, this->scale_det, this->is_scale_); std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); this->permute_op_.Run(&resize_img, input.data()); - TensorVector* real_in = new TensorVector(); if (!real_in) { LOG(ERROR) << "real_in is nullptr,error"; @@ -121,14 +121,15 @@ int GeneralDetectionOp::inference() { for (int i = 0; i < in->size(); ++i) { input_shape = {1, 3, resize_img.rows, resize_img.cols}; - in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); - databuf_size = in_num*sizeof(float); + in_num = std::accumulate( + input_shape.begin(), input_shape.end(), 1, std::multiplies()); + databuf_size = in_num * sizeof(float); databuf_data = MempoolWrapper::instance().malloc(databuf_size); if (!databuf_data) { - LOG(ERROR) << "Malloc failed, size: " << databuf_size; - return -1; + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; } - memcpy(databuf_data,input.data(),databuf_size); + memcpy(databuf_data, input.data(), databuf_size); databuf_char = reinterpret_cast(databuf_data); paddle::PaddleBuf paddleBuf(databuf_char, databuf_size); paddle::PaddleTensor tensor_in; @@ -143,21 +144,23 @@ int GeneralDetectionOp::inference() { Timer timeline; int64_t start = timeline.TimeStampUS(); timeline.Start(); - + if (InferManager::instance().infer( engine_name().c_str(), real_in, out, batch_size)) { LOG(ERROR) << "(logid=" << log_id << ") Failed do infer in fluid model: " << engine_name().c_str(); return -1; } + delete real_in; + std::vector output_shape; - int out_num =0; + int out_num = 0; void* databuf_data_out = NULL; char* databuf_char_out = NULL; size_t databuf_size_out = 0; - //this is special add for PaddleOCR postprecess - int infer_outnum = out->size(); - for (int k = 0;k size(); + for (int k = 0; k < infer_outnum; ++k) { int n2 = out->at(k).shape[2]; int n3 = out->at(k).shape[3]; int n = n2 * n3; @@ -171,17 +174,19 @@ int GeneralDetectionOp::inference() { cbuf[i] = (unsigned char)((out_data[i]) * 255); } - cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data()); - cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data()); + cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data()); + cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data()); const double threshold = this->det_db_thresh_ * 255; const double maxvalue = 255; cv::Mat bit_map; cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); cv::Mat dilation_map; - cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); + cv::Mat dila_ele = + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); cv::dilate(bit_map, dilation_map, dila_ele); - boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map, + boxes = post_processor_.BoxesFromBitmap(pred_map, + dilation_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_); @@ -192,25 +197,28 @@ int GeneralDetectionOp::inference() { float wh_ratio = float(crop_img.cols) / float(crop_img.rows); - this->resize_op_rec.Run(crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_); + this->resize_op_rec.Run( + crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_); - this->normalize_op_.Run(&resize_img_rec, this->mean_rec, this->scale_rec, - this->is_scale_); + this->normalize_op_.Run( + &resize_img_rec, this->mean_rec, this->scale_rec, this->is_scale_); - std::vector output_rec(1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f); + std::vector output_rec( + 1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f); this->permute_op_.Run(&resize_img_rec, output_rec.data()); // Inference. output_shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols}; - out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - databuf_size_out = out_num*sizeof(float); + out_num = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + databuf_size_out = out_num * sizeof(float); databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out); if (!databuf_data_out) { - LOG(ERROR) << "Malloc failed, size: " << databuf_size_out; - return -1; + LOG(ERROR) << "Malloc failed, size: " << databuf_size_out; + return -1; } - memcpy(databuf_data_out,output_rec.data(),databuf_size_out); + memcpy(databuf_data_out, output_rec.data(), databuf_size_out); databuf_char_out = reinterpret_cast(databuf_data_out); paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out); paddle::PaddleTensor tensor_out; @@ -221,9 +229,8 @@ int GeneralDetectionOp::inference() { out->push_back(tensor_out); } } - out->erase(out->begin(),out->begin()+infer_outnum); + out->erase(out->begin(), out->begin() + infer_outnum); - int64_t end = timeline.TimeStampUS(); CopyBlobInfo(input_blob, output_blob); AddBlobInfo(output_blob, start); @@ -231,68 +238,62 @@ int GeneralDetectionOp::inference() { return 0; } -cv::Mat GeneralDetectionOp::Base2Mat(std::string &base64_data) -{ - cv::Mat img; - std::string s_mat; - s_mat = base64Decode(base64_data.data(), base64_data.size()); - std::vector base64_img(s_mat.begin(), s_mat.end()); - img = cv::imdecode(base64_img, cv::IMREAD_COLOR);//CV_LOAD_IMAGE_COLOR - return img; +cv::Mat GeneralDetectionOp::Base2Mat(std::string& base64_data) { + cv::Mat img; + std::string s_mat; + s_mat = base64Decode(base64_data.data(), base64_data.size()); + std::vector base64_img(s_mat.begin(), s_mat.end()); + img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR + return img; } -std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte) -{ - - const char DecodeTable[] = - { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, - 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, - 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' - }; - - std::string strDecode; - int nValue; - int i = 0; - while (i < DataByte) - { - if (*Data != '\r' && *Data != '\n') - { - nValue = DecodeTable[*Data++] << 18; - nValue += DecodeTable[*Data++] << 12; - strDecode += (nValue & 0x00FF0000) >> 16; - if (*Data != '=') - { - nValue += DecodeTable[*Data++] << 6; - strDecode += (nValue & 0x0000FF00) >> 8; - if (*Data != '=') - { - nValue += DecodeTable[*Data++]; - strDecode += nValue & 0x000000FF; - } - } - i += 4; - } - else// 回车换行,跳过 - { - Data++; - i++; - } - } - return strDecode; +std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte) { + const char + DecodeTable[] = + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, // 'a'-'z' + }; + + std::string strDecode; + int nValue; + int i = 0; + while (i < DataByte) { + if (*Data != '\r' && *Data != '\n') { + nValue = DecodeTable[*Data++] << 18; + nValue += DecodeTable[*Data++] << 12; + strDecode += (nValue & 0x00FF0000) >> 16; + if (*Data != '=') { + nValue += DecodeTable[*Data++] << 6; + strDecode += (nValue & 0x0000FF00) >> 8; + if (*Data != '=') { + nValue += DecodeTable[*Data++]; + strDecode += nValue & 0x000000FF; + } + } + i += 4; + } else // 回车换行,跳过 + { + Data++; + i++; + } + } + return strDecode; } -cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box) { +cv::Mat GeneralDetectionOp::GetRotateCropImage( + const cv::Mat& srcimage, std::vector> box) { cv::Mat image; srcimage.copyTo(image); std::vector> points = box; @@ -332,7 +333,9 @@ cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage, cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); cv::Mat dst_img; - cv::warpPerspective(img_crop, dst_img, M, + cv::warpPerspective(img_crop, + dst_img, + M, cv::Size(img_crop_width, img_crop_height), cv::BORDER_REPLICATE); @@ -350,4 +353,4 @@ DEFINE_OP(GeneralDetectionOp); } // namespace serving } // namespace paddle_serving -} // namespace baidu \ No newline at end of file +} // namespace baidu diff --git a/core/general-server/op/general_reader_op.cpp b/core/general-server/op/general_reader_op.cpp index 4b4e25cb075f56449359f7af0c064fcb83c2dd07..3e1091dd844f0afd71c8556586f82aafc42c5097 100644 --- a/core/general-server/op/general_reader_op.cpp +++ b/core/general-server/op/general_reader_op.cpp @@ -77,9 +77,6 @@ int GeneralReaderOp::inference() { uint64_t log_id = req->log_id(); int input_var_num = 0; - std::vector elem_type; - std::vector elem_size; - std::vector databuf_size; GeneralBlob *res = mutable_data(); if (!res) { @@ -119,40 +116,44 @@ int GeneralReaderOp::inference() { } */ // package tensor - - elem_type.resize(var_num); - elem_size.resize(var_num); - databuf_size.resize(var_num); // prepare basic information for input // specify the memory needed for output tensor_vector // fill the data into output general_blob int data_len = 0; + int64_t elem_type = 0; + int64_t elem_size = 0; + int64_t databuf_size = 0; for (int i = 0; i < var_num; ++i) { - paddle::PaddleTensor lod_tensor; + paddle::PaddleTensor paddleTensor; const Tensor &tensor = req->insts(0).tensor_array(i); data_len = 0; - elem_type[i] = tensor.elem_type(); - VLOG(2) << "var[" << i << "] has elem type: " << elem_type[i]; - if (elem_type[i] == P_INT64) { // int64 - elem_size[i] = sizeof(int64_t); - lod_tensor.dtype = paddle::PaddleDType::INT64; + elem_type = 0; + elem_size = 0; + databuf_size = 0; + elem_type = tensor.elem_type(); + VLOG(2) << "var[" << i << "] has elem type: " << elem_type; + if (elem_type == P_INT64) { // int64 + elem_size = sizeof(int64_t); + paddleTensor.dtype = paddle::PaddleDType::INT64; data_len = tensor.int64_data_size(); - } else if (elem_type[i] == P_FLOAT32) { - elem_size[i] = sizeof(float); - lod_tensor.dtype = paddle::PaddleDType::FLOAT32; + } else if (elem_type == P_FLOAT32) { + elem_size = sizeof(float); + paddleTensor.dtype = paddle::PaddleDType::FLOAT32; data_len = tensor.float_data_size(); - } else if (elem_type[i] == P_INT32) { - elem_size[i] = sizeof(int32_t); - lod_tensor.dtype = paddle::PaddleDType::INT32; + } else if (elem_type == P_INT32) { + elem_size = sizeof(int32_t); + paddleTensor.dtype = paddle::PaddleDType::INT32; data_len = tensor.int_data_size(); - } else if (elem_type[i] == P_STRING) { + } else if (elem_type == P_STRING) { // use paddle::PaddleDType::UINT8 as for String. - elem_size[i] = sizeof(uint8_t); - lod_tensor.dtype = paddle::PaddleDType::UINT8; + elem_size = sizeof(char); + paddleTensor.dtype = paddle::PaddleDType::UINT8; // this is for vector, cause the databuf_size != // vector.size()*sizeof(char); + // data_len should be +1 cause '\0' + // now only support single string for (int idx = 0; idx < tensor.data_size(); idx++) { - data_len += tensor.data()[idx].length(); + data_len += tensor.data()[idx].length() + 1; } } // implement lod tensor here @@ -160,29 +161,29 @@ int GeneralReaderOp::inference() { // TODO(HexToString): support 2-D lod if (tensor.lod_size() > 0) { VLOG(2) << "(logid=" << log_id << ") var[" << i << "] is lod_tensor"; - lod_tensor.lod.resize(1); + paddleTensor.lod.resize(1); for (int k = 0; k < tensor.lod_size(); ++k) { - lod_tensor.lod[0].push_back(tensor.lod(k)); + paddleTensor.lod[0].push_back(tensor.lod(k)); } } for (int k = 0; k < tensor.shape_size(); ++k) { int dim = tensor.shape(k); VLOG(2) << "(logid=" << log_id << ") shape for var[" << i << "]: " << dim; - lod_tensor.shape.push_back(dim); + paddleTensor.shape.push_back(dim); } - lod_tensor.name = model_config->_feed_name[i]; - out->push_back(lod_tensor); + paddleTensor.name = model_config->_feed_name[i]; + out->push_back(paddleTensor); VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i << "]: " << data_len; - databuf_size[i] = data_len * elem_size[i]; - out->at(i).data.Resize(data_len * elem_size[i]); + databuf_size = data_len * elem_size; + out->at(i).data.Resize(databuf_size); if (out->at(i).lod.size() > 0) { VLOG(2) << "(logid=" << log_id << ") var[" << i << "] has lod_tensor and len=" << out->at(i).lod[0].back(); } - if (elem_type[i] == P_INT64) { + if (elem_type == P_INT64) { int64_t *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.int64_data(0); @@ -190,14 +191,14 @@ int GeneralReaderOp::inference() { LOG(ERROR) << "dst_ptr is nullptr"; return -1; } - memcpy(dst_ptr, tensor.int64_data().data(), databuf_size[i]); + memcpy(dst_ptr, tensor.int64_data().data(), databuf_size); /* int elem_num = tensor.int64_data_size(); for (int k = 0; k < elem_num; ++k) { dst_ptr[k] = tensor.int64_data(k); } */ - } else if (elem_type[i] == P_FLOAT32) { + } else if (elem_type == P_FLOAT32) { float *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.float_data(0); @@ -205,12 +206,12 @@ int GeneralReaderOp::inference() { LOG(ERROR) << "dst_ptr is nullptr"; return -1; } - memcpy(dst_ptr, tensor.float_data().data(), databuf_size[i]); + memcpy(dst_ptr, tensor.float_data().data(), databuf_size); /*int elem_num = tensor.float_data_size(); for (int k = 0; k < elem_num; ++k) { dst_ptr[k] = tensor.float_data(k); }*/ - } else if (elem_type[i] == P_INT32) { + } else if (elem_type == P_INT32) { int32_t *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.int_data(0); @@ -218,15 +219,9 @@ int GeneralReaderOp::inference() { LOG(ERROR) << "dst_ptr is nullptr"; return -1; } - memcpy(dst_ptr, tensor.int_data().data(), databuf_size[i]); - /* - int elem_num = tensor.int_data_size(); - for (int k = 0; k < elem_num; ++k) { - dst_ptr[k] = tensor.int_data(k); - } - */ - } else if (elem_type[i] == P_STRING) { - std::string *dst_ptr = static_cast(out->at(i).data.data()); + memcpy(dst_ptr, tensor.int_data().data(), databuf_size); + } else if (elem_type == P_STRING) { + char *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.data(0); if (!dst_ptr) { @@ -234,8 +229,12 @@ int GeneralReaderOp::inference() { return -1; } int elem_num = tensor.data_size(); + int offset = 0; for (int k = 0; k < elem_num; ++k) { - dst_ptr[k] = tensor.data(k); + memcpy(dst_ptr + offset, + tensor.data(k).c_str(), + strlen(tensor.data(k).c_str()) + 1); + offset += strlen(tensor.data(k).c_str()) + 1; } } } diff --git a/doc/LOD.md b/doc/LOD.md new file mode 100644 index 0000000000000000000000000000000000000000..4e20c495334e1c8609c3a3d480b2540ae9811ad4 --- /dev/null +++ b/doc/LOD.md @@ -0,0 +1,32 @@ +# Lod Introduction + +(English|[简体中文](LOD_CN.md)) + +## Principle + +LoD(Level-of-Detail) Tensor is an advanced feature of paddle and an extension of tensor. LoD Tensor improves training efficiency by sacrificing flexibility. + +**Notice:** For most users, there is no need to pay attention to the usage of LoD Tensor. Currently, serving only supports the usage of one-dimensional LOD. + + +## Use + +**Prerequisite:** Your prediction model needs to support variable length tensor input. + + +Take the visual task as an example. In the visual task, we often need to process video and image. These elements are high-dimensional objects. +Suppose that an existing Mini batch contains three videos, each video contains three frames, one frame and two frames respectively. +If each frame has the same size: 640x480, the mini batch can be expressed as: +``` +3 1 2 +口口口 口 口口 +``` +The size of the bottom tenor is (3 + 1 + 2) x640x480, and each 口 represents a 640x480 image. + +Then, the shape of tensor is [6,640,480], lod=[0,3,4,6]. + +Where 0 is the starting value and 3-0 = 3; 4-3=1; 6-4 = 2, these three values just represent your variable length information. + +The last element 6 in LOD should be equal to the total length of the first dimension in shape. + +The variable length information recorded in LOD and the first dimension information of shape in tensor should be aligned in the above way. diff --git a/doc/LOD_CN.md b/doc/LOD_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..ff04bd3d7e6fcda2e30e48dd10c985b97b461685 --- /dev/null +++ b/doc/LOD_CN.md @@ -0,0 +1,29 @@ +# Lod字段说明 + +(简体中文|[English](LOD.md)) + +## 概念 + +LoD(Level-of-Detail) Tensor是Paddle的高级特性,是对Tensor的一种扩充。LoDTensor通过牺牲灵活性来提升训练的效率。 + +**注:** 对于大部分用户来说,无需关注LoDTensor的用法,目前Serving中仅支持一维Lod的用法。 + + +## 使用 + +**前提:** 首先您的预测模型需要支持变长Tensor的输入。 + + +以视觉任务为例,在视觉任务中,时常需要处理视频和图像这些元素是高维的对象,假设现存的一个mini-batch包含3个视频,分别有3个,1个和2个帧。 +每个帧都具有相同大小:640x480,则这个mini-batch可以被表示为: +``` +3 1 2 +口口口 口 口口 +``` +最底层tensor大小为(3+1+2)x640x480,每一个 口 表示一个640x480的图像。 + +那么此时,Tensor的shape为[6,640,480],lod=[0,3,4,6]. + +其中0为起始值,3-0=3;4-3=1;6-4=2,这三个值正好表示您的变长信息,lod中的最后一个元素6,应等于shape中第一维度的总长度。 + +lod中记录的变长信息与Tensor中shape的第一维度的信息应按照上述方式对齐。 diff --git a/doc/LOW_PRECISION_DEPLOYMENT.md b/doc/LOW_PRECISION_DEPLOYMENT.md index cb08a88f2f3b2435f3b270575652217b1d956fbf..fb3bd208f2f52399afff1f96228543685f3cf389 100644 --- a/doc/LOW_PRECISION_DEPLOYMENT.md +++ b/doc/LOW_PRECISION_DEPLOYMENT.md @@ -17,7 +17,7 @@ python -m paddle_serving_client.convert --dirname ResNet50_quant ``` Start RPC service, specify the GPU id and precision mode ``` -python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 --use_gpu --use_trt --precision int8 +python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 --use_trt --precision int8 ``` Request the serving service with Client ``` @@ -27,7 +27,7 @@ from paddle_serving_app.reader import RGB2BGR, Transpose, Div, Normalize client = Client() client.load_client_config( - "resnet_v2_50_imagenet_client/serving_client_conf.prototxt") + "serving_client/serving_client_conf.prototxt") client.connect(["127.0.0.1:9393"]) seq = Sequential([ @@ -37,11 +37,11 @@ seq = Sequential([ image_file = "daisy.jpg" img = seq(image_file) -fetch_map = client.predict(feed={"image": img}, fetch=["score"]) -print(fetch_map["score"].reshape(-1)) +fetch_map = client.predict(feed={"image": img}, fetch=["save_infer_model/scale_0.tmp_0"]) +print(fetch_map["save_infer_model/scale_0.tmp_0"].reshape(-1)) ``` ## Reference * [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) * [Deploy the quantized model Using Paddle Inference on Intel CPU](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_x86_cpu_int8.html) -* [Deploy the quantized model Using Paddle Inference on Nvidia GPU](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html) \ No newline at end of file +* [Deploy the quantized model Using Paddle Inference on Nvidia GPU](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html) diff --git a/doc/LOW_PRECISION_DEPLOYMENT_CN.md b/doc/LOW_PRECISION_DEPLOYMENT_CN.md index e543db94396eecbe64a61d7a9362369d02ab42de..f77f4e241f3f4b95574d22b9ca55788b5abc968e 100644 --- a/doc/LOW_PRECISION_DEPLOYMENT_CN.md +++ b/doc/LOW_PRECISION_DEPLOYMENT_CN.md @@ -16,7 +16,7 @@ python -m paddle_serving_client.convert --dirname ResNet50_quant ``` 启动rpc服务, 设定所选GPU id、部署模型精度 ``` -python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 --use_gpu --use_trt --precision int8 +python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 --use_trt --precision int8 ``` 使用client进行请求 ``` @@ -43,4 +43,4 @@ print(fetch_map["score"].reshape(-1)) ## 参考文档 * [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) * PaddleInference Intel CPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_x86_cpu_int8.html) -* PaddleInference NV GPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html) \ No newline at end of file +* PaddleInference NV GPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html) diff --git a/python/examples/low_precision/resnet50/README.md b/python/examples/low_precision/resnet50/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9e1ff16c676b067437183e6e19446e8a526feed5 --- /dev/null +++ b/python/examples/low_precision/resnet50/README.md @@ -0,0 +1,28 @@ +# resnet50 int8 example +(English|[简体中文](./README_CN.md)) + +## Obtain the quantized model through PaddleSlim tool +Train the low-precision models please refer to [PaddleSlim](https://paddleslim.readthedocs.io/zh_CN/latest/tutorials/quant/overview.html). + +## Deploy the quantized model from PaddleSlim using Paddle Serving with Nvidia TensorRT int8 mode + +Firstly, download the [Resnet50 int8 model](https://paddle-inference-dist.bj.bcebos.com/inference_demo/python/resnet50/ResNet50_quant.tar.gz) and convert to Paddle Serving's saved model。 +``` +wget https://paddle-inference-dist.bj.bcebos.com/inference_demo/python/resnet50/ResNet50_quant.tar.gz +tar zxvf ResNet50_quant.tar.gz + +python -m paddle_serving_client.convert --dirname ResNet50_quant +``` +Start RPC service, specify the GPU id and precision mode +``` +python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 --use_trt --precision int8 +``` +Request the serving service with Client +``` +python resnet50_client.py +``` + +## Reference +* [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) +* [Deploy the quantized model Using Paddle Inference on Intel CPU](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_x86_cpu_int8.html) +* [Deploy the quantized model Using Paddle Inference on Nvidia GPU](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html) diff --git a/python/examples/low_precision/resnet50/README_CN.md b/python/examples/low_precision/resnet50/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..1c1a3be1de1690e9736d994016ac05cfba12bcab --- /dev/null +++ b/python/examples/low_precision/resnet50/README_CN.md @@ -0,0 +1,27 @@ +# resnet50 int8示例 +(简体中文|[English](./README.md)) + +## 通过PaddleSlim量化生成低精度模型 +详细见[PaddleSlim量化](https://paddleslim.readthedocs.io/zh_CN/latest/tutorials/quant/overview.html) + +## 使用TensorRT int8加载PaddleSlim Int8量化模型进行部署 +首先下载Resnet50 [PaddleSlim量化模型](https://paddle-inference-dist.bj.bcebos.com/inference_demo/python/resnet50/ResNet50_quant.tar.gz),并转换为Paddle Serving支持的部署模型格式。 +``` +wget https://paddle-inference-dist.bj.bcebos.com/inference_demo/python/resnet50/ResNet50_quant.tar.gz +tar zxvf ResNet50_quant.tar.gz + +python -m paddle_serving_client.convert --dirname ResNet50_quant +``` +启动rpc服务, 设定所选GPU id、部署模型精度 +``` +python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 --use_trt --precision int8 +``` +使用client进行请求 +``` +python resnet50_client.py +``` + +## 参考文档 +* [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) +* PaddleInference Intel CPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_x86_cpu_int8.html) +* PaddleInference NV GPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html) diff --git a/python/examples/low_precision/resnet50/daisy.jpg b/python/examples/low_precision/resnet50/daisy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7edeca63e5f32e68550ef720d81f59df58a8eabc Binary files /dev/null and b/python/examples/low_precision/resnet50/daisy.jpg differ diff --git a/python/examples/low_precision/resnet50/resnet50_client.py b/python/examples/low_precision/resnet50/resnet50_client.py new file mode 100644 index 0000000000000000000000000000000000000000..999b143c8a9aaf42784cbe225a8417b86a054c64 --- /dev/null +++ b/python/examples/low_precision/resnet50/resnet50_client.py @@ -0,0 +1,32 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import Sequential, File2Image, Resize, CenterCrop +from paddle_serving_app.reader import RGB2BGR, Transpose, Div, Normalize + +client = Client() +client.load_client_config( + "serving_client/serving_client_conf.prototxt") +client.connect(["127.0.0.1:9303"]) + +seq = Sequential([ + File2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)), + Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True) +]) + +image_file = "daisy.jpg" +img = seq(image_file) +fetch_map = client.predict(feed={"image": img}, fetch=["save_infer_model/scale_0.tmp_0"]) +print(fetch_map["save_infer_model/scale_0.tmp_0"].reshape(-1)) diff --git a/python/paddle_serving_server/server.py b/python/paddle_serving_server/server.py index 079ccde87b9bfcdea7ac94781cc90284b0faf4ae..34bf66f9ba73709dd5dfe9c34158ac0fd9a2d4b9 100755 --- a/python/paddle_serving_server/server.py +++ b/python/paddle_serving_server/server.py @@ -386,8 +386,6 @@ class Server(object): return if not os.path.exists(self.server_path): - os.system("touch {}/{}.is_download".format(self.module_path, - folder_name)) print('Frist time run, downloading PaddleServing components ...') r = os.system('wget ' + bin_url + ' --no-check-certificate') @@ -403,9 +401,10 @@ class Server(object): tar = tarfile.open(tar_name) tar.extractall() tar.close() + open(download_flag, "a").close() except: - if os.path.exists(exe_path): - os.remove(exe_path) + if os.path.exists(self.server_path): + os.remove(self.server_path) raise SystemExit( 'Decompressing failed, please check your permission of {} or disk space left.' .format(self.module_path)) diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index 7ea8858d2b47c1c20226c4f53805f3aa2fd75643..0afa3872a82a3f140b811a3bb3a40f0f7bdd373a 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -56,7 +56,6 @@ class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer): resp = pipeline_service_pb2.Response() resp.err_no = channel.ChannelDataErrcode.NO_SERVICE.value resp.err_msg = "Failed to inference: Service name error." - resp.result = "" return resp resp = self._dag_executor.call(request) return resp diff --git a/tools/scripts/ipipe_py3.sh b/tools/scripts/ipipe_py3.sh index 9ae4012a245910b349d3d02471a79c038e494f73..d6b3193e720792cc81422011a69e3950ad888d4f 100644 --- a/tools/scripts/ipipe_py3.sh +++ b/tools/scripts/ipipe_py3.sh @@ -323,7 +323,7 @@ function bert_rpc_cpu() { link_data ${data_dir} sed -i 's/8860/8861/g' bert_client.py python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8861 > ${dir}server_log.txt 2>&1 & - check_result server 3 + check_result server 5 cp data-c.txt.1 data-c.txt head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1 check_result client "bert_CPU_RPC server test completed" @@ -338,7 +338,7 @@ function pipeline_imagenet() { data_dir=${data}imagenet/ link_data ${data_dir} python3.6 resnet50_web_service.py > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 nvidia-smi python3.6 pipeline_rpc_client.py > ${dir}client_log.txt 2>&1 check_result client "pipeline_imagenet_GPU_RPC server test completed" @@ -355,7 +355,7 @@ function ResNet50_rpc() { link_data ${data_dir} sed -i 's/9696/8863/g' resnet50_rpc_client.py python3.6 -m paddle_serving_server.serve --model ResNet50_vd_model --port 8863 --gpu_ids 0 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 nvidia-smi python3.6 resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1 check_result client "ResNet50_GPU_RPC server test completed" @@ -372,7 +372,7 @@ function ResNet101_rpc() { link_data ${data_dir} sed -i "22cclient.connect(['127.0.0.1:8864'])" image_rpc_client.py python3.6 -m paddle_serving_server.serve --model ResNet101_vd_model --port 8864 --gpu_ids 0 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 nvidia-smi python3.6 image_rpc_client.py ResNet101_vd_client_config/serving_client_conf.prototxt > ${dir}client_log.txt 2>&1 check_result client "ResNet101_GPU_RPC server test completed" @@ -482,7 +482,7 @@ function cascade_rcnn_rpc() { link_data ${data_dir} sed -i "s/9292/8879/g" test_client.py python3.6 -m paddle_serving_server.serve --model serving_server --port 8879 --gpu_ids 0 --thread 2 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 nvidia-smi python3.6 test_client.py > ${dir}client_log.txt 2>&1 nvidia-smi @@ -499,7 +499,7 @@ function deeplabv3_rpc() { link_data ${data_dir} sed -i "s/9494/8880/g" deeplabv3_client.py python3.6 -m paddle_serving_server.serve --model deeplabv3_server --gpu_ids 0 --port 8880 --thread 2 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 10 nvidia-smi python3.6 deeplabv3_client.py > ${dir}client_log.txt 2>&1 nvidia-smi @@ -516,7 +516,7 @@ function mobilenet_rpc() { tar xf mobilenet_v2_imagenet.tar.gz sed -i "s/9393/8881/g" mobilenet_tutorial.py python3.6 -m paddle_serving_server.serve --model mobilenet_v2_imagenet_model --gpu_ids 0 --port 8881 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 nvidia-smi python3.6 mobilenet_tutorial.py > ${dir}client_log.txt 2>&1 nvidia-smi @@ -533,7 +533,7 @@ function unet_rpc() { link_data ${data_dir} sed -i "s/9494/8882/g" seg_client.py python3.6 -m paddle_serving_server.serve --model unet_model --gpu_ids 0 --port 8882 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 nvidia-smi python3.6 seg_client.py > ${dir}client_log.txt 2>&1 nvidia-smi @@ -599,7 +599,7 @@ function criteo_ctr_rpc_gpu() { link_data ${data_dir} sed -i "s/8885/8886/g" test_client.py python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8886 --gpu_ids 0 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 nvidia-smi python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0 > ${dir}client_log.txt 2>&1 nvidia-smi @@ -617,7 +617,7 @@ function yolov4_rpc_gpu() { sed -i "s/9393/8887/g" test_client.py python3.6 -m paddle_serving_server.serve --model yolov4_model --port 8887 --gpu_ids 0 > ${dir}server_log.txt 2>&1 & nvidia-smi - check_result server 5 + check_result server 8 python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1 nvidia-smi check_result client "yolov4_GPU_RPC server test completed" @@ -634,7 +634,7 @@ function senta_rpc_cpu() { sed -i "s/9393/8887/g" test_client.py python3.6 -m paddle_serving_server.serve --model yolov4_model --port 8887 --gpu_ids 0 > ${dir}server_log.txt 2>&1 & nvidia-smi - check_result server 5 + check_result server 8 python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1 nvidia-smi check_result client "senta_GPU_RPC server test completed" @@ -724,7 +724,7 @@ function bert_http() { cp vocab.txt.1 vocab.txt export CUDA_VISIBLE_DEVICES=0 python3.6 bert_web_service.py bert_seq128_model/ 8878 > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 8 curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}' http://127.0.0.1:8878/bert/prediction > ${dir}client_log.txt 2>&1 check_result client "bert_GPU_HTTP server test completed" kill_server_process @@ -762,7 +762,7 @@ function grpc_yolov4() { link_data ${data_dir} echo -e "${GREEN_COLOR}grpc_impl_example_yolov4_GPU_gRPC server started${RES}" python3.6 -m paddle_serving_server.serve --model yolov4_model --port 9393 --gpu_ids 0 --use_multilang > ${dir}server_log.txt 2>&1 & - check_result server 5 + check_result server 10 echo -e "${GREEN_COLOR}grpc_impl_example_yolov4_GPU_gRPC client started${RES}" python3.6 test_client.py 000000570688.jpg > ${dir}client_log.txt 2>&1 check_result client "grpc_yolov4_GPU_GRPC server test completed"