ocr_ppredictor.h 2.6 KB
Newer Older
A
authorfu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Created by fujiayi on 2020/7/1.
//

#pragma once

#include <string>
#include <opencv2/opencv.hpp>
#include <paddle_api.h>
#include "ppredictor.h"

namespace ppredictor {

/**
A
authorfu 已提交
15
 * Config
A
authorfu 已提交
16 17
 */
struct OCR_Config {
A
authorfu 已提交
18
    int thread_num = 4; // Thread num
A
authorfu 已提交
19 20 21 22
    paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
};

/**
A
authorfu 已提交
23
 * PolyGone Result
A
authorfu 已提交
24 25
 */
struct OCRPredictResult {
A
authorfu 已提交
26
    std::vector<int> word_index;
A
authorfu 已提交
27 28 29 30 31
    std::vector<std::vector<int>> points;
    float score;
};

/**
A
authorfu 已提交
32 33 34
 * OCR there are 2 models
 * 1. First model(det),select polygones to show where are the texts
 * 2. crop from the origin images, use these polygones to infer
A
authorfu 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
 */
class OCR_PPredictor : public PPredictor_Interface {
public:
    OCR_PPredictor(const OCR_Config &config);

    virtual ~OCR_PPredictor() {

    }

    /**
     * 初始化二个模型的Predictor
     * @param det_model_content
     * @param rec_model_content
     * @return
     */
    int init(const std::string &det_model_content, const std::string &rec_model_content);
    int init_from_file(const std::string &det_model_path, const std::string &rec_model_path);
    /**
A
authorfu 已提交
53
     * Return OCR result
A
authorfu 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
     * @param dims
     * @param input_data
     * @param input_len
     * @param net_flag
     * @param origin
     * @return
     */
    virtual std::vector<OCRPredictResult>
    infer_ocr(const std::vector<int64_t> &dims, const float *input_data, int input_len,
              int net_flag, cv::Mat &origin);


    virtual NET_TYPE get_net_flag() const;


private:

    /**
A
authorfu 已提交
72
     * calcul Polygone from the result image of first model
A
authorfu 已提交
73 74 75 76 77 78 79 80 81 82 83
     * @param pred
     * @param output_height
     * @param output_width
     * @param origin
     * @return
     */
    std::vector<std::vector<std::vector<int>>>
    calc_filtered_boxes(const float *pred, int pred_size, int output_height, int output_width,
                        const cv::Mat &origin);

    /**
A
authorfu 已提交
84
     * infer for second model
A
authorfu 已提交
85 86 87 88 89 90 91 92 93
     *
     * @param boxes
     * @param origin
     * @return
     */
    std::vector<OCRPredictResult>
    infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes, const cv::Mat &origin);

    /**
A
authorfu 已提交
94
     * Postprocess or sencod model to extract text
A
authorfu 已提交
95 96 97 98 99 100
     * @param res
     * @return
     */
    std::vector<int> postprocess_rec_word_index(const PredictorOutput &res);

    /**
A
authorfu 已提交
101
     * calculate confidence of second model text result
A
authorfu 已提交
102 103 104 105 106 107 108 109 110 111 112
     * @param res
     * @return
     */
    float postprocess_rec_score(const PredictorOutput &res);

    std::unique_ptr<PPredictor> _det_predictor;
    std::unique_ptr<PPredictor> _rec_predictor;
    OCR_Config _config;

};
}