未验证 提交 ac4cef10 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #2437 from JetHong/pgnet-readme

fix eval score
...@@ -61,6 +61,7 @@ PostProcess: ...@@ -61,6 +61,7 @@ PostProcess:
score_thresh: 0.5 score_thresh: 0.5
Metric: Metric:
name: E2EMetric name: E2EMetric
gt_mat_dir: # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e main_indicator: f_score_e2e
...@@ -106,7 +107,7 @@ Eval: ...@@ -106,7 +107,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ] keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
- [一、简介](#简介) - [一、简介](#简介)
- [二、环境配置](#环境配置) - [二、环境配置](#环境配置)
- [三、快速使用](#快速使用) - [三、快速使用](#快速使用)
- [四、模型训练、评估、推理](#快速训练) - [四、模型训练、评估、推理](#模型训练、评估、推理)
<a name="简介"></a> <a name="简介"></a>
## 一、简介 ## 一、简介
...@@ -20,7 +20,9 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang ...@@ -20,7 +20,9 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
![](../pgnet_framework.png) ![](../pgnet_framework.png)
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。 输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。 其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
其检测识别效果图如下: 其检测识别效果图如下:
![](../imgs_results/e2e_res_img293_pgnet.png) ![](../imgs_results/e2e_res_img293_pgnet.png)
![](../imgs_results/e2e_res_img295_pgnet.png) ![](../imgs_results/e2e_res_img295_pgnet.png)
...@@ -61,12 +63,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im ...@@ -61,12 +63,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: 可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg) ![](../imgs_results/e2e_res_img623_pgnet.jpg)
<a name="快速训练"></a> <a name="模型训练、评估、推理"></a>
## 四、模型训练、评估、推理 ## 四、模型训练、评估、推理
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。 本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
### 准备数据 ### 准备数据
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md)数据集到PaddleOCR/train_data/目录,数据集组织结构: 下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
``` ```
/PaddleOCR/train_data/total_text/train/ /PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text数据集的训练数据 |- rgb/ # total_text数据集的训练数据
......
...@@ -64,9 +64,6 @@ class PGDataSet(Dataset): ...@@ -64,9 +64,6 @@ class PGDataSet(Dataset):
for line in f.readlines(): for line in f.readlines():
poly_str, txt = line.strip().split('\t') poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(','))) poly = list(map(float, poly_str.split(',')))
if self.mode.lower() == "eval":
while len(poly) < 100:
poly.append(-1)
text_polys.append( text_polys.append(
np.array( np.array(
poly, dtype=np.float32).reshape(-1, 2)) poly, dtype=np.float32).reshape(-1, 2))
...@@ -139,10 +136,6 @@ class PGDataSet(Dataset): ...@@ -139,10 +136,6 @@ class PGDataSet(Dataset):
try: try:
if self.data_format == 'icdar': if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line) im_path = os.path.join(data_path, 'rgb', data_line)
if self.mode.lower() == "eval":
poly_path = os.path.join(data_path, 'poly_gt',
data_line.split('.')[0] + '.txt')
else:
poly_path = os.path.join(data_path, 'poly', poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt') data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path) text_polys, text_tags, text_strs = self.extract_polys(poly_path)
...@@ -150,12 +143,14 @@ class PGDataSet(Dataset): ...@@ -150,12 +143,14 @@ class PGDataSet(Dataset):
image_dir = os.path.join(os.path.dirname(data_path), 'image') image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir) data_line, image_dir)
img_id = int(data_line.split(".")[0][3:])
data = { data = {
'img_path': im_path, 'img_path': im_path,
'polys': text_polys, 'polys': text_polys,
'tags': text_tags, 'tags': text_tags,
'strs': text_strs 'strs': text_strs,
'img_id': img_id
} }
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
......
...@@ -24,52 +24,23 @@ from ppocr.utils.e2e_utils.extract_textpoint import get_dict ...@@ -24,52 +24,23 @@ from ppocr.utils.e2e_utils.extract_textpoint import get_dict
class E2EMetric(object): class E2EMetric(object):
def __init__(self, def __init__(self,
gt_mat_dir,
character_dict_path, character_dict_path,
main_indicator='f_score_e2e', main_indicator='f_score_e2e',
**kwargs): **kwargs):
self.gt_mat_dir = gt_mat_dir
self.label_list = get_dict(character_dict_path) self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list) self.max_index = len(self.label_list)
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
temp_gt_polyons_batch = batch[2] img_id = batch[5][0]
temp_gt_strs_batch = batch[3]
ignore_tags_batch = batch[4]
gt_polyons_batch = []
gt_strs_batch = []
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
for temp_list in temp_gt_polyons_batch:
t = []
for index in temp_list:
if index[0] != -1 and index[1] != -1:
t.append(index)
gt_polyons_batch.append(t)
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
for temp_list in temp_gt_strs_batch:
t = ""
for index in temp_list:
if index < self.max_index:
t += self.label_list[index]
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': gt_str,
'ignore': ignore_tag
} for gt_polyon, gt_str, ignore_tag in
zip(gt_polyons, gt_strs, ignore_tags)]
# prepare det
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'text': pred_str 'text': pred_str
} for det_polyon, pred_str in zip(pred['points'], pred['strs'])] } for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
result = get_socre(gt_info_list, e2e_info_list) result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result) self.results.append(result)
def get_metric(self): def get_metric(self):
......
...@@ -138,6 +138,7 @@ class PGPostProcess(object): ...@@ -138,6 +138,7 @@ class PGPostProcess(object):
continue continue
keep_str_list.append(keep_str) keep_str_list.append(keep_str)
detected_poly = np.round(detected_poly).astype('int32')
if self.valid_set == 'partvgg': if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2 middle_point = len(detected_poly) // 2
detected_poly = detected_poly[ detected_poly = detected_poly[
......
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import scipy.io as io
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
def get_socre(gt_dict, pred_dict): def get_socre(gt_dir, img_id, pred_dict):
allInputs = 1 allInputs = 1
def input_reading_mod(pred_dict): def input_reading_mod(pred_dict):
...@@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict): ...@@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
det.append([point, text]) det.append([point, text])
return det return det
def gt_reading_mod(gt_dict): def gt_reading_mod(gt_dir, gt_id):
"""This helper reads groundtruths from mat files""" gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
gt = [] gt = gt['polygt']
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points']
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "" and "#" not in text:
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt return gt
def detection_filtering(detections, groundtruths, threshold=0.5): def detection_filtering(detections, groundtruths, threshold=0.5):
...@@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict): ...@@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict):
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'): and (input_id != 'Deteval_result_non_curved.txt'):
detections = input_reading_mod(pred_dict) detections = input_reading_mod(pred_dict)
groundtruths = gt_reading_mod(gt_dict) groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
detections = detection_filtering( detections = detection_filtering(
detections, detections,
groundtruths) # filters detections overlapping with DC area groundtruths) # filters detections overlapping with DC area
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册