提交 051fe64a 编写于 作者: J Jethong

fix config add e2e_ch.md e2e_res_img623_pg

上级 84528168
Global: Global:
use_gpu: False use_gpu: True
epoch_num: 600 epoch_num: 600
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 2 print_batch_step: 10
save_model_dir: ./output/pg_r50_vd_tt/ save_model_dir: ./output/pg_r50_vd_tt/
save_epoch_step: 1 save_epoch_step: 10
# evaluation is run every 5000 iterationss after the 4000th iteration # evaluation is run every 0 iterationss after the 1000th iteration
eval_batch_step: [ 0, 1000 ] eval_batch_step: [ 0, 1000 ]
# if pretrained_model is saved in static mode, load_static_weights must set to True # 1. If pretrained_model is saved in static mode, such as classification pretrained model
load_static_weights: False # from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: pretrained_model:
checkpoints: checkpoints:
...@@ -19,7 +22,7 @@ Global: ...@@ -19,7 +22,7 @@ Global:
Architecture: Architecture:
model_type: e2e model_type: e2e
algorithm: PG algorithm: PGNet
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet
...@@ -34,28 +37,16 @@ Architecture: ...@@ -34,28 +37,16 @@ Architecture:
Loss: Loss:
name: PGLoss name: PGLoss
#Optimizer:
# name: Adam
# beta1: 0.9
# beta2: 0.999
# lr:
# name: Cosine
# learning_rate: 0.001
# warmup_epoch: 1
# regularizer:
# name: 'L2'
# factor: 0
Optimizer: Optimizer:
name: RMSProp name: Adam
beta1: 0.9
beta2: 0.999
lr: lr:
name: Piecewise
learning_rate: 0.001 learning_rate: 0.001
decay_epochs: [ 40, 80, 120, 160, 200 ]
values: [ 0.001, 0.00033, 0.0001, 0.000033, 0.00001 ]
regularizer: regularizer:
name: 'L2' name: 'L2'
factor: 0.00005 factor: 0
PostProcess: PostProcess:
name: PGPostProcess name: PGPostProcess
...@@ -65,45 +56,45 @@ PostProcess: ...@@ -65,45 +56,45 @@ PostProcess:
Metric: Metric:
name: E2EMetric name: E2EMetric
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
main_indicator: f_score_e2e main_indicator: f_score_e2e
Train: Train:
dataset: dataset:
name: PGDateSet name: PGDateSet
label_file_list: label_file_list: [./train_data/total_text/train/]
ratio_list: ratio_list: [1.0]
data_format: textnet # textnet/partvgg data_format: icdar
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- PGProcessTrain: - PGProcessTrain:
batch_size: 14 batch_size: 14
data_format: icdar
tcl_len: 64
min_crop_size: 24 min_crop_size: 24
min_text_size: 4 min_text_size: 4
max_text_size: 512 max_text_size: 512
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
- KeepKeys: - KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: True drop_last: True
batch_size_per_card: 1 batch_size_per_card: 14
num_workers: 8 num_workers: 16
Eval: Eval:
dataset: dataset:
name: PGDateSet name: PGDataSet
data_dir: ./train_data/ data_dir: ./train_data/
label_file_list: label_file_list: [./train_data/total_text/test/]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- E2ELabelEncode: - E2ELabelEncode:
label_list: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
max_len: 50
- E2EResizeForTest: - E2EResizeForTest:
valid_set: totaltext valid_set: totaltext
max_side_len: 768 max_side_len: 768
......
# 端到端文字识别
本节以partvgg/totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
## 数据准备
支持两种不同的数据形式textnet / icdar ,分别为四点标注数据和十四点标注数据,十四点标注数据效果要比四点标注效果好
###数据形式为textnet
解压数据集和下载标注文件后,PaddleOCR/train_data/part_vgg_synth/train/ 有一个文件夹和一个文件,分别是:
```
/PaddleOCR/train_data/part_vgg_synth/train/
└─ image/ partvgg数据集的训练数据
└─ train_annotation_info.txt partvgg数据集的测试标注
```
提供的标注文件格式如下,中间用"\t"分隔:
```
" 图像文件名 图像标注信息--四点标注 图像标注信息--识别标注
119_nile_110_31 140.2 222.5 266.0 194.6 278.7 251.8 152.9 279.7 Path: 32.9 133.1 106.0 130.8 106.4 143.8 33.3 146.1 were 21.8 81.9 106.9 80.4 107.7 123.2 22.6 124.7 why
```
标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
###数据形式为icdar
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
```
/PaddleOCR/train_data/total_text/train/
└─ rgb/ total_text数据集的训练数据
└─ poly/ total_text数据集的测试标注
```
提供的标注文件格式如下,中间用"\t"分隔:
```
" 图像标注信息--十四点标注数据 图像标注信息--识别标注
1004.0,689.0,1019.0,698.0,1034.0,708.0,1049.0,718.0,1064.0,728.0,1079.0,738.0,1095.0,748.0,1094.0,774.0,1079.0,765.0,1065.0,756.0,1050.0,747.0,1036.0,738.0,1021.0,729.0,1007.0,721.0 EST
1102.0,755.0,1116.0,764.0,1131.0,773.0,1146.0,783.0,1161.0,792.0,1176.0,801.0,1191.0,811.0,1193.0,837.0,1178.0,828.0,1164.0,819.0,1150.0,810.0,1135.0,801.0,1121.0,792.0,1107.0,784.0 1972
```
标注文件当中,其中每一个txt文件代表一组数据,文件名同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
## 快速启动训练
首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列,
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
```shell
cd PaddleOCR/
下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
# 解压预训练模型文件,以ResNet50_vd为例
tar -xf ./pretrain_models/ResNet50_vd_ssld_pretrained.tar ./pretrain_models/
# 注:正确解压backbone预训练权重文件后,文件夹下包含众多以网络层命名的权重文件,格式如下:
./pretrain_models/ResNet50_vd_ssld_pretrained/
└─ conv_last_bn_mean
└─ conv_last_bn_offset
└─ conv_last_bn_scale
└─ conv_last_bn_variance
└─ ......
```
#### 启动训练
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```shell
# 单机单卡训练 e2e 模型
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \
-o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \
-o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True
```
上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
```
#### 断点训练
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
```
**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
## 指标评估
PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。
运行如下代码,根据配置文件`e2e_r50_vd_pg.yml``save_res_path`指定的测试集检测结果文件,计算评估指标。
评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
```shell
python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
```
## 测试端到端效果
测试单张图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
测试文件夹下所有图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
...@@ -34,7 +34,7 @@ import paddle.distributed as dist ...@@ -34,7 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDateSet from ppocr.data.pgnet_dataset import PGDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -55,8 +55,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -55,8 +55,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
support_dict = ['SimpleDataSet', 'LMDBDateSet', 'PGDateSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
......
...@@ -35,9 +35,10 @@ class ClsLabelEncode(object): ...@@ -35,9 +35,10 @@ class ClsLabelEncode(object):
class E2ELabelEncode(object): class E2ELabelEncode(object):
def __init__(self, label_list, **kwargs): def __init__(self, Lexicon_Table, max_len, **kwargs):
self.label_list = label_list self.Lexicon_Table = Lexicon_Table
self.max_len = 50 self.max_len = max_len
self.pad_num = len(self.Lexicon_Table)
def __call__(self, data): def __call__(self, data):
text_label_index_list, temp_text = [], [] text_label_index_list, temp_text = [], []
...@@ -46,9 +47,10 @@ class E2ELabelEncode(object): ...@@ -46,9 +47,10 @@ class E2ELabelEncode(object):
text = text.upper() text = text.upper()
temp_text = [] temp_text = []
for c_ in text: for c_ in text:
if c_ in self.label_list: if c_ in self.Lexicon_Table:
temp_text.append(self.label_list.index(c_)) temp_text.append(self.Lexicon_Table.index(c_))
temp_text = temp_text + [36] * (self.max_len - len(temp_text)) temp_text = temp_text + [self.pad_num] * (self.max_len -
len(temp_text))
text_label_index_list.append(temp_text) text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list) data['strs'] = np.array(text_label_index_list)
return data return data
......
...@@ -197,7 +197,6 @@ class DetResizeForTest(object): ...@@ -197,7 +197,6 @@ class DetResizeForTest(object):
sys.exit(0) sys.exit(0)
ratio_h = resize_h / float(h) ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w] return img, [ratio_h, ratio_w]
def resize_image_type2(self, img): def resize_image_type2(self, img):
...@@ -206,7 +205,6 @@ class DetResizeForTest(object): ...@@ -206,7 +205,6 @@ class DetResizeForTest(object):
resize_w = w resize_w = w
resize_h = h resize_h = h
# Fix the longer side
if resize_h > resize_w: if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h ratio = float(self.resize_long) / resize_h
else: else:
...@@ -245,10 +243,8 @@ class E2EResizeForTest(object): ...@@ -245,10 +243,8 @@ class E2EResizeForTest(object):
return data return data
def resize_image_for_totaltext(self, im, max_side_len=512): def resize_image_for_totaltext(self, im, max_side_len=512):
"""
"""
h, w, _ = im.shape
h, w, _ = im.shape
resize_w = w resize_w = w
resize_h = h resize_h = h
ratio = 1.25 ratio = 1.25
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import math import math
import cv2 import cv2
import numpy as np import numpy as np
import os
__all__ = ['PGProcessTrain'] __all__ = ['PGProcessTrain']
...@@ -23,15 +22,11 @@ __all__ = ['PGProcessTrain'] ...@@ -23,15 +22,11 @@ __all__ = ['PGProcessTrain']
class PGProcessTrain(object): class PGProcessTrain(object):
def __init__(self, def __init__(self,
batch_size=14, batch_size=14,
data_format='icdar',
tcl_len=64,
min_crop_size=24, min_crop_size=24,
min_text_size=10, min_text_size=10,
max_text_size=512, max_text_size=512,
**kwargs): **kwargs):
self.batch_size = batch_size self.batch_size = batch_size
self.data_format = data_format
self.tcl_len = tcl_len
self.min_crop_size = min_crop_size self.min_crop_size = min_crop_size
self.min_text_size = min_text_size self.min_text_size = min_text_size
self.max_text_size = max_text_size self.max_text_size = max_text_size
...@@ -60,24 +55,22 @@ class PGProcessTrain(object): ...@@ -60,24 +55,22 @@ class PGProcessTrain(object):
""" """
point_num = poly.shape[0] point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32) min_area_quad = np.zeros((4, 2), dtype=np.float32)
if True: rect = cv2.minAreaRect(poly.astype(
rect = cv2.minAreaRect(poly.astype( np.int32)) # (center (x,y), (width, height), angle of rotation)
np.int32)) # (center (x,y), (width, height), angle of rotation) box = np.array(cv2.boxPoints(rect))
center_point = rect[0]
box = np.array(cv2.boxPoints(rect)) first_point_idx = 0
min_dist = 1e4
first_point_idx = 0 for i in range(4):
min_dist = 1e4 dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
for i in range(4): np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ if dist < min_dist:
np.linalg.norm(box[(i + 3) % 4] - poly[-1]) min_dist = dist
if dist < min_dist: first_point_idx = i
min_dist = dist for i in range(4):
first_point_idx = i min_area_quad[i] = box[(first_point_idx + i) % 4]
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad return min_area_quad
...@@ -235,8 +228,6 @@ class PGProcessTrain(object): ...@@ -235,8 +228,6 @@ class PGProcessTrain(object):
ys, xs = np.where(tmp_image > 0) ys, xs = np.where(tmp_image > 0)
xy_text = np.array(list(zip(xs, ys)), dtype='float32') xy_text = np.array(list(zip(xs, ys)), dtype='float32')
# left_center_pt = np.array(key_point_xys[0]).reshape(1, 2)
# right_center_pt = np.array(key_point_xys[-1]).reshape(1, 2)
left_center_pt = ( left_center_pt = (
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2) (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
right_center_pt = ( right_center_pt = (
...@@ -317,16 +308,6 @@ class PGProcessTrain(object): ...@@ -317,16 +308,6 @@ class PGProcessTrain(object):
average_height = max(sum(height_list) / len(height_list), 1.0) average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height return average_height
def encode(self, text):
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append([self.dict[char]])
if len(text_list) == 0:
return None
return text_list
def generate_tcl_ctc_label(self, def generate_tcl_ctc_label(self,
h, h,
w, w,
...@@ -390,8 +371,6 @@ class PGProcessTrain(object): ...@@ -390,8 +371,6 @@ class PGProcessTrain(object):
text_label = text_strs[poly_idx] text_label = text_strs[poly_idx]
text_label = self.prepare_text_label(text_label, text_label = self.prepare_text_label(text_label,
self.Lexicon_Table) self.Lexicon_Table)
# text = text.decode('utf-8')
# text_label_index_list = self.encode(text)
text_label_index_list = [[self.Lexicon_Table.index(c_)] text_label_index_list = [[self.Lexicon_Table.index(c_)]
for c_ in text_label for c_ in text_label
...@@ -402,22 +381,18 @@ class PGProcessTrain(object): ...@@ -402,22 +381,18 @@ class PGProcessTrain(object):
tcl_poly = self.poly2tcl(poly, tcl_ratio) tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly) tcl_quads = self.poly2quads(tcl_poly)
poly_quads = self.poly2quads(poly) poly_quads = self.poly2quads(poly)
# stcl map
stcl_quads, quad_index = self.shrink_poly_along_width( stcl_quads, quad_index = self.shrink_poly_along_width(
tcl_quads, tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width, shrink_ratio_of_width=shrink_ratio_of_width,
expand_height_ratio=1.0 / tcl_ratio) expand_height_ratio=1.0 / tcl_ratio)
# generate tcl map
cv2.fillPoly(score_map, cv2.fillPoly(score_map,
np.round(stcl_quads).astype(np.int32), 1.0) np.round(stcl_quads).astype(np.int32), 1.0)
cv2.fillPoly(score_map_big, cv2.fillPoly(score_map_big,
np.round(stcl_quads / ds_ratio).astype(np.int32), np.round(stcl_quads / ds_ratio).astype(np.int32),
1.0) 1.0)
# generate tbo map
# tbo_tcl_poly = poly2tcl(poly, 0.5)
# tbo_tcl_quads = poly2quads(tbo_tcl_poly)
# for idx, quad in enumerate(tbo_tcl_quads):
for idx, quad in enumerate(stcl_quads): for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32) quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly( quad_mask = cv2.fillPoly(
...@@ -432,7 +407,6 @@ class PGProcessTrain(object): ...@@ -432,7 +407,6 @@ class PGProcessTrain(object):
score_label_map_text_label_list.append(text_pos_list_) score_label_map_text_label_list.append(text_pos_list_)
label_idx += 1 label_idx += 1
# cv2.fillPoly(score_label_map, np.round(poly_quads[np.newaxis, :, :]).astype(np.int32), label_idx)
cv2.fillPoly(score_label_map, cv2.fillPoly(score_label_map,
np.round(poly_quads).astype(np.int32), label_idx) np.round(poly_quads).astype(np.int32), label_idx)
score_label_map_text_label_list.append(text_label_index_list) score_label_map_text_label_list.append(text_label_index_list)
...@@ -641,8 +615,6 @@ class PGProcessTrain(object): ...@@ -641,8 +615,6 @@ class PGProcessTrain(object):
d = a1 * b2 - a2 * b1 d = a1 * b2 - a2 * b1
if d == 0: if d == 0:
# print("line1", line1)
# print("line2", line2)
print('Cross point does not exist') print('Cross point does not exist')
return np.array([0, 0], dtype=np.float32) return np.array([0, 0], dtype=np.float32)
else: else:
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,9 +18,9 @@ from .imaug import transform, create_operators ...@@ -18,9 +18,9 @@ from .imaug import transform, create_operators
import random import random
class PGDateSet(Dataset): class PGDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None): def __init__(self, config, mode, logger, seed=None):
super(PGDateSet, self).__init__() super(PGDataSet, self).__init__()
self.logger = logger self.logger = logger
self.seed = seed self.seed = seed
...@@ -81,7 +81,9 @@ class PGDateSet(Dataset): ...@@ -81,7 +81,9 @@ class PGDateSet(Dataset):
""" """
info_list = im_fn.split('\t') info_list = im_fn.split('\t')
img_path = '' img_path = ''
for ext in ['.jpg', '.png', '.jpeg', '.JPG']: for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + ext)): if os.path.exists(os.path.join(img_dir, info_list[0] + ext)):
img_path = os.path.join(img_dir, info_list[0] + ext) img_path = os.path.join(img_dir, info_list[0] + ext)
break break
...@@ -111,11 +113,12 @@ class PGDateSet(Dataset): ...@@ -111,11 +113,12 @@ class PGDateSet(Dataset):
for idx, data_source in enumerate(file_list): for idx, data_source in enumerate(file_list):
image_files = [] image_files = []
if data_format == 'icdar': if data_format == 'icdar':
image_files = [ image_files = [(data_source, x) for x in
(data_source, x) os.listdir(os.path.join(data_source, 'rgb'))
for x in os.listdir(os.path.join(data_source, 'rgb')) if x.split('.')[-1] in [
if x.split('.')[-1] in ['jpg', 'png', 'jpeg', 'JPG'] 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
] 'tiff', 'gif', 'JPG'
]]
elif data_format == 'textnet': elif data_format == 'textnet':
with open(data_source) as f: with open(data_source) as f:
image_files = [(data_source, x.strip()) image_files = [(data_source, x.strip())
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,20 +21,12 @@ import paddle ...@@ -21,20 +21,12 @@ import paddle
import numpy as np import numpy as np
import copy import copy
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss from .det_basic_loss import DiceLoss
class PGLoss(nn.Layer): class PGLoss(nn.Layer):
""" def __init__(self, eps=1e-6, **kwargs):
Differentiable Binarization (DB) Loss Function
args:
param (dict): the super paramter for DB Loss
"""
def __init__(self, alpha=5, beta=10, eps=1e-6, **kwargs):
super(PGLoss, self).__init__() super(PGLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.dice_loss = DiceLoss(eps=eps) self.dice_loss = DiceLoss(eps=eps)
def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists): def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists):
...@@ -86,27 +78,30 @@ class PGLoss(nn.Layer): ...@@ -86,27 +78,30 @@ class PGLoss(nn.Layer):
return pos_lists_, pos_masks_, label_lists_ return pos_lists_, pos_masks_, label_lists_
def pre_process(self, label_list, pos_list, pos_mask): def pre_process(self, label_list, pos_list, pos_mask):
max_len = 30 # the max texts in a single image
max_str_len = 50 # the max len in a single text
pad_num = 36 # padding num
label_list = label_list.numpy() label_list = label_list.numpy()
b, h, w, c = label_list.shape batch, _, _, _ = label_list.shape
pos_list = pos_list.numpy() pos_list = pos_list.numpy()
pos_mask = pos_mask.numpy() pos_mask = pos_mask.numpy()
pos_list_t = [] pos_list_t = []
pos_mask_t = [] pos_mask_t = []
label_list_t = [] label_list_t = []
for i in range(b): for i in range(batch):
for j in range(30): for j in range(max_len):
if pos_mask[i, j].any(): if pos_mask[i, j].any():
pos_list_t.append(pos_list[i][j]) pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j]) pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j]) label_list_t.append(label_list[i][j])
pos_list, pos_mask, label_list = self.org_tcl_rois( pos_list, pos_mask, label_list = self.org_tcl_rois(
b, pos_list_t, pos_mask_t, label_list_t) batch, pos_list_t, pos_mask_t, label_list_t)
label = [] label = []
tt = [l.tolist() for l in label_list] tt = [l.tolist() for l in label_list]
for i in range(64): for i in range(batch):
k = 0 k = 0
for j in range(50): for j in range(max_str_len):
if tt[i][j][0] != 36: if tt[i][j][0] != pad_num:
k += 1 k += 1
else: else:
break break
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,12 +22,9 @@ from ppocr.utils.e2e_metric.Deteval import * ...@@ -22,12 +22,9 @@ from ppocr.utils.e2e_metric.Deteval import *
class E2EMetric(object): class E2EMetric(object):
def __init__(self, main_indicator='f_score_e2e', **kwargs): def __init__(self, Lexicon_Table, main_indicator='f_score_e2e', **kwargs):
self.label_list = [ self.label_list = Lexicon_Table
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', self.max_index = len(self.label_list)
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.reset() self.reset()
...@@ -40,12 +37,12 @@ class E2EMetric(object): ...@@ -40,12 +37,12 @@ class E2EMetric(object):
for temp_list in temp_gt_strs_batch: for temp_list in temp_gt_strs_batch:
t = "" t = ""
for index in temp_list: for index in temp_list:
if index < 36: if index < self.max_index:
t += self.label_list[index] t += self.label_list[index]
gt_strs_batch.append(t) gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip( for pred, gt_polyons, gt_strs, ignore_tags in zip(
preds, gt_polyons_batch, gt_strs_batch, ignore_tags_batch): [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
# prepare gt # prepare gt
gt_info_list = [{ gt_info_list = [{
'points': gt_polyon, 'points': gt_polyon,
...@@ -57,7 +54,7 @@ class E2EMetric(object): ...@@ -57,7 +54,7 @@ class E2EMetric(object):
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'text': pred_str 'text': pred_str
} for det_polyon, pred_str in zip(pred['points'], preds['strs'])] } for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
result = get_socre(gt_info_list, e2e_info_list) result = get_socre(gt_info_list, e2e_info_list)
self.results.append(result) self.results.append(result)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -62,8 +62,6 @@ class ConvBNLayer(nn.Layer): ...@@ -62,8 +62,6 @@ class ConvBNLayer(nn.Layer):
moving_variance_name=bn_name + '_variance') moving_variance_name=bn_name + '_variance')
def forward(self, inputs): def forward(self, inputs):
# if self.is_vd_mode:
# inputs = self._pool2d_avg(inputs)
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
return y return y
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -179,7 +179,7 @@ class PGHead(nn.Layer): ...@@ -179,7 +179,7 @@ class PGHead(nn.Layer):
name="conv_f_char{}".format(5)) name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2D( self.conv3 = nn.Conv2D(
in_channels=256, in_channels=256,
out_channels=6625, out_channels=37,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -60,8 +60,6 @@ class ConvBNLayer(nn.Layer): ...@@ -60,8 +60,6 @@ class ConvBNLayer(nn.Layer):
use_global_stats=False) use_global_stats=False)
def forward(self, inputs): def forward(self, inputs):
# if self.is_vd_mode:
# inputs = self._pool2d_avg(inputs)
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
return y return y
...@@ -112,7 +110,6 @@ class PGFPN(nn.Layer): ...@@ -112,7 +110,6 @@ class PGFPN(nn.Layer):
num_inputs = [2048, 2048, 1024, 512, 256] num_inputs = [2048, 2048, 1024, 512, 256]
num_outputs = [256, 256, 192, 192, 128] num_outputs = [256, 256, 192, 192, 128]
self.out_channels = 128 self.out_channels = 128
# print(in_channels)
self.conv_bn_layer_1 = ConvBNLayer( self.conv_bn_layer_1 = ConvBNLayer(
in_channels=3, in_channels=3,
out_channels=32, out_channels=32,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,14 +23,9 @@ __dir__ = os.path.dirname(__file__) ...@@ -23,14 +23,9 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
from ppocr.utils.e2e_utils.extract_textpoint import * from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.ski_thin import *
from ppocr.utils.e2e_utils.visual import * from ppocr.utils.e2e_utils.visual import *
import paddle import paddle
import cv2
import time
class PGPostProcess(object): class PGPostProcess(object):
...@@ -115,7 +110,6 @@ class PGPostProcess(object): ...@@ -115,7 +110,6 @@ class PGPostProcess(object):
if len(yx_center_line) == 1: if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1]) yx_center_line.append(yx_center_line[-1])
# expand corresponding offset for total-text.
offset_expand = 1.0 offset_expand = 1.0
if self.valid_set == 'totaltext': if self.valid_set == 'totaltext':
offset_expand = 1.2 offset_expand = 1.2
...@@ -137,7 +131,6 @@ class PGPostProcess(object): ...@@ -137,7 +131,6 @@ class PGPostProcess(object):
[ratio_w, ratio_h]).reshape(-1, 2) [ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair) point_pair_list.append(point_pair)
# for visualization
all_point_list.append([ all_point_list.append([
int(round(x * 4.0 / ratio_w)), int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h)) int(round(y * 4.0 / ratio_h))
...@@ -145,7 +138,6 @@ class PGPostProcess(object): ...@@ -145,7 +138,6 @@ class PGPostProcess(object):
all_point_pair_list.append(point_pair.round().astype(np.int32) all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist()) .tolist())
# ndarry: (x, 2)
detected_poly, pair_length_info = point_pair2poly(point_pair_list) detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width( detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2) detected_poly, shrink_ratio_of_width=0.2)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,42 +11,27 @@ ...@@ -11,42 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
try: # python2
range = xrange
except Exception:
# python3
range = range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
def get_socre(gt_dict, pred_dict): def get_socre(gt_dict, pred_dict):
# allInputs = listdir(input_dir)
allInputs = 1 allInputs = 1
def input_reading_mod(pred_dict, input): def input_reading_mod(pred_dict):
"""This helper reads input from txt files""" """This helper reads input from txt files"""
det = [] det = []
n = len(pred_dict) n = len(pred_dict)
for i in range(n): for i in range(n):
points = pred_dict[i]['points'] points = pred_dict[i]['points']
text = pred_dict[i]['text'] text = pred_dict[i]['text']
# for i in range(len(points)):
point = ",".join(map(str, points.reshape(-1, ))) point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text]) det.append([point, text])
return det return det
def gt_reading_mod(gt_dict, gt_id): def gt_reading_mod(gt_dict):
"""This helper reads groundtruths from mat files""" """This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt = [] gt = []
n = len(gt_dict) n = len(gt_dict)
for i in range(n): for i in range(n):
...@@ -74,23 +59,12 @@ def get_socre(gt_dict, pred_dict): ...@@ -74,23 +59,12 @@ def get_socre(gt_dict, pred_dict):
def detection_filtering(detections, groundtruths, threshold=0.5): def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths): for gt_id, gt in enumerate(groundtruths):
print
"liushanshan gt[1] = {}".format(gt[1])
print
"liushanshan gt[2] = {}".format(gt[2])
print
"liushanshan gt[3] = {}".format(gt[3])
print
"liushanshan gt[4] = {}".format(gt[4])
print
"liushanshan gt[5] = {}".format(gt[5])
if (gt[5] == '#') and (gt[1].shape[1] > 1): if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1]))) gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3]))) gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections): for det_id, detection in enumerate(detections):
detection_orig = detection detection_orig = detection
detection = [float(x) for x in detection[0].split(',')] detection = [float(x) for x in detection[0].split(',')]
# detection = detection.split(',')
detection = list(map(int, detection)) detection = list(map(int, detection))
det_x = detection[0::2] det_x = detection[0::2]
det_y = detection[1::2] det_y = detection[1::2]
...@@ -105,18 +79,10 @@ def get_socre(gt_dict, pred_dict): ...@@ -105,18 +79,10 @@ def get_socre(gt_dict, pred_dict):
""" """
sigma = inter_area / gt_area sigma = inter_area / gt_area
""" """
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2) area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y): def tau_calculation(det_x, det_y, gt_x, gt_y):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if area(det_x, det_y) == 0.0: if area(det_x, det_y) == 0.0:
return 0 return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
...@@ -141,10 +107,8 @@ def get_socre(gt_dict, pred_dict): ...@@ -141,10 +107,8 @@ def get_socre(gt_dict, pred_dict):
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'): and (input_id != 'Deteval_result_non_curved.txt'):
print(input_id) detections = input_reading_mod(pred_dict)
detections = input_reading_mod(pred_dict, input_id) groundtruths = gt_reading_mod(gt_dict)
# print "liushanshan detections = {}".format(detections)
groundtruths = gt_reading_mod(gt_dict, input_id)
detections = detection_filtering( detections = detection_filtering(
detections, detections,
groundtruths) # filters detections overlapping with DC area groundtruths) # filters detections overlapping with DC area
...@@ -187,10 +151,6 @@ def get_socre(gt_dict, pred_dict): ...@@ -187,10 +151,6 @@ def get_socre(gt_dict, pred_dict):
global_tau.append(local_tau_table) global_tau.append(local_tau_table)
global_pred_str.append(local_pred_str) global_pred_str.append(local_pred_str)
global_gt_str.append(local_gt_str) global_gt_str.append(local_gt_str)
print
"liushanshan global_pred_str = {}".format(global_pred_str)
print
"liushanshan global_gt_str = {}".format(global_gt_str)
global_accumulative_recall = 0 global_accumulative_recall = 0
global_accumulative_precision = 0 global_accumulative_precision = 0
...@@ -236,17 +196,11 @@ def get_socre(gt_dict, pred_dict): ...@@ -236,17 +196,11 @@ def get_socre(gt_dict, pred_dict):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start # recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]] 0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
else: else:
...@@ -290,20 +244,10 @@ def get_socre(gt_dict, pred_dict): ...@@ -290,20 +244,10 @@ def get_socre(gt_dict, pred_dict):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1 det_flag[0, qualified_tau_candidates] = 1
# recg start # recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][ pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]] qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
else: else:
...@@ -315,19 +259,11 @@ def get_socre(gt_dict, pred_dict): ...@@ -315,19 +259,11 @@ def get_socre(gt_dict, pred_dict):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1 det_flag[0, qualified_tau_candidates] = 1
# recg start # recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][ pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]] qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
else: else:
...@@ -377,25 +313,14 @@ def get_socre(gt_dict, pred_dict): ...@@ -377,25 +313,14 @@ def get_socre(gt_dict, pred_dict):
gt_flag[0, qualified_sigma_candidates] = 1 gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1 det_flag[0, det_id] = 1
# recg start # recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id] pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0]) gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len): for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[ ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx] idx]
if not global_gt_str[idy].has_key(ele_gt_id): if ele_gt_id not in global_gt_str[idy]:
continue continue
gt_str_cur = global_gt_str[idy][ele_gt_id] gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
break break
...@@ -409,24 +334,14 @@ def get_socre(gt_dict, pred_dict): ...@@ -409,24 +334,14 @@ def get_socre(gt_dict, pred_dict):
det_flag[0, det_id] = 1 det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1 gt_flag[0, qualified_sigma_candidates] = 1
# recg start # recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id] pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0]) gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len): for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id): if ele_gt_id not in global_gt_str[idy]:
continue continue
gt_str_cur = global_gt_str[idy][ele_gt_id] gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
break break
...@@ -434,9 +349,6 @@ def get_socre(gt_dict, pred_dict): ...@@ -434,9 +349,6 @@ def get_socre(gt_dict, pred_dict):
if pred_str_cur.lower() == gt_str_cur.lower(): if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1 hit_str_num += 1
break break
else:
print
'no match'
# recg end # recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
...@@ -448,7 +360,6 @@ def get_socre(gt_dict, pred_dict): ...@@ -448,7 +360,6 @@ def get_socre(gt_dict, pred_dict):
single_data = {} single_data = {}
for idx in range(len(global_sigma)): for idx in range(len(global_sigma)):
# print(allInputs[idx])
local_sigma_table = global_sigma[idx] local_sigma_table = global_sigma[idx]
local_tau_table = global_tau[idx] local_tau_table = global_tau[idx]
...@@ -504,8 +415,6 @@ def get_socre(gt_dict, pred_dict): ...@@ -504,8 +415,6 @@ def get_socre(gt_dict, pred_dict):
except ZeroDivisionError: except ZeroDivisionError:
local_f_score = 0 local_f_score = 0
# temp = ('%s: Recall=%.4f, Precision=%.4f, f_score=%.4f\n' % (
# allInputs[idx], local_recall, local_precision, local_f_score))
single_data['sigma'] = global_sigma single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str single_data['global_pred_str'] = global_pred_str
...@@ -575,17 +484,9 @@ def combine_results(all_data): ...@@ -575,17 +484,9 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start # recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]] 0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
else: else:
...@@ -629,20 +530,9 @@ def combine_results(all_data): ...@@ -629,20 +530,9 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1 det_flag[0, qualified_tau_candidates] = 1
# recg start # recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][ pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]] qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
else: else:
...@@ -654,19 +544,9 @@ def combine_results(all_data): ...@@ -654,19 +544,9 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1 det_flag[0, qualified_tau_candidates] = 1
# recg start # recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][ pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]] qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
else: else:
...@@ -716,11 +596,6 @@ def combine_results(all_data): ...@@ -716,11 +596,6 @@ def combine_results(all_data):
gt_flag[0, qualified_sigma_candidates] = 1 gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1 det_flag[0, det_id] = 1
# recg start # recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id] pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0]) gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len): for idx in range(gt_len):
...@@ -729,12 +604,6 @@ def combine_results(all_data): ...@@ -729,12 +604,6 @@ def combine_results(all_data):
if ele_gt_id not in global_gt_str[idy]: if ele_gt_id not in global_gt_str[idy]:
continue continue
gt_str_cur = global_gt_str[idy][ele_gt_id] gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
break break
...@@ -748,24 +617,13 @@ def combine_results(all_data): ...@@ -748,24 +617,13 @@ def combine_results(all_data):
det_flag[0, det_id] = 1 det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1 gt_flag[0, qualified_sigma_candidates] = 1
# recg start # recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id] pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0]) gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len): for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id): if ele_gt_id not in global_gt_str[idy]:
continue continue
gt_str_cur = global_gt_str[idy][ele_gt_id] gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
break break
...@@ -773,9 +631,6 @@ def combine_results(all_data): ...@@ -773,9 +631,6 @@ def combine_results(all_data):
if pred_str_cur.lower() == gt_str_cur.lower(): if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1 hit_str_num += 1
break break
else:
print
'no match'
# recg end # recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,14 +16,12 @@ from __future__ import absolute_import ...@@ -16,14 +16,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import cv2 import cv2
import time
import math import math
import numpy as np import numpy as np
from itertools import groupby from itertools import groupby
from ppocr.utils.e2e_utils.ski_thin import thin from skimage.morphology._skeletonize import thin
def softmax(logits): def softmax(logits):
...@@ -518,28 +516,6 @@ def generate_pivot_list_tt_inference(p_score, ...@@ -518,28 +516,6 @@ def generate_pivot_list_tt_inference(p_score,
continue continue
pos_list_sorted = sort_and_expand_with_direction_v2( pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map) pos_list, f_direction, p_tcl_map)
# pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id) pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id) all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs return all_pos_yxs
if __name__ == '__main__':
np.random.seed(0)
import time
logits_map = np.random.random([10, 20, 33])
# a list of [x, y]
instance_gather_info_1 = [(2, 3), (2, 4), (3, 5)]
instance_gather_info_2 = [(15, 6), (15, 7), (18, 8)]
instance_gather_info_3 = [(8, 8), (8, 8), (8, 8)]
gather_info_list = [
instance_gather_info_1, instance_gather_info_2, instance_gather_info_3
]
time0 = time.time()
res = ctc_decoder_for_image(
gather_info_list, logits_map, keep_blank_in_idxs=True)
print(res)
print('cost {}'.format(time.time() - time0))
print('--' * 20)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy import ndimage as ndi
G123_LUT = np.array(
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0
],
dtype=np.bool)
G123P_LUT = np.array(
[
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
dtype=np.bool)
def thin(image, max_iter=None):
"""
Perform morphological thinning of a binary image.
Parameters
----------
image : binary (M, N) ndarray
The image to be thinned.
max_iter : int, number of iterations, optional
Regardless of the value of this parameter, the thinned image
is returned immediately if an iteration produces no change.
If this parameter is specified it thus sets an upper bound on
the number of iterations performed.
Returns
-------
out : ndarray of bool
Thinned image.
See also
--------
skeletonize, medial_axis
Notes
-----
This algorithm [1]_ works by making multiple passes over the image,
removing pixels matching a set of criteria designed to thin
connected regions while preserving eight-connected components and
2 x 2 squares [2]_. In each of the two sub-iterations the algorithm
correlates the intermediate skeleton image with a neighborhood mask,
then looks up each neighborhood in a lookup table indicating whether
the central pixel should be deleted in that sub-iteration.
References
----------
.. [1] Z. Guo and R. W. Hall, "Parallel thinning with
two-subiteration algorithms," Comm. ACM, vol. 32, no. 3,
pp. 359-373, 1989. :DOI:`10.1145/62065.62074`
.. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning
Methodologies-A Comprehensive Survey," IEEE Transactions on
Pattern Analysis and Machine Intelligence, Vol 14, No. 9,
p. 879, 1992. :DOI:`10.1109/34.161346`
Examples
--------
>>> square = np.zeros((7, 7), dtype=np.uint8)
>>> square[1:-1, 2:-2] = 1
>>> square[0, 1] = 1
>>> square
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
>>> skel = thin(square)
>>> skel.astype(np.uint8)
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
"""
# convert image to uint8 with values in {0, 1}
skel = np.asanyarray(image, dtype=bool).astype(np.uint8)
# neighborhood mask
mask = np.array([[8, 4, 2], [16, 0, 1], [32, 64, 128]], dtype=np.uint8)
# iterate until convergence, up to the iteration limit
max_iter = max_iter or np.inf
n_iter = 0
n_pts_old, n_pts_new = np.inf, np.sum(skel)
while n_pts_old != n_pts_new and n_iter < max_iter:
n_pts_old = n_pts_new
# perform the two "subiterations" described in the paper
for lut in [G123_LUT, G123P_LUT]:
# correlate image with neighborhood mask
N = ndi.correlate(skel, mask, mode='constant')
# take deletion decision from this subiteration's LUT
D = np.take(lut, N)
# perform deletion
skel[D] = 0
n_pts_new = np.sum(skel) # count points after thinning
n_iter += 1
return skel.astype(np.bool)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -50,7 +50,7 @@ def resize_image(im, max_side_len=512): ...@@ -50,7 +50,7 @@ def resize_image(im, max_side_len=512):
def resize_image_min(im, max_side_len=512): def resize_image_min(im, max_side_len=512):
""" """
""" """
print('--> Using resize_image_min') # print('--> Using resize_image_min')
h, w, _ = im.shape h, w, _ = im.shape
resize_w = w resize_w = w
......
...@@ -45,8 +45,14 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name): ...@@ -45,8 +45,14 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name):
for box, str in zip(dt_boxes, strs): for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2)) box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(src_im, str, org=(int(box[0, 0, 0]), int(box[0, 0, 1])), cv2.putText(
fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.7, color=(0, 255, 0), thickness=1) src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
save_det_path = os.path.dirname(config['Global'][ save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/e2e_results/" 'save_res_path']) + "/e2e_results/"
if not os.path.exists(save_det_path): if not os.path.exists(save_det_path):
...@@ -55,6 +61,7 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name): ...@@ -55,6 +61,7 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name):
cv2.imwrite(save_path, src_im) cv2.imwrite(save_path, src_im)
logger.info("The e2e Image saved in {}".format(save_path)) logger.info("The e2e Image saved in {}".format(save_path))
def main(): def main():
global_config = config['Global'] global_config = config['Global']
...@@ -111,4 +118,4 @@ def main(): ...@@ -111,4 +118,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main() main()
\ No newline at end of file
...@@ -375,7 +375,7 @@ def preprocess(is_train=False): ...@@ -375,7 +375,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PG' 'CLS', 'PGNet'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册