未验证 提交 c455034f 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #2497 from JetHong/rel/fix_data_input_format

Rel/fix data input format
...@@ -62,20 +62,21 @@ PostProcess: ...@@ -62,20 +62,21 @@ PostProcess:
mode: fast # fast or slow two ways mode: fast # fast or slow two ways
Metric: Metric:
name: E2EMetric name: E2EMetric
gt_mat_dir: # the dir of gt_mat gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e main_indicator: f_score_e2e
Train: Train:
dataset: dataset:
name: PGDataSet name: PGDataSet
label_file_list: [.././train_data/total_text/train/] data_dir: ./train_data/total_text/train
label_file_list: [./train_data/total_text/train/]
ratio_list: [1.0] ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- E2ELabelEncode:
- PGProcessTrain: - PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24 min_crop_size: 24
...@@ -92,13 +93,12 @@ Train: ...@@ -92,13 +93,12 @@ Train:
Eval: Eval:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/ data_dir: ./train_data/total_text/test
label_file_list: [./train_data/total_text/test/] label_file_list: [./train_data/total_text/test/]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
channel_first: False channel_first: False
- E2ELabelEncode:
- E2EResizeForTest: - E2EResizeForTest:
max_side_len: 768 max_side_len: 768
- NormalizeImage: - NormalizeImage:
...@@ -108,7 +108,7 @@ Eval: ...@@ -108,7 +108,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id'] keep_keys: [ 'image', 'shape', 'img_id']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -30,6 +30,7 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang ...@@ -30,6 +30,7 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
测试集:Total-Text 测试集:Total-Text
测试环境: NVIDIA Tesla V100-SXM2-16GB 测试环境: NVIDIA Tesla V100-SXM2-16GB
|PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|下载| |PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|下载|
| --- | --- | --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-| |Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-|
......
...@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode): ...@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class E2ELabelEncode(BaseRecLabelEncode): class E2ELabelEncode(object):
def __init__(self, def __init__(self, **kwargs):
max_text_length, pass
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
def __call__(self, data): def __call__(self, data):
texts = data['strs'] import json
temp_texts = [] label = data['label']
for text in texts: label = json.loads(label)
text = text.lower() nBox = len(label)
text = self.encode(text) boxes, txts, txt_tags = [], [], []
if text is None: for bno in range(0, nBox):
return None box = label[bno]['points']
text = text + [self.pad_num] * (self.max_text_len - len(text)) txt = label[bno]['transcription']
temp_texts.append(text) boxes.append(box)
data['strs'] = np.array(temp_texts) txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data return data
......
...@@ -88,7 +88,7 @@ class PGProcessTrain(object): ...@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return min_area_quad return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): def check_and_validate_polys(self, polys, tags, im_size):
""" """
check so that the text poly is in the same direction, check so that the text poly is in the same direction,
and also filter some invalid polygons and also filter some invalid polygons
...@@ -96,7 +96,7 @@ class PGProcessTrain(object): ...@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags: :param tags:
:return: :return:
""" """
(h, w) = xxx_todo_changeme (h, w) = im_size
if polys.shape[0] == 0: if polys.shape[0] == 0:
return polys, np.array([]), np.array([]) return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
...@@ -750,8 +750,8 @@ class PGProcessTrain(object): ...@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size = 512 input_size = 512
im = data['image'] im = data['image']
text_polys = data['polys'] text_polys = data['polys']
text_tags = data['tags'] text_tags = data['ignore_tags']
text_strs = data['strs'] text_strs = data['texts']
h, w, _ = im.shape h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys( text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w)) text_polys, text_tags, (h, w))
......
...@@ -29,20 +29,20 @@ class PGDataSet(Dataset): ...@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0]) ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)): if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num) ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len( assert len(
ratio_list ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list, self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_format)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if mode.lower() == "train":
self.shuffle_data_random() self.shuffle_data_random()
...@@ -55,108 +55,40 @@ class PGDataSet(Dataset): ...@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def extract_polys(self, poly_txt_path): def get_image_info_list(self, file_list, ratio_list):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt)
txt_tags.append(txt == '###')
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines = [] data_lines = []
for idx, data_source in enumerate(file_list): for idx, file in enumerate(file_list):
image_files = [] with open(file, "rb") as f:
if data_format == 'icdar': lines = f.readlines()
image_files = [(data_source, x) for x in if self.mode == "train" or ratio_list[idx] < 1.0:
os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
'tiff', 'gif', 'JPG'
]]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
random.seed(self.seed) random.seed(self.seed)
image_files = random.sample( lines = random.sample(lines,
image_files, round(len(image_files) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
data_lines.extend(image_files) data_lines.extend(lines)
return data_lines return data_lines
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
try: try:
if self.data_format == 'icdar': data_line = data_line.decode('utf-8')
im_path = os.path.join(data_path, 'rgb', data_line) substr = data_line.strip("\n").split(self.delimiter)
poly_path = os.path.join(data_path, 'poly', file_name = substr[0]
data_line.split('.')[0] + '.txt') label = substr[1]
text_polys, text_tags, text_strs = self.extract_polys(poly_path) img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval':
img_id = int(data_line.split(".")[0][7:])
else: else:
image_dir = os.path.join(os.path.dirname(data_path), 'image') img_id = 0
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( data = {'img_path': img_path, 'label': label, 'img_id': img_id}
data_line, image_dir) if not os.path.exists(img_path):
img_id = int(data_line.split(".")[0][3:]) raise Exception("{} does not exist!".format(img_path))
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs,
'img_id': img_id
}
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
......
...@@ -35,11 +35,11 @@ class E2EMetric(object): ...@@ -35,11 +35,11 @@ class E2EMetric(object):
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
img_id = batch[5][0] img_id = batch[2][0]
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'text': pred_str 'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])] } for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result) self.results.append(result)
......
...@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict): ...@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n = len(pred_dict) n = len(pred_dict)
for i in range(n): for i in range(n):
points = pred_dict[i]['points'] points = pred_dict[i]['points']
text = pred_dict[i]['text'] text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, ))) point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text]) det.append([point, text])
return det return det
......
...@@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score, ...@@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score,
center_pos_yxs = [] center_pos_yxs = []
end_points_yxs = [] end_points_yxs = []
instance_center_pos_yxs = [] instance_center_pos_yxs = []
pred_strs = []
if instance_count > 0: if instance_count > 0:
for instance_id in range(1, instance_count): for instance_id in range(1, instance_count):
pos_list = [] pos_list = []
...@@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score, ...@@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score,
if is_backbone: if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id) instance_center_pos_yxs.append(keep_yxs_list_with_id)
pred_strs.append(decoded_str)
else: else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list) center_pos_yxs.extend(keep_yxs_list)
if is_backbone: if is_backbone:
return instance_center_pos_yxs return pred_strs, instance_center_pos_yxs
else: else:
return center_pos_yxs, end_points_yxs return center_pos_yxs, end_points_yxs
......
...@@ -64,7 +64,7 @@ class PGNet_PostProcess(object): ...@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w, src_h, self.valid_set) src_w, src_h, self.valid_set)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data
...@@ -85,32 +85,13 @@ class PGNet_PostProcess(object): ...@@ -85,32 +85,13 @@ class PGNet_PostProcess(object):
p_char = p_char[0] p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0] src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
is_curved = self.valid_set == "totaltext" is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list_slow( char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
p_score, p_score,
p_char, p_char,
p_direction, p_direction,
score_thresh=self.score_thresh, score_thresh=self.score_thresh,
is_backbone=True, is_backbone=True,
is_curved=is_curved) is_curved=is_curved)
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
feature_len = [len(feature_seq[0])]
featyre_seq = paddle.to_tensor(feature_seq)
feature_len = np.array([feature_len]).astype(np.int64)
length = paddle.to_tensor(feature_len)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred_str = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for c in seq_pred_str[:seq_len]:
temp_t.append(c)
char_seq_idx_set.append(temp_t)
seq_strs = [] seq_strs = []
for char_idx_set in char_seq_idx_set: for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
...@@ -176,6 +157,6 @@ class PGNet_PostProcess(object): ...@@ -176,6 +157,6 @@ class PGNet_PostProcess(object):
exit(-1) exit(-1)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data
...@@ -122,7 +122,7 @@ class TextE2E(object): ...@@ -122,7 +122,7 @@ class TextE2E(object):
else: else:
raise NotImplementedError raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime elapse = time.time() - starttime
return dt_boxes, strs, elapse return dt_boxes, strs, elapse
......
...@@ -103,7 +103,7 @@ def main(): ...@@ -103,7 +103,7 @@ def main():
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images) preds = model(images)
post_result = post_process_class(preds, shape_list) post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
# write resule # write resule
dt_boxes_json = [] dt_boxes_json = []
for poly, str in zip(points, strs): for poly, str in zip(points, strs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册