diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 8ac92d3bd2a99617a1309c993749b069ea90eb36..b125698c6de22647ea59fce5ca2e713bb6b28bbb 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -82,19 +82,16 @@ class GemmConvKernel : public framework::OpKernel { int in_step = input_channels / groups; int out_step = output_channels / groups; for (int i = 0; i < batch_size; i++) { - Tensor in_slice_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_slice_batch = - output->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); for (int g = 0; g < groups; g++) { // im2col - Tensor in_slice = - in_slice_batch.Slice(g * in_step, (g + 1) * in_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - Tensor out_slice = - out_slice_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); math::matmul(filter_slice, false, col_matrix, false, T(1.0), &out_slice, T(0.0), device_context); @@ -125,12 +122,13 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - // int groups = context.Attr("groups"); + int groups = context.Attr("groups"); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; int filter_height = filter.dims()[filter.dims().size() - 2]; int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output_grad->dims()[1]; int output_height = output_grad->dims()[2]; int output_width = output_grad->dims()[3]; @@ -141,11 +139,11 @@ class GemmConvGradKernel : public framework::OpKernel { paddle::operators::math::ColFormat::kCFO, Place, T> im2col; // use col_shape in the im2col and col2im calculation - framework::DDim col_shape = {input_channels, filter_height, filter_width, - output_height, output_width}; + framework::DDim col_shape = {input_channels / groups, filter_height, + filter_width, output_height, output_width}; // use col_matrix_shape in the gemm calculation framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, + input_channels / groups * filter_height * filter_width, output_height * output_width}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -176,26 +174,38 @@ class GemmConvGradKernel : public framework::OpKernel { // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; for (int i = 0; i < batch_size; i++) { - // gemm - Tensor out_slice = + Tensor out_grad_batch = output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - math::matmul(filter, true, out_slice, false, T(1.0), - &col_matrix, T(0.0), device_context); - - // col2im - Tensor in_grad_slice = input_grad->Slice(i, i + 1).Resize(input_shape); - col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); - - // im2col - Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); - - // gemm - math::matmul(out_slice, false, col_matrix, true, T(1.0), - &filter_grad, T(1.0), device_context); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, true, out_grad_slice, false, + T(1.0), &col_matrix, T(0.0), device_context); + + // col2im + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + + // im2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + Tensor filter_grad_slice = + filter_grad.Slice(g * out_step, (g + 1) * out_step); + math::matmul(out_grad_slice, false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0), device_context); + } } } };