未验证 提交 efd21dd9 编写于 作者: W will-jl944 提交者: GitHub

Fix assertion bug (#4772)

* fix assertion bug

* fix assertion bug
上级 1cb2d6c3
......@@ -48,10 +48,8 @@ class ConvNormLayer(nn.Layer):
self.act = act
norm_lr = 0. if freeze_norm else 1.
if norm_type is not None:
assert (
norm_type in ['bn', 'sync_bn', 'gn'],
"norm_type should be one of ['bn', 'sync_bn', 'gn'], but got {}".
format(norm_type))
assert norm_type in ['bn', 'sync_bn', 'gn'], \
"norm_type should be one of ['bn', 'sync_bn', 'gn'], but got {}".format(norm_type)
param_attr = ParamAttr(
initializer=Constant(1.0),
learning_rate=norm_lr,
......@@ -277,10 +275,8 @@ class ShuffleUnit(nn.Layer):
branch_channel = out_channel // 2
self.stride = stride
if self.stride == 1:
assert (
in_channel == branch_channel * 2,
"when stride=1, in_channel {} should equal to branch_channel*2 {}"
.format(in_channel, branch_channel * 2))
assert in_channel == branch_channel * 2, \
"when stride=1, in_channel {} should equal to branch_channel*2 {}".format(in_channel, branch_channel * 2)
if stride > 1:
self.branch1 = nn.Sequential(
ConvNormLayer(
......@@ -500,11 +496,11 @@ class LiteHRNetModule(nn.Layer):
freeze_norm=False,
norm_decay=0.):
super(LiteHRNetModule, self).__init__()
assert (num_branches == len(in_channels),
"num_branches {} should equal to num_in_channels {}"
.format(num_branches, len(in_channels)))
assert (module_type in ['LITE', 'NAIVE'],
"module_type should be one of ['LITE', 'NAIVE']")
assert num_branches == len(in_channels),\
"num_branches {} should equal to num_in_channels {}".format(num_branches, len(in_channels))
assert module_type in [
'LITE', 'NAIVE'
], "module_type should be one of ['LITE', 'NAIVE']"
self.num_branches = num_branches
self.in_channels = in_channels
self.multiscale_output = multiscale_output
......@@ -699,10 +695,8 @@ class LiteHRNet(nn.Layer):
super(LiteHRNet, self).__init__()
if isinstance(return_idx, Integral):
return_idx = [return_idx]
assert (
network_type in ["lite_18", "lite_30", "naive", "wider_naive"],
assert network_type in ["lite_18", "lite_30", "naive", "wider_naive"], \
"the network_type should be one of [lite_18, lite_30, naive, wider_naive]"
)
assert len(return_idx) > 0, "need one or more return index"
self.freeze_at = freeze_at
self.freeze_norm = freeze_norm
......
......@@ -1592,8 +1592,7 @@ def smooth_l1(input, label, inside_weight=None, outside_weight=None,
def channel_shuffle(x, groups):
batch_size, num_channels, height, width = x.shape[0:4]
assert (num_channels % groups == 0,
'num_channels should be divisible by groups')
assert num_channels % groups == 0, 'num_channels should be divisible by groups'
channels_per_group = num_channels // groups
x = paddle.reshape(
x=x, shape=[batch_size, groups, channels_per_group, height, width])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册