提交 85706e16 编写于 作者: L LDOUBLEV

add deploy lite demo

上级 6f456775
......@@ -18,3 +18,4 @@ output/
*.idea
*.log
.clang-format
ARM_ABI = arm8
export ARM_ABI
include ../Makefile.def
LITE_ROOT=../../../
THIRD_PARTY_DIR=${LITE_ROOT}/third_party
OPENCV_VERSION=opencv4.1.0
OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a
OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
ocr_db_crnn: fetch_opencv ocr_db_crnn.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
ocr_db_crnn.o: ocr_db_crnn.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
fetch_opencv:
@ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR}
@ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \
(echo "fetch opencv libs" && \
wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz)
@ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \
tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR}
.PHONY: clean
clean:
rm -f ocr_db_crnn.o
rm -f ocr_db_crnnn
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include <chrono>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" // NOLINT
#include "utils/db_post_process.cpp"
#include "utils/crnn_process.cpp"
#include <cstring>
#include <fstream>
using namespace paddle::lite_api; // NOLINT
struct Object {
cv::Rect rec;
int class_id;
float prob;
};
int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1;
for (auto i : shape) res *= i;
return res;
}
// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
void neon_mean_scale(const float* din,
float* dout,
int size,
const std::vector<float> mean,
const std::vector<float> scale) {
if (mean.size() != 3 || scale.size() != 3) {
std::cerr << "[ERROR] mean or scale size must equal to 3\n";
exit(1);
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
float32x4_t vmean1 = vdupq_n_f32(mean[1]);
float32x4_t vmean2 = vdupq_n_f32(mean[2]);
float32x4_t vscale0 = vdupq_n_f32(scale[0]);
float32x4_t vscale1 = vdupq_n_f32(scale[1]);
float32x4_t vscale2 = vdupq_n_f32(scale[2]);
float* dout_c0 = dout;
float* dout_c1 = dout + size;
float* dout_c2 = dout + size * 2;
int i = 0;
for (; i < size - 3; i += 4) {
float32x4x3_t vin3 = vld3q_f32(din);
float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
vst1q_f32(dout_c0, vs0);
vst1q_f32(dout_c1, vs1);
vst1q_f32(dout_c2, vs2);
din += 12;
dout_c0 += 4;
dout_c1 += 4;
dout_c2 += 4;
}
for (; i < size; i++) {
*(dout_c0++) = (*(din++) - mean[0]) * scale[0];
*(dout_c1++) = (*(din++) - mean[1]) * scale[1];
*(dout_c2++) = (*(din++) - mean[2]) * scale[2];
}
}
// resize image to a size multiple of 32 which is required by the network
cv::Mat resize_img_type0(const cv::Mat img, int max_size_len, float *ratio_h, float *ratio_w){
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
int max_wh = w >=h ? w : h;
if (max_wh > max_size_len){
if (h > w){
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
}
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h;
else if (resize_h / 32 < 1)
resize_h = 32;
else
resize_h = (resize_h / 32 - 1) * 32;
if (resize_w % 32 == 0)
resize_w = resize_w;
else if (resize_w /32 < 1)
resize_w = 32;
else
resize_w = (resize_w/32 - 1)*32;
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
*ratio_h = float(resize_h) / float(h);
*ratio_w = float(resize_w) / float(w);
return resize_img;
}
using namespace std;
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img, std::string rec_model_file){
MobileConfig config;
config.set_model_from_file(rec_model_file);
std::shared_ptr<PaddlePredictor> predictor_crnn =
CreatePaddlePredictor<MobileConfig>(config);
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat crop_img;
cv::Mat resize_img;
std::string dict_path = "./ppocr_keys_v1.txt";
auto charactor_dict = read_dict(dict_path);
std::cout << "The predicted text is :" << std::endl;
int index = 0;
for (int i=boxes.size()-1; i >= 0; i--) {
crop_img = get_rotate_crop_image(srcimg, boxes[i]);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
resize_img = crnn_resize_img(crop_img, wh_ratio);
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
std::unique_ptr <Tensor> input_tensor0(std::move(predictor_crnn->GetInput(0)));
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
auto *data0 = input_tensor0->mutable_data<float>();
neon_mean_scale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
//// Run CRNN predictor
predictor_crnn->Run();
// Get output and run postprocess
std::unique_ptr<const Tensor> output_tensor0(
std::move(predictor_crnn->GetOutput(0)));
auto *rec_idx = output_tensor0->data<int>();
auto rec_idx_lod = output_tensor0->lod();
auto shape_out = output_tensor0->shape();
std::vector<int> pred_idx;
for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) {
pred_idx.push_back(int(rec_idx[n]));
}
if (pred_idx.size() < 1e-3)
continue;
std::cout << std::endl;
index += 1;
std::cout << index << "\t";
for (int n = 0; n < pred_idx.size(); n++) {
std::cout << charactor_dict[pred_idx[n]];
}
////get score
std::unique_ptr<const Tensor> output_tensor1(std::move(predictor_crnn->GetOutput(1)));
auto *predict_batch = output_tensor1->data<float>();
auto predict_shape = output_tensor1->shape();
auto predict_lod = output_tensor1->lod();
int argmax_idx;
int blank = predict_shape[1];
float score = 0.f;
int count = 0;
float max_value = 0.0f;
for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) {
argmax_idx = int(argmax(&predict_batch[n * predict_shape[1]], &predict_batch[(n + 1) * predict_shape[1]]));
max_value = float(
*std::max_element(&predict_batch[n * predict_shape[1]], &predict_batch[(n + 1) * predict_shape[1]]));
if (blank - 1 - argmax_idx > 1e-5) {
score += max_value;
count += 1;
}
}
score /= count;
std::cout << "\tscore: " << score << std::endl;
}
}
std::vector<std::vector<std::vector<int>>> RunDetModel(std::string model_file, cv::Mat img) {
// Set MobileConfig
MobileConfig config;
config.set_model_from_file(model_file);
std::shared_ptr<PaddlePredictor> predictor =
CreatePaddlePredictor<MobileConfig>(config);
// Read img
int max_side_len = 960;
float ratio_h{};
float ratio_w{};
cv::Mat srcimg;
img.copyTo(srcimg);
img = resize_img_type0(img, max_side_len, &ratio_h, &ratio_w);
cv::Mat img_fp;
img.convertTo(img_fp, CV_32FC3, 1.0 / 255.f);
// Prepare input data from image
std::unique_ptr<Tensor> input_tensor0(std::move(predictor->GetInput(0)));
input_tensor0->Resize({1, 3, img_fp.rows, img_fp.cols});
auto* data0 = input_tensor0->mutable_data<float>();
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> scale = {1/0.229f, 1/0.224f, 1/0.225f};
const float* dimg = reinterpret_cast<const float*>(img_fp.data);
neon_mean_scale(dimg, data0, img_fp.rows * img_fp.cols, mean, scale);
// Run predictor
predictor->Run();
// Get output and post process
std::unique_ptr<const Tensor> output_tensor(std::move(predictor->GetOutput(0)));
auto* outptr = output_tensor->data<float>();
auto shape_out = output_tensor->shape();
int64_t out_numl = 1;
double sum = 0;
for (auto i : shape_out) {
out_numl *= i;
}
// Save output
float pred[shape_out[2]][shape_out[3]];
unsigned char cbuf[shape_out[2]][shape_out[3]];
for (int i=0; i< int(shape_out[2]*shape_out[3]); i++){
pred[int(i/int(shape_out[3]))][int(i%shape_out[3])] = float(outptr[i]);
cbuf[int(i/int(shape_out[3]))][int(i%shape_out[3])] = (unsigned char) ((outptr[i])*255);
}
cv::Mat cbuf_map(shape_out[2], shape_out[3], CV_8UC1, (unsigned char*)cbuf);
cv::Mat pred_map(shape_out[2], shape_out[3], CV_32F, (float *)pred);
const double threshold = 0.3*255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
auto boxes = boxes_from_bitmap(pred_map, bit_map);
std::vector<std::vector<std::vector<int>>> filter_boxes = filter_tag_det_res(boxes, ratio_h, ratio_w, srcimg);
//// visualization
cv::Point rook_points[filter_boxes.size()][4];
for (int n=0; n<filter_boxes.size(); n++){
for (int m=0; m< filter_boxes[0].size(); m++){
rook_points[n][m] = cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1]));
}
}
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n=0; n<boxes.size(); n++){
const cv::Point* ppt[1] = { rook_points[n] };
int npt[] = { 4 };
cv::polylines(img_vis, ppt, npt,1,1,CV_RGB(0,255,0),2,8,0);
}
cv::imwrite("./imgs_vis/vis.jpg", img_vis);
std::cout << "The detection visualized image saved in ./imgs_vis/" <<std::endl;
return filter_boxes;
}
int main(int argc, char** argv) {
if (argc < 4) {
std::cerr << "[ERROR] usage: " << argv[0] << " det_model_file rec_model_file image_path\n";
exit(1);
}
std::string det_model_file = argv[1];
std::string rec_model_file = argv[2];
std::string img_path = argv[3];
auto start = std::chrono::system_clock::now();
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
auto boxes = RunDetModel(det_model_file, srcimg);
RunRecModel(boxes, srcimg, rec_model_file);
auto end = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
std::cout << "花费了"
<< double(duration.count()) * std::chrono::microseconds::period::num /std::chrono::microseconds::period::den
<< "秒" << std::endl;
return 0;
}
# PaddleOCR 移动端部署
本教程介绍如何在移动端部署PaddleOCR超轻量中文检测、识别模型。
## 运行准备
- 电脑(编译Paddle-Lite)
- 安卓手机(armv7或armv8)
## 1. 准备环境
### 1.1 准备交叉编译环境
交叉编译环境用于编译Paddle-Lite和PaddleOCR的C++ demo。
支持多种开发环境,不同开发环境的编译流程请参考对应文档。:
1. [Docker](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#docker)
2. [Linux](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#android)
3. [MAC OS](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#id13)
4. [Windows](https://paddle-lite.readthedocs.io/zh/latest/demo_guides/x86.html#windows)
### 1.2 准备预编译库
预编译库有两种获取方式:
- 1. 直接下载,下载[链接](https://paddle-lite.readthedocs.io/zh/latest/user_guides/release_lib.html#android-toolchain-gcc).
注意选择with_extra=ON,with_cv=ON的下载链接。
- 2. 编译Paddle-Lite得到,Paddle-Lite的编译方式如下:
```
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
./lite/tools/build_android.sh --arch=armv8 --with_cv=ON --with_extra=ON
```
注意:编译Paddle-Lite获得预编译库时,需要打开--with_cv=ON --with_extra=ON两个选项,--arch表示arm版本,这里指定为armv8,
更多编译命令
介绍请参考[链接](https://paddle-lite.readthedocs.io/zh/latest/user_guides/Compile/Android.html#id2)
直接下载预编译库并解压后,可以得到'inference_lite_lib.android.armv8/'文件夹,通过编译Paddle-Lite得到的预编译库位于
'Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/'文件夹下。
预编译库的文件目录如下:
```
inference_lite_lib.android.armv8/
|-- cxx C++ 预测库和头文件
| |-- include C++ 头文件
| | |-- paddle_api.h
| | |-- paddle_image_preprocess.h
| | |-- paddle_lite_factory_helper.h
| | |-- paddle_place.h
| | |-- paddle_use_kernels.h
| | |-- paddle_use_ops.h
| | `-- paddle_use_passes.h
| `-- lib C++预测库
| |-- libpaddle_api_light_bundled.a C++静态库
| `-- libpaddle_light_api_shared.so C++动态库
|-- java Java预测库
| |-- jar
| | `-- PaddlePredictor.jar
| |-- so
| | `-- libpaddle_lite_jni.so
| `-- src
|-- demo C++和Java示例代码
| |-- cxx C++ 预测库demo
| `-- java Java 预测库demo
```
## 2 开始运行
### 2.1 模型优化
Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括量化、子图融合、混合调度、Kernel优选等方法,使用Paddle_lite的opt工具可以自动
对模inference型进行优化,优化后的模型更轻量,模型运行速度更快。
模型优化需要使用Paddle-Lite的opt可执行文件,可以通过编译Paddle-Lite源码获得,编译步骤如下:
```
# 如果准备环境中已经clone了Paddle-Lite,则不用重新clone Paddle-Lite
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
# 启动编译
./lite/tools/build.sh build_optimize_tool
```
编译完成后,opt文件位于'build.opt/lite/api/'下,可通过如下方式查看opt的运行选项和使用方式;
```
cd build.opt/lite/api/
./opt
```
|选项|说明|
|:-:|:-:|
|--model_dir|待优化的PaddlePaddle模型(非combined形式)的路径|
|--model_file|待优化的PaddlePaddle模型(combined形式)的网络结构文件路径|
|--param_file|待优化的PaddlePaddle模型(combined形式)的权重文件路径|
|--optimize_out_type|输出模型类型,目前支持两种类型:protobuf和naive_buffer,其中naive_buffer是一种更轻量级的序列化/反序列化实现。若您需要在mobile端执行模型预测,请将此选项设置为naive_buffer。默认为protobuf|
|--optimize_out|优化模型的输出路径|
|--valid_targets|指定模型可执行的backend,默认为arm。目前可支持x86、arm、opencl、npu、xpu,可以同时指定多个backend(以空格分隔),Model Optimize Tool将会自动选择最佳方式。如果需要支持华为NPU(Kirin 810/990 Soc搭载的达芬奇架构NPU),应当设置为npu, arm|
|--record_tailoring_info|当使用 根据模型裁剪库文件 功能时,则设置该选项为true,以记录优化后模型含有的kernel和OP信息,默认为false|
--model_dir适用于待优化的模型是非combined方式,PaddleOCR的inference模型是combined方式,即模型结构和模型参数使用单独一个文件存储。
下面以PaddleOCR的超轻量中文模型为例,介绍使用编译好的opt文件完成inference模型到Paddle-Lite优化模型的转换。
```
# 下载PaddleOCR的超轻量文inference模型,并解压
wget https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar && tar xf ch_det_mv3_db_infer.tar
wget https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar && tar xf ch_rec_mv3_crnn_infer.tar
# 转换检测模型
./opt --model_file=./ch_det_mv3_db/model --param_file=./ch_det_mv3_db/params --optimize_out_type=naive_buffer --optimize_out=./ch_det_mv3_db_opt --valid_targets=arm
# 转换识别模型
./opt --model_file=./ch_rec_mv3_crnn/model --param_file=./ch_rec_mv3_crnn/params --optimize_out_type=naive_buffer --optimize_out=./ch_rec_mv3_crnn_opt --valid_targets=arm
```
转换成功后,当前目录下会多出ch_det_mv3_db_opt.nb, ch_rec_mv3_crnn_opt.nb结尾的文件,即是转换成功的模型文件。
### 2.2 与手机联调
首先需要进行一些准备工作。
1. 准备一台arm8的安卓手机,如果编译的预测库和opt文件是armv7,则需要arm7的手机。
2. 打开手机的USB调试选项,选择文件传输模式,连接电脑
3. 电脑上安装adb工具,用于调试。在电脑终端中输入'adb devices',如果有类似以下输出,则表示安装成功。
```
List of devices attached
744be294 device
```
4. 准备预测库、模型和预测文件,在预测库inference_lite_lib.android.armv8/demo/cxx/下新建一个ocr/文件夹,并将转换后的nb模型、
PaddleOCR repo中PaddleOCR/deploy/lite/ 下的所有文件放在新建的ocr文件夹下。执行完成后,ocr文件夹下将有如下文件格式:
```
demo/cxx/ocr/
|-- debug/ 新建debug文件夹存放模型文件
| |--ch_det_mv3_db_opt.nb 优化后的检测模型文件
| |--ch_rec_mv3_crnn_opt.nb 优化后的识别模型文件
|-- utils/
| |-- clipper.cpp Clipper库的cpp文件
| |-- clipper.hpp Clipper库的hpp文件
| |-- crnn_process.cpp 识别模型CRNN的预处理和后处理cpp文件
| |-- db_post_process.cpp 检测模型DB的后处理cpp文件
|-- Makefile 编译文件
|-- ocr_db_crnn.cc C++预测文件
```
5. 编译C++预测文件,准备测试图像,准备字典文件
```
cd demo/cxx/ocr/
# 执行编译
make
# 将编译的可执行文件移动到debug文件夹中
mv ocr_db_crnn ./debug/
```
准备测试图像,以PaddleOCR/doc/imgs/12.jpg为例,将测试的图像复制到demo/cxx/ocr/debug/文件夹下。
准备字典文件,将PaddleOCR/ppocr/utils/ppocr_keys_v1.txt复制到demo/cxx/ocr/debug/文件夹下。
上述步骤完成后就可以使用adb将文件push到手机上运行,步骤如下:
```
adb push debug /data/local/tmp/
adb shell
cd /data/local/tmp/debug
export LD_LIBRARY_PATH=/data/local/tmp/debug:$LD_LIBRARY_PATH
./ocr_db_crnn ch_det_mv3_db_opt.nb ch_rec_mv3_crnn_opt.nb ./12.jpg
```
如果对代码做了修改,则需要重新编译并push到手机上。
此差异已折叠。
/*******************************************************************************
* *
* Author : Angus Johnson *
* Version : 6.4.2 *
* Date : 27 February 2017 *
* Website : http://www.angusj.com *
* Copyright : Angus Johnson 2010-2017 *
* *
* License: *
* Use, modification & distribution is subject to Boost Software License Ver 1. *
* http://www.boost.org/LICENSE_1_0.txt *
* *
* Attributions: *
* The code in this library is an extension of Bala Vatti's clipping algorithm: *
* "A generic solution to polygon clipping" *
* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
* http://portal.acm.org/citation.cfm?id=129906 *
* *
* Computer graphics and geometric modeling: implementation and algorithms *
* By Max K. Agoston *
* Springer; 1 edition (January 4, 2005) *
* http://books.google.com/books?q=vatti+clipping+agoston *
* *
* See also: *
* "Polygon Offsetting by Computing Winding Numbers" *
* Paper no. DETC2005-85513 pp. 565-575 *
* ASME 2005 International Design Engineering Technical Conferences *
* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
* September 24-28, 2005 , Long Beach, California, USA *
* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
* *
*******************************************************************************/
#ifndef clipper_hpp
#define clipper_hpp
#define CLIPPER_VERSION "6.4.2"
//use_int32: When enabled 32bit ints are used instead of 64bit ints. This
//improve performance but coordinate values are limited to the range +/- 46340
//#define use_int32
//use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
//#define use_xyz
//use_lines: Enables line clipping. Adds a very minor cost to performance.
#define use_lines
//use_deprecated: Enables temporary support for the obsolete functions
//#define use_deprecated
#include <vector>
#include <list>
#include <set>
#include <stdexcept>
#include <cstring>
#include <cstdlib>
#include <ostream>
#include <functional>
#include <queue>
namespace ClipperLib {
enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor };
enum PolyType { ptSubject, ptClip };
//By far the most widely used winding rules for polygon filling are
//EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32)
//Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL)
//see http://glprogramming.com/red/chapter11.html
enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative };
#ifdef use_int32
typedef int cInt;
static cInt const loRange = 0x7FFF;
static cInt const hiRange = 0x7FFF;
#else
typedef signed long long cInt;
static cInt const loRange = 0x3FFFFFFF;
static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL;
typedef signed long long long64; //used by Int128 class
typedef unsigned long long ulong64;
#endif
struct IntPoint {
cInt X;
cInt Y;
#ifdef use_xyz
cInt Z;
IntPoint(cInt x = 0, cInt y = 0, cInt z = 0): X(x), Y(y), Z(z) {};
#else
IntPoint(cInt x = 0, cInt y = 0): X(x), Y(y) {};
#endif
friend inline bool operator== (const IntPoint& a, const IntPoint& b)
{
return a.X == b.X && a.Y == b.Y;
}
friend inline bool operator!= (const IntPoint& a, const IntPoint& b)
{
return a.X != b.X || a.Y != b.Y;
}
};
//------------------------------------------------------------------------------
typedef std::vector< IntPoint > Path;
typedef std::vector< Path > Paths;
inline Path& operator <<(Path& poly, const IntPoint& p) {poly.push_back(p); return poly;}
inline Paths& operator <<(Paths& polys, const Path& p) {polys.push_back(p); return polys;}
std::ostream& operator <<(std::ostream &s, const IntPoint &p);
std::ostream& operator <<(std::ostream &s, const Path &p);
std::ostream& operator <<(std::ostream &s, const Paths &p);
struct DoublePoint
{
double X;
double Y;
DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {}
DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {}
};
//------------------------------------------------------------------------------
#ifdef use_xyz
typedef void (*ZFillCallback)(IntPoint& e1bot, IntPoint& e1top, IntPoint& e2bot, IntPoint& e2top, IntPoint& pt);
#endif
enum InitOptions {ioReverseSolution = 1, ioStrictlySimple = 2, ioPreserveCollinear = 4};
enum JoinType {jtSquare, jtRound, jtMiter};
enum EndType {etClosedPolygon, etClosedLine, etOpenButt, etOpenSquare, etOpenRound};
class PolyNode;
typedef std::vector< PolyNode* > PolyNodes;
class PolyNode
{
public:
PolyNode();
virtual ~PolyNode(){};
Path Contour;
PolyNodes Childs;
PolyNode* Parent;
PolyNode* GetNext() const;
bool IsHole() const;
bool IsOpen() const;
int ChildCount() const;
private:
//PolyNode& operator =(PolyNode& other);
unsigned Index; //node index in Parent.Childs
bool m_IsOpen;
JoinType m_jointype;
EndType m_endtype;
PolyNode* GetNextSiblingUp() const;
void AddChild(PolyNode& child);
friend class Clipper; //to access Index
friend class ClipperOffset;
};
class PolyTree: public PolyNode
{
public:
~PolyTree(){ Clear(); };
PolyNode* GetFirst() const;
void Clear();
int Total() const;
private:
//PolyTree& operator =(PolyTree& other);
PolyNodes AllNodes;
friend class Clipper; //to access AllNodes
};
bool Orientation(const Path &poly);
double Area(const Path &poly);
int PointInPolygon(const IntPoint &pt, const Path &path);
void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd);
void CleanPolygon(const Path& in_poly, Path& out_poly, double distance = 1.415);
void CleanPolygon(Path& poly, double distance = 1.415);
void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance = 1.415);
void CleanPolygons(Paths& polys, double distance = 1.415);
void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed);
void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed);
void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution);
void PolyTreeToPaths(const PolyTree& polytree, Paths& paths);
void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths);
void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths);
void ReversePath(Path& p);
void ReversePaths(Paths& p);
struct IntRect { cInt left; cInt top; cInt right; cInt bottom; };
//enums that are used internally ...
enum EdgeSide { esLeft = 1, esRight = 2};
//forward declarations (for stuff used internally) ...
struct TEdge;
struct IntersectNode;
struct LocalMinimum;
struct OutPt;
struct OutRec;
struct Join;
typedef std::vector < OutRec* > PolyOutList;
typedef std::vector < TEdge* > EdgeList;
typedef std::vector < Join* > JoinList;
typedef std::vector < IntersectNode* > IntersectList;
//------------------------------------------------------------------------------
//ClipperBase is the ancestor to the Clipper class. It should not be
//instantiated directly. This class simply abstracts the conversion of sets of
//polygon coordinates into edge objects that are stored in a LocalMinima list.
class ClipperBase
{
public:
ClipperBase();
virtual ~ClipperBase();
virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed);
bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed);
virtual void Clear();
IntRect GetBounds();
bool PreserveCollinear() {return m_PreserveCollinear;};
void PreserveCollinear(bool value) {m_PreserveCollinear = value;};
protected:
void DisposeLocalMinimaList();
TEdge* AddBoundsToLML(TEdge *e, bool IsClosed);
virtual void Reset();
TEdge* ProcessBound(TEdge* E, bool IsClockwise);
void InsertScanbeam(const cInt Y);
bool PopScanbeam(cInt &Y);
bool LocalMinimaPending();
bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin);
OutRec* CreateOutRec();
void DisposeAllOutRecs();
void DisposeOutRec(PolyOutList::size_type index);
void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2);
void DeleteFromAEL(TEdge *e);
void UpdateEdgeIntoAEL(TEdge *&e);
typedef std::vector<LocalMinimum> MinimaList;
MinimaList::iterator m_CurrentLM;
MinimaList m_MinimaList;
bool m_UseFullRange;
EdgeList m_edges;
bool m_PreserveCollinear;
bool m_HasOpenPaths;
PolyOutList m_PolyOuts;
TEdge *m_ActiveEdges;
typedef std::priority_queue<cInt> ScanbeamList;
ScanbeamList m_Scanbeam;
};
//------------------------------------------------------------------------------
class Clipper : public virtual ClipperBase
{
public:
Clipper(int initOptions = 0);
bool Execute(ClipType clipType,
Paths &solution,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType,
Paths &solution,
PolyFillType subjFillType,
PolyFillType clipFillType);
bool Execute(ClipType clipType,
PolyTree &polytree,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType,
PolyTree &polytree,
PolyFillType subjFillType,
PolyFillType clipFillType);
bool ReverseSolution() { return m_ReverseOutput; };
void ReverseSolution(bool value) {m_ReverseOutput = value;};
bool StrictlySimple() {return m_StrictSimple;};
void StrictlySimple(bool value) {m_StrictSimple = value;};
//set the callback function for z value filling on intersections (otherwise Z is 0)
#ifdef use_xyz
void ZFillFunction(ZFillCallback zFillFunc);
#endif
protected:
virtual bool ExecuteInternal();
private:
JoinList m_Joins;
JoinList m_GhostJoins;
IntersectList m_IntersectList;
ClipType m_ClipType;
typedef std::list<cInt> MaximaList;
MaximaList m_Maxima;
TEdge *m_SortedEdges;
bool m_ExecuteLocked;
PolyFillType m_ClipFillType;
PolyFillType m_SubjFillType;
bool m_ReverseOutput;
bool m_UsingPolyTree;
bool m_StrictSimple;
#ifdef use_xyz
ZFillCallback m_ZFill; //custom callback
#endif
void SetWindingCount(TEdge& edge);
bool IsEvenOddFillType(const TEdge& edge) const;
bool IsEvenOddAltFillType(const TEdge& edge) const;
void InsertLocalMinimaIntoAEL(const cInt botY);
void InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge);
void AddEdgeToSEL(TEdge *edge);
bool PopEdgeFromSEL(TEdge *&edge);
void CopyAELToSEL();
void DeleteFromSEL(TEdge *e);
void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2);
bool IsContributing(const TEdge& edge) const;
bool IsTopHorz(const cInt XPos);
void DoMaxima(TEdge *e);
void ProcessHorizontals();
void ProcessHorizontal(TEdge *horzEdge);
void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutPt* AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutRec* GetOutRec(int idx);
void AppendPolygon(TEdge *e1, TEdge *e2);
void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt);
OutPt* AddOutPt(TEdge *e, const IntPoint &pt);
OutPt* GetLastOutPt(TEdge *e);
bool ProcessIntersections(const cInt topY);
void BuildIntersectList(const cInt topY);
void ProcessIntersectList();
void ProcessEdgesAtTopOfScanbeam(const cInt topY);
void BuildResult(Paths& polys);
void BuildResult2(PolyTree& polytree);
void SetHoleState(TEdge *e, OutRec *outrec);
void DisposeIntersectNodes();
bool FixupIntersectionOrder();
void FixupOutPolygon(OutRec &outrec);
void FixupOutPolyline(OutRec &outrec);
bool IsHole(TEdge *e);
bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl);
void FixHoleLinkage(OutRec &outrec);
void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt);
void ClearJoins();
void ClearGhostJoins();
void AddGhostJoin(OutPt *op, const IntPoint offPt);
bool JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2);
void JoinCommonEdges();
void DoSimplePolygons();
void FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec);
void FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec);
void FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec);
#ifdef use_xyz
void SetZ(IntPoint& pt, TEdge& e1, TEdge& e2);
#endif
};
//------------------------------------------------------------------------------
class ClipperOffset
{
public:
ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25);
~ClipperOffset();
void AddPath(const Path& path, JoinType joinType, EndType endType);
void AddPaths(const Paths& paths, JoinType joinType, EndType endType);
void Execute(Paths& solution, double delta);
void Execute(PolyTree& solution, double delta);
void Clear();
double MiterLimit;
double ArcTolerance;
private:
Paths m_destPolys;
Path m_srcPoly;
Path m_destPoly;
std::vector<DoublePoint> m_normals;
double m_delta, m_sinA, m_sin, m_cos;
double m_miterLim, m_StepsPerRad;
IntPoint m_lowest;
PolyNode m_polyNodes;
void FixOrientations();
void DoOffset(double delta);
void OffsetPoint(int j, int& k, JoinType jointype);
void DoSquare(int j, int k);
void DoMiter(int j, int k, double r);
void DoRound(int j, int k);
};
//------------------------------------------------------------------------------
class clipperException : public std::exception
{
public:
clipperException(const char* description): m_descr(description) {}
virtual ~clipperException() throw() {}
virtual const char* what() const throw() {return m_descr.c_str();}
private:
std::string m_descr;
};
//------------------------------------------------------------------------------
} //ClipperLib namespace
#endif //clipper_hpp
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "math.h"
#include <iostream>
#include <cstring>
#include <fstream>
#define character_type "ch"
#define max_dict_length 6624
const std::vector<int> rec_image_shape {3, 32, 320};
cv::Mat crnn_resize_norm_img(cv::Mat img, float wh_ratio){
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgW = rec_image_shape[2];
imgH = rec_image_shape[1];
if (character_type=="ch")
imgW = int(32*wh_ratio);
float ratio = float(img.cols)/float(img.rows);
int resize_w, resize_h;
if (ceilf(imgH*ratio)>imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH*ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH),0.f, 0.f, cv::INTER_CUBIC);
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
for (int h=0; h< resize_img.rows; h++){
for (int w=0; w< resize_img.cols; w++){
resize_img.at<cv::Vec3f>(h, w)[0] = (resize_img.at<cv::Vec3f>(h, w)[0] - 0.5) *2;
resize_img.at<cv::Vec3f>(h, w)[1] = (resize_img.at<cv::Vec3f>(h, w)[1] - 0.5) *2;
resize_img.at<cv::Vec3f>(h, w)[2] = (resize_img.at<cv::Vec3f>(h, w)[2] - 0.5) *2;
}
}
cv::Mat dist;
cv::copyMakeBorder(resize_img, dist, 0, 0, 0, int(imgW-resize_w), cv::BORDER_CONSTANT, {0, 0, 0});
return dist;
}
cv::Mat crnn_resize_img(cv::Mat img, float wh_ratio){
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgW = rec_image_shape[2];
imgH = rec_image_shape[1];
if (character_type=="ch")
imgW = int(32*wh_ratio);
float ratio = float(img.cols)/float(img.rows);
int resize_w, resize_h;
if (ceilf(imgH*ratio)>imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH*ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH),0.f, 0.f, cv::INTER_LINEAR);
return resize_img;
}
std::basic_string<char, std::char_traits<char>, std::allocator<char>> * read_dict(std::string path){
std::ifstream ifs;
std::string charactors[max_dict_length];
ifs.open(path);
if (!ifs.is_open())
{
std::cout<<"open file "<<path<<" failed"<<std::endl;
}
else
{
std::string con = "";
int count = 0;
while (ifs)
{
getline(ifs, charactors[count]);
count++;
}
}
return charactors;
}
cv::Mat get_rotate_crop_image(cv::Mat srcimage, std::vector<std::vector<int>> box){
cv::Mat image;
srcimage.copyTo(image);
std::vector<std::vector<int>> points = box;
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
int left = int(*std::min_element(x_collect, x_collect+4));
int right = int(*std::max_element(x_collect, x_collect+4));
int top = int(*std::min_element(y_collect, y_collect+4));
int bottom = int(*std::max_element(y_collect, y_collect+4));
cv::Mat img_crop;
image(cv::Rect(left, top, right-left, bottom-top)).copyTo(img_crop);
for(int i=0; i<points.size(); i++){
points[i][0] -= left;
points[i][1] -= top;
}
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
pow(points[0][1] - points[1][1], 2)));
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
pow(points[0][1] - points[3][1], 2)));
cv::Point2f pts_std[4];
pts_std[0] = cv::Point2f(0., 0.);
pts_std[1] = cv::Point2f(img_crop_width, 0.);
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
pts_std[3] = cv::Point2f(0.f, img_crop_height);
cv::Point2f pointsf[4];
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M, cv::Size(img_crop_width, img_crop_height), cv::BORDER_REPLICATE);
if (float(dst_img.rows) >= float(dst_img.cols)*1.5){
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
cv::transpose(dst_img, srcCopy);
cv::flip(srcCopy, srcCopy, 0);
return srcCopy;
}
else{
return dst_img;
}
}
template<class ForwardIterator>
inline size_t argmax(ForwardIterator first, ForwardIterator last)
{
return std::distance(first, std::max_element(first, last));
}
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include <math.h>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "clipper.hpp"
#include "clipper.cpp"
void getcontourarea(float ** box, float unclip_ratio, float & distance){
int pts_num=4;
float area = 0.0f;
float dist = 0.0f;
for (int i=0; i<pts_num; i++){
area += box[i][0] * box[(i+1)%pts_num][1] - box[i][1] * box[(i + 1) % pts_num][0];
dist += sqrtf( (box[i][0] - box[(i + 1) % pts_num][0]) * (box[i][0] - box[(i + 1) % pts_num][0]) + (box[i][1] - box[(i + 1) % pts_num][1]) * (box[i][1] - box[(i + 1) % pts_num][1]) );
}
area = fabs(float(area/2.0));
distance = area * unclip_ratio / dist;
}
cv::RotatedRect unclip(float ** box){
float unclip_ratio = 2.0;
float distance = 1.0;
getcontourarea(box, unclip_ratio, distance);
ClipperLib::ClipperOffset offset;
ClipperLib::Path p;
p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1])) << ClipperLib::IntPoint(int(box[1][0]), int(box[1][1])) <<
ClipperLib::IntPoint(int(box[2][0]), int(box[2][1])) << ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);
ClipperLib::Paths soln;
offset.Execute(soln, distance);
std::vector<cv::Point2f> points;
for (int j=0; j<soln.size(); j++){
for (int i=0; i< soln[soln.size()-1].size(); i++){
points.emplace_back(soln[j][i].X, soln[j][i].Y);
}
}
cv::RotatedRect res = cv::minAreaRect(points);
return res;
}
float ** Mat2Vec(cv::Mat mat)
{
auto **array = new float*[mat.rows];
for (int i = 0; i<mat.rows; ++i)
array[i] = new float[mat.cols];
for (int i = 0; i < mat.rows; ++i)
{
for (int j = 0; j < mat.cols; ++j)
{
array[i][j] = mat.at<float>(i, j);
}
}
return array;
}
void quickSort(float ** s, int l, int r)
{
if (l < r)
{
int i = l, j = r;
float x = s[l][0];
float * xp = s[l];
while (i < j)
{
while(i < j && s[j][0]>= x)
j--;
if(i < j)
std::swap(s[i++], s[j]);
while(i < j && s[i][0]< x)
i++;
if(i < j)
std::swap(s[j--], s[i]);
}
s[i] = xp;
quickSort(s, l, i - 1);
quickSort(s, i + 1, r);
}
}
void quickSort_vector(std::vector<std::vector<int>> & box, int l, int r, int axis){
if (l < r){
int i = l, j = r;
int x = box[l][axis];
std::vector<int> xp (box[l]);
while (i < j)
{
while(i < j && box[j][axis]>= x)
j--;
if(i < j)
std::swap(box[i++], box[j]);
while(i < j && box[i][axis]< x)
i++;
if(i < j)
std::swap(box[j--], box[i]);
}
box[i] = xp;
quickSort_vector(box, l, i - 1, axis);
quickSort_vector(box, i + 1, r, axis);
}
}
std::vector<std::vector<int>> order_points_clockwise(std::vector<std::vector<int>> pts){
std::vector<std::vector<int>> box = pts;
quickSort_vector(box, 0, int(box.size()-1), 0);
std::vector<std::vector<int>> leftmost = {box[0], box[1]};
std::vector<std::vector<int>> rightmost = {box[2], box[3]};
if (leftmost[0][1]>leftmost[1][1])
std::swap(leftmost[0], leftmost[1]);
if (rightmost[0][1]> rightmost[1][1])
std::swap(rightmost[0], rightmost[1]);
std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1], leftmost[1]};
return rect;
}
float ** get_mini_boxes(cv::RotatedRect box, float & ssid){
ssid = box.size.width>=box.size.height?box.size.height:box.size.width;
cv::Mat points;
cv::boxPoints(box, points);
// sorted box points
auto array = Mat2Vec(points);
quickSort(array, 0, 3);
float * idx1=array[0], *idx2=array[1], *idx3=array[2], *idx4=array[3];
if (array[3][1]<=array[2][1]) {
idx2 = array[3];
idx3 = array[2];
}
else{
idx2 = array[2];
idx3 = array[3];
}
if (array[1][1]<=array[0][1]){
idx1 = array[1];
idx4 = array[0];
}
else{
idx1 = array[0];
idx4 = array[1];
}
array[0] = idx1;
array[1] = idx2;
array[2] = idx3;
array[3] = idx4;
return array;
}
template<class T>
T clamp(T x, T min, T max)
{
if (x > max)
return max;
if (x < min)
return min;
return x;
}
float clampf(float x, float min, float max){
if (x > max)
return max;
if (x < min)
return min;
return x;
}
float box_score_fast(float ** box_array, cv::Mat pred){
auto array=box_array;
int width = pred.cols;
int height = pred.rows;
float box_x[4]={array[0][0], array[1][0], array[2][0], array[3][0]};
float box_y[4]={array[0][1], array[1][1], array[2][1], array[3][1]};
int xmin = clamp(int(std::floorf(*(std::min_element(box_x, box_x+4)))), 0, width - 1);
int xmax = clamp(int(std::ceilf(*(std::max_element(box_x, box_x+4)))), 0, width - 1);
int ymin = clamp(int(std::floorf(*(std::min_element(box_y, box_y+4)))), 0, height - 1);
int ymax = clamp(int(std::ceilf(*(std::max_element(box_y, box_y+4)))), 0, height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point root_point[4];
root_point[0] = cv::Point(int(array[0][0])-xmin, int(array[0][1])-ymin);
root_point[1] = cv::Point(int(array[1][0])-xmin, int(array[1][1])-ymin);
root_point[2] = cv::Point(int(array[2][0])-xmin, int(array[2][1])-ymin);
root_point[3] = cv::Point(int(array[3][0])-xmin, int(array[3][1])-ymin);
const cv::Point* ppt[1] = {root_point};
int npt[] = {4};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax-xmin+1,ymax-ymin+1)).copyTo(croppedImg);
auto score = cv::mean(croppedImg, mask)[0];
return score;
}
std::vector<std::vector<std::vector<int>>> boxes_from_bitmap(const cv::Mat pred, const cv::Mat bitmap) {
const int min_size=3;
const int max_candidates = 1000;
const float box_thresh=0.5;
int width = bitmap.cols;
int height = bitmap.rows;
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> hierarchy;
cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST, cv::CHAIN_APPROX_SIMPLE);
int num_contours = contours.size() >= max_candidates ? max_candidates : contours.size();
std::vector<std::vector<std::vector<int>>> boxes;
for (int _i = 0; _i < num_contours; _i++) {
float ssid;
cv::RotatedRect box = cv::minAreaRect(contours[_i]);
auto array = get_mini_boxes(box, ssid);
auto box_for_unclip = array;
//end get_mini_box
if (ssid< min_size) {
continue;
}
float score;
score = box_score_fast(array, pred);
//end box_score_fast
if (score < box_thresh)
continue;
// start for unclip
cv::RotatedRect points = unclip(box_for_unclip);
// end for unclip
cv::RotatedRect clipbox = points;
auto cliparray = get_mini_boxes(clipbox, ssid);
if (ssid < min_size+2) continue;
int dest_width=pred.cols;
int dest_height=pred.rows;
std::vector<std::vector<int>> intcliparray;
for (int num_pt=0; num_pt<4; num_pt++){
std::vector<int> a { int( clampf(roundf(cliparray[num_pt][0]/float(width)*float(dest_width)), 0, float(dest_width)) ),
int( clampf(roundf(cliparray[num_pt][1]/float(height)*float(dest_height)), 0, float(dest_height)) )};
intcliparray.push_back(a);
}
boxes.push_back(intcliparray);
}//end for
return boxes;
}
int _max(int a, int b){
return a>=b?a:b;
}
int _min(int a, int b){
return a>=b?b:a;
}
std::vector<std::vector<std::vector<int>>> filter_tag_det_res(std::vector<std::vector<std::vector<int>>> boxes,
float ratio_h, float ratio_w, cv::Mat srcimg){
int oriimg_h = srcimg.rows;
int oriimg_w = srcimg.cols;
std::vector<std::vector<std::vector<int>>> root_points;
for (int n=0; n<boxes.size(); n++){
boxes[n] = order_points_clockwise(boxes[n]);
for (int m=0; m< boxes[0].size(); m++){
boxes[n][m][0] /= ratio_w;
boxes[n][m][1] /= ratio_h;
boxes[n][m][0] = int(_min(_max(boxes[n][m][0], 0), oriimg_w-1));
boxes[n][m][1] = int(_min(_max(boxes[n][m][1], 0), oriimg_h-1));
}
}
for(int n=0; n<boxes.size(); n++){
int rect_width, rect_height;
rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) + pow(boxes[n][0][1] - boxes[n][1][1], 2)));
rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) + pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <=10 || rect_height<=10)
continue;
root_points.push_back(boxes[n]);
}
return root_points;
}
/*
using namespace std;
// read data from txt file
cv::Mat readtxt2(std::string path, int imgw, int imgh, int imgc) {
std::cout << "read data file from txt file! " << std::endl;
ifstream in(path);
string line;
int count = 0;
int i = 0, j = 0;
std::vector<float> img_mean = {0.485, 0.456, 0.406};
std::vector<float> img_std = {0.229, 0.224, 0.225};
float trainData[imgh][imgw*imgc];
while (getline(in, line)) {
stringstream ss(line);
double x;
while (ss >> x) {
// trainData[i][j] = float(x) * img_std[j % 3] + img_mean[j % 3];
trainData[i][j] = float(x);
j++;
}
i++;
j = 0;
}
cv::Mat pred_map(imgh, imgw*imgc, CV_32FC1, (float *) trainData);
cv::Mat reshape_img = pred_map.reshape(imgc, imgh);
return reshape_img;
}
*/
//using namespace std;
//
//void writetxt(vector<vector<float>> data, std::string save_path){
//
// ofstream fout(save_path);
//
// for (int i = 0; i < data.size(); i++) {
// for (int j=0; j< data[0].size(); j++){
// fout << data[i][j] << " ";
// }
// fout << endl;
// }
// fout << endl;
// fout.close();
//}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册