提交 50bcec46 编写于 作者: J Jethong

fix data input format

上级 579b66b0
...@@ -69,6 +69,7 @@ Metric: ...@@ -69,6 +69,7 @@ Metric:
Train: Train:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/
label_file_list: [.././train_data/total_text/train/] label_file_list: [.././train_data/total_text/train/]
ratio_list: [1.0] ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet data_format: icdar #two data format: icdar/textnet
...@@ -76,6 +77,7 @@ Train: ...@@ -76,6 +77,7 @@ Train:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- E2ELabelEncode:
- PGProcessTrain: - PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24 min_crop_size: 24
...@@ -98,7 +100,6 @@ Eval: ...@@ -98,7 +100,6 @@ Eval:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
channel_first: False channel_first: False
- E2ELabelEncode:
- E2EResizeForTest: - E2EResizeForTest:
max_side_len: 768 max_side_len: 768
- NormalizeImage: - NormalizeImage:
...@@ -108,7 +109,7 @@ Eval: ...@@ -108,7 +109,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id'] keep_keys: [ 'image', 'shape', 'img_id']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode): ...@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class E2ELabelEncode(BaseRecLabelEncode): class E2ELabelEncode(object):
def __init__(self, def __init__(self, **kwargs):
max_text_length, pass
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
def __call__(self, data): def __call__(self, data):
texts = data['strs'] import json
temp_texts = [] label = data['label']
for text in texts: label = json.loads(label)
text = text.lower() nBox = len(label)
text = self.encode(text) boxes, txts, txt_tags = [], [], []
if text is None: for bno in range(0, nBox):
return None box = label[bno]['points']
text = text + [self.pad_num] * (self.max_text_len - len(text)) txt = label[bno]['transcription']
temp_texts.append(text) boxes.append(box)
data['strs'] = np.array(temp_texts) txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data return data
......
...@@ -88,7 +88,7 @@ class PGProcessTrain(object): ...@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return min_area_quad return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): def check_and_validate_polys(self, polys, tags, im_size):
""" """
check so that the text poly is in the same direction, check so that the text poly is in the same direction,
and also filter some invalid polygons and also filter some invalid polygons
...@@ -96,7 +96,7 @@ class PGProcessTrain(object): ...@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags: :param tags:
:return: :return:
""" """
(h, w) = xxx_todo_changeme (h, w) = im_size
if polys.shape[0] == 0: if polys.shape[0] == 0:
return polys, np.array([]), np.array([]) return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
...@@ -750,8 +750,8 @@ class PGProcessTrain(object): ...@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size = 512 input_size = 512
im = data['image'] im = data['image']
text_polys = data['polys'] text_polys = data['polys']
text_tags = data['tags'] text_tags = data['ignore_tags']
text_strs = data['strs'] text_strs = data['texts']
h, w, _ = im.shape h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys( text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w)) text_polys, text_tags, (h, w))
......
...@@ -29,20 +29,20 @@ class PGDataSet(Dataset): ...@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0]) ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)): if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num) ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len( assert len(
ratio_list ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list, self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_format)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if mode.lower() == "train":
self.shuffle_data_random() self.shuffle_data_random()
...@@ -55,108 +55,37 @@ class PGDataSet(Dataset): ...@@ -55,108 +55,37 @@ class PGDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def extract_polys(self, poly_txt_path): def get_image_info_list(self, file_list, ratio_list):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt)
txt_tags.append(txt == '###')
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines = [] data_lines = []
for idx, data_source in enumerate(file_list): for idx, file in enumerate(file_list):
image_files = [] with open(file, "rb") as f:
if data_format == 'icdar': lines = f.readlines()
image_files = [(data_source, x) for x in if self.mode == "train" or ratio_list[idx] < 1.0:
os.listdir(os.path.join(data_source, 'rgb')) random.seed(self.seed)
if x.split('.')[-1] in [ lines = random.sample(lines,
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', round(len(lines) * ratio_list[idx]))
'tiff', 'gif', 'JPG' data_lines.extend(lines)
]]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
random.seed(self.seed)
image_files = random.sample(
image_files, round(len(image_files) * ratio_list[idx]))
data_lines.extend(image_files)
return data_lines return data_lines
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
try: try:
if self.data_format == 'icdar': data_line = data_line.decode('utf-8')
im_path = os.path.join(data_path, 'rgb', data_line) substr = data_line.strip("\n").split(self.delimiter)
poly_path = os.path.join(data_path, 'poly', file_name = substr[0]
data_line.split('.')[0] + '.txt') label = substr[1]
text_polys, text_tags, text_strs = self.extract_polys(poly_path) img_path = os.path.join(self.data_dir, file_name)
else:
image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir)
img_id = int(data_line.split(".")[0][3:]) img_id = int(data_line.split(".")[0][3:])
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
data = { if not os.path.exists(img_path):
'img_path': im_path, raise Exception("{} does not exist!".format(img_path))
'polys': text_polys,
'tags': text_tags,
'strs': text_strs,
'img_id': img_id
}
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
......
...@@ -35,11 +35,11 @@ class E2EMetric(object): ...@@ -35,11 +35,11 @@ class E2EMetric(object):
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
img_id = batch[5][0] img_id = batch[2][0]
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'text': pred_str 'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])] } for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result) self.results.append(result)
......
...@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict): ...@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n = len(pred_dict) n = len(pred_dict)
for i in range(n): for i in range(n):
points = pred_dict[i]['points'] points = pred_dict[i]['points']
text = pred_dict[i]['text'] text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, ))) point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text]) det.append([point, text])
return det return det
......
...@@ -21,6 +21,7 @@ import math ...@@ -21,6 +21,7 @@ import math
import numpy as np import numpy as np
from itertools import groupby from itertools import groupby
from cv2.ximgproc import thinning as thin
from skimage.morphology._skeletonize import thin from skimage.morphology._skeletonize import thin
......
...@@ -64,7 +64,7 @@ class PGNet_PostProcess(object): ...@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w, src_h, self.valid_set) src_w, src_h, self.valid_set)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data
...@@ -176,6 +176,6 @@ class PGNet_PostProcess(object): ...@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit(-1) exit(-1)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册