diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 6c72362195b788c7b59ea55916afd21f792654ce..560dfd311f25d8411be41f77c62735a4f3330ec8 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -29,61 +29,68 @@ class GemmConvKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); - Tensor* filter = const_cast(context.Input("Filter")); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - auto filter_dims = filter->dims(); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; - int filter_height = filter->dims()[filter->dims().size() - 2]; - int filter_width = filter->dims()[filter->dims().size() - 1]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output->dims()[1]; int output_height = output->dims()[2]; int output_width = output->dims()[3]; paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kCFO, Place, T> im2col; + // use col_shape in the im2col calculation framework::DDim col_shape = {input_channels, filter_height, filter_width, output_height, output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); - - auto* device_context = - const_cast(context.device_context_); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); framework::DDim input_shape = {input->dims()[1], input->dims()[2], input->dims()[3]}; framework::DDim filter_matrix_shape = { - filter->dims()[0], - filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; - framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, - output_height * output_width}; - framework::DDim output_matrix_shape = { - output->dims()[1], output->dims()[2] * output->dims()[3]}; - filter->Resize(filter_matrix_shape); + output_channels, framework::product(filter.dims()) / output_channels}; + filter.Resize(filter_matrix_shape); + + framework::DDim output_matrix_shape = {output_channels, + output_height * output_width}; + + auto* device_context = + const_cast(context.device_context_); // convolution operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // im2col Tensor in_slice = input->Slice(i, i + 1); in_slice.Resize(input_shape); - col.Resize(col_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm Tensor out_slice = output->Slice(i, i + 1); out_slice.Resize(output_matrix_shape); - col.Resize(col_matrix_shape); - math::matmul(*filter, false, col, false, T(1.0), &out_slice, - T(0.0), device_context); + math::matmul(filter, false, col_matrix, false, T(1.0), + &out_slice, T(0.0), device_context); } - filter->Resize(filter_dims); } };