提交 2918feb9 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add stride config of backbone and fix load ckp

上级 ebfd4475
...@@ -31,16 +31,28 @@ __all__ = [ ...@@ -31,16 +31,28 @@ __all__ = [
class MobileNetV3(): class MobileNetV3():
def __init__(self, params): def __init__(self, params):
self.scale = params['scale'] self.scale = params.get("scale", 0.5)
model_name = params['model_name'] model_name = params.get("model_name", "small")
large_stride = params.get("large_stride", [1, 2, 2, 2])
small_stride = params.get("small_stride", [2, 2, 2, 2])
assert isinstance(large_stride, list), "large_stride type must " \
"be list but got {}".format(type(large_stride))
assert isinstance(small_stride, list), "small_stride type must " \
"be list but got {}".format(type(small_stride))
assert len(large_stride) == 4, "large_stride length must be " \
"4 but got {}".format(len(large_stride))
assert len(small_stride) == 4, "small_stride length must be " \
"4 but got {}".format(len(small_stride))
self.inplanes = 16 self.inplanes = 16
if model_name == "large": if model_name == "large":
self.cfg = [ self.cfg = [
# k, exp, c, se, nl, s, # k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1], [3, 16, 16, False, 'relu', large_stride[0]],
[3, 64, 24, False, 'relu', (2, 1)], [3, 64, 24, False, 'relu', (large_stride[1], 1)],
[3, 72, 24, False, 'relu', 1], [3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', (2, 1)], [5, 72, 40, True, 'relu', (large_stride[2], 1)],
[5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1], [5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 1], [3, 240, 80, False, 'hard_swish', 1],
...@@ -49,7 +61,7 @@ class MobileNetV3(): ...@@ -49,7 +61,7 @@ class MobileNetV3():
[3, 184, 80, False, 'hard_swish', 1], [3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1], [3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1], [3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', (2, 1)], [5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
[5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1], [5, 960, 160, True, 'hard_swish', 1],
] ]
...@@ -58,15 +70,15 @@ class MobileNetV3(): ...@@ -58,15 +70,15 @@ class MobileNetV3():
elif model_name == "small": elif model_name == "small":
self.cfg = [ self.cfg = [
# k, exp, c, se, nl, s, # k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', (2, 1)], [3, 16, 16, True, 'relu', (small_stride[0], 1)],
[3, 72, 24, False, 'relu', (2, 1)], [3, 72, 24, False, 'relu', (small_stride[1], 1)],
[3, 88, 24, False, 'relu', 1], [3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', (2, 1)], [5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
[5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1], [5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1], [5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1], [5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', (2, 1)], [5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
[5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1], [5, 576, 96, True, 'hard_swish', 1],
] ]
......
...@@ -32,6 +32,7 @@ class CTCPredict(object): ...@@ -32,6 +32,7 @@ class CTCPredict(object):
self.char_num = params['char_num'] self.char_num = params['char_num']
self.encoder = SequenceEncoder(params) self.encoder = SequenceEncoder(params)
self.encoder_type = params['encoder_type'] self.encoder_type = params['encoder_type']
self.fc_decay = params.get("fc_decay", 0.0004)
def __call__(self, inputs, labels=None, mode=None): def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs) encoder_features = self.encoder(inputs)
...@@ -39,7 +40,7 @@ class CTCPredict(object): ...@@ -39,7 +40,7 @@ class CTCPredict(object):
encoder_features = fluid.layers.concat(encoder_features, axis=1) encoder_features = fluid.layers.concat(encoder_features, axis=1)
name = "ctc_fc" name = "ctc_fc"
para_attr, bias_attr = get_para_bias_attr( para_attr, bias_attr = get_para_bias_attr(
l2_decay=0.0004, k=encoder_features.shape[1], name=name) l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name)
predict = fluid.layers.fc(input=encoder_features, predict = fluid.layers.fc(input=encoder_features,
size=self.char_num + 1, size=self.char_num + 1,
param_attr=para_attr, param_attr=para_attr,
......
...@@ -114,10 +114,10 @@ def init_model(config, program, exe): ...@@ -114,10 +114,10 @@ def init_model(config, program, exe):
fluid.load(program, path, exe) fluid.load(program, path, exe)
logger.info("Finish initing model from {}".format(path)) logger.info("Finish initing model from {}".format(path))
else: else:
raise ValueError( raise ValueError("Model checkpoints {} does not exists,"
"Model checkpoints {} does not exists," "check if you lost the file prefix.".format(
"check if you lost the file prefix.".format(checkpoints + '.pdparams')) checkpoints + '.pdparams'))
else:
pretrain_weights = config['Global'].get('pretrain_weights') pretrain_weights = config['Global'].get('pretrain_weights')
if pretrain_weights: if pretrain_weights:
path = pretrain_weights path = pretrain_weights
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册