From 8219f20672dcb660174ab9c96f54d7214f248f7a Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 11:01:24 +0800 Subject: [PATCH] Refine gemm convolution kernel. --- paddle/operators/gemm_conv_op.h | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index cdcc0039b0..3b7ba685c8 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -58,7 +58,7 @@ class GemmConvKernel : public framework::OpKernel { input_channels * filter_height * filter_width, output_height * output_width}; Tensor col; - col.mutable_data(col_shape, context.GetPlace()); + col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. @@ -67,8 +67,8 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; - framework::DDim filter_matrix_shape = { - filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = {output_channels, @@ -80,14 +80,12 @@ class GemmConvKernel : public framework::OpKernel { // convolution operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // im2col - Tensor in_slice = input->Slice(i, i + 1); - in_slice.Resize(input_shape); + Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - Tensor out_slice = output->Slice(i, i + 1); - out_slice.Resize(output_matrix_shape); + Tensor out_slice = output->Slice(i, i + 1).Resize(output_matrix_shape); math::matmul(filter, false, col_matrix, false, T(1.0), &out_slice, T(0.0), device_context); } @@ -138,7 +136,7 @@ class GemmConvGradKernel : public framework::OpKernel { input_channels * filter_height * filter_width, output_height * output_width}; Tensor col; - col.mutable_data(col_shape, context.GetPlace()); + col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. @@ -151,8 +149,8 @@ class GemmConvGradKernel : public framework::OpKernel { output_grad->dims()[1], output_grad->dims()[2] * output_grad->dims()[3]}; - framework::DDim filter_matrix_shape = { - filter.dims()[0], framework::product(filter.dims()) / filter.dims()[0]}; + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); filter_grad.Resize(filter_matrix_shape); @@ -168,20 +166,18 @@ class GemmConvGradKernel : public framework::OpKernel { // convolution backward weight operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // gemm - Tensor out_slice = output_grad->Slice(i, i + 1); - out_slice.Resize(output_matrix_shape); + Tensor out_slice = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); math::matmul(filter, true, out_slice, false, T(1.0), &col_matrix, T(0.0), device_context); // col2im - Tensor in_grad_slice = input_grad->Slice(i, i + 1); - in_grad_slice.Resize(input_shape); + Tensor in_grad_slice = input_grad->Slice(i, i + 1).Resize(input_shape); col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // im2col - Tensor in_slice = input->Slice(i, i + 1); - in_slice.Resize(input_shape); + Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); -- GitLab