提交 be3a1644 编写于 作者: T tink2123

fix inference in tps

上级 b722eb56
...@@ -12,8 +12,10 @@ Global: ...@@ -12,8 +12,10 @@ Global:
test_batch_size_per_card: 256 test_batch_size_per_card: 256
image_shape: [3, 32, 100] image_shape: [3, 32, 100]
max_text_length: 25 max_text_length: 25
character_type: en character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: attention loss_type: attention
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
......
...@@ -14,6 +14,7 @@ Global: ...@@ -14,6 +14,7 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: ctc loss_type: ctc
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
......
...@@ -41,6 +41,8 @@ class LMDBReader(object): ...@@ -41,6 +41,8 @@ class LMDBReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
if "tps" in params:
self.tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
self.drop_last = params['drop_last'] self.drop_last = params['drop_last']
...@@ -109,7 +111,8 @@ class LMDBReader(object): ...@@ -109,7 +111,8 @@ class LMDBReader(object):
norm_img = process_image( norm_img = process_image(
img=img, img=img,
image_shape=self.image_shape, image_shape=self.image_shape,
char_ops=self.char_ops) char_ops=self.char_ops,
tps=self.tps)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
......
...@@ -92,9 +92,14 @@ def process_image(img, ...@@ -92,9 +92,14 @@ def process_image(img,
label=None, label=None,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None): max_text_length=None,
tps=None):
if char_ops.character_type == "en": if char_ops.character_type == "en":
norm_img = resize_norm_img(img, image_shape) norm_img = resize_norm_img(img, image_shape)
else:
if tps:
image_shape = [3, 32, 320]
norm_img = resize_norm_img(img, image_shape)
else: else:
norm_img = resize_norm_img_chinese(img, image_shape) norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
......
...@@ -30,6 +30,7 @@ class RecModel(object): ...@@ -30,6 +30,7 @@ class RecModel(object):
global_params = params['Global'] global_params = params['Global']
char_num = global_params['char_ops'].get_char_num() char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num global_params['char_num'] = char_num
self.char_type = global_params['character_type']
if "TPS" in params: if "TPS" in params:
tps_params = deepcopy(params["TPS"]) tps_params = deepcopy(params["TPS"])
tps_params.update(global_params) tps_params.update(global_params)
...@@ -60,8 +61,8 @@ class RecModel(object): ...@@ -60,8 +61,8 @@ class RecModel(object):
def create_feed(self, mode): def create_feed(self, mode):
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train": if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention": if self.loss_type == "attention":
label_in = fluid.data( label_in = fluid.data(
name='label_in', name='label_in',
...@@ -86,6 +87,17 @@ class RecModel(object): ...@@ -86,6 +87,17 @@ class RecModel(object):
use_double_buffer=True, use_double_buffer=True,
iterable=False) iterable=False)
else: else:
if self.char_type == "ch":
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set default shape=[3,32,320], it may affect the inference effect"
)
image_shape[-1] = 320
image = fluid.data(
name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
return image, labels, loader return image, labels, loader
......
...@@ -112,7 +112,7 @@ class TextRecognizer(object): ...@@ -112,7 +112,7 @@ class TextRecognizer(object):
else: else:
preds = rec_idx_batch[rno, 1:end_pos[1]] preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]]) score = np.mean(predict_batch[rno, 1:end_pos[1]])
#todo: why index has 2 offset #attenton index has 2 offset: beg and end
preds = preds - 2 preds = preds - 2
preds_text = self.char_ops.decode(preds) preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score]) rec_res.append([preds_text, score])
...@@ -138,7 +138,7 @@ if __name__ == "__main__": ...@@ -138,7 +138,7 @@ if __name__ == "__main__":
except: except:
logger.info( logger.info(
"ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n" "ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n"
"Please set --rec_image_shape=input_shape and --rec_char_type='ch' ") "Please set --rec_image_shape=input_shape and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册