diff --git a/test/net/test_ocr.cpp b/test/net/test_ocr.cpp index 661b6e5cbf30b625e50a1f68d07ffc85f53e06bf..ef0c45b9d6e65f2033ddde1f6ec076bfd5e4093c 100644 --- a/test/net/test_ocr.cpp +++ b/test/net/test_ocr.cpp @@ -17,76 +17,91 @@ limitations under the License. */ #include "../test_helper.h" #include "../test_include.h" -void load_images(const char *image_dir, const char *images_list, - std::vector *image_names, - std::vector> *image_shapes) { - int channel, height, width; - std::string filename; - std::ifstream if_list(images_list, std::ios::in); - while (!if_list.eof()) { - if_list >> channel >> height >> width >> filename; - image_shapes->push_back(std::make_pair(height, width)); - image_names->push_back(filename); - } - if_list.close(); -} +const int max_run_times = 10; int main(int argc, char **argv) { - if (argc < 5) { + if (argc < 3) { std::cerr - << "Usage: ./test_ocr model_dir image_dir images_list output_name." + << "Usage: ./test_ocr [detect_model_dir|recog_model_dir] image_path" << std::endl; return 1; } - char *model_dir = argv[1]; - char *image_dir = argv[2]; - char *images_list = argv[3]; - char *output_name = argv[4]; + std::string model_dir = argv[1]; + std::string image_path = argv[2]; + + // init input, output params + std::vector input_vec; + std::vector input_shape; + std::vector output_fetch_nodes; + int PRINT_NODE_ELEM_NUM = 10; + + bool is_det_model = model_dir.find("detect") != string::npos; + if (is_det_model) { + input_shape.emplace_back(1); + input_shape.emplace_back(3); + input_shape.emplace_back(512); + input_shape.emplace_back(512); + output_fetch_nodes.emplace_back("sigmoid_0.tmp_0"); + output_fetch_nodes.emplace_back("tmp_5"); + } else { + input_shape.emplace_back(1); + input_shape.emplace_back(3); + input_shape.emplace_back(48); + input_shape.emplace_back(512); + output_fetch_nodes.emplace_back("top_k_1.tmp_0"); + output_fetch_nodes.emplace_back("cast_330.tmp_0"); + } + std::shared_ptr outputs[output_fetch_nodes.size()]; + // init paddle instance paddle_mobile::PaddleMobile paddle_mobile; paddle_mobile.SetThreadNum(1); - auto isok = paddle_mobile.Load(std::string(model_dir) + "/model", - std::string(model_dir) + "/params", true, - false, 1, true); - // auto isok = paddle_mobile.Load(std::string(model_dir), false, - // false, 1, true); + auto load_success = paddle_mobile.Load(std::string(model_dir) + "/model", + std::string(model_dir) + "/params", + true, false, 1, true); - DLOG << "pass init model"; - std::vector image_names; - std::vector> image_shapes; - load_images(image_dir, images_list, &image_names, &image_shapes); - DLOG << "pass load images"; + // input image raw tensor, generated by + // [scripts](tools/python/imagetools/img2nchw.py) + std::cout << "image_path: " << image_path << std::endl; + std::cout << "input_shape: " << input_shape[0] << ", " << input_shape[1] + << ", " << input_shape[2] << ", " << input_shape[3] << std::endl; + GetInput(image_path, &input_vec, input_shape); - for (int i = 0; i < image_names.size(); i++) { - std::string file_name = image_names[i]; - std::vector input_vec; - std::vector dims{1, 3, 48, 512}; - dims[2] = image_shapes[i].first; - dims[3] = image_shapes[i].second; - // load input image - std::string img_path = std::string(image_dir) + "/" + file_name; - std::cerr << "img_path: " << img_path << std::endl; - std::cerr << "shape = [" << dims[0] << ", " << dims[1] << ", " << dims[2] - << ", " << dims[3] << "]" << std::endl; - GetInput(img_path, &input_vec, dims); - // framework::Tensor input(input_vec, framework::make_ddim(dims)); - // predict - // for (int j = 0; j < 10000; ++j) { - auto time3 = paddle_mobile::time(); - paddle_mobile.Predict(input_vec, dims); - auto output_topk = paddle_mobile.Fetch(output_name); - auto time4 = paddle_mobile::time(); - std::cerr << "average predict elapsed: " - << paddle_mobile::time_diff(time3, time4) << "ms" << std::endl; - // } + // model predict + auto pred_start_time = paddle_mobile::time(); + for (int run_idx = 0; run_idx < max_run_times; ++run_idx) { + paddle_mobile.Predict(input_vec, input_shape); + for (int out_idx = 0; out_idx < output_fetch_nodes.size(); ++out_idx) { + auto fetch_name = output_fetch_nodes[out_idx]; + outputs[out_idx] = paddle_mobile.Fetch(fetch_name); + } + } + auto pred_end_time = paddle_mobile::time(); - // print result - std::cerr << output_name << std::endl; - std::cerr << output_topk->data()[0]; - for (int j = 1; j < output_topk->numel(); ++j) { - std::cerr << " " << output_topk->data()[j]; + // inference time + double pred_time = + paddle_mobile::time_diff(pred_start_time, pred_end_time) / max_run_times; + std::cout << "predict time(ms): " << pred_time << std::endl; + + // output result + for (int out_idx = 0; out_idx < output_fetch_nodes.size(); ++out_idx) { + std::string node_id = output_fetch_nodes[out_idx]; + auto node_lod_tensor = outputs[out_idx]; + int node_elem_num = node_lod_tensor->numel(); + float *node_ptr = node_lod_tensor->data(); + std::cout << "==== output_fetch_nodes[" << out_idx + << "] =====" << std::endl; + std::cout << "node_id: " << node_id << std::endl; + std::cout << "node_elem_num: " << node_elem_num << std::endl; + std::cout << "PRINT_NODE_ELEM_NUM: " << PRINT_NODE_ELEM_NUM << std::endl; + PRINT_NODE_ELEM_NUM = + (node_elem_num > PRINT_NODE_ELEM_NUM) ? PRINT_NODE_ELEM_NUM : 0; + for (int eidx = 0; eidx < PRINT_NODE_ELEM_NUM; ++eidx) { + std::cout << node_id << "[" << eidx << "]: " << node_ptr[eidx] + << std::endl; } - std::cerr << std::endl; + std::cout << std::endl; } + return 0; }