diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index ec3e9d867ce659b729c021c3a02acead73cacf52..49608bd9a7e1342f54d2734f7b8f5ba33dbc7f86 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -247,10 +247,12 @@ class SimpleReader(object): print("multiprocess is not fully compatible with Windows." "num_workers will be 1.") self.num_workers = 1 - if self.batch_size * get_device_num() > img_num: + if self.batch_size * get_device_num( + ) * self.num_workers > img_num: raise Exception( - "The number of the whole data ({}) is smaller than the batch_size * devices_num ({})". - format(img_num, self.batch_size * get_device_num())) + "The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})". + format(img_num, self.batch_size * get_device_num() * + self.num_workers)) for img_id in range(process_id, img_num, self.num_workers): label_infor = label_infor_list[img_id_list[img_id]] substr = label_infor.decode('utf-8').strip("\n").split("\t") diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index e504064242bfe6b0f433126ef4efb191b550a2eb..ff39a81210b7b71914f3c447b5e0035ac03db73b 100755 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -31,16 +31,28 @@ __all__ = [ class MobileNetV3(): def __init__(self, params): - self.scale = params['scale'] - model_name = params['model_name'] + self.scale = params.get("scale", 0.5) + model_name = params.get("model_name", "small") + large_stride = params.get("large_stride", [1, 2, 2, 2]) + small_stride = params.get("small_stride", [2, 2, 2, 2]) + + assert isinstance(large_stride, list), "large_stride type must " \ + "be list but got {}".format(type(large_stride)) + assert isinstance(small_stride, list), "small_stride type must " \ + "be list but got {}".format(type(small_stride)) + assert len(large_stride) == 4, "large_stride length must be " \ + "4 but got {}".format(len(large_stride)) + assert len(small_stride) == 4, "small_stride length must be " \ + "4 but got {}".format(len(small_stride)) + self.inplanes = 16 if model_name == "large": self.cfg = [ # k, exp, c, se, nl, s, - [3, 16, 16, False, 'relu', 1], - [3, 64, 24, False, 'relu', (2, 1)], + [3, 16, 16, False, 'relu', large_stride[0]], + [3, 64, 24, False, 'relu', (large_stride[1], 1)], [3, 72, 24, False, 'relu', 1], - [5, 72, 40, True, 'relu', (2, 1)], + [5, 72, 40, True, 'relu', (large_stride[2], 1)], [5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1], [3, 240, 80, False, 'hard_swish', 1], @@ -49,7 +61,7 @@ class MobileNetV3(): [3, 184, 80, False, 'hard_swish', 1], [3, 480, 112, True, 'hard_swish', 1], [3, 672, 112, True, 'hard_swish', 1], - [5, 672, 160, True, 'hard_swish', (2, 1)], + [5, 672, 160, True, 'hard_swish', (large_stride[3], 1)], [5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hard_swish', 1], ] @@ -58,15 +70,15 @@ class MobileNetV3(): elif model_name == "small": self.cfg = [ # k, exp, c, se, nl, s, - [3, 16, 16, True, 'relu', (2, 1)], - [3, 72, 24, False, 'relu', (2, 1)], + [3, 16, 16, True, 'relu', (small_stride[0], 1)], + [3, 72, 24, False, 'relu', (small_stride[1], 1)], [3, 88, 24, False, 'relu', 1], - [5, 96, 40, True, 'hard_swish', (2, 1)], + [5, 96, 40, True, 'hard_swish', (small_stride[2], 1)], [5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hard_swish', 1], [5, 120, 48, True, 'hard_swish', 1], [5, 144, 48, True, 'hard_swish', 1], - [5, 288, 96, True, 'hard_swish', (2, 1)], + [5, 288, 96, True, 'hard_swish', (small_stride[3], 1)], [5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hard_swish', 1], ] diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 37b4b00f8a16b219c978c685ea8a0d37234b40e4..6b8635e4647f186390179b880e132641342df0d6 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -32,6 +32,7 @@ class CTCPredict(object): self.char_num = params['char_num'] self.encoder = SequenceEncoder(params) self.encoder_type = params['encoder_type'] + self.fc_decay = params.get("fc_decay", 0.0004) def __call__(self, inputs, labels=None, mode=None): encoder_features = self.encoder(inputs) @@ -39,7 +40,7 @@ class CTCPredict(object): encoder_features = fluid.layers.concat(encoder_features, axis=1) name = "ctc_fc" para_attr, bias_attr = get_para_bias_attr( - l2_decay=0.0004, k=encoder_features.shape[1], name=name) + l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name) predict = fluid.layers.fc(input=encoder_features, size=self.char_num + 1, param_attr=para_attr, diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 80f64dc5835d1a33a9f746f71715f9ea202310da..f2346b3d22bbc4cf2e6d3ac54eaa1b378df6338b 100755 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -114,15 +114,15 @@ def init_model(config, program, exe): fluid.load(program, path, exe) logger.info("Finish initing model from {}".format(path)) else: - raise ValueError( - "Model checkpoints {} does not exists," - "check if you lost the file prefix.".format(checkpoints + '.pdparams')) - - pretrain_weights = config['Global'].get('pretrain_weights') - if pretrain_weights: - path = pretrain_weights - load_params(exe, program, path) - logger.info("Finish initing model from {}".format(path)) + raise ValueError("Model checkpoints {} does not exists," + "check if you lost the file prefix.".format( + checkpoints + '.pdparams')) + else: + pretrain_weights = config['Global'].get('pretrain_weights') + if pretrain_weights: + path = pretrain_weights + load_params(exe, program, path) + logger.info("Finish initing model from {}".format(path)) def save_model(program, model_path):