提交 3a9016d7 编写于 作者: S Shuai Yuan 提交者: Jiaying Zhao

Update ocr test case (#1605)

* Update ocr test case

* format code

* format code
上级 9070d4cf
......@@ -17,76 +17,91 @@ limitations under the License. */
#include "../test_helper.h"
#include "../test_include.h"
void load_images(const char *image_dir, const char *images_list,
std::vector<std::string> *image_names,
std::vector<std::pair<int, int>> *image_shapes) {
int channel, height, width;
std::string filename;
std::ifstream if_list(images_list, std::ios::in);
while (!if_list.eof()) {
if_list >> channel >> height >> width >> filename;
image_shapes->push_back(std::make_pair(height, width));
image_names->push_back(filename);
}
if_list.close();
}
const int max_run_times = 10;
int main(int argc, char **argv) {
if (argc < 5) {
if (argc < 3) {
std::cerr
<< "Usage: ./test_ocr model_dir image_dir images_list output_name."
<< "Usage: ./test_ocr [detect_model_dir|recog_model_dir] image_path"
<< std::endl;
return 1;
}
char *model_dir = argv[1];
char *image_dir = argv[2];
char *images_list = argv[3];
char *output_name = argv[4];
std::string model_dir = argv[1];
std::string image_path = argv[2];
// init input, output params
std::vector<float> input_vec;
std::vector<int64_t> input_shape;
std::vector<std::string> output_fetch_nodes;
int PRINT_NODE_ELEM_NUM = 10;
bool is_det_model = model_dir.find("detect") != string::npos;
if (is_det_model) {
input_shape.emplace_back(1);
input_shape.emplace_back(3);
input_shape.emplace_back(512);
input_shape.emplace_back(512);
output_fetch_nodes.emplace_back("sigmoid_0.tmp_0");
output_fetch_nodes.emplace_back("tmp_5");
} else {
input_shape.emplace_back(1);
input_shape.emplace_back(3);
input_shape.emplace_back(48);
input_shape.emplace_back(512);
output_fetch_nodes.emplace_back("top_k_1.tmp_0");
output_fetch_nodes.emplace_back("cast_330.tmp_0");
}
std::shared_ptr<framework::LoDTensor> outputs[output_fetch_nodes.size()];
// init paddle instance
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(1);
auto isok = paddle_mobile.Load(std::string(model_dir) + "/model",
std::string(model_dir) + "/params", true,
false, 1, true);
// auto isok = paddle_mobile.Load(std::string(model_dir), false,
// false, 1, true);
auto load_success = paddle_mobile.Load(std::string(model_dir) + "/model",
std::string(model_dir) + "/params",
true, false, 1, true);
DLOG << "pass init model";
std::vector<std::string> image_names;
std::vector<std::pair<int, int>> image_shapes;
load_images(image_dir, images_list, &image_names, &image_shapes);
DLOG << "pass load images";
// input image raw tensor, generated by
// [scripts](tools/python/imagetools/img2nchw.py)
std::cout << "image_path: " << image_path << std::endl;
std::cout << "input_shape: " << input_shape[0] << ", " << input_shape[1]
<< ", " << input_shape[2] << ", " << input_shape[3] << std::endl;
GetInput<float>(image_path, &input_vec, input_shape);
for (int i = 0; i < image_names.size(); i++) {
std::string file_name = image_names[i];
std::vector<float> input_vec;
std::vector<int64_t> dims{1, 3, 48, 512};
dims[2] = image_shapes[i].first;
dims[3] = image_shapes[i].second;
// load input image
std::string img_path = std::string(image_dir) + "/" + file_name;
std::cerr << "img_path: " << img_path << std::endl;
std::cerr << "shape = [" << dims[0] << ", " << dims[1] << ", " << dims[2]
<< ", " << dims[3] << "]" << std::endl;
GetInput<float>(img_path, &input_vec, dims);
// framework::Tensor input(input_vec, framework::make_ddim(dims));
// predict
// for (int j = 0; j < 10000; ++j) {
auto time3 = paddle_mobile::time();
paddle_mobile.Predict(input_vec, dims);
auto output_topk = paddle_mobile.Fetch(output_name);
auto time4 = paddle_mobile::time();
std::cerr << "average predict elapsed: "
<< paddle_mobile::time_diff(time3, time4) << "ms" << std::endl;
// }
// model predict
auto pred_start_time = paddle_mobile::time();
for (int run_idx = 0; run_idx < max_run_times; ++run_idx) {
paddle_mobile.Predict(input_vec, input_shape);
for (int out_idx = 0; out_idx < output_fetch_nodes.size(); ++out_idx) {
auto fetch_name = output_fetch_nodes[out_idx];
outputs[out_idx] = paddle_mobile.Fetch(fetch_name);
}
}
auto pred_end_time = paddle_mobile::time();
// print result
std::cerr << output_name << std::endl;
std::cerr << output_topk->data<float>()[0];
for (int j = 1; j < output_topk->numel(); ++j) {
std::cerr << " " << output_topk->data<float>()[j];
// inference time
double pred_time =
paddle_mobile::time_diff(pred_start_time, pred_end_time) / max_run_times;
std::cout << "predict time(ms): " << pred_time << std::endl;
// output result
for (int out_idx = 0; out_idx < output_fetch_nodes.size(); ++out_idx) {
std::string node_id = output_fetch_nodes[out_idx];
auto node_lod_tensor = outputs[out_idx];
int node_elem_num = node_lod_tensor->numel();
float *node_ptr = node_lod_tensor->data<float>();
std::cout << "==== output_fetch_nodes[" << out_idx
<< "] =====" << std::endl;
std::cout << "node_id: " << node_id << std::endl;
std::cout << "node_elem_num: " << node_elem_num << std::endl;
std::cout << "PRINT_NODE_ELEM_NUM: " << PRINT_NODE_ELEM_NUM << std::endl;
PRINT_NODE_ELEM_NUM =
(node_elem_num > PRINT_NODE_ELEM_NUM) ? PRINT_NODE_ELEM_NUM : 0;
for (int eidx = 0; eidx < PRINT_NODE_ELEM_NUM; ++eidx) {
std::cout << node_id << "[" << eidx << "]: " << node_ptr[eidx]
<< std::endl;
}
std::cerr << std::endl;
std::cout << std::endl;
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册