未验证 提交 bad9f6cd 编写于 作者: D dyning 提交者: GitHub

Merge pull request #520 from littletomatodonkey/fix_mv3

add stride config of backbone and fix load ckp
......@@ -247,10 +247,12 @@ class SimpleReader(object):
print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.")
self.num_workers = 1
if self.batch_size * get_device_num() > img_num:
if self.batch_size * get_device_num(
) * self.num_workers > img_num:
raise Exception(
"The number of the whole data ({}) is smaller than the batch_size * devices_num ({})".
format(img_num, self.batch_size * get_device_num()))
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
format(img_num, self.batch_size * get_device_num() *
self.num_workers))
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
substr = label_infor.decode('utf-8').strip("\n").split("\t")
......
......@@ -31,16 +31,28 @@ __all__ = [
class MobileNetV3():
def __init__(self, params):
self.scale = params['scale']
model_name = params['model_name']
self.scale = params.get("scale", 0.5)
model_name = params.get("model_name", "small")
large_stride = params.get("large_stride", [1, 2, 2, 2])
small_stride = params.get("small_stride", [2, 2, 2, 2])
assert isinstance(large_stride, list), "large_stride type must " \
"be list but got {}".format(type(large_stride))
assert isinstance(small_stride, list), "small_stride type must " \
"be list but got {}".format(type(small_stride))
assert len(large_stride) == 4, "large_stride length must be " \
"4 but got {}".format(len(large_stride))
assert len(small_stride) == 4, "small_stride length must be " \
"4 but got {}".format(len(small_stride))
self.inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', (2, 1)],
[3, 16, 16, False, 'relu', large_stride[0]],
[3, 64, 24, False, 'relu', (large_stride[1], 1)],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', (2, 1)],
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 1],
......@@ -49,7 +61,7 @@ class MobileNetV3():
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', (2, 1)],
[5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
]
......@@ -58,15 +70,15 @@ class MobileNetV3():
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', (2, 1)],
[3, 72, 24, False, 'relu', (2, 1)],
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', (2, 1)],
[5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', (2, 1)],
[5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
]
......
......@@ -32,6 +32,7 @@ class CTCPredict(object):
self.char_num = params['char_num']
self.encoder = SequenceEncoder(params)
self.encoder_type = params['encoder_type']
self.fc_decay = params.get("fc_decay", 0.0004)
def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs)
......@@ -39,7 +40,7 @@ class CTCPredict(object):
encoder_features = fluid.layers.concat(encoder_features, axis=1)
name = "ctc_fc"
para_attr, bias_attr = get_para_bias_attr(
l2_decay=0.0004, k=encoder_features.shape[1], name=name)
l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name)
predict = fluid.layers.fc(input=encoder_features,
size=self.char_num + 1,
param_attr=para_attr,
......
......@@ -114,15 +114,15 @@ def init_model(config, program, exe):
fluid.load(program, path, exe)
logger.info("Finish initing model from {}".format(path))
else:
raise ValueError(
"Model checkpoints {} does not exists,"
"check if you lost the file prefix.".format(checkpoints + '.pdparams'))
pretrain_weights = config['Global'].get('pretrain_weights')
if pretrain_weights:
path = pretrain_weights
load_params(exe, program, path)
logger.info("Finish initing model from {}".format(path))
raise ValueError("Model checkpoints {} does not exists,"
"check if you lost the file prefix.".format(
checkpoints + '.pdparams'))
else:
pretrain_weights = config['Global'].get('pretrain_weights')
if pretrain_weights:
path = pretrain_weights
load_params(exe, program, path)
logger.info("Finish initing model from {}".format(path))
def save_model(program, model_path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册