super_scale.cpp 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Tencent is pleased to support the open source community by making WeChat QRCode available.
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
// Modified by darkliang wangberlinT

#include "../../precomp.hpp"
#include "super_scale.hpp"

#ifdef HAVE_OPENCV_DNN

namespace cv {
namespace barcode {
constexpr static float MAX_SCALE = 4.0f;

int SuperScale::init(const std::string &proto_path, const std::string &model_path)
{
    srnet_ = dnn::readNetFromCaffe(proto_path, model_path);
    net_loaded_ = true;
    return 0;
}

void SuperScale::processImageScale(const Mat &src, Mat &dst, float scale, const bool &use_sr, int sr_max_size)
{
    scale = min(scale, MAX_SCALE);
    if (scale > .0 && scale < 1.0)
    {  // down sample
        resize(src, dst, Size(), scale, scale, INTER_AREA);
    }
    else if (scale > 1.5 && scale < 2.0)
    {
        resize(src, dst, Size(), scale, scale, INTER_CUBIC);
    }
    else if (scale >= 2.0)
    {
        int width = src.cols;
        int height = src.rows;
        if (use_sr && (int) sqrt(width * height * 1.0) < sr_max_size && net_loaded_)
        {
            superResolutionScale(src, dst);
            if (scale > 2.0)
            {
                processImageScale(dst, dst, scale / 2.0f, use_sr);
            }
        }
        else
        { resize(src, dst, Size(), scale, scale, INTER_CUBIC); }
    }
}

int SuperScale::superResolutionScale(const Mat &src, Mat &dst)
{
    Mat blob;
    dnn::blobFromImage(src, blob, 1.0 / 255, Size(src.cols, src.rows), {0.0f}, false, false);

    srnet_.setInput(blob);
    auto prob = srnet_.forward();

    dst = Mat(prob.size[2], prob.size[3], CV_8UC1);

    for (int row = 0; row < prob.size[2]; row++)
    {
        const float *prob_score = prob.ptr<float>(0, 0, row);
        auto *dst_row = dst.ptr<uchar>(row);
        for (int col = 0; col < prob.size[3]; col++)
        {
            dst_row[col] = saturate_cast<uchar>(prob_score[col] * 255.0f);
        }
    }
    return 0;
}
}  // namespace barcode
}  // namespace cv

#endif // HAVE_OPENCV_DNN