From 80c97e560dbc2e687ccc3c57882c6b86ea158bac Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 18 Oct 2019 14:52:36 +0800 Subject: [PATCH] fix bias_attr's bug of conv and conv_transpose, test=develop (#20704) --- python/paddle/fluid/layers/nn.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) mode change 100644 => 100755 python/paddle/fluid/layers/nn.py diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py old mode 100644 new mode 100755 index fd9af1a015b..dd657d2c1d8 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2801,7 +2801,10 @@ def conv2d(input, "data_format": data_format, }) - pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + if data_format == 'NCHW': + pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + else: + pre_act = helper.append_bias_op(pre_bias, dim_start=3, dim_end=4) return helper.append_activation(pre_act) @@ -3049,7 +3052,10 @@ def conv3d(input, "data_format": data_format, }) - pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + if data_format == 'NCDHW': + pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + else: + pre_act = helper.append_bias_op(pre_bias, dim_start=4, dim_end=5) return helper.append_activation(pre_act) @@ -5148,7 +5154,10 @@ def conv2d_transpose(input, 'data_format': data_format }) - pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + if data_format == 'NCHW': + pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + else: + pre_act = helper.append_bias_op(pre_bias, dim_start=3, dim_end=4) out = helper.append_activation(pre_act) return out @@ -5423,7 +5432,10 @@ def conv3d_transpose(input, 'data_format': data_format }) - pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + if data_format == 'NCHW': + pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + else: + pre_act = helper.append_bias_op(pre_bias, dim_start=4, dim_end=5) out = helper.append_activation(pre_act) return out -- GitLab