提交 9955d74f 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix acnet

上级 ed3ce87d
mode: 'train' mode: 'train'
ARCHITECTURE: ARCHITECTURE:
name: "ResNet_ACNet" name: "ResNet50_ACNet"
pretrained_model: "" pretrained_model: ""
model_save_dir: "./output/" model_save_dir: "./output/"
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -18,13 +18,12 @@ from __future__ import print_function ...@@ -18,13 +18,12 @@ from __future__ import print_function
import math import math
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
__all__ = [ __all__ = [
"ResNet_ACNet", "ResNet18_ACNet", "ResNet34_ACNet", "ResNet50_ACNet", "ResNet18_ACNet", "ResNet34_ACNet", "ResNet50_ACNet", "ResNet101_ACNet",
"ResNet101_ACNet", "ResNet152_ACNet" "ResNet152_ACNet"
] ]
...@@ -41,7 +40,8 @@ class ResNetACNet(object): ...@@ -41,7 +40,8 @@ class ResNetACNet(object):
layers = self.layers layers = self.layers
supported_layers = [18, 34, 50, 101, 152] supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \ assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers) "supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18: if layers == 18:
depth = [2, 2, 2, 2] depth = [2, 2, 2, 2]
...@@ -240,7 +240,7 @@ class ResNetACNet(object): ...@@ -240,7 +240,7 @@ class ResNetACNet(object):
def shortcut(self, input, ch_out, stride, is_first, name): def shortcut(self, input, ch_out, stride, is_first, name):
""" shortcut """ """ shortcut """
ch_in = input.shape[1] ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True: if ch_in != ch_out or stride != 1 or is_first is True:
return self.conv_bn_layer( return self.conv_bn_layer(
input=input, input=input,
num_filters=ch_out, num_filters=ch_out,
...@@ -304,7 +304,7 @@ class ResNetACNet(object): ...@@ -304,7 +304,7 @@ class ResNetACNet(object):
def ResNet18_ACNet(deploy=False): def ResNet18_ACNet(deploy=False):
"""ResNet18 + ACNet""" """ResNet18 + ACNet"""
model = ResNet_ACNet(layers=18, deploy=deploy) model = ResNetACNet(layers=18, deploy=deploy)
return model return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册