diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cc b/paddle/operators/conv2d_transpose_cudnn_op.cc index 042ccc2be88f7aebc8aba40e449dd0d0eeecd6c5..fce1357ce5af5f11ccc5941690431393301e6725 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cc +++ b/paddle/operators/conv2d_transpose_cudnn_op.cc @@ -44,7 +44,7 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp, REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn, - ops::GemmConv2DTransposeKernel); + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn_grad, - ops::GemmConv2DTransposeGradKernel); + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 3362124b3b1171a123468b352087baaf248e1d74..50081779a5ea3c81884007d4e4b7832dc4ea2bdd 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -187,17 +187,17 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, REGISTER_OP_CPU_KERNEL( conv2d_transpose, - ops::GemmConv2DTransposeKernel); + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv2d_transpose_grad, - ops::GemmConv2DTransposeGradKernel); + ops::GemmConvTransposeGradKernel); REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, conv3d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OP_CPU_KERNEL( conv3d_transpose, - ops::GemmConv3DTransposeKernel); + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv3d_transpose_grad, - ops::GemmConv3DTransposeGradKernel); + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv_transpose_op.cu b/paddle/operators/conv_transpose_op.cu index 95463ade159038c92543be1c5360dc13dca51998..401cddb379ced134b800d2a078fe130a2850fbb2 100644 --- a/paddle/operators/conv_transpose_op.cu +++ b/paddle/operators/conv_transpose_op.cu @@ -18,14 +18,14 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( conv2d_transpose, - ops::GemmConv2DTransposeKernel); + ops::GemmConvTransposeKernel); REGISTER_OP_GPU_KERNEL( conv2d_transpose_grad, - ops::GemmConv2DTransposeGradKernel); + ops::GemmConvTransposeGradKernel); REGISTER_OP_GPU_KERNEL( conv3d_transpose, - ops::GemmConv3DTransposeKernel); + ops::GemmConvTransposeKernel); REGISTER_OP_GPU_KERNEL( conv3d_transpose_grad, - ops::GemmConv3DTransposeGradKernel); + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index f9db5990b321544611f9a16ecdfb122f67a23089..6c1a6220d784abf89ec789f94d9cff9e5414db04 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -57,7 +57,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel { }; template -class GemmConv2DTransposeKernel : public framework::OpKernel { +class GemmConvTransposeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -70,24 +70,31 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { // groups will alway be disabled in conv2dtranspose. const int batch_size = static_cast(input->dims()[0]); - const int64_t m = input->dims()[1]; - const int64_t h = input->dims()[2]; - const int64_t w = input->dims()[3]; - const int64_t k_h = filter.dims()[2]; - const int64_t k_w = filter.dims()[3]; - - const int64_t c = output->dims()[1]; // output channels - const int64_t o_h = output->dims()[2]; - const int64_t o_w = output->dims()[3]; - - math::Col2ImFunctor col2im; - - // use col_shape in the im2col and col2im calculation - DDim col_shape = {c, k_h, k_w, h, w}; + // input_shape_vec: {h, w} or {d, h, w} + std::vector input_shape_vec = framework::vectorize(input->dims()); + input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); + + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec = framework::vectorize(filter.dims()); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); + + // use col_shape in the im2col and col2im (or vol2col and col2vol) + // calculation + // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} + std::vector col_shape_vec; + col_shape_vec.push_back(output->dims()[1]); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), + input_shape_vec.end()); + DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape = {c * k_h * k_w, h * w}; + // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) + DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -98,47 +105,61 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - DDim output_shape = {c, o_h, o_w}; - DDim input_matrix_shape = {m, h * w}; + // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + DDim output_shape = + framework::slice_ddim(output->dims(), 1, output->dims().size()); - // filter size: (m, c * k_h * k_w) - DDim filter_matrix_shape = {m, c * k_h * k_w}; + // input matrix size: (m, h * w) or (m, d * h * w) + DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; + + // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) + DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; filter.Resize(filter_matrix_shape); output->mutable_data(context.GetPlace()); math::SetConstant set_zero; set_zero(context.device_context(), output, static_cast(0)); - // convolution transpose: gemm + col2im (similar to conv-backward on input) + // convolution transpose: gemm + col2im or col2vol (similar to conv-backward + // on input) for (int i = 0; i < batch_size; i++) { - // batch with size (m, h * w) + // batch with size (m, h * w) or (m, d * h * w) Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // output size: (c, o_h, o_w) + // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); // col_matrix = filter * input_batch - // of shape (c * k_h * k_w, h * w) + // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) math::matmul(context.device_context(), filter, true, input_batch, false, static_cast(1.0), &col_matrix, static_cast(0.0)); - // col2im: col_matrix -> dy - // from (c * k_h * k_w, h * w) to (c, o_h, o_w) - col2im(context.device_context(), output_batch, col, strides[0], - strides[1], 0, 0, 0, 0); + if (filter_shape_vec.size() == 2) { + // col2im: col_matrix -> dy + // from (c * k_h * k_w, h * w) to (c, o_h, o_w) + math::Col2ImFunctor col2im; + + col2im(context.device_context(), output_batch, col, strides[0], + strides[1], 0, 0, 0, 0); + } else if (filter_shape_vec.size() == 3) { + // col2vol: col_matrix -> dy + // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) + math::Col2VolFunctor col2vol; + col2vol(context.device_context(), output_batch, col, strides[0], + strides[1], strides[2], 0, 0, 0); + } } } }; template -class GemmConv2DTransposeGradKernel : public framework::OpKernel { +class GemmConvTransposeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); const Tensor* output_grad = context.Input(framework::GradVarName("Output")); - // For filter, we do not use const pointer b/c we will do reshape, // but we should avoid modifying its value. Tensor filter = *context.Input("Filter"); @@ -147,38 +168,50 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { Tensor* filter_grad = context.Output(framework::GradVarName("Filter")); + if ((!input_grad) && (!filter_grad)) return; + std::vector strides = context.Attr>("strides"); // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); const int batch_size = static_cast(input->dims()[0]); - const int64_t m = input->dims()[1]; - const int64_t h = input->dims()[2]; - const int64_t w = input->dims()[3]; - const int64_t k_h = filter.dims()[2]; - const int64_t k_w = filter.dims()[3]; + // input_shape_vec: {h, w} or {d, h, w} + std::vector input_shape_vec = framework::vectorize(input->dims()); + input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); + + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec = framework::vectorize(filter.dims()); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); + + // use col_shape in the im2col and col2im (or vol2col and col2vol) + // calculation + // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} + std::vector col_shape_vec; + col_shape_vec.push_back(output_grad->dims()[1]); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), + input_shape_vec.end()); + DDim col_shape(framework::make_ddim(col_shape_vec)); - const int64_t c = output_grad->dims()[1]; // output channels - const int64_t o_h = output_grad->dims()[2]; - const int64_t o_w = output_grad->dims()[3]; - - // Only im2col functor required for bp to get to the right shape - math::Im2ColFunctor im2col; + // use col_matrix_shape in the gemm calculation + // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) + DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); - // use col_shape in the im2col and col2im calculation - DDim col_shape = {c, k_h, k_w, h, w}; + // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + DDim output_shape = framework::slice_ddim(output_grad->dims(), 1, + output_grad->dims().size()); - DDim output_shape = {c, o_h, o_w}; - DDim input_matrix_shape = {m, h * w}; + // input matrix size: (m, h * w) or (m, d * h * w) + DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; - DDim filter_matrix_shape = {m, c * k_h * k_w}; + // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) + DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; filter.Resize(filter_matrix_shape); - if ((!input_grad) && (!filter_grad)) { - return; - } - // convolution transpose grad on input: // im2col + gemm (similar to conv-forward) // input need to compute gradient @@ -190,7 +223,6 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { // to call the matrix multiplication interface. Tensor col_matrix; col_matrix.ShareDataWith(col); - DDim col_matrix_shape = {c * k_h * k_w, h * w}; col_matrix.Resize(col_matrix_shape); Tensor filter_grad_; @@ -212,10 +244,21 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { Tensor output_grad_batch = output_grad->Slice(i, i + 1).Resize(output_shape); - // im2col: dy -> col matrix - // from (c, o_h, o_w) to (c * k_h * k_w, h * w) - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); + if (filter_shape_vec.size() == 2) { + // im2col: dy -> col matrix + // from (c, o_h, o_w) to (c * k_h * k_w, h * w) + math::Im2ColFunctor im2col; + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + } else if (filter_shape_vec.size() == 3) { + // vol2col: dy -> col_matrix + // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } if (input_grad) { // batch with size (m, h, w) @@ -223,197 +266,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { input_grad->Slice(i, i + 1).Resize(input_matrix_shape); // gemm: dx = filter * dy // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w) - math::matmul(context.device_context(), filter, false, - col_matrix, false, static_cast(1.0), - &input_grad_batch, static_cast(0.0)); - } - if (filter_grad) { - // input batch - Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // gemm: d_filter = x * dy^T - // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w) - math::matmul(context.device_context(), in_batch, false, - col_matrix, true, static_cast(1.0), - &filter_grad_, static_cast(1.0)); - } - } - } - } -}; - -template -class GemmConv3DTransposeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped, so it should not be constant pointer - Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); - - std::vector strides = context.Attr>("strides"); - // TODO(chengduo): Paddings can be added in future. - // groups will alway be disabled in conv3dtranspose. - - const int batch_size = static_cast(input->dims()[0]); - const int64_t m = input->dims()[1]; - const int64_t d = input->dims()[2]; - const int64_t h = input->dims()[3]; - const int64_t w = input->dims()[4]; - - const int64_t k_d = filter.dims()[2]; - const int64_t k_h = filter.dims()[3]; - const int64_t k_w = filter.dims()[4]; - - const int64_t c = output->dims()[1]; // output channels - const int64_t o_d = output->dims()[2]; - const int64_t o_h = output->dims()[3]; - const int64_t o_w = output->dims()[4]; - - math::Col2VolFunctor col2vol; - - // use col_shape in the vol2col and col2vol calculation - DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; - // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; - - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix; - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - - DDim output_shape = {c, o_d, o_h, o_w}; - DDim input_matrix_shape = {m, d * h * w}; - - // filter size: (m, c * k_d * k_h * k_w) - DDim filter_matrix_shape = {m, c * k_d * k_h * k_w}; - filter.Resize(filter_matrix_shape); - - output->mutable_data(context.GetPlace()); - math::SetConstant set_zero; - set_zero(context.device_context(), output, static_cast(0)); - - // convolution transpose: gemm + col2vol (similar to conv-backward on input) - for (int i = 0; i < batch_size; i++) { - // batch with size (m, d * h * w) - Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - - // output size: (c, o_d, o_h, o_w) - Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); - - // col_matrix = filter * input_batch - // of shape (c * k_d * k_h * k_w, d * h * w) - math::matmul(context.device_context(), filter, true, - input_batch, false, static_cast(1.0), - &col_matrix, static_cast(0.0)); - // col2vol: col_matrix -> dy - // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) - col2vol(context.device_context(), output_batch, col, strides[0], - strides[1], strides[2], 0, 0, 0); - } - } -}; - -template -class GemmConv3DTransposeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - - // For filter, we do not use const pointer b/c we will do reshape, - // but we should avoid modifying its value. - Tensor filter = *context.Input("Filter"); - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = - context.Output(framework::GradVarName("Filter")); - - std::vector strides = context.Attr>("strides"); - // Actually, no paddings and groups allowed in conv transpose. - std::vector paddings = context.Attr>("paddings"); - - const int batch_size = static_cast(input->dims()[0]); - const int64_t m = input->dims()[1]; - const int64_t d = input->dims()[2]; - const int64_t h = input->dims()[3]; - const int64_t w = input->dims()[4]; - - const int64_t k_d = filter.dims()[2]; - const int64_t k_h = filter.dims()[3]; - const int64_t k_w = filter.dims()[4]; - - const int64_t c = output_grad->dims()[1]; // output channels - const int64_t o_d = output_grad->dims()[2]; - const int64_t o_h = output_grad->dims()[3]; - const int64_t o_w = output_grad->dims()[4]; - - // Only vol2col functor required for bp to get to the right shape - math::Vol2ColFunctor vol2col; - - // use col_shape in the vol2col and col2vol calculation - DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; - - // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w}; - - DDim output_shape = {c, o_d, o_h, o_w}; - DDim input_matrix_shape = {m, d * h * w}; - - DDim filter_matrix_shape = {m, c * k_d * k_h * k_w}; - filter.Resize(filter_matrix_shape); - - if ((!input_grad) && (!filter_grad)) { - return; - } - - // convolution transpose grad on input: - // vol2col + gemm (similar to conv-forward) - // input need to compute gradient - if (input_grad || filter_grad) { - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix; - col_matrix.ShareDataWith(col); - DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; - col_matrix.Resize(col_matrix_shape); - - Tensor filter_grad_; - math::SetConstant set_zero; - - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - set_zero(context.device_context(), input_grad, static_cast(0)); - } - if (filter_grad) { // filter size (m, c * k_d * k_h * k_w) - filter_grad->mutable_data(context.GetPlace()); - set_zero(context.device_context(), filter_grad, static_cast(0)); - filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - } - - for (int i = 0; i < batch_size; i++) { - // batch with size (c, o_d * o_h * o_w) - Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - - // vol2col: dy -> col_matrix - // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) - vol2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], paddings[2]); - - if (input_grad) { - // batch with size (m, d, h, w) - Tensor input_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_matrix_shape); - // gemm: dx = filter * dy + // or // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // d, h, w) math::matmul(context.device_context(), filter, false, @@ -424,6 +277,8 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel { // input batch Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); // gemm: d_filter = x * dy^T + // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w) + // or // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // k_h * k_w) math::matmul(context.device_context(), in_batch, false, @@ -434,6 +289,5 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel { } } }; - } // namespace operators } // namespace paddle