From 6c0129af951d3b209300d3635b5cb934f03ab3bb Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 19 Sep 2017 11:15:29 +0800 Subject: [PATCH] Refine the GemmConvGrad2DKernel. --- paddle/operators/gemm_conv2d_op.h | 69 ++++++++++++++----------------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 96f4c06005..08b7df1dfe 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -109,18 +109,13 @@ class GemmConvGrad2DKernel : public framework::OpKernel { context.Input(framework::GradVarName("Output")); Tensor* input_grad = context.Output(framework::GradVarName("Input")); - Tensor* filter_grad_ = + Tensor* filter_grad = context.Output(framework::GradVarName("Filter")); // The filter and filter_grad will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); - Tensor filter_grad; - if (filter_grad_) { - filter_grad_->mutable_data(context.GetPlace()); - filter_grad = *filter_grad_; - } std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); @@ -165,20 +160,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel { filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - if (filter_grad_) { - filter_grad.Resize(filter_matrix_shape); - auto t1 = framework::EigenVector::Flatten(filter_grad); - t1.device(context.GetEigenDevice()) = - t1.constant(static_cast(0)); - } - - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - auto t2 = framework::EigenVector::Flatten(*input_grad); - t2.device(context.GetEigenDevice()) = - t2.constant(static_cast(0)); - } - auto* device_context = const_cast(context.device_context_); @@ -186,22 +167,21 @@ class GemmConvGrad2DKernel : public framework::OpKernel { // convolution backward weight operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; - Tensor in_grad_batch; - Tensor in_batch; - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - if (input_grad) { - in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - } - if (filter_grad_) { - in_batch = input->Slice(i, i + 1).Resize(input_shape); - } - for (int g = 0; g < groups; g++) { - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - if (input_grad) { + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); math::matmul(filter_slice, true, out_grad_slice, false, @@ -213,16 +193,31 @@ class GemmConvGrad2DKernel : public framework::OpKernel { col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); } + } + } - if (filter_grad_) { + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm Tensor filter_grad_slice = - filter_grad.Slice(g * out_step, (g + 1) * out_step); + filter_grad_.Slice(g * out_step, (g + 1) * out_step); math::matmul(out_grad_slice, false, col_matrix, true, T(1.0), &filter_grad_slice, T(1.0), device_context); -- GitLab