From b4c5dac27a45cbcd474b8e1e215ac1656c9e101a Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 3 Jun 2020 00:10:02 +0800 Subject: [PATCH] commit for tmp --- configs/rec/rec_mv3_none_bilstm_ctc.yml | 3 ++- ppocr/data/rec/dataset_traversal.py | 10 ++++++-- ppocr/data/rec/img_tools.py | 31 ++++++++++++++++++++++++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 951c83cc..cea72bba 100755 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,6 +1,6 @@ Global: algorithm: CRNN - use_gpu: false + use_gpu: true epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -17,6 +17,7 @@ Global: pretrain_weights: ./output/rec_CRNN/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index 9d0d2e96..a8a090f1 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -105,7 +105,10 @@ class LMDBReader(object): img = cv2.imread(single_img) if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - norm_img = process_image(img, self.image_shape) + norm_img = process_image( + img=img, + image_shape=self.image_shape, + char_ops=self.char_ops) yield norm_img else: lmdb_sets = self.load_hierarchical_lmdb_dataset() @@ -182,7 +185,10 @@ class SimpleReader(object): img = cv2.imread(single_img) if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - norm_img = process_image(img, self.image_shape) + norm_img = process_image( + img=img, + image_shape=self.image_shape, + char_ops=self.char_ops) yield norm_img else: with open(self.label_file_path, "rb") as fin: diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index a27f108c..df1d3dd5 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape): return padding_im +def resize_norm_img_chinese(img, image_shape): + imgC, imgH, imgW = image_shape + # todo: change to 0 and modified image shape + max_wh_ratio = 10 + h, w = img.shape[0], img.shape[1] + ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, ratio) + imgW = int(32 * max_wh_ratio) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def get_img_data(value): """get_img_data""" if not value: @@ -67,7 +93,10 @@ def process_image(img, char_ops=None, loss_type=None, max_text_length=None): - norm_img = resize_norm_img(img, image_shape) + if char_ops.character_type == "en": + norm_img = resize_norm_img(img, image_shape) + else: + norm_img = resize_norm_img_chinese(img, image_shape) norm_img = norm_img[np.newaxis, :] if label is not None: char_num = char_ops.get_char_num() -- GitLab