未验证 提交 c14724de 编写于 作者: L littletomatodonkey 提交者: GitHub

Merge pull request #168 from littletomatodonkey/fix_lr

fix lr mult in resnet_vd
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -49,7 +49,8 @@ class ResNet(): ...@@ -49,7 +49,8 @@ class ResNet():
layers = self.layers layers = self.layers
supported_layers = [18, 34, 50, 101, 152, 200] supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \ assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers) "supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18: if layers == 18:
depth = [2, 2, 2, 2] depth = [2, 2, 2, 2]
...@@ -159,7 +160,9 @@ class ResNet(): ...@@ -159,7 +160,9 @@ class ResNet():
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights" + self.postfix_name), param_attr=ParamAttr(
name=name + "_weights" + self.postfix_name,
learning_rate=lr_mult),
bias_attr=False) bias_attr=False)
if name == "conv1": if name == "conv1":
bn_name = "bn_" + name bn_name = "bn_" + name
...@@ -168,8 +171,12 @@ class ResNet(): ...@@ -168,8 +171,12 @@ class ResNet():
return fluid.layers.batch_norm( return fluid.layers.batch_norm(
input=conv, input=conv,
act=act, act=act,
param_attr=ParamAttr(name=bn_name + '_scale' + self.postfix_name), param_attr=ParamAttr(
bias_attr=ParamAttr(bn_name + '_offset' + self.postfix_name), name=bn_name + '_scale' + self.postfix_name,
learning_rate=lr_mult),
bias_attr=ParamAttr(
bn_name + '_offset' + self.postfix_name,
learning_rate=lr_mult),
moving_mean_name=bn_name + '_mean' + self.postfix_name, moving_mean_name=bn_name + '_mean' + self.postfix_name,
moving_variance_name=bn_name + '_variance' + self.postfix_name) moving_variance_name=bn_name + '_variance' + self.postfix_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册