提交 9e55291d 编写于 作者: Z Zhang Ting 提交者: Aurelius84

[cherry-pick] fix bias_attr's bug of conv and conv_transpose, test=release/1.6 (#20704) (#20716)

上级 c2f86f95
......@@ -2801,7 +2801,10 @@ def conv2d(input,
"data_format": data_format,
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
if data_format == 'NCHW':
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
else:
pre_act = helper.append_bias_op(pre_bias, dim_start=3, dim_end=4)
return helper.append_activation(pre_act)
......@@ -3049,7 +3052,10 @@ def conv3d(input,
"data_format": data_format,
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
if data_format == 'NCDHW':
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
else:
pre_act = helper.append_bias_op(pre_bias, dim_start=4, dim_end=5)
return helper.append_activation(pre_act)
......@@ -5148,7 +5154,10 @@ def conv2d_transpose(input,
'data_format': data_format
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
if data_format == 'NCHW':
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
else:
pre_act = helper.append_bias_op(pre_bias, dim_start=3, dim_end=4)
out = helper.append_activation(pre_act)
return out
......@@ -5423,7 +5432,10 @@ def conv3d_transpose(input,
'data_format': data_format
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
if data_format == 'NCHW':
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
else:
pre_act = helper.append_bias_op(pre_bias, dim_start=4, dim_end=5)
out = helper.append_activation(pre_act)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册