提交 5a376837 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix hrnet name

上级 27f5ac5a
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -74,7 +74,7 @@ class HRNet(): ...@@ -74,7 +74,7 @@ class HRNet():
tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3') tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
st4 = self.stage(tr3, num_modules_4, channels_4, name='st4') st4 = self.stage(tr3, num_modules_4, channels_4, name='st4')
#classification # classification
last_cls = self.last_cls_out(x=st4, name='cls_head') last_cls = self.last_cls_out(x=st4, name='cls_head')
y = last_cls[0] y = last_cls[0]
last_num_filters = [256, 512, 1024] last_num_filters = [256, 512, 1024]
...@@ -273,7 +273,7 @@ class HRNet(): ...@@ -273,7 +273,7 @@ class HRNet():
input=conv, input=conv,
num_channels=num_filters, num_channels=num_filters,
reduction_ratio=16, reduction_ratio=16,
name=name + '_fc') name="fc" + name)
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def bottleneck_block(self, def bottleneck_block(self,
...@@ -312,7 +312,7 @@ class HRNet(): ...@@ -312,7 +312,7 @@ class HRNet():
input=conv, input=conv,
num_channels=num_filters * 4, num_channels=num_filters * 4,
reduction_ratio=16, reduction_ratio=16,
name=name + '_fc') name="fc" + name)
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def squeeze_excitation(self, def squeeze_excitation(self,
...@@ -325,7 +325,7 @@ class HRNet(): ...@@ -325,7 +325,7 @@ class HRNet():
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc( squeeze = fluid.layers.fc(
input=pool, input=pool,
size=num_channels / reduction_ratio, size=int(num_channels / reduction_ratio),
act='relu', act='relu',
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), initializer=fluid.initializer.Uniform(-stdv, stdv),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册