structure_table.cpp 6.6 KB
Newer Older
文幕地方's avatar
文幕地方 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/structure_table.h>

namespace PaddleOCR {

void StructureTableRecognizer::Run(
    std::vector<cv::Mat> img_list,
    std::vector<std::vector<std::string>> &structure_html_tags,
    std::vector<float> &structure_scores,
文幕地方's avatar
文幕地方 已提交
23
    std::vector<std::vector<std::vector<int>>> &structure_boxes,
文幕地方's avatar
文幕地方 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    std::vector<double> &times) {
  std::chrono::duration<float> preprocess_diff =
      std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
  std::chrono::duration<float> inference_diff =
      std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
  std::chrono::duration<float> postprocess_diff =
      std::chrono::steady_clock::now() - std::chrono::steady_clock::now();

  int img_num = img_list.size();
  for (int beg_img_no = 0; beg_img_no < img_num;
       beg_img_no += this->table_batch_num_) {
    // preprocess
    auto preprocess_start = std::chrono::steady_clock::now();
    int end_img_no = min(img_num, beg_img_no + this->table_batch_num_);
    int batch_num = end_img_no - beg_img_no;
    std::vector<cv::Mat> norm_img_batch;
    std::vector<int> width_list;
    std::vector<int> height_list;
    for (int ino = beg_img_no; ino < end_img_no; ino++) {
      cv::Mat srcimg;
      img_list[ino].copyTo(srcimg);
      cv::Mat resize_img;
      cv::Mat pad_img;
      this->resize_op_.Run(srcimg, resize_img, this->table_max_len_);
      this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
                              this->is_scale_);
      this->pad_op_.Run(resize_img, pad_img, this->table_max_len_);
      norm_img_batch.push_back(pad_img);
      width_list.push_back(srcimg.cols);
      height_list.push_back(srcimg.rows);
    }

    std::vector<float> input(
        batch_num * 3 * this->table_max_len_ * this->table_max_len_, 0.0f);
    this->permute_op_.Run(norm_img_batch, input.data());
    auto preprocess_end = std::chrono::steady_clock::now();
    preprocess_diff += preprocess_end - preprocess_start;
    // inference.
    auto input_names = this->predictor_->GetInputNames();
    auto input_t = this->predictor_->GetInputHandle(input_names[0]);
    input_t->Reshape(
        {batch_num, 3, this->table_max_len_, this->table_max_len_});
    auto inference_start = std::chrono::steady_clock::now();
    input_t->CopyFromCpu(input.data());
    this->predictor_->Run();
    auto output_names = this->predictor_->GetOutputNames();
    auto output_tensor0 = this->predictor_->GetOutputHandle(output_names[0]);
    auto output_tensor1 = this->predictor_->GetOutputHandle(output_names[1]);
    std::vector<int> predict_shape0 = output_tensor0->shape();
    std::vector<int> predict_shape1 = output_tensor1->shape();

    int out_num0 = std::accumulate(predict_shape0.begin(), predict_shape0.end(),
                                   1, std::multiplies<int>());
    int out_num1 = std::accumulate(predict_shape1.begin(), predict_shape1.end(),
                                   1, std::multiplies<int>());
    std::vector<float> loc_preds;
    std::vector<float> structure_probs;
    loc_preds.resize(out_num0);
    structure_probs.resize(out_num1);

    output_tensor0->CopyToCpu(loc_preds.data());
    output_tensor1->CopyToCpu(structure_probs.data());
    auto inference_end = std::chrono::steady_clock::now();
    inference_diff += inference_end - inference_start;
    // postprocess
    auto postprocess_start = std::chrono::steady_clock::now();
    std::vector<std::vector<std::string>> structure_html_tag_batch;
    std::vector<float> structure_score_batch;
文幕地方's avatar
文幕地方 已提交
92
    std::vector<std::vector<std::vector<int>>> structure_boxes_batch;
文幕地方's avatar
文幕地方 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch,
                              predict_shape0, predict_shape1,
                              structure_html_tag_batch, structure_boxes_batch,
                              width_list, height_list);
    for (int m = 0; m < predict_shape0[0]; m++) {

      structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
                                         "<table>");
      structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
                                         "<body>");
      structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
                                         "<html>");
      structure_html_tag_batch[m].push_back("</table>");
      structure_html_tag_batch[m].push_back("</body>");
      structure_html_tag_batch[m].push_back("</html>");
      structure_html_tags.push_back(structure_html_tag_batch[m]);
      structure_scores.push_back(structure_score_batch[m]);
      structure_boxes.push_back(structure_boxes_batch[m]);
    }
    auto postprocess_end = std::chrono::steady_clock::now();
    postprocess_diff += postprocess_end - postprocess_start;
    times.push_back(double(preprocess_diff.count() * 1000));
    times.push_back(double(inference_diff.count() * 1000));
    times.push_back(double(postprocess_diff.count() * 1000));
  }
}

void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
  AnalysisConfig config;
  config.SetModel(model_dir + "/inference.pdmodel",
                  model_dir + "/inference.pdiparams");

  if (this->use_gpu_) {
    config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
    if (this->use_tensorrt_) {
      auto precision = paddle_infer::Config::Precision::kFloat32;
      if (this->precision_ == "fp16") {
        precision = paddle_infer::Config::Precision::kHalf;
      }
      if (this->precision_ == "int8") {
        precision = paddle_infer::Config::Precision::kInt8;
      }
      config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
    }
  } else {
    config.DisableGpu();
    if (this->use_mkldnn_) {
      config.EnableMKLDNN();
    }
    config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
  }

  // false for zero copy tensor
  config.SwitchUseFeedFetchOps(false);
  // true for multiple input
  config.SwitchSpecifyInputNames(true);

  config.SwitchIrOptim(true);

  config.EnableMemoryOptim();
  config.DisableGlogInfo();

  this->predictor_ = CreatePredictor(config);
}
} // namespace PaddleOCR