提交 4146a043 编写于 作者: H HexToString

fix ocr core dump

上级 d4a7b2a0
......@@ -22,7 +22,6 @@
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
/*
#include "opencv2/imgcodecs/legacy/constants_c.h"
#include "opencv2/imgproc/types_c.h"
......@@ -52,18 +51,18 @@ int GeneralDetectionOp::inference() {
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
const GeneralBlob* input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
GeneralBlob* output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
return -1;
}
output_blob->SetLogId(log_id);
......@@ -73,7 +72,7 @@ int GeneralDetectionOp::inference() {
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
const TensorVector* in = &input_blob->tensor_vector;
TensorVector* out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
......@@ -81,38 +80,39 @@ int GeneralDetectionOp::inference() {
output_blob->_batch_size = batch_size;
VLOG(2) << "(logid=" << log_id << ") infer batch size: " << batch_size;
std::vector<int> input_shape;
int in_num =0;
int in_num = 0;
void* databuf_data = NULL;
char* databuf_char = NULL;
size_t databuf_size = 0;
// now only support single string
char* total_input_ptr = static_cast<char*>(in->at(0).data.data());
std::string base64str = total_input_ptr;
std::string* input_ptr = static_cast<std::string*>(in->at(0).data.data());
std::string base64str = input_ptr[0];
float ratio_h{};
float ratio_w{};
cv::Mat img = Base2Mat(base64str);
cv::Mat srcimg;
cv::Mat resize_img;
cv::Mat resize_img_rec;
cv::Mat crop_img;
img.copyTo(srcimg);
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
this->resize_op_.Run(img,
resize_img,
this->max_side_len_,
ratio_h,
ratio_w,
this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_det, this->scale_det,
this->is_scale_);
this->normalize_op_.Run(
&resize_img, this->mean_det, this->scale_det, this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
TensorVector* real_in = new TensorVector();
if (!real_in) {
LOG(ERROR) << "real_in is nullptr,error";
......@@ -121,14 +121,15 @@ int GeneralDetectionOp::inference() {
for (int i = 0; i < in->size(); ++i) {
input_shape = {1, 3, resize_img.rows, resize_img.cols};
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());
databuf_size = in_num*sizeof(float);
in_num = std::accumulate(
input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());
databuf_size = in_num * sizeof(float);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
memcpy(databuf_data,input.data(),databuf_size);
memcpy(databuf_data, input.data(), databuf_size);
databuf_char = reinterpret_cast<char*>(databuf_data);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
paddle::PaddleTensor tensor_in;
......@@ -143,21 +144,23 @@ int GeneralDetectionOp::inference() {
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
if (InferManager::instance().infer(
engine_name().c_str(), real_in, out, batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
delete real_in;
std::vector<int> output_shape;
int out_num =0;
int out_num = 0;
void* databuf_data_out = NULL;
char* databuf_char_out = NULL;
size_t databuf_size_out = 0;
//this is special add for PaddleOCR postprecess
int infer_outnum = out->size();
for (int k = 0;k <infer_outnum; ++k) {
// this is special add for PaddleOCR postprecess
int infer_outnum = out->size();
for (int k = 0; k < infer_outnum; ++k) {
int n2 = out->at(k).shape[2];
int n3 = out->at(k).shape[3];
int n = n2 * n3;
......@@ -171,17 +174,19 @@ int GeneralDetectionOp::inference() {
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
const double threshold = this->det_db_thresh_ * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map,
boxes = post_processor_.BoxesFromBitmap(pred_map,
dilation_map,
this->det_db_box_thresh_,
this->det_db_unclip_ratio_);
......@@ -192,25 +197,28 @@ int GeneralDetectionOp::inference() {
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
this->resize_op_rec.Run(crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_);
this->resize_op_rec.Run(
crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_);
this->normalize_op_.Run(&resize_img_rec, this->mean_rec, this->scale_rec,
this->is_scale_);
this->normalize_op_.Run(
&resize_img_rec, this->mean_rec, this->scale_rec, this->is_scale_);
std::vector<float> output_rec(1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f);
std::vector<float> output_rec(
1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f);
this->permute_op_.Run(&resize_img_rec, output_rec.data());
// Inference.
output_shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols};
out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
databuf_size_out = out_num*sizeof(float);
out_num = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
databuf_size_out = out_num * sizeof(float);
databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out);
if (!databuf_data_out) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size_out;
return -1;
LOG(ERROR) << "Malloc failed, size: " << databuf_size_out;
return -1;
}
memcpy(databuf_data_out,output_rec.data(),databuf_size_out);
memcpy(databuf_data_out, output_rec.data(), databuf_size_out);
databuf_char_out = reinterpret_cast<char*>(databuf_data_out);
paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out);
paddle::PaddleTensor tensor_out;
......@@ -221,9 +229,8 @@ int GeneralDetectionOp::inference() {
out->push_back(tensor_out);
}
}
out->erase(out->begin(),out->begin()+infer_outnum);
out->erase(out->begin(), out->begin() + infer_outnum);
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
......@@ -231,68 +238,62 @@ int GeneralDetectionOp::inference() {
return 0;
}
cv::Mat GeneralDetectionOp::Base2Mat(std::string &base64_data)
{
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR);//CV_LOAD_IMAGE_COLOR
return img;
cv::Mat GeneralDetectionOp::Base2Mat(std::string& base64_data) {
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR
return img;
}
std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte)
{
const char DecodeTable[] =
{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte)
{
if (*Data != '\r' && *Data != '\n')
{
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=')
{
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=')
{
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
}
else// 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
std::string GeneralDetectionOp::base64Decode(const char* Data, int DataByte) {
const char
DecodeTable[] =
{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, // 'a'-'z'
};
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte) {
if (*Data != '\r' && *Data != '\n') {
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=') {
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=') {
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
} else // 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
}
cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box) {
cv::Mat GeneralDetectionOp::GetRotateCropImage(
const cv::Mat& srcimage, std::vector<std::vector<int>> box) {
cv::Mat image;
srcimage.copyTo(image);
std::vector<std::vector<int>> points = box;
......@@ -332,7 +333,9 @@ cv::Mat GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage,
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M,
cv::warpPerspective(img_crop,
dst_img,
M,
cv::Size(img_crop_width, img_crop_height),
cv::BORDER_REPLICATE);
......@@ -350,4 +353,4 @@ DEFINE_OP(GeneralDetectionOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
\ No newline at end of file
} // namespace baidu
......@@ -77,9 +77,6 @@ int GeneralReaderOp::inference() {
uint64_t log_id = req->log_id();
int input_var_num = 0;
std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size;
std::vector<int64_t> databuf_size;
GeneralBlob *res = mutable_data<GeneralBlob>();
if (!res) {
......@@ -119,40 +116,44 @@ int GeneralReaderOp::inference() {
}
*/
// package tensor
elem_type.resize(var_num);
elem_size.resize(var_num);
databuf_size.resize(var_num);
// prepare basic information for input
// specify the memory needed for output tensor_vector
// fill the data into output general_blob
int data_len = 0;
int64_t elem_type = 0;
int64_t elem_size = 0;
int64_t databuf_size = 0;
for (int i = 0; i < var_num; ++i) {
paddle::PaddleTensor lod_tensor;
paddle::PaddleTensor paddleTensor;
const Tensor &tensor = req->insts(0).tensor_array(i);
data_len = 0;
elem_type[i] = tensor.elem_type();
VLOG(2) << "var[" << i << "] has elem type: " << elem_type[i];
if (elem_type[i] == P_INT64) { // int64
elem_size[i] = sizeof(int64_t);
lod_tensor.dtype = paddle::PaddleDType::INT64;
elem_type = 0;
elem_size = 0;
databuf_size = 0;
elem_type = tensor.elem_type();
VLOG(2) << "var[" << i << "] has elem type: " << elem_type;
if (elem_type == P_INT64) { // int64
elem_size = sizeof(int64_t);
paddleTensor.dtype = paddle::PaddleDType::INT64;
data_len = tensor.int64_data_size();
} else if (elem_type[i] == P_FLOAT32) {
elem_size[i] = sizeof(float);
lod_tensor.dtype = paddle::PaddleDType::FLOAT32;
} else if (elem_type == P_FLOAT32) {
elem_size = sizeof(float);
paddleTensor.dtype = paddle::PaddleDType::FLOAT32;
data_len = tensor.float_data_size();
} else if (elem_type[i] == P_INT32) {
elem_size[i] = sizeof(int32_t);
lod_tensor.dtype = paddle::PaddleDType::INT32;
} else if (elem_type == P_INT32) {
elem_size = sizeof(int32_t);
paddleTensor.dtype = paddle::PaddleDType::INT32;
data_len = tensor.int_data_size();
} else if (elem_type[i] == P_STRING) {
} else if (elem_type == P_STRING) {
// use paddle::PaddleDType::UINT8 as for String.
elem_size[i] = sizeof(uint8_t);
lod_tensor.dtype = paddle::PaddleDType::UINT8;
elem_size = sizeof(char);
paddleTensor.dtype = paddle::PaddleDType::UINT8;
// this is for vector<String>, cause the databuf_size !=
// vector<String>.size()*sizeof(char);
// data_len should be +1 cause '\0'
// now only support single string
for (int idx = 0; idx < tensor.data_size(); idx++) {
data_len += tensor.data()[idx].length();
data_len += tensor.data()[idx].length() + 1;
}
}
// implement lod tensor here
......@@ -160,29 +161,29 @@ int GeneralReaderOp::inference() {
// TODO(HexToString): support 2-D lod
if (tensor.lod_size() > 0) {
VLOG(2) << "(logid=" << log_id << ") var[" << i << "] is lod_tensor";
lod_tensor.lod.resize(1);
paddleTensor.lod.resize(1);
for (int k = 0; k < tensor.lod_size(); ++k) {
lod_tensor.lod[0].push_back(tensor.lod(k));
paddleTensor.lod[0].push_back(tensor.lod(k));
}
}
for (int k = 0; k < tensor.shape_size(); ++k) {
int dim = tensor.shape(k);
VLOG(2) << "(logid=" << log_id << ") shape for var[" << i << "]: " << dim;
lod_tensor.shape.push_back(dim);
paddleTensor.shape.push_back(dim);
}
lod_tensor.name = model_config->_feed_name[i];
out->push_back(lod_tensor);
paddleTensor.name = model_config->_feed_name[i];
out->push_back(paddleTensor);
VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i
<< "]: " << data_len;
databuf_size[i] = data_len * elem_size[i];
out->at(i).data.Resize(data_len * elem_size[i]);
databuf_size = data_len * elem_size;
out->at(i).data.Resize(databuf_size);
if (out->at(i).lod.size() > 0) {
VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] has lod_tensor and len=" << out->at(i).lod[0].back();
}
if (elem_type[i] == P_INT64) {
if (elem_type == P_INT64) {
int64_t *dst_ptr = static_cast<int64_t *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << tensor.int64_data(0);
......@@ -190,14 +191,14 @@ int GeneralReaderOp::inference() {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
memcpy(dst_ptr, tensor.int64_data().data(), databuf_size[i]);
memcpy(dst_ptr, tensor.int64_data().data(), databuf_size);
/*
int elem_num = tensor.int64_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = tensor.int64_data(k);
}
*/
} else if (elem_type[i] == P_FLOAT32) {
} else if (elem_type == P_FLOAT32) {
float *dst_ptr = static_cast<float *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << tensor.float_data(0);
......@@ -205,12 +206,12 @@ int GeneralReaderOp::inference() {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
memcpy(dst_ptr, tensor.float_data().data(), databuf_size[i]);
memcpy(dst_ptr, tensor.float_data().data(), databuf_size);
/*int elem_num = tensor.float_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = tensor.float_data(k);
}*/
} else if (elem_type[i] == P_INT32) {
} else if (elem_type == P_INT32) {
int32_t *dst_ptr = static_cast<int32_t *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << tensor.int_data(0);
......@@ -218,15 +219,9 @@ int GeneralReaderOp::inference() {
LOG(ERROR) << "dst_ptr is nullptr";
return -1;
}
memcpy(dst_ptr, tensor.int_data().data(), databuf_size[i]);
/*
int elem_num = tensor.int_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = tensor.int_data(k);
}
*/
} else if (elem_type[i] == P_STRING) {
std::string *dst_ptr = static_cast<std::string *>(out->at(i).data.data());
memcpy(dst_ptr, tensor.int_data().data(), databuf_size);
} else if (elem_type == P_STRING) {
char *dst_ptr = static_cast<char *>(out->at(i).data.data());
VLOG(2) << "(logid=" << log_id << ") first element data in var[" << i
<< "] is " << tensor.data(0);
if (!dst_ptr) {
......@@ -234,8 +229,12 @@ int GeneralReaderOp::inference() {
return -1;
}
int elem_num = tensor.data_size();
int offset = 0;
for (int k = 0; k < elem_num; ++k) {
dst_ptr[k] = tensor.data(k);
memcpy(dst_ptr + offset,
tensor.data(k).c_str(),
strlen(tensor.data(k).c_str()) + 1);
offset += strlen(tensor.data(k).c_str()) + 1;
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册