未验证 提交 efd21dd9 编写于 作者: W will-jl944 提交者: GitHub

Fix assertion bug (#4772)

* fix assertion bug

* fix assertion bug
上级 1cb2d6c3
...@@ -48,10 +48,8 @@ class ConvNormLayer(nn.Layer): ...@@ -48,10 +48,8 @@ class ConvNormLayer(nn.Layer):
self.act = act self.act = act
norm_lr = 0. if freeze_norm else 1. norm_lr = 0. if freeze_norm else 1.
if norm_type is not None: if norm_type is not None:
assert ( assert norm_type in ['bn', 'sync_bn', 'gn'], \
norm_type in ['bn', 'sync_bn', 'gn'], "norm_type should be one of ['bn', 'sync_bn', 'gn'], but got {}".format(norm_type)
"norm_type should be one of ['bn', 'sync_bn', 'gn'], but got {}".
format(norm_type))
param_attr = ParamAttr( param_attr = ParamAttr(
initializer=Constant(1.0), initializer=Constant(1.0),
learning_rate=norm_lr, learning_rate=norm_lr,
...@@ -277,10 +275,8 @@ class ShuffleUnit(nn.Layer): ...@@ -277,10 +275,8 @@ class ShuffleUnit(nn.Layer):
branch_channel = out_channel // 2 branch_channel = out_channel // 2
self.stride = stride self.stride = stride
if self.stride == 1: if self.stride == 1:
assert ( assert in_channel == branch_channel * 2, \
in_channel == branch_channel * 2, "when stride=1, in_channel {} should equal to branch_channel*2 {}".format(in_channel, branch_channel * 2)
"when stride=1, in_channel {} should equal to branch_channel*2 {}"
.format(in_channel, branch_channel * 2))
if stride > 1: if stride > 1:
self.branch1 = nn.Sequential( self.branch1 = nn.Sequential(
ConvNormLayer( ConvNormLayer(
...@@ -500,11 +496,11 @@ class LiteHRNetModule(nn.Layer): ...@@ -500,11 +496,11 @@ class LiteHRNetModule(nn.Layer):
freeze_norm=False, freeze_norm=False,
norm_decay=0.): norm_decay=0.):
super(LiteHRNetModule, self).__init__() super(LiteHRNetModule, self).__init__()
assert (num_branches == len(in_channels), assert num_branches == len(in_channels),\
"num_branches {} should equal to num_in_channels {}" "num_branches {} should equal to num_in_channels {}".format(num_branches, len(in_channels))
.format(num_branches, len(in_channels))) assert module_type in [
assert (module_type in ['LITE', 'NAIVE'], 'LITE', 'NAIVE'
"module_type should be one of ['LITE', 'NAIVE']") ], "module_type should be one of ['LITE', 'NAIVE']"
self.num_branches = num_branches self.num_branches = num_branches
self.in_channels = in_channels self.in_channels = in_channels
self.multiscale_output = multiscale_output self.multiscale_output = multiscale_output
...@@ -699,10 +695,8 @@ class LiteHRNet(nn.Layer): ...@@ -699,10 +695,8 @@ class LiteHRNet(nn.Layer):
super(LiteHRNet, self).__init__() super(LiteHRNet, self).__init__()
if isinstance(return_idx, Integral): if isinstance(return_idx, Integral):
return_idx = [return_idx] return_idx = [return_idx]
assert ( assert network_type in ["lite_18", "lite_30", "naive", "wider_naive"], \
network_type in ["lite_18", "lite_30", "naive", "wider_naive"],
"the network_type should be one of [lite_18, lite_30, naive, wider_naive]" "the network_type should be one of [lite_18, lite_30, naive, wider_naive]"
)
assert len(return_idx) > 0, "need one or more return index" assert len(return_idx) > 0, "need one or more return index"
self.freeze_at = freeze_at self.freeze_at = freeze_at
self.freeze_norm = freeze_norm self.freeze_norm = freeze_norm
......
...@@ -1592,8 +1592,7 @@ def smooth_l1(input, label, inside_weight=None, outside_weight=None, ...@@ -1592,8 +1592,7 @@ def smooth_l1(input, label, inside_weight=None, outside_weight=None,
def channel_shuffle(x, groups): def channel_shuffle(x, groups):
batch_size, num_channels, height, width = x.shape[0:4] batch_size, num_channels, height, width = x.shape[0:4]
assert (num_channels % groups == 0, assert num_channels % groups == 0, 'num_channels should be divisible by groups'
'num_channels should be divisible by groups')
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
x = paddle.reshape( x = paddle.reshape(
x=x, shape=[batch_size, groups, channels_per_group, height, width]) x=x, shape=[batch_size, groups, channels_per_group, height, width])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册