diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml new file mode 100644 index 0000000000000000000000000000000000000000..8b568637a189ac47438b84e89fc55ddc643ab297 --- /dev/null +++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml @@ -0,0 +1,126 @@ +Global: + debug: false + use_gpu: true + epoch_num: 800 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_mobile_pp-OCRv2_enhanced_ctc_loss + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + character_type: ch + max_text_length: 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_mobile_pp-OCRv2_enhanced_ctc_loss.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs : [700, 800] + values : [0.001, 0.0001] + warmup_epoch: 5 + regularizer: + name: L2 + factor: 2.0e-05 + + +Architecture: + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + return_feats: true + +Loss: + name: CombinedLoss + loss_config_list: + - CTCLoss: + use_focal_loss: false + weight: 1.0 + - CenterLoss: + weight: 0.05 + num_classes: 6625 + feat_dim: 96 + init_center: false + center_file_path: "./train_center.pkl" + # you can also try to add ace loss on your own dataset + # - ACELoss: + # weight: 0.1 + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + - label_ace + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 8 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index f761eaf6691aadb5b8e0452341cfe06927677dc9..ebf52ec4e1d8713fd4da407318b14e682952606d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode): data['length'] = np.array(len(text)) text = text + [0] * (self.max_text_len - len(text)) data['label'] = np.array(text) + + label = [0] * len(self.character) + for x in text: + label[x] += 1 + data['label_ace'] = np.array(label) return data def add_special_char(self, dict_character): diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index a6c2a9f6d16f5d777eb3003bb63b49cad7259acd..f3f4cd49332b605ec3a0e65e688d965fd91a5cdf 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -52,7 +52,6 @@ def build_loss(config): 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss' ] - config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/ace_loss.py b/ppocr/losses/ace_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9c868520e5bd7b398c7f248c416b70427baee0a6 --- /dev/null +++ b/ppocr/losses/ace_loss.py @@ -0,0 +1,50 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + + +class ACELoss(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + self.loss_func = nn.CrossEntropyLoss( + weight=None, + ignore_index=0, + reduction='none', + soft_label=True, + axis=-1) + + def __call__(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] + B, N = predicts.shape[:2] + div = paddle.to_tensor([N]).astype('float32') + + predicts = nn.functional.softmax(predicts, axis=-1) + aggregation_preds = paddle.sum(predicts, axis=1) + aggregation_preds = paddle.divide(aggregation_preds, div) + + length = batch[2].astype("float32") + batch = batch[3].astype("float32") + batch[:, 0] = paddle.subtract(div, length) + + batch = paddle.divide(batch, div) + + loss = self.loss_func(aggregation_preds, batch) + + return {"loss_ace": loss} diff --git a/ppocr/losses/center_loss.py b/ppocr/losses/center_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..72149df19f9e9a864ec31239177dc574648da3d5 --- /dev/null +++ b/ppocr/losses/center_loss.py @@ -0,0 +1,89 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import pickle + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class CenterLoss(nn.Layer): + """ + Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. + """ + + def __init__(self, + num_classes=6625, + feat_dim=96, + init_center=False, + center_file_path=None): + super().__init__() + self.num_classes = num_classes + self.feat_dim = feat_dim + self.centers = paddle.randn( + shape=[self.num_classes, self.feat_dim]).astype( + "float64") #random center + + if init_center: + assert os.path.exists( + center_file_path + ), f"center path({center_file_path}) must exist when init_center is set as True." + with open(center_file_path, 'rb') as f: + char_dict = pickle.load(f) + for key in char_dict.keys(): + self.centers[key] = paddle.to_tensor(char_dict[key]) + + def __call__(self, predicts, batch): + assert isinstance(predicts, (list, tuple)) + features, predicts = predicts + + feats_reshape = paddle.reshape( + features, [-1, features.shape[-1]]).astype("float64") + label = paddle.argmax(predicts, axis=2) + label = paddle.reshape(label, [label.shape[0] * label.shape[1]]) + + batch_size = feats_reshape.shape[0] + + #calc feat * feat + dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True) + dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) + + #dist2 of centers + dist2 = paddle.sum(paddle.square(self.centers), axis=1, + keepdim=True) #num_classes + dist2 = paddle.expand(dist2, + [self.num_classes, batch_size]).astype("float64") + dist2 = paddle.transpose(dist2, [1, 0]) + + #first x * x + y * y + distmat = paddle.add(dist1, dist2) + tmp = paddle.matmul(feats_reshape, + paddle.transpose(self.centers, [1, 0])) + distmat = distmat - 2.0 * tmp + + #generate the mask + classes = paddle.arange(self.num_classes).astype("int64") + label = paddle.expand( + paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) + mask = paddle.equal( + paddle.expand(classes, [batch_size, self.num_classes]), + label).astype("float64") #get mask + dist = paddle.multiply(distmat, mask) + loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size + return {'loss_center': loss} diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index f3bb36cf5ac751e6c27e4aa29a46fc5f913f7d05..72f706e37d6eb0c640cc30de80afe00bce82fd13 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -15,6 +15,10 @@ import paddle import paddle.nn as nn +from .rec_ctc_loss import CTCLoss +from .center_loss import CenterLoss +from .ace_loss import ACELoss + from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 73d3ae2ad2499607f897a102f6ea25e4cb7f297f..06aa7fa8458a5deece75f1393fe7300e8227d3ca 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss): if isinstance(loss, dict): for key in loss: loss_dict["{}_{}_{}_{}_{}".format(key, pair[ - 0], pair[1], map_name, idx)] = loss[key] + 0], pair[1], self.maps_name, idx)] = loss[key] else: loss_dict["{}_{}_{}".format(self.name, self.maps_name[ _c], idx)] = loss diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 6c0b56ff84db4ff23786fb781d461bf9fbc86ef2..5d09802b46d7ddfa802461760b917267155b3923 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -21,16 +21,24 @@ from paddle import nn class CTCLoss(nn.Layer): - def __init__(self, **kwargs): + def __init__(self, use_focal_loss=False, **kwargs): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') + self.use_focal_loss = use_focal_loss def forward(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape preds_lengths = paddle.to_tensor([N] * B, dtype='int64') labels = batch[1].astype("int32") label_lengths = batch[2].astype('int64') loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) + if self.use_focal_loss: + weight = paddle.exp(-loss) + weight = paddle.subtract(paddle.to_tensor([1.0]), weight) + weight = paddle.square(weight) * self.focal_loss_alpha + loss = paddle.multiply(loss, weight) loss = loss.mean() # sum return {'loss': loss} diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 9c38d31fa0abcf39a583e5edcebfc8f336f41c46..35d33d5f56b3b378286565cbfa9755f43343b278 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -38,6 +38,7 @@ class CTCHead(nn.Layer): out_channels, fc_decay=0.0004, mid_channels=None, + return_feats=False, **kwargs): super(CTCHead, self).__init__() if mid_channels is None: @@ -66,14 +67,22 @@ class CTCHead(nn.Layer): bias_attr=bias_attr2) self.out_channels = out_channels self.mid_channels = mid_channels + self.return_feats = return_feats def forward(self, x, targets=None): if self.mid_channels is None: predicts = self.fc(x) else: - predicts = self.fc1(x) - predicts = self.fc2(predicts) - + x = self.fc1(x) + predicts = self.fc2(x) + + if self.return_feats: + result = (x, predicts) + else: + result = predicts + if not self.training: predicts = F.softmax(predicts, axis=2) - return predicts + result = predicts + + return result diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index b17fea2c9f09c9067d8f2bb47ab5a9fde4d06c9b..3a4ebf52a3bd91ffd509b113103dab900588b0bd 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -18,6 +18,7 @@ from __future__ import print_function from __future__ import unicode_literals import copy +import platform __all__ = ['build_post_process'] @@ -28,7 +29,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess -from .pse_postprocess import PSEPostProcess + +if platform.system() != "Windows": + # pse is not support in Windows + from .pse_postprocess import PSEPostProcess def build_post_process(config, global_config=None): diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 16f7f76596649e59e3818f1c785feff6535ef499..c06159ca55600e7afe01a68ab43acd1919cf742c 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -111,6 +111,8 @@ class CTCLabelDecode(BaseRecLabelDecode): character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple): + preds = preds[-1] if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds_idx = preds.argmax(axis=2) diff --git a/tools/export_model.py b/tools/export_model.py index d8fe297235b2f5de6861d387cff64e8737cd30c0..64a0d4036303716a632eb93c53f2478f32b42848 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger): ] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SAR": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, 160], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 332cffd5395f8f511089b0bfde762820af7bbe8c..dad70281ef7604f110d29963103068bba1c8fd9d 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -68,6 +68,13 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "SAR": + postprocess_params = { + 'name': 'SARLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -194,6 +201,41 @@ class TextRecognizer(object): return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) + def resize_norm_img_sar(self, img, image_shape, + width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -216,11 +258,19 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - if self.rec_algorithm != "SRN": + if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR": norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == "SAR": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) + norm_img_batch.append(norm_img) else: norm_img = self.process_image_srn( img_list[indices[ino]], self.rec_image_shape, 8, 25) @@ -266,6 +316,25 @@ class TextRecognizer(object): if self.benchmark: self.autolog.times.stamp() preds = {"predict": outputs[2]} + elif self.rec_algorithm == "SAR": + valid_ratios = np.concatenate(valid_ratios) + inputs = [ + norm_img_batch, + valid_ratios, + ] + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[ + i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs[0] else: self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() diff --git a/tools/program.py b/tools/program.py index 4df87c16868260f7e09979b4dcfa76bccef72a79..430631bfea8a01d590f93dc5ed4e4829c1cc62e9 100755 --- a/tools/program.py +++ b/tools/program.py @@ -394,6 +394,18 @@ def preprocess(is_train=False): config = load_config(FLAGS.config) merge_config(FLAGS.opt) + if is_train: + # save_config + save_model_dir = config['Global']['save_model_dir'] + os.makedirs(save_model_dir, exist_ok=True) + with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: + yaml.dump( + dict(config), f, default_flow_style=False, sort_keys=False) + log_file = '{}/train.log'.format(save_model_dir) + else: + log_file = None + logger = get_logger(name='root', log_file=log_file) + # check if set use_gpu=True in paddlepaddle cpu version use_gpu = config['Global']['use_gpu'] check_gpu(use_gpu) @@ -403,22 +415,17 @@ def preprocess(is_train=False): 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED'] + windows_not_support_list = ['PSE'] + if platform.system() == "Windows" and alg in windows_not_support_list: + logger.warning('{} is not support in Windows now'.format( + windows_not_support_list)) + sys.exit() device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1 - if is_train: - # save_config - save_model_dir = config['Global']['save_model_dir'] - os.makedirs(save_model_dir, exist_ok=True) - with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: - yaml.dump( - dict(config), f, default_flow_style=False, sort_keys=False) - log_file = '{}/train.log'.format(save_model_dir) - else: - log_file = None - logger = get_logger(name='root', log_file=log_file) + if config['Global']['use_visualdl']: from visualdl import LogWriter save_model_dir = config['Global']['save_model_dir']