diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 560dfd311f25d8411be41f77c62735a4f3330ec8..cdcc0039b0fa8f87b3d24abba3a1fb9004cc96bc 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -68,7 +68,7 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; framework::DDim filter_matrix_shape = { - output_channels, framework::product(filter.dims()) / output_channels}; + filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = {output_channels, @@ -99,24 +99,28 @@ class GemmConvGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); - Tensor* filter = const_cast(context.Input("Filter")); const Tensor* output_grad = context.Input(framework::GradVarName("Output")); Tensor* input_grad = context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = + Tensor* filter_grad_ = context.Output(framework::GradVarName("Filter")); input_grad->mutable_data(context.GetPlace()); - filter_grad->mutable_data(context.GetPlace()); + filter_grad_->mutable_data(context.GetPlace()); + + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor filter_grad = *filter_grad_; std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - auto filter_dims = filter->dims(); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; - int filter_height = filter->dims()[filter->dims().size() - 2]; - int filter_width = filter->dims()[filter->dims().size() - 1]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; int output_height = output_grad->dims()[2]; int output_width = output_grad->dims()[3]; @@ -126,64 +130,65 @@ class GemmConvGradKernel : public framework::OpKernel { paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kCFO, Place, T> im2col; - Tensor col; + // use col_shape in the im2col and col2im calculation framework::DDim col_shape = {input_channels, filter_height, filter_width, output_height, output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; + Tensor col; col.mutable_data(col_shape, context.GetPlace()); - - auto* device_context = - const_cast(context.device_context_); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; - framework::DDim filter_matrix_shape = { - filter->dims()[0], - filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; - framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, - output_height * output_width}; framework::DDim output_matrix_shape = { output_grad->dims()[1], output_grad->dims()[2] * output_grad->dims()[3]}; - filter->Resize(filter_matrix_shape); - filter_grad->Resize(filter_matrix_shape); - auto t1 = framework::EigenVector::Flatten(*filter_grad); + framework::DDim filter_matrix_shape = { + filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + filter_grad.Resize(filter_matrix_shape); + + auto t1 = framework::EigenVector::Flatten(filter_grad); t1.device(context.GetEigenDevice()) = t1.constant(static_cast(0)); auto t2 = framework::EigenVector::Flatten(*input_grad); t2.device(context.GetEigenDevice()) = t2.constant(static_cast(0)); + auto* device_context = + const_cast(context.device_context_); + // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // gemm Tensor out_slice = output_grad->Slice(i, i + 1); out_slice.Resize(output_matrix_shape); - col.Resize(col_matrix_shape); - math::matmul(*filter, true, out_slice, false, T(1.0), &col, - T(0.0), device_context); + math::matmul(filter, true, out_slice, false, T(1.0), + &col_matrix, T(0.0), device_context); // col2im Tensor in_grad_slice = input_grad->Slice(i, i + 1); in_grad_slice.Resize(input_shape); - col.Resize(col_shape); col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // im2col Tensor in_slice = input->Slice(i, i + 1); in_slice.Resize(input_shape); - col.Resize(col_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - col.Resize(col_matrix_shape); - math::matmul(out_slice, false, col, true, T(1.0), filter_grad, - T(1.0), device_context); + math::matmul(out_slice, false, col_matrix, true, T(1.0), + &filter_grad, T(1.0), device_context); } - filter->Resize(filter_dims); - filter_grad->Resize(filter_dims); } };