提交 45983e33 编写于 作者: C chenguowei01

update hrnet.py

上级 7295d092
...@@ -645,12 +645,14 @@ class FuseLayers(fluid.dygraph.Layer): ...@@ -645,12 +645,14 @@ class FuseLayers(fluid.dygraph.Layer):
residual_func_idx = 0 residual_func_idx = 0
for i in range(self._actual_ch): for i in range(self._actual_ch):
residual = input[i] residual = input[i]
residual_shape = residual.shape[-2:]
for j in range(len(self._in_channels)): for j in range(len(self._in_channels)):
if j > i: if j > i:
y = self.residual_func_list[residual_func_idx](input[j]) y = self.residual_func_list[residual_func_idx](input[j])
residual_func_idx += 1 residual_func_idx += 1
y = fluid.layers.resize_bilinear(input=y, scale=2**(j - i)) y = fluid.layers.resize_bilinear(
input=y, out_shape=residual_shape)
residual = fluid.layers.elementwise_add( residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None) x=residual, y=y, act=None)
elif j < i: elif j < i:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册