From 2918feb90eb3c0f3e35f486c209acf260c9a8471 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 12 Aug 2020 07:08:07 +0000 Subject: [PATCH] add stride config of backbone and fix load ckp --- ppocr/modeling/backbones/rec_mobilenet_v3.py | 32 ++++++++++++++------ ppocr/modeling/heads/rec_ctc_head.py | 3 +- ppocr/utils/save_load.py | 18 +++++------ 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index e5040642..ff39a812 100755 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -31,16 +31,28 @@ __all__ = [ class MobileNetV3(): def __init__(self, params): - self.scale = params['scale'] - model_name = params['model_name'] + self.scale = params.get("scale", 0.5) + model_name = params.get("model_name", "small") + large_stride = params.get("large_stride", [1, 2, 2, 2]) + small_stride = params.get("small_stride", [2, 2, 2, 2]) + + assert isinstance(large_stride, list), "large_stride type must " \ + "be list but got {}".format(type(large_stride)) + assert isinstance(small_stride, list), "small_stride type must " \ + "be list but got {}".format(type(small_stride)) + assert len(large_stride) == 4, "large_stride length must be " \ + "4 but got {}".format(len(large_stride)) + assert len(small_stride) == 4, "small_stride length must be " \ + "4 but got {}".format(len(small_stride)) + self.inplanes = 16 if model_name == "large": self.cfg = [ # k, exp, c, se, nl, s, - [3, 16, 16, False, 'relu', 1], - [3, 64, 24, False, 'relu', (2, 1)], + [3, 16, 16, False, 'relu', large_stride[0]], + [3, 64, 24, False, 'relu', (large_stride[1], 1)], [3, 72, 24, False, 'relu', 1], - [5, 72, 40, True, 'relu', (2, 1)], + [5, 72, 40, True, 'relu', (large_stride[2], 1)], [5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1], [3, 240, 80, False, 'hard_swish', 1], @@ -49,7 +61,7 @@ class MobileNetV3(): [3, 184, 80, False, 'hard_swish', 1], [3, 480, 112, True, 'hard_swish', 1], [3, 672, 112, True, 'hard_swish', 1], - [5, 672, 160, True, 'hard_swish', (2, 1)], + [5, 672, 160, True, 'hard_swish', (large_stride[3], 1)], [5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hard_swish', 1], ] @@ -58,15 +70,15 @@ class MobileNetV3(): elif model_name == "small": self.cfg = [ # k, exp, c, se, nl, s, - [3, 16, 16, True, 'relu', (2, 1)], - [3, 72, 24, False, 'relu', (2, 1)], + [3, 16, 16, True, 'relu', (small_stride[0], 1)], + [3, 72, 24, False, 'relu', (small_stride[1], 1)], [3, 88, 24, False, 'relu', 1], - [5, 96, 40, True, 'hard_swish', (2, 1)], + [5, 96, 40, True, 'hard_swish', (small_stride[2], 1)], [5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hard_swish', 1], [5, 120, 48, True, 'hard_swish', 1], [5, 144, 48, True, 'hard_swish', 1], - [5, 288, 96, True, 'hard_swish', (2, 1)], + [5, 288, 96, True, 'hard_swish', (small_stride[3], 1)], [5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hard_swish', 1], ] diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 37b4b00f..6b8635e4 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -32,6 +32,7 @@ class CTCPredict(object): self.char_num = params['char_num'] self.encoder = SequenceEncoder(params) self.encoder_type = params['encoder_type'] + self.fc_decay = params.get("fc_decay", 0.0004) def __call__(self, inputs, labels=None, mode=None): encoder_features = self.encoder(inputs) @@ -39,7 +40,7 @@ class CTCPredict(object): encoder_features = fluid.layers.concat(encoder_features, axis=1) name = "ctc_fc" para_attr, bias_attr = get_para_bias_attr( - l2_decay=0.0004, k=encoder_features.shape[1], name=name) + l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name) predict = fluid.layers.fc(input=encoder_features, size=self.char_num + 1, param_attr=para_attr, diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 80f64dc5..f2346b3d 100755 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -114,15 +114,15 @@ def init_model(config, program, exe): fluid.load(program, path, exe) logger.info("Finish initing model from {}".format(path)) else: - raise ValueError( - "Model checkpoints {} does not exists," - "check if you lost the file prefix.".format(checkpoints + '.pdparams')) - - pretrain_weights = config['Global'].get('pretrain_weights') - if pretrain_weights: - path = pretrain_weights - load_params(exe, program, path) - logger.info("Finish initing model from {}".format(path)) + raise ValueError("Model checkpoints {} does not exists," + "check if you lost the file prefix.".format( + checkpoints + '.pdparams')) + else: + pretrain_weights = config['Global'].get('pretrain_weights') + if pretrain_weights: + path = pretrain_weights + load_params(exe, program, path) + logger.info("Finish initing model from {}".format(path)) def save_model(program, model_path): -- GitLab