diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index ad0e96f519c669be6b65e5bb34894c286d6ffaef..cc2cfe4e6e28af1fabd225d0229994a1d5e0165a 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -63,29 +63,25 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { const Tensor* input = context.Input("Input"); // The filter will be reshaped, so it should not be constant pointer Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); std::vector strides = context.Attr>("strides"); - // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; + const int batch_size = static_cast(input->dims()[0]); + const int64_t m = input->dims()[1]; + const int64_t h = input->dims()[2]; + const int64_t w = input->dims()[3]; - const int k_h = filter.dims()[2]; - const int k_w = filter.dims()[3]; + const int64_t k_h = filter.dims()[2]; + const int64_t k_w = filter.dims()[3]; - const int c = output->dims()[1]; // output channels - const int o_h = output->dims()[2]; - const int o_w = output->dims()[3]; + const int64_t c = output->dims()[1]; // output channels + const int64_t o_h = output->dims()[2]; + const int64_t o_w = output->dims()[3]; - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - col2im; + math::Col2ImFunctor col2im; // use col_shape in the im2col and col2im calculation DDim col_shape = {c, k_h, k_w, h, w}; @@ -105,19 +101,18 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { DDim output_shape = {c, o_h, o_w}; DDim input_matrix_shape = {m, h * w}; + // filter size: (m, c * k_h * k_w) DDim filter_matrix_shape = {m, c * k_h * k_w}; filter.Resize(filter_matrix_shape); - // convolution transpose: gemm + col2im (similar to conv-backward on input) - output->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*output); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + math::SetConstant set_zero; + set_zero(context.device_context(), output, static_cast(0)); + // convolution transpose: gemm + col2im (similar to conv-backward on input) for (int i = 0; i < batch_size; i++) { - // batch with size (M, h * w) + // batch with size (m, h * w) Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // filter size: (M, c * k_h * k_w) // output size: (c, o_h, o_w) Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); @@ -125,7 +120,11 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { // col_matrix = filter * input_batch // of shape (c * k_h * k_w, h * w) math::matmul(context.device_context(), filter, true, - input_batch, false, T(1.0), &col_matrix, T(0.0)); + input_batch, false, static_cast(1.0), + &col_matrix, static_cast(0.0)); + + // col2im: col_matrix -> dy + // from (c * k_h * k_w, h * w) to (c, o_h, o_w) col2im(context.device_context(), output_batch, col, strides[0], strides[1], 0, 0, 0, 0); } @@ -143,7 +142,6 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { // For filter, we do not use const pointer b/c we will do reshape, // but we should avoid modifying its value. Tensor filter = *context.Input("Filter"); - Tensor* input_grad = context.Output(framework::GradVarName("Input")); Tensor* filter_grad = @@ -153,35 +151,24 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; + const int batch_size = static_cast(input->dims()[0]); + const int64_t m = input->dims()[1]; + const int64_t h = input->dims()[2]; + const int64_t w = input->dims()[3]; - const int k_h = filter.dims()[2]; - const int k_w = filter.dims()[3]; + const int64_t k_h = filter.dims()[2]; + const int64_t k_w = filter.dims()[3]; - const int c = output_grad->dims()[1]; // output channels - const int o_h = output_grad->dims()[2]; - const int o_w = output_grad->dims()[3]; + const int64_t c = output_grad->dims()[1]; // output channels + const int64_t o_h = output_grad->dims()[2]; + const int64_t o_w = output_grad->dims()[3]; // Only im2col functor required for bp to get to the right shape - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - im2col; + math::Im2ColFunctor im2col; // use col_shape in the im2col and col2im calculation DDim col_shape = {c, k_h, k_w, h, w}; - // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; - - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - DDim output_shape = {c, o_h, o_w}; DDim input_matrix_shape = {m, h * w}; @@ -191,67 +178,60 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { // convolution transpose grad on input: // im2col + gemm (similar to conv-forward) // input need to compute gradient - if (input_grad) { + if (input_grad || filter_grad) { + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. Tensor col_matrix; col_matrix.ShareDataWith(col); DDim col_matrix_shape = {c * k_h * k_w, h * w}; col_matrix.Resize(col_matrix_shape); - input_grad->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - // batch with size (c, o_h * o_w) - Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - // filter of size (m, c * k_h * k_w) - - // batch with size (m, h, w) - Tensor input_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + Tensor filter_grad_; + math::SetConstant set_zero; - // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); - - // gemm: dx = filter * dy - // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) - math::matmul(context.device_context(), filter, false, - col_matrix, false, T(1.0), &input_grad_batch, - T(0.0)); + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + set_zero(context.device_context(), input_grad, static_cast(0)); + } + if (filter_grad) { // filter size (m, c, k_h, k_w) + filter_grad->mutable_data(context.GetPlace()); + set_zero(context.device_context(), filter_grad, static_cast(0)); + filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); } - } - // filter gradient required - if (filter_grad) { - Tensor col_matrix_f; - col_matrix_f.ShareDataWith(col); - DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; - col_matrix_f.Resize(col_matrix_shape_f); - - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - auto t = framework::EigenVector::Flatten(filter_grad_); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; ++i) { - // batch with size (c, o_h, o_w) + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_h * o_w) Tensor output_grad_batch = output_grad->Slice(i, i + 1).Resize(output_shape); - // input batch - Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // im2col: (c * h * w, k_h * k_w) + // im2col: dy -> col matrix + // from (c, o_h, o_w) to (c * k_h * k_w, h * w) im2col(context.device_context(), output_grad_batch, col, strides[0], strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); - // gemm: d_filter = x * y_grad^T - // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) - math::matmul(context.device_context(), in_batch, false, - col_matrix_f, true, T(1.0), &filter_grad_, - T(1.0)); + if (input_grad) { + // batch with size (m, h, w) + Tensor input_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + // gemm: dx = filter * dy + // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w) + math::matmul(context.device_context(), filter, false, + col_matrix, false, static_cast(1.0), + &input_grad_batch, static_cast(0.0)); + } + if (filter_grad) { + // input batch + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // gemm: d_filter = x * dy^T + // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w) + math::matmul(context.device_context(), in_batch, false, + col_matrix, true, static_cast(1.0), + &filter_grad_, static_cast(1.0)); + } } } } @@ -267,30 +247,28 @@ class GemmConv3DTransposeKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); std::vector strides = context.Attr>("strides"); - // TODO(chengduo): Paddings can be added in future. // groups will alway be disabled in conv3dtranspose. - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int d = input->dims()[2]; - const int h = input->dims()[3]; - const int w = input->dims()[4]; + const int batch_size = static_cast(input->dims()[0]); + const int64_t m = input->dims()[1]; + const int64_t d = input->dims()[2]; + const int64_t h = input->dims()[3]; + const int64_t w = input->dims()[4]; - const int k_d = filter.dims()[2]; - const int k_h = filter.dims()[3]; - const int k_w = filter.dims()[4]; + const int64_t k_d = filter.dims()[2]; + const int64_t k_h = filter.dims()[3]; + const int64_t k_w = filter.dims()[4]; - const int c = output->dims()[1]; // output channels - const int o_d = output->dims()[2]; - const int o_h = output->dims()[3]; - const int o_w = output->dims()[4]; + const int64_t c = output->dims()[1]; // output channels + const int64_t o_d = output->dims()[2]; + const int64_t o_h = output->dims()[3]; + const int64_t o_w = output->dims()[4]; paddle::operators::math::Col2VolFunctor col2vol; // use col_shape in the vol2col and col2vol calculation DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; - // use col_matrix_shape in the gemm calculation DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; @@ -306,19 +284,18 @@ class GemmConv3DTransposeKernel : public framework::OpKernel { DDim output_shape = {c, o_d, o_h, o_w}; DDim input_matrix_shape = {m, d * h * w}; + // filter size: (m, c * k_d * k_h * k_w) DDim filter_matrix_shape = {m, c * k_d * k_h * k_w}; filter.Resize(filter_matrix_shape); - // convolution transpose: gemm + col2vol (similar to conv-backward on input) - output->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*output); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + math::SetConstant set_zero; + set_zero(context.device_context(), output, static_cast(0)); + // convolution transpose: gemm + col2vol (similar to conv-backward on input) for (int i = 0; i < batch_size; i++) { - // batch with size (M, d * h * w) + // batch with size (m, d * h * w) Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // filter size: (M, c * k_d * k_h * k_w) // output size: (c, o_d, o_h, o_w) Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); @@ -326,7 +303,10 @@ class GemmConv3DTransposeKernel : public framework::OpKernel { // col_matrix = filter * input_batch // of shape (c * k_d * k_h * k_w, d * h * w) math::matmul(context.device_context(), filter, true, - input_batch, false, T(1.0), &col_matrix, T(0.0)); + input_batch, false, static_cast(1.0), + &col_matrix, static_cast(0.0)); + // col2vol: col_matrix -> dy + // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) col2vol(context.device_context(), output_batch, col, strides[0], strides[1], strides[2], 0, 0, 0); } @@ -344,7 +324,6 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel { // For filter, we do not use const pointer b/c we will do reshape, // but we should avoid modifying its value. Tensor filter = *context.Input("Filter"); - Tensor* input_grad = context.Output(framework::GradVarName("Input")); Tensor* filter_grad = @@ -354,20 +333,20 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int d = input->dims()[2]; - const int h = input->dims()[3]; - const int w = input->dims()[4]; + const int batch_size = static_cast(input->dims()[0]); + const int64_t m = input->dims()[1]; + const int64_t d = input->dims()[2]; + const int64_t h = input->dims()[3]; + const int64_t w = input->dims()[4]; - const int k_d = filter.dims()[2]; - const int k_h = filter.dims()[3]; - const int k_w = filter.dims()[4]; + const int64_t k_d = filter.dims()[2]; + const int64_t k_h = filter.dims()[3]; + const int64_t k_w = filter.dims()[4]; - const int c = output_grad->dims()[1]; // output channels - const int o_d = output_grad->dims()[2]; - const int o_h = output_grad->dims()[3]; - const int o_w = output_grad->dims()[4]; + const int64_t c = output_grad->dims()[1]; // output channels + const int64_t o_d = output_grad->dims()[2]; + const int64_t o_h = output_grad->dims()[3]; + const int64_t o_w = output_grad->dims()[4]; // Only vol2col functor required for bp to get to the right shape paddle::operators::math::Vol2ColFunctor vol2col; @@ -378,12 +357,6 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel { // use col_matrix_shape in the gemm calculation DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - DDim output_shape = {c, o_d, o_h, o_w}; DDim input_matrix_shape = {m, d * h * w}; @@ -393,70 +366,62 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel { // convolution transpose grad on input: // vol2col + gemm (similar to conv-forward) // input need to compute gradient - if (input_grad) { + if (input_grad || filter_grad) { + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. Tensor col_matrix; col_matrix.ShareDataWith(col); DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; col_matrix.Resize(col_matrix_shape); - input_grad->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + Tensor filter_grad_; + math::SetConstant set_zero; - for (int i = 0; i < batch_size; i++) { - // batch with size (c, o_d * o_h * o_w) - Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - // filter of size (m, c * k_d * k_h * k_w) - - // batch with size (m, d, h, w) - Tensor input_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_matrix_shape); - - // vol2col: dy from (c, o_d, o_h, o_w) -> (c * k_d * k_h * k_w, d * h * - // w) - vol2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], paddings[2]); - - // gemm: dx = filter * dy - // (m, c *k_d * k_h * k_w) * (c * k_d * k_h * k_w, d* h * w) -> (m, c, - // d, h, w) - math::matmul(context.device_context(), filter, false, - col_matrix, false, T(1.0), &input_grad_batch, - T(0.0)); + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + set_zero(context.device_context(), input_grad, static_cast(0)); + } + if (filter_grad) { // filter size (m, c * k_d * k_h * k_w) + filter_grad->mutable_data(context.GetPlace()); + set_zero(context.device_context(), filter_grad, static_cast(0)); + filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); } - } - // filter gradient required - if (filter_grad) { - Tensor col_matrix_f; - col_matrix_f.ShareDataWith(col); - DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w}; - col_matrix_f.Resize(col_matrix_shape_f); - - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - auto t = framework::EigenVector::Flatten(filter_grad_); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; ++i) { - // batch with size (c, o_d, o_h, o_w) + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_d * o_h * o_w) Tensor output_grad_batch = output_grad->Slice(i, i + 1).Resize(output_shape); - // input batch - Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // vol2col: (c * d * h * w, k_d * k_h * k_w) + // vol2col: dy -> col_matrix + // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) vol2col(context.device_context(), output_grad_batch, col, strides[0], strides[1], strides[2], paddings[0], paddings[1], paddings[2]); - // gemm: d_filter = x * y_grad^T - // (m, c * d * h * w) * (k_d * k_h * k_w, c * d * h * w) -> (m, c, d, h, - // w) - math::matmul(context.device_context(), in_batch, false, - col_matrix_f, true, T(1.0), &filter_grad_, - T(1.0)); + if (input_grad) { + // batch with size (m, d, h, w) + Tensor input_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + // gemm: dx = filter * dy + // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, + // d, h, w) + math::matmul(context.device_context(), filter, false, + col_matrix, false, static_cast(1.0), + &input_grad_batch, static_cast(0.0)); + } + if (filter_grad) { + // input batch + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // gemm: d_filter = x * dy^T + // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * + // k_h * k_w) + math::matmul(context.device_context(), in_batch, false, + col_matrix, true, static_cast(1.0), + &filter_grad_, static_cast(1.0)); + } } } }