diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index b55ff99be91bfea795d4b023edcc57e95b97ecc2..640f3a205b1fd3ec7fe19d5c6b6e3aef9ddf3968 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -45,9 +45,7 @@ Optimizer: beta1: 0.9 beta2: 0.999 lr: -# name: Cosine learning_rate: 0.001 -# warmup_epoch: 0 regularizer: name: 'L2' factor: 0 diff --git a/configs/det/det_r50_vd_db.yml b/configs/det/det_r50_vd_db.yml new file mode 100644 index 0000000000000000000000000000000000000000..491983f57a59f0a4105712743a69a56c8212a0e9 --- /dev/null +++ b/configs/det/det_r50_vd_db.yml @@ -0,0 +1,130 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/det_rc/det_r50_vd/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [5000,4000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 50 + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.001 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [0.5] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [640, 640] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 16 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 8 \ No newline at end of file diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index def72375142ccf9f0988c0821d041b837442f3d0..38f1e8691e6056ada01a2d5c19f70955e8117498 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -5,7 +5,7 @@ Global: print_batch_step: 10 save_model_dir: ./output/rec/mv3_none_bilstm_ctc/ save_epoch_step: 3 - # evaluation is run every 5000 iterations after the 4000th iteration + # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] # if pretrained_model is saved in static mode, load_static_weights must set to True cal_metric_during_train: True @@ -13,7 +13,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: character_type: en @@ -21,7 +21,6 @@ Global: infer_mode: False use_space_char: False - Optimizer: name: Adam beta1: 0.9 diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml new file mode 100644 index 0000000000000000000000000000000000000000..33079ad48c94c217ef86ef3f245492a540559350 --- /dev/null +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -0,0 +1,95 @@ +Global: + use_gpu: True + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/mv3_none_none_ctc/ + save_epoch_step: 3 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: Rosetta + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + fc_decay: 0.0004 + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDateSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + batch_size_per_card: 256 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDateSet + data_dir: ./train_data/data_lmdb_release/validation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 8 diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml new file mode 100644 index 0000000000000000000000000000000000000000..08f68939d4f1e6de1c3688652bd86f6556a43384 --- /dev/null +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -0,0 +1,100 @@ +Global: + use_gpu: true + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/mv3_tps_bilstm_ctc/ + save_epoch_step: 3 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: STARNet + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 0.1 + model_name: small + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 96 + Head: + name: CTCHead + fc_decay: 0.0004 + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDateSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + batch_size_per_card: 256 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDateSet + data_dir: ./train_data/data_lmdb_release/validation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 4 diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 74937138aac1ae155ce8754d861d21a85b35e031..4ad2ff89ef1e72c58c426670742bc2ada27cfc4a 100644 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -5,7 +5,7 @@ Global: print_batch_step: 10 save_model_dir: ./output/rec/r34_vd_none_bilstm_ctc/ save_epoch_step: 3 - # evaluation is run every 5000 iterations after the 4000th iteration + # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] # if pretrained_model is saved in static mode, load_static_weights must set to True cal_metric_during_train: True @@ -13,7 +13,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: character_type: en @@ -21,7 +21,6 @@ Global: infer_mode: False use_space_char: False - Optimizer: name: Adam beta1: 0.9 @@ -71,7 +70,7 @@ Train: - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: - shuffle: False + shuffle: True batch_size_per_card: 256 drop_last: True num_workers: 8 diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml new file mode 100644 index 0000000000000000000000000000000000000000..9c1eeb304f41d46e49cee350e5d659dd1e0c8b0e --- /dev/null +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -0,0 +1,93 @@ +Global: + use_gpu: true + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/r34_vd_none_none_ctc/ + save_epoch_step: 3 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0005 + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: Rosetta + Backbone: + name: ResNet + layers: 34 + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + fc_decay: 0.0004 + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDateSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDateSet + data_dir: ./train_data/data_lmdb_release/validation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 32, 100] + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 4 diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index 269f1e411776d591610082cff2900ca6fc621752..aeded4926a6d09cf30210f2d348d2933461a06b1 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -5,7 +5,7 @@ Global: print_batch_step: 10 save_model_dir: ./output/rec/r34_vd_tps_bilstm_ctc/ save_epoch_step: 3 - # evaluation is run every 5000 iterations after the 4000th iteration + # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] # if pretrained_model is saved in static mode, load_static_weights must set to True cal_metric_during_train: True @@ -13,7 +13,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: character_type: en @@ -21,7 +21,6 @@ Global: infer_mode: False use_space_char: False - Optimizer: name: Adam beta1: 0.9 @@ -34,7 +33,7 @@ Optimizer: Architecture: model_type: rec - algorithm: CRNN + algorithm: STARNet Transform: name: TPS num_fiducial: 20 diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py deleted file mode 100755 index b4b2021e02c9905623fd9fad5c9673543569c1c2..0000000000000000000000000000000000000000 --- a/ppocr/utils/character.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import string -import re -from .check import check_config_params -import sys - - -class CharacterOps(object): - """ Convert between text-label and text-index """ - - def __init__(self, config): - self.character_type = config['character_type'] - self.loss_type = config['loss_type'] - self.max_text_len = config['max_text_length'] - if self.character_type == "en": - self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" - dict_character = list(self.character_str) - elif self.character_type == "ch": - character_dict_path = config['character_dict_path'] - add_space = False - if 'use_space_char' in config: - add_space = config['use_space_char'] - self.character_str = "" - with open(character_dict_path, "rb") as fin: - lines = fin.readlines() - for line in lines: - line = line.decode('utf-8').strip("\n").strip("\r\n") - self.character_str += line - if add_space: - self.character_str += " " - dict_character = list(self.character_str) - elif self.character_type == "en_sensitive": - # same with ASTER setting (use 94 char). - self.character_str = string.printable[:-6] - dict_character = list(self.character_str) - else: - self.character_str = None - assert self.character_str is not None, \ - "Nonsupport type of the character: {}".format(self.character_str) - self.beg_str = "sos" - self.end_str = "eos" - if self.loss_type == "attention": - dict_character = [self.beg_str, self.end_str] + dict_character - elif self.loss_type == "srn": - dict_character = dict_character + [self.beg_str, self.end_str] - self.dict = {} - for i, char in enumerate(dict_character): - self.dict[char] = i - self.character = dict_character - - def encode(self, text): - """convert text-label into text-index. - input: - text: text labels of each image. [batch_size] - - output: - text: concatenated text index for CTCLoss. - [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] - length: length of each text. [batch_size] - """ - if self.character_type == "en": - text = text.lower() - - text_list = [] - for char in text: - if char not in self.dict: - continue - text_list.append(self.dict[char]) - text = np.array(text_list) - return text - - def decode(self, text_index, is_remove_duplicate=False): - """ convert text-index into text-label. """ - char_list = [] - char_num = self.get_char_num() - - if self.loss_type == "attention": - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - ignored_tokens = [beg_idx, end_idx] - else: - ignored_tokens = [char_num] - - for idx in range(len(text_index)): - if text_index[idx] in ignored_tokens: - continue - if is_remove_duplicate: - if idx > 0 and text_index[idx - 1] == text_index[idx]: - continue - char_list.append(self.character[int(text_index[idx])]) - text = ''.join(char_list) - return text - - def get_char_num(self): - return len(self.character) - - def get_beg_end_flag_idx(self, beg_or_end): - if self.loss_type == "attention": - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx"\ - % beg_or_end - return idx - else: - err = "error in get_beg_end_flag_idx when using the loss %s"\ - % (self.loss_type) - assert False, err - - -def cal_predicts_accuracy(char_ops, - preds, - preds_lod, - labels, - labels_lod, - is_remove_duplicate=False): - acc_num = 0 - img_num = 0 - for ino in range(len(labels_lod) - 1): - beg_no = preds_lod[ino] - end_no = preds_lod[ino + 1] - preds_text = preds[beg_no:end_no].reshape(-1) - preds_text = char_ops.decode(preds_text, is_remove_duplicate) - - beg_no = labels_lod[ino] - end_no = labels_lod[ino + 1] - labels_text = labels[beg_no:end_no].reshape(-1) - labels_text = char_ops.decode(labels_text, is_remove_duplicate) - img_num += 1 - - if preds_text == labels_text: - acc_num += 1 - acc = acc_num * 1.0 / img_num - return acc, acc_num, img_num - - -def cal_predicts_accuracy_srn(char_ops, - preds, - labels, - max_text_len, - is_debug=False): - acc_num = 0 - img_num = 0 - - char_num = char_ops.get_char_num() - - total_len = preds.shape[0] - img_num = int(total_len / max_text_len) - for i in range(img_num): - cur_label = [] - cur_pred = [] - for j in range(max_text_len): - if labels[j + i * max_text_len] != int(char_num-1): #0 - cur_label.append(labels[j + i * max_text_len][0]) - else: - break - - for j in range(max_text_len + 1): - if j < len(cur_label) and preds[j + i * max_text_len][ - 0] != cur_label[j]: - break - elif j == len(cur_label) and j == max_text_len: - acc_num += 1 - break - elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num-1): - acc_num += 1 - break - acc = acc_num * 1.0 / img_num - return acc, acc_num, img_num - - -def convert_rec_attention_infer_res(preds): - img_num = preds.shape[0] - target_lod = [0] - convert_ids = [] - for ino in range(img_num): - end_pos = np.where(preds[ino, :] == 1)[0] - if len(end_pos) <= 1: - text_list = preds[ino, 1:] - else: - text_list = preds[ino, 1:end_pos[1]] - target_lod.append(target_lod[ino] + len(text_list)) - convert_ids = convert_ids + list(text_list) - convert_ids = np.array(convert_ids) - convert_ids = convert_ids.reshape((-1, 1)) - return convert_ids, target_lod - - -def convert_rec_label_to_lod(ori_labels): - img_num = len(ori_labels) - target_lod = [0] - convert_ids = [] - for ino in range(img_num): - target_lod.append(target_lod[ino] + len(ori_labels[ino])) - convert_ids = convert_ids + list(ori_labels[ino]) - convert_ids = np.array(convert_ids) - convert_ids = convert_ids.reshape((-1, 1)) - return convert_ids, target_lod diff --git a/ppocr/utils/check.py b/ppocr/utils/check.py deleted file mode 100755 index 3a0b14061e2058464c4942e6cf4f891a2a77ba69..0000000000000000000000000000000000000000 --- a/ppocr/utils/check.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import sys - -import logging -logger = logging.getLogger(__name__) - - -def check_config_params(config, config_name, params): - for param in params: - if param not in config: - err = "param %s didn't find in %s!" % (param, config_name) - assert False, err - return