From 80c188785c852aee2b875134887de8e363c26d69 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 3 Jun 2020 17:09:14 +0800 Subject: [PATCH] fix eval --- ppocr/data/rec/dataset_traversal.py | 17 ++++++++++++----- ppocr/data/rec/img_tools.py | 7 ++++--- ppocr/modeling/architectures/rec_model.py | 6 +++--- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index 8d6b4d1e..21fdc352 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -41,6 +41,7 @@ class LMDBReader(object): self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] self.mode = params['mode'] + self.drop_last = False if "tps" in params: self.tps = True if params['mode'] == 'train': @@ -112,7 +113,8 @@ class LMDBReader(object): img=img, image_shape=self.image_shape, char_ops=self.char_ops, - tps=self.tps) + tps=self.tps, + infer_mode=True) yield norm_img else: lmdb_sets = self.load_hierarchical_lmdb_dataset() @@ -132,9 +134,13 @@ class LMDBReader(object): if sample_info is None: continue img, label = sample_info - outs = process_image(img, self.image_shape, label, - self.char_ops, self.loss_type, - self.max_text_length) + outs = process_image( + img=img, + image_shape=self.image_shape, + label=label, + char_ops=self.char_ops, + loss_type=self.loss_type, + max_text_length=self.max_text_length) if outs is None: continue yield outs @@ -154,7 +160,7 @@ class LMDBReader(object): if len(batch_outs) != 0: yield batch_outs - if self.mode != 'train' and self.infer_img is None: + if self.infer_img is None: return batch_iter_reader return sample_iter_reader @@ -174,6 +180,7 @@ class SimpleReader(object): self.max_text_length = params['max_text_length'] self.mode = params['mode'] self.infer_img = params['infer_img'] + self.drop_last = False if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] self.drop_last = params['drop_last'] diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index 71efea17..6d7b66e9 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -93,11 +93,12 @@ def process_image(img, char_ops=None, loss_type=None, max_text_length=None, - tps=None): - if char_ops.character_type == "en": + tps=None, + infer_mode=False): + if not infer_mode or char_ops.character_type == "en": norm_img = resize_norm_img(img, image_shape) else: - if tps: + if tps != None and char_ops.character_type == "ch": image_shape = [3, 32, 320] norm_img = resize_norm_img(img, image_shape) else: diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index d729f5d9..af651e9a 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -31,6 +31,7 @@ class RecModel(object): char_num = global_params['char_ops'].get_char_num() global_params['char_num'] = char_num self.char_type = global_params['character_type'] + self.infer_img = global_params['infer_img'] if "TPS" in params: tps_params = deepcopy(params["TPS"]) tps_params.update(global_params) @@ -87,7 +88,7 @@ class RecModel(object): use_double_buffer=True, iterable=False) else: - if self.char_type == "ch": + if self.char_type == "ch" and self.infer_img: image_shape[-1] = -1 if self.tps != None: logger.info( @@ -96,8 +97,7 @@ class RecModel(object): "We set default shape=[3,32,320], it may affect the inference effect" ) image_shape[-1] = 320 - image = fluid.data( - name='image', shape=image_shape, dtype='float32') + image = fluid.data(name='image', shape=image_shape, dtype='float32') labels = None loader = None return image, labels, loader -- GitLab