提交 be5fdae5 编写于 作者: J Jethong

fix config format

上级 50bcec46
...@@ -69,10 +69,9 @@ Metric: ...@@ -69,10 +69,9 @@ Metric:
Train: Train:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/ data_dir: ./train_data/train
label_file_list: [.././train_data/total_text/train/] label_file_list: [.././train_data/total_text/train/]
ratio_list: [1.0] ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
...@@ -94,7 +93,7 @@ Train: ...@@ -94,7 +93,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/ data_dir: ./train_data/test
label_file_list: [./train_data/total_text/test/] label_file_list: [./train_data/total_text/test/]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
......
...@@ -78,7 +78,10 @@ class PGDataSet(Dataset): ...@@ -78,7 +78,10 @@ class PGDataSet(Dataset):
file_name = substr[0] file_name = substr[0]
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
img_id = int(data_line.split(".")[0][3:]) if self.mode.lower() == 'eval':
img_id = int(data_line.split(".")[0][7:])
else:
img_id = 0
data = {'img_path': img_path, 'label': label, 'img_id': img_id} data = {'img_path': img_path, 'label': label, 'img_id': img_id}
if not os.path.exists(img_path): if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path)) raise Exception("{} does not exist!".format(img_path))
......
...@@ -122,7 +122,7 @@ class TextE2E(object): ...@@ -122,7 +122,7 @@ class TextE2E(object):
else: else:
raise NotImplementedError raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime elapse = time.time() - starttime
return dt_boxes, strs, elapse return dt_boxes, strs, elapse
......
...@@ -103,7 +103,7 @@ def main(): ...@@ -103,7 +103,7 @@ def main():
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images) preds = model(images)
post_result = post_process_class(preds, shape_list) post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
# write resule # write resule
dt_boxes_json = [] dt_boxes_json = []
for poly, str in zip(points, strs): for poly, str in zip(points, strs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册