提交 b4c5dac2 编写于 作者: T tink2123

commit for tmp

上级 24757ba6
Global: Global:
algorithm: CRNN algorithm: CRNN
use_gpu: false use_gpu: true
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
pretrain_weights: ./output/rec_CRNN/rec_mv3_none_bilstm_ctc/best_accuracy pretrain_weights: ./output/rec_CRNN/rec_mv3_none_bilstm_ctc/best_accuracy
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -105,7 +105,10 @@ class LMDBReader(object): ...@@ -105,7 +105,10 @@ class LMDBReader(object):
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1] == 1 or len(list(img.shape)) == 2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
...@@ -182,7 +185,10 @@ class SimpleReader(object): ...@@ -182,7 +185,10 @@ class SimpleReader(object):
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1] == 1 or len(list(img.shape)) == 2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops)
yield norm_img yield norm_img
else: else:
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
......
...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape): ...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape):
return padding_im return padding_im
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 10
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def get_img_data(value): def get_img_data(value):
"""get_img_data""" """get_img_data"""
if not value: if not value:
...@@ -67,7 +93,10 @@ def process_image(img, ...@@ -67,7 +93,10 @@ def process_image(img,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None): max_text_length=None):
norm_img = resize_norm_img(img, image_shape) if char_ops.character_type == "en":
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
if label is not None: if label is not None:
char_num = char_ops.get_char_num() char_num = char_ops.get_char_num()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册