提交 67d33cdc 编写于 作者: J Jethong

fix data input format

上级 010e94d6
......@@ -69,6 +69,7 @@ Metric:
Train:
dataset:
name: PGDataSet
data_dir: ./train_data/
label_file_list: [.././train_data/total_text/train/]
ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet
......@@ -76,6 +77,7 @@ Train:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- E2ELabelEncode:
- PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24
......@@ -98,7 +100,6 @@ Eval:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- E2ELabelEncode:
- E2EResizeForTest:
max_side_len: 768
- NormalizeImage:
......@@ -108,7 +109,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
keep_keys: [ 'image', 'shape', 'img_id']
loader:
shuffle: False
drop_last: False
......
......@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character
class E2ELabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
class E2ELabelEncode(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
texts = data['strs']
temp_texts = []
for text in texts:
text = text.lower()
text = self.encode(text)
if text is None:
return None
text = text + [self.pad_num] * (self.max_text_len - len(text))
temp_texts.append(text)
data['strs'] = np.array(temp_texts)
import json
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data
......
......@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
def check_and_validate_polys(self, polys, tags, im_size):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
......@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags:
:return:
"""
(h, w) = xxx_todo_changeme
(h, w) = im_size
if polys.shape[0] == 0:
return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
......@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size = 512
im = data['image']
text_polys = data['polys']
text_tags = data['tags']
text_strs = data['strs']
text_tags = data['ignore_tags']
text_strs = data['texts']
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w))
......
......@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
self.data_format)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
......@@ -55,108 +55,37 @@ class PGDataSet(Dataset):
random.shuffle(self.data_lines)
return
def extract_polys(self, poly_txt_path):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt)
txt_tags.append(txt == '###')
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, data_source in enumerate(file_list):
image_files = []
if data_format == 'icdar':
image_files = [(data_source, x) for x in
os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
'tiff', 'gif', 'JPG'
]]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
random.seed(self.seed)
image_files = random.sample(
image_files, round(len(image_files) * ratio_list[idx]))
data_lines.extend(image_files)
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx]
data_line = self.data_lines[file_idx]
try:
if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line)
poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
else:
image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir)
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
img_id = int(data_line.split(".")[0][3:])
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs,
'img_id': img_id
}
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
......
......@@ -35,11 +35,11 @@ class E2EMetric(object):
self.reset()
def __call__(self, preds, batch, **kwargs):
img_id = batch[5][0]
img_id = batch[2][0]
e2e_info_list = [{
'points': det_polyon,
'text': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result)
......
......@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['text']
text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
......
......@@ -21,6 +21,7 @@ import math
import numpy as np
from itertools import groupby
from cv2.ximgproc import thinning as thin
from skimage.morphology._skeletonize import thin
......
......@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'strs': keep_str_list,
'texts': keep_str_list,
}
return data
......@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
'texts': keep_str_list,
}
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册