diff --git a/core/general-server/op/general_detection_op.cpp b/core/general-server/op/general_detection_op.cpp index f02465e0a70ce5ee86f71f8c194df34e545269df..7c33ec8efa8c6e89a7a778def6342415d19ffa94 100755 --- a/core/general-server/op/general_detection_op.cpp +++ b/core/general-server/op/general_detection_op.cpp @@ -22,7 +22,6 @@ #include "core/predictor/framework/resource.h" #include "core/util/include/timer.h" - /* #include "opencv2/imgcodecs/legacy/constants_c.h" #include "opencv2/imgproc/types_c.h" @@ -52,18 +51,18 @@ int GeneralDetectionOp::inference() { } const std::string pre_name = pre_node_names[0]; - const GeneralBlob *input_blob = get_depend_argument(pre_name); + const GeneralBlob* input_blob = get_depend_argument(pre_name); if (!input_blob) { LOG(ERROR) << "input_blob is nullptr,error"; - return -1; + return -1; } uint64_t log_id = input_blob->GetLogId(); VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; - GeneralBlob *output_blob = mutable_data(); + GeneralBlob* output_blob = mutable_data(); if (!output_blob) { LOG(ERROR) << "output_blob is nullptr,error"; - return -1; + return -1; } output_blob->SetLogId(log_id); @@ -73,7 +72,7 @@ int GeneralDetectionOp::inference() { return -1; } - const TensorVector *in = &input_blob->tensor_vector; + const TensorVector* in = &input_blob->tensor_vector; TensorVector* out = &output_blob->tensor_vector; int batch_size = input_blob->_batch_size; @@ -81,38 +80,39 @@ int GeneralDetectionOp::inference() { output_blob->_batch_size = batch_size; - VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size; - std::vector input_shape; - int in_num =0; + int in_num = 0; void* databuf_data = NULL; char* databuf_char = NULL; size_t databuf_size = 0; + // now only support single string + char* total_input_ptr = static_cast(in->at(0).data.data()); + std::string base64str = total_input_ptr; - std::string* input_ptr = static_cast(in->at(0).data.data()); - std::string base64str = input_ptr[0]; float ratio_h{}; float ratio_w{}; - cv::Mat img = Base2Mat(base64str); cv::Mat srcimg; cv::Mat resize_img; - + cv::Mat resize_img_rec; cv::Mat crop_img; img.copyTo(srcimg); - this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w, + this->resize_op_.Run(img, + resize_img, + this->max_side_len_, + ratio_h, + ratio_w, this->use_tensorrt_); - this->normalize_op_.Run(&resize_img, this->mean_det, this->scale_det, - this->is_scale_); + this->normalize_op_.Run( + &resize_img, this->mean_det, this->scale_det, this->is_scale_); std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); this->permute_op_.Run(&resize_img, input.data()); - TensorVector* real_in = new TensorVector(); if (!real_in) { LOG(ERROR) << "real_in is nullptr,error"; @@ -121,14 +121,15 @@ int GeneralDetectionOp::inference() { for (int i = 0; i < in->size(); ++i) { input_shape = {1, 3, resize_img.rows, resize_img.cols}; - in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); - databuf_size = in_num*sizeof(float); + in_num = std::accumulate( + input_shape.begin(), input_shape.end(), 1, std::multiplies()); + databuf_size = in_num * sizeof(float); databuf_data = MempoolWrapper::instance().malloc(databuf_size); if (!databuf_data) { - LOG(ERROR) << "Malloc failed, size: " << databuf_size; - return -1; + LOG(ERROR) << "Malloc failed, size: " << databuf_size; + return -1; } - memcpy(databuf_data,input.data(),databuf_size); + memcpy(databuf_data, input.data(), databuf_size); databuf_char = reinterpret_cast(databuf_data); paddle::PaddleBuf paddleBuf(databuf_char, databuf_size); paddle::PaddleTensor tensor_in; @@ -143,21 +144,23 @@ int GeneralDetectionOp::inference() { Timer timeline; int64_t start = timeline.TimeStampUS(); timeline.Start(); - + if (InferManager::instance().infer( engine_name().c_str(), real_in, out, batch_size)) { LOG(ERROR) << "(logid=" << log_id << ") Failed do infer in fluid model: " << engine_name().c_str(); return -1; } + delete real_in; + std::vector output_shape; - int out_num =0; + int out_num = 0; void* databuf_data_out = NULL; char* databuf_char_out = NULL; size_t databuf_size_out = 0; - //this is special add for PaddleOCR postprecess - int infer_outnum = out->size(); - for (int k = 0;k size(); + for (int k = 0; k < infer_outnum; ++k) { int n2 = out->at(k).shape[2]; int n3 = out->at(k).shape[3]; int n = n2 * n3; @@ -171,17 +174,19 @@ int GeneralDetectionOp::inference() { cbuf[i] = (unsigned char)((out_data[i]) * 255); } - cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data()); - cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data()); + cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data()); + cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data()); const double threshold = this->det_db_thresh_ * 255; const double maxvalue = 255; cv::Mat bit_map; cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); cv::Mat dilation_map; - cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); + cv::Mat dila_ele = + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); cv::dilate(bit_map, dilation_map, dila_ele); - boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map, + boxes = post_processor_.BoxesFromBitmap(pred_map, + dilation_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_); @@ -192,25 +197,28 @@ int GeneralDetectionOp::inference() { float wh_ratio = float(crop_img.cols) / float(crop_img.rows); - this->resize_op_rec.Run(crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_); + this->resize_op_rec.Run( + crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_); - this->normalize_op_.Run(&resize_img_rec, this->mean_rec, this->scale_rec, - this->is_scale_); + this->normalize_op_.Run( + &resize_img_rec, this->mean_rec, this->scale_rec, this->is_scale_); - std::vector output_rec(1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f); + std::vector output_rec( + 1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f); this->permute_op_.Run(&resize_img_rec, output_rec.data()); // Inference. output_shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols}; - out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - databuf_size_out = out_num*sizeof(float); + out_num = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + databuf_size_out = out_num * sizeof(float); databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out); if (!databuf_data_out) { - LOG(ERROR) << "Malloc failed, size: " << databuf_size_out; - return -1; + LOG(ERROR) << "Malloc failed, size: " << databuf_size_out; + return -1; } - memcpy(databuf_data_out,output_rec.data(),databuf_size_out); + memcpy(databuf_data_out, output_rec.data(), databuf_size_out); databuf_char_out = reinterpret_cast(databuf_data_out); paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out); paddle::PaddleTensor tensor_out; @@ -221,9 +229,8 @@ int GeneralDetectionOp::inference() { out->push_back(tensor_out); } } - out->erase(out->begin(),out->begin()+infer_outnum); + out->erase(out->begin(), out->begin() + infer_outnum); - int64_t end = timeline.TimeStampUS(); CopyBlobInfo(input_blob, output_blob); AddBlobInfo(output_blob, start); @@ -231,68 +238,62 @@ int GeneralDetectionOp::inference() { return 0; } -cv::Mat GeneralDetectionOp::Base2Mat(std::string &base64_data) -{ - cv::Mat img; - std::string s_mat; - s_mat = base64Decode(base64_data.data(), base64_data.size()); - std::vector base64_img(s_mat.begin(), s_mat.end()); - img = cv::imdecode(base64_img, cv::IMREAD_COLOR);//CV_LOAD_IMAGE_COLOR - return img; +cv::Mat GeneralDetectionOp::Base2Mat(std::string& base64_data) { + cv::Mat img; + std::string s_mat; + s_mat = base64Decode(base64_data.data(), base64_data.size()); + std::vector base64_img(s_mat.begin(), s_mat.end()); + img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR + return img; } -std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte) -{ - - const char DecodeTable[] = - { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, - 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, - 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' - }; - - std::string strDecode; - int nValue; - int i = 0; - while (i < DataByte) - { - if (*Data != '\r' && *Data != '\n') - { - nValue = DecodeTable[*Data++] << 18; - nValue += DecodeTable[*Data++] << 12; - strDecode += (nValue & 0x00FF0000) >> 16; - if (*Data != '=') - { - nValue += DecodeTable[*Data++] << 6; - strDecode += (nValue & 0x0000FF00) >> 8; - if (*Data != '=') - { - nValue += DecodeTable[*Data++]; - strDecode += nValue & 0x000000FF; - } - } - i += 4; - } - else// 回车换行,跳过 - { - Data++; - i++; - } - } - return strDecode; +std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte) { + const char + DecodeTable[] = + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, // 'a'-'z' + }; + + std::string strDecode; + int nValue; + int i = 0; + while (i < DataByte) { + if (*Data != '\r' && *Data != '\n') { + nValue = DecodeTable[*Data++] << 18; + nValue += DecodeTable[*Data++] << 12; + strDecode += (nValue & 0x00FF0000) >> 16; + if (*Data != '=') { + nValue += DecodeTable[*Data++] << 6; + strDecode += (nValue & 0x0000FF00) >> 8; + if (*Data != '=') { + nValue += DecodeTable[*Data++]; + strDecode += nValue & 0x000000FF; + } + } + i += 4; + } else // 回车换行,跳过 + { + Data++; + i++; + } + } + return strDecode; } -cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage, - std::vector> box) { +cv::Mat GeneralDetectionOp::GetRotateCropImage( + const cv::Mat& srcimage, std::vector> box) { cv::Mat image; srcimage.copyTo(image); std::vector> points = box; @@ -332,7 +333,9 @@ cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage, cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); cv::Mat dst_img; - cv::warpPerspective(img_crop, dst_img, M, + cv::warpPerspective(img_crop, + dst_img, + M, cv::Size(img_crop_width, img_crop_height), cv::BORDER_REPLICATE); @@ -350,4 +353,4 @@ DEFINE_OP(GeneralDetectionOp); } // namespace serving } // namespace paddle_serving -} // namespace baidu \ No newline at end of file +} // namespace baidu diff --git a/core/general-server/op/general_reader_op.cpp b/core/general-server/op/general_reader_op.cpp index 4b4e25cb075f56449359f7af0c064fcb83c2dd07..3e1091dd844f0afd71c8556586f82aafc42c5097 100644 --- a/core/general-server/op/general_reader_op.cpp +++ b/core/general-server/op/general_reader_op.cpp @@ -77,9 +77,6 @@ int GeneralReaderOp::inference() { uint64_t log_id = req->log_id(); int input_var_num = 0; - std::vector elem_type; - std::vector elem_size; - std::vector databuf_size; GeneralBlob *res = mutable_data(); if (!res) { @@ -119,40 +116,44 @@ int GeneralReaderOp::inference() { } */ // package tensor - - elem_type.resize(var_num); - elem_size.resize(var_num); - databuf_size.resize(var_num); // prepare basic information for input // specify the memory needed for output tensor_vector // fill the data into output general_blob int data_len = 0; + int64_t elem_type = 0; + int64_t elem_size = 0; + int64_t databuf_size = 0; for (int i = 0; i < var_num; ++i) { - paddle::PaddleTensor lod_tensor; + paddle::PaddleTensor paddleTensor; const Tensor &tensor = req->insts(0).tensor_array(i); data_len = 0; - elem_type[i] = tensor.elem_type(); - VLOG(2) << "var[" << i << "] has elem type: " << elem_type[i]; - if (elem_type[i] == P_INT64) { // int64 - elem_size[i] = sizeof(int64_t); - lod_tensor.dtype = paddle::PaddleDType::INT64; + elem_type = 0; + elem_size = 0; + databuf_size = 0; + elem_type = tensor.elem_type(); + VLOG(2) << "var[" << i << "] has elem type: " << elem_type; + if (elem_type == P_INT64) { // int64 + elem_size = sizeof(int64_t); + paddleTensor.dtype = paddle::PaddleDType::INT64; data_len = tensor.int64_data_size(); - } else if (elem_type[i] == P_FLOAT32) { - elem_size[i] = sizeof(float); - lod_tensor.dtype = paddle::PaddleDType::FLOAT32; + } else if (elem_type == P_FLOAT32) { + elem_size = sizeof(float); + paddleTensor.dtype = paddle::PaddleDType::FLOAT32; data_len = tensor.float_data_size(); - } else if (elem_type[i] == P_INT32) { - elem_size[i] = sizeof(int32_t); - lod_tensor.dtype = paddle::PaddleDType::INT32; + } else if (elem_type == P_INT32) { + elem_size = sizeof(int32_t); + paddleTensor.dtype = paddle::PaddleDType::INT32; data_len = tensor.int_data_size(); - } else if (elem_type[i] == P_STRING) { + } else if (elem_type == P_STRING) { // use paddle::PaddleDType::UINT8 as for String. - elem_size[i] = sizeof(uint8_t); - lod_tensor.dtype = paddle::PaddleDType::UINT8; + elem_size = sizeof(char); + paddleTensor.dtype = paddle::PaddleDType::UINT8; // this is for vector, cause the databuf_size != // vector.size()*sizeof(char); + // data_len should be +1 cause '\0' + // now only support single string for (int idx = 0; idx < tensor.data_size(); idx++) { - data_len += tensor.data()[idx].length(); + data_len += tensor.data()[idx].length() + 1; } } // implement lod tensor here @@ -160,29 +161,29 @@ int GeneralReaderOp::inference() { // TODO(HexToString): support 2-D lod if (tensor.lod_size() > 0) { VLOG(2) << "(logid=" << log_id << ") var[" << i << "] is lod_tensor"; - lod_tensor.lod.resize(1); + paddleTensor.lod.resize(1); for (int k = 0; k < tensor.lod_size(); ++k) { - lod_tensor.lod[0].push_back(tensor.lod(k)); + paddleTensor.lod[0].push_back(tensor.lod(k)); } } for (int k = 0; k < tensor.shape_size(); ++k) { int dim = tensor.shape(k); VLOG(2) << "(logid=" << log_id << ") shape for var[" << i << "]: " << dim; - lod_tensor.shape.push_back(dim); + paddleTensor.shape.push_back(dim); } - lod_tensor.name = model_config->_feed_name[i]; - out->push_back(lod_tensor); + paddleTensor.name = model_config->_feed_name[i]; + out->push_back(paddleTensor); VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i << "]: " << data_len; - databuf_size[i] = data_len * elem_size[i]; - out->at(i).data.Resize(data_len * elem_size[i]); + databuf_size = data_len * elem_size; + out->at(i).data.Resize(databuf_size); if (out->at(i).lod.size() > 0) { VLOG(2) << "(logid=" << log_id << ") var[" << i << "] has lod_tensor and len=" << out->at(i).lod[0].back(); } - if (elem_type[i] == P_INT64) { + if (elem_type == P_INT64) { int64_t *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.int64_data(0); @@ -190,14 +191,14 @@ int GeneralReaderOp::inference() { LOG(ERROR) << "dst_ptr is nullptr"; return -1; } - memcpy(dst_ptr, tensor.int64_data().data(), databuf_size[i]); + memcpy(dst_ptr, tensor.int64_data().data(), databuf_size); /* int elem_num = tensor.int64_data_size(); for (int k = 0; k < elem_num; ++k) { dst_ptr[k] = tensor.int64_data(k); } */ - } else if (elem_type[i] == P_FLOAT32) { + } else if (elem_type == P_FLOAT32) { float *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.float_data(0); @@ -205,12 +206,12 @@ int GeneralReaderOp::inference() { LOG(ERROR) << "dst_ptr is nullptr"; return -1; } - memcpy(dst_ptr, tensor.float_data().data(), databuf_size[i]); + memcpy(dst_ptr, tensor.float_data().data(), databuf_size); /*int elem_num = tensor.float_data_size(); for (int k = 0; k < elem_num; ++k) { dst_ptr[k] = tensor.float_data(k); }*/ - } else if (elem_type[i] == P_INT32) { + } else if (elem_type == P_INT32) { int32_t *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.int_data(0); @@ -218,15 +219,9 @@ int GeneralReaderOp::inference() { LOG(ERROR) << "dst_ptr is nullptr"; return -1; } - memcpy(dst_ptr, tensor.int_data().data(), databuf_size[i]); - /* - int elem_num = tensor.int_data_size(); - for (int k = 0; k < elem_num; ++k) { - dst_ptr[k] = tensor.int_data(k); - } - */ - } else if (elem_type[i] == P_STRING) { - std::string *dst_ptr = static_cast(out->at(i).data.data()); + memcpy(dst_ptr, tensor.int_data().data(), databuf_size); + } else if (elem_type == P_STRING) { + char *dst_ptr = static_cast(out->at(i).data.data()); VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i << "] is " << tensor.data(0); if (!dst_ptr) { @@ -234,8 +229,12 @@ int GeneralReaderOp::inference() { return -1; } int elem_num = tensor.data_size(); + int offset = 0; for (int k = 0; k < elem_num; ++k) { - dst_ptr[k] = tensor.data(k); + memcpy(dst_ptr + offset, + tensor.data(k).c_str(), + strlen(tensor.data(k).c_str()) + 1); + offset += strlen(tensor.data(k).c_str()) + 1; } } }