提交 23431930 编写于 作者: S sunyanfang01

fix the blazenet

上级 7a614e76
...@@ -90,6 +90,7 @@ class BlazeNet(object): ...@@ -90,6 +90,7 @@ class BlazeNet(object):
name='double_blaze_{}'.format(k)) name='double_blaze_{}'.format(k))
elif len(v) == 4: elif len(v) == 4:
layers.append(conv) layers.append(conv)
fluid.layers.Print(layers[-1])
conv = self.BlazeBlock( conv = self.BlazeBlock(
conv, conv,
v[0], v[0],
...@@ -102,6 +103,8 @@ class BlazeNet(object): ...@@ -102,6 +103,8 @@ class BlazeNet(object):
if not self.with_extra_blocks: if not self.with_extra_blocks:
return layers[-1] return layers[-1]
fluid.layers.Print(layers[-2])
fluid.layers.Print(layers[-1])
return layers[-2], layers[-1] return layers[-2], layers[-1]
else: else:
conv1 = self._conv_norm( conv1 = self._conv_norm(
...@@ -152,7 +155,7 @@ class BlazeNet(object): ...@@ -152,7 +155,7 @@ class BlazeNet(object):
stride=stride, stride=stride,
padding=2, padding=2,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=True, use_cudnn=False,
name=name + "1_dw") name=name + "1_dw")
else: else:
conv_dw_1 = self._conv_norm( conv_dw_1 = self._conv_norm(
...@@ -162,7 +165,7 @@ class BlazeNet(object): ...@@ -162,7 +165,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=True, use_cudnn=False,
name=name + "1_dw_1") name=name + "1_dw_1")
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
input=conv_dw_1, input=conv_dw_1,
...@@ -171,7 +174,7 @@ class BlazeNet(object): ...@@ -171,7 +174,7 @@ class BlazeNet(object):
stride=stride, stride=stride,
padding=1, padding=1,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=True, use_cudnn=False,
name=name + "1_dw_2") name=name + "1_dw_2")
conv_pw = self._conv_norm( conv_pw = self._conv_norm(
...@@ -191,7 +194,7 @@ class BlazeNet(object): ...@@ -191,7 +194,7 @@ class BlazeNet(object):
num_filters=out_channels, num_filters=out_channels,
stride=1, stride=1,
padding=2, padding=2,
use_cudnn=True, use_cudnn=False,
name=name + "2_dw") name=name + "2_dw")
else: else:
conv_dw_1 = self._conv_norm( conv_dw_1 = self._conv_norm(
...@@ -201,7 +204,7 @@ class BlazeNet(object): ...@@ -201,7 +204,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=out_channels, num_groups=out_channels,
use_cudnn=True, use_cudnn=False,
name=name + "2_dw_1") name=name + "2_dw_1")
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
input=conv_dw_1, input=conv_dw_1,
...@@ -210,7 +213,7 @@ class BlazeNet(object): ...@@ -210,7 +213,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=out_channels, num_groups=out_channels,
use_cudnn=True, use_cudnn=False,
name=name + "2_dw_2") name=name + "2_dw_2")
conv_pw = self._conv_norm( conv_pw = self._conv_norm(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册