diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 91028877b61418d3a851b7f78492b7cc2f940d14..6df2d0591fe1d5dcdb6104d22ec2126306bd4d86 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -69,6 +69,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() + # conv_transpose_op contains several operators + if ("${TARGET}" STREQUAL "conv_transpose_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(conv2dtranspose);\n") + endif() + # save_restore_op contains several operators if ("${TARGET}" STREQUAL "save_restore_op") set(pybind_flag 1) @@ -124,7 +131,7 @@ set(DEPS_OPS pool_op pool_with_index_op lstm_op - conv3dtranspose_op) + conv_transpose_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -136,7 +143,7 @@ op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) -op_library(conv3dtranspose_op DEPS vol2col) +op_library(conv_transpose_op DEPS vol2col) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/conv2dtranspose_op.cc b/paddle/operators/conv2dtranspose_op.cc deleted file mode 100644 index c1b231906e2f172b6f9cee55f850d1a5ec6c3221..0000000000000000000000000000000000000000 --- a/paddle/operators/conv2dtranspose_op.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/conv2dtranspose_op.h" - -namespace paddle { -namespace operators { - -void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv2DTransposeOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv2DTransposeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv2DTransposeOp should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - - for (size_t i = 0; i < paddings.size(); ++i) { - PADDLE_ENFORCE_EQ(paddings[i], 0, - "No Padding allowed in conv transpose op."); - } - - PADDLE_ENFORCE_EQ(in_dims.size(), 4, - "Conv2DTransposeOp input should be 4-D tensor."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 4, - "Conv2DTransposeOp filter should be 4-D tensor."); - PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], - "input and kernel input dimension should be equal."); - - auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; - auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; - ctx->SetOutputDim("Output", - {in_dims[0], filter_dims[1], output_height, output_width}); -} - -Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( - framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "(Tensor) The input tensor of convolution transpose operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of input channels, H and W is the height and width of image."); - AddInput("Filter", - "(Tensor) The filter tensor of convolution transpose operator." - "The format of the filter tensor is CMHW, where C is the number of " - "output image channels, M is the number of input image channels, " - "H and W is height and width of filter. " - "We enforce groups number == 1 and padding == 0 in " - "convolution transpose Scenario."); - AddOutput("Output", - "(Tensor) The output tensor of convolution transpose operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", - "strides of convolution transpose operator.") - .SetDefault({1, 1}); - AddAttr>("paddings", - "paddings of convolution transpose operator.") - .SetDefault({0, 0}); - AddComment(R"DOC( -The convolution transpose operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the -parameters is checked in the infer-shape. -)DOC"); -} - -void Conv2DTransposeOpGrad::InferShape( - framework::InferShapeContext* ctx) const { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp, - ops::Conv2DTransposeOpMaker, conv2dtranspose_grad, - ops::Conv2DTransposeOpGrad); - -REGISTER_OP_CPU_KERNEL( - conv2dtranspose, - ops::GemmConv2DTransposeKernel); -REGISTER_OP_CPU_KERNEL( - conv2dtranspose_grad, - ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.cu b/paddle/operators/conv2dtranspose_op.cu deleted file mode 100644 index 761bc1959e69be94f43571728e6b92a322558b99..0000000000000000000000000000000000000000 --- a/paddle/operators/conv2dtranspose_op.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/conv2dtranspose_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_GPU_KERNEL( - conv2dtranspose, - ops::GemmConv2DTransposeKernel); -REGISTER_OP_GPU_KERNEL( - conv2dtranspose_grad, - ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.h b/paddle/operators/conv2dtranspose_op.h deleted file mode 100644 index 8c70b3dcec1e26ab3d8a42d88040764c643b5ae6..0000000000000000000000000000000000000000 --- a/paddle/operators/conv2dtranspose_op.h +++ /dev/null @@ -1,254 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/im2col.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -// Define Op classes in .h file so that other conv transpose -// operator implementations can reuse the code. -class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Conv2DTransposeOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker); -}; - -class Conv2DTransposeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -class Conv2DTransposeOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -template -class GemmConv2DTransposeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped, so it should not be constant pointer - Tensor filter = *context.Input("Filter"); - - Tensor* output = context.Output("Output"); - - std::vector strides = context.Attr>("strides"); - - // TODO(Zhuoyuan): Paddings can be added in future. - // groups will alway be disabled in conv2dtranspose. - - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - const int k_h = filter.dims()[2]; - const int k_w = filter.dims()[3]; - - const int c = output->dims()[1]; // output channels - const int o_h = output->dims()[2]; - const int o_w = output->dims()[3]; - - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - col2im; - - // use col_shape in the im2col and col2im calculation - DDim col_shape = {c, k_h, k_w, h, w}; - - // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape = {c * k_h * k_w, h * w}; - - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix; - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - - DDim output_shape = {c, o_h, o_w}; - DDim input_matrix_shape = {m, h * w}; - - DDim filter_matrix_shape = {m, c * k_h * k_w}; - filter.Resize(filter_matrix_shape); - - // convolution transpose: gemm + col2im (similar to conv-backward on input) - - output->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*output); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - // batch with size (M, h * w) - Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // filter size: (M, c * k_h * k_w) - - // output size: (c, o_h, o_w) - Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); - - // col_matrix = filter * input_batch - // of shape (c * k_h * k_w, h * w) - math::matmul(context.device_context(), filter, true, - input_batch, false, T(1.0), &col_matrix, T(0.0)); - col2im(context.device_context(), output_batch, col, strides[0], - strides[1], 0, 0, 0, 0); - } - } -}; - -template -class GemmConv2DTransposeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - - // For filter, we do not use const pointer b/c we will do reshape, - // but we should avoid modifying its value. - Tensor filter = *context.Input("Filter"); - - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = - context.Output(framework::GradVarName("Filter")); - - std::vector strides = context.Attr>("strides"); - // Actually, no paddings and groups allowed in conv transpose. - std::vector paddings = context.Attr>("paddings"); - - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - const int k_h = filter.dims()[2]; - const int k_w = filter.dims()[3]; - - const int c = output_grad->dims()[1]; // output channels - const int o_h = output_grad->dims()[2]; - const int o_w = output_grad->dims()[3]; - - // Only im2col functor required for bp to get to the right shape - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - im2col; - - // use col_shape in the im2col and col2im calculation - DDim col_shape = {c, k_h, k_w, h, w}; - - // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; - - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - - DDim output_shape = {c, o_h, o_w}; - DDim input_matrix_shape = {m, h * w}; - - DDim filter_matrix_shape = {m, c * k_h * k_w}; - filter.Resize(filter_matrix_shape); - - // convolution transpose grad on input: - // im2col + gemm (similar to conv-forward) - // input need to compute gradient - if (input_grad) { - Tensor col_matrix; - col_matrix.ShareDataWith(col); - DDim col_matrix_shape = {c * k_h * k_w, h * w}; - col_matrix.Resize(col_matrix_shape); - - input_grad->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - // batch with size (c, o_h * o_w) - Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - // filter of size (m, c * k_h * k_w) - - // batch with size (m, h, w) - Tensor input_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_matrix_shape); - - // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); - - // gemm: dx = filter * dy - // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) - math::matmul(context.device_context(), filter, false, - col_matrix, false, T(1.0), &input_grad_batch, - T(0.0)); - } - } - - // filter gradient required - if (filter_grad) { - Tensor col_matrix_f; - col_matrix_f.ShareDataWith(col); - DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; - col_matrix_f.Resize(col_matrix_shape_f); - - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - auto t = framework::EigenVector::Flatten(filter_grad_); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; ++i) { - // batch with size (c, o_h, o_w) - Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - // input batch - Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - - // im2col: (c * h * w, k_h * k_w) - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); - - // gemm: d_filter = x * y_grad^T - // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) - math::matmul(context.device_context(), in_batch, false, - col_matrix_f, true, T(1.0), &filter_grad_, - T(1.0)); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/conv3dtranspose_op.cc b/paddle/operators/conv_transpose_op.cc similarity index 54% rename from paddle/operators/conv3dtranspose_op.cc rename to paddle/operators/conv_transpose_op.cc index f67c2fff8afd5db4ebc6f4749984a96bbe3027c9..9dca2a8b1be7af222326479bd93660a82e01f675 100644 --- a/paddle/operators/conv3dtranspose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -12,18 +12,18 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv3dtranspose_op.h" +#include "paddle/operators/conv_transpose_op.h" namespace paddle { namespace operators { -void Conv3DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { +void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv3DTransposeOp should not be null."); + "Input(Input) of ConvTransposeOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv3DTransposeOp should not be null."); + "Input(Filter) of ConvTransposeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv3DTransposeOp should not be null."); + "Output(Output) of ConvTransposeOp should not be null."); auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -35,12 +35,20 @@ void Conv3DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "No Padding allowed in conv transpose op."); } - PADDLE_ENFORCE_EQ(in_dims.size(), 5, - "Conv3DTransposeOp input should be 5-D tensor."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 5, - "Conv3DTransposeOp filter should be 5-D tensor."); - PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], - "input and kernel input dimension should be equal."); + PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, + "ConvTransposeOp intput should be 4-D or 5-D tensor."); + PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(), + "ConvTransposeOp input dimension and filter dimension " + "should be the same."); + PADDLE_ENFORCE(in_dims.size() - strides.size() == 2U, + "ConvTransposeOp input dimension and strides dimension should " + "be consistent."); + PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), + "ConvTransposeOp paddings dimension and Conv strides " + "dimension should be the same."); + PADDLE_ENFORCE_EQ( + in_dims[1], filter_dims[0], + "ConvTransposeOp input and kernel input dimension should be equal."); std::vector output_shape({in_dims[0], filter_dims[1]}); for (size_t i = 0; i < paddings.size(); ++i) { @@ -50,6 +58,37 @@ void Conv3DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } +Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( + framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "(Tensor) The input tensor of convolution transpose operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of input channels, H and W is the height and width of image."); + AddInput("Filter", + "(Tensor) The filter tensor of convolution transpose operator." + "The format of the filter tensor is CMHW, where C is the number of " + "output image channels, M is the number of input image channels, " + "H and W is height and width of filter. " + "We enforce groups number == 1 and padding == 0 in " + "convolution transpose Scenario."); + AddOutput("Output", + "(Tensor) The output tensor of convolution transpose operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", + "strides of convolution transpose operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "paddings of convolution transpose operator.") + .SetDefault({0, 0}); + AddComment(R"DOC( +The convolution transpose operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); +} + Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { @@ -85,8 +124,7 @@ parameters is checked in the infer-shape. )DOC"); } -void Conv3DTransposeOpGrad::InferShape( - framework::InferShapeContext* ctx) const { +void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const { auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); if (ctx->HasOutput(framework::GradVarName("Input"))) { @@ -101,9 +139,19 @@ void Conv3DTransposeOpGrad::InferShape( } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv3dtranspose, ops::Conv3DTransposeOp, - ops::Conv3DTransposeOpMaker, conv3dtranspose_grad, - ops::Conv3DTransposeOpGrad); + +REGISTER_OP(conv2dtranspose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, + conv2dtranspose_grad, ops::ConvTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2dtranspose, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); + +REGISTER_OP(conv3dtranspose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, + conv3dtranspose_grad, ops::ConvTransposeOpGrad); REGISTER_OP_CPU_KERNEL( conv3dtranspose, diff --git a/paddle/operators/conv3dtranspose_op.cu b/paddle/operators/conv_transpose_op.cu similarity index 75% rename from paddle/operators/conv3dtranspose_op.cu rename to paddle/operators/conv_transpose_op.cu index 447646fd756d3f6d3b8429c91dbee2643a42c87b..2a054143151d4b04949d4913f2b373c173277960 100644 --- a/paddle/operators/conv3dtranspose_op.cu +++ b/paddle/operators/conv_transpose_op.cu @@ -12,10 +12,17 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv3dtranspose_op.h" +#include "paddle/operators/conv_transpose_op.h" namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + conv2dtranspose, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_GPU_KERNEL( + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); + REGISTER_OP_GPU_KERNEL( conv3dtranspose, ops::GemmConv3DTransposeKernel); diff --git a/paddle/operators/conv3dtranspose_op.h b/paddle/operators/conv_transpose_op.h similarity index 54% rename from paddle/operators/conv3dtranspose_op.h rename to paddle/operators/conv_transpose_op.h index fbab12731420b3756485c1f5d663a85604382ed9..ad0e96f519c669be6b65e5bb34894c286d6ffaef 100644 --- a/paddle/operators/conv3dtranspose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/vol2col.h" @@ -27,13 +28,19 @@ using DDim = framework::DDim; // Define Op classes in .h file so that other conv transpose // operator implementations can reuse the code. +class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { public: Conv3DTransposeOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker); }; -class Conv3DTransposeOp : public framework::OperatorWithKernel { +class ConvTransposeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -41,7 +48,7 @@ class Conv3DTransposeOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; }; -class Conv3DTransposeOpGrad : public framework::OperatorWithKernel { +class ConvTransposeOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -50,7 +57,7 @@ class Conv3DTransposeOpGrad : public framework::OperatorWithKernel { }; template -class GemmConv3DTransposeKernel : public framework::OpKernel { +class GemmConv2DTransposeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -61,6 +68,206 @@ class GemmConv3DTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); + // TODO(Zhuoyuan): Paddings can be added in future. + // groups will alway be disabled in conv2dtranspose. + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; + + const int c = output->dims()[1]; // output channels + const int o_h = output->dims()[2]; + const int o_w = output->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + + // use col_shape in the im2col and col2im calculation + DDim col_shape = {c, k_h, k_w, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape = {c * k_h * k_w, h * w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; + + DDim filter_matrix_shape = {m, c * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose: gemm + col2im (similar to conv-backward on input) + + output->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*output); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (M, h * w) + Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // filter size: (M, c * k_h * k_w) + + // output size: (c, o_h, o_w) + Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + + // col_matrix = filter * input_batch + // of shape (c * k_h * k_w, h * w) + math::matmul(context.device_context(), filter, true, + input_batch, false, T(1.0), &col_matrix, T(0.0)); + col2im(context.device_context(), output_batch, col, strides[0], + strides[1], 0, 0, 0, 0); + } + } +}; + +template +class GemmConv2DTransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + + // For filter, we do not use const pointer b/c we will do reshape, + // but we should avoid modifying its value. + Tensor filter = *context.Input("Filter"); + + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; + + const int c = output_grad->dims()[1]; // output channels + const int o_h = output_grad->dims()[2]; + const int o_w = output_grad->dims()[3]; + + // Only im2col functor required for bp to get to the right shape + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + + // use col_shape in the im2col and col2im calculation + DDim col_shape = {c, k_h, k_w, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; + + DDim filter_matrix_shape = {m, c * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose grad on input: + // im2col + gemm (similar to conv-forward) + // input need to compute gradient + if (input_grad) { + Tensor col_matrix; + col_matrix.ShareDataWith(col); + DDim col_matrix_shape = {c * k_h * k_w, h * w}; + col_matrix.Resize(col_matrix_shape); + + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_h * o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // filter of size (m, c * k_h * k_w) + + // batch with size (m, h, w) + Tensor input_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + + // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); + + // gemm: dx = filter * dy + // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) + math::matmul(context.device_context(), filter, false, + col_matrix, false, T(1.0), &input_grad_batch, + T(0.0)); + } + } + + // filter gradient required + if (filter_grad) { + Tensor col_matrix_f; + col_matrix_f.ShareDataWith(col); + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; + col_matrix_f.Resize(col_matrix_shape_f); + + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; ++i) { + // batch with size (c, o_h, o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // input batch + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + + // im2col: (c * h * w, k_h * k_w) + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); + + // gemm: d_filter = x * y_grad^T + // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) + math::matmul(context.device_context(), in_batch, false, + col_matrix_f, true, T(1.0), &filter_grad_, + T(1.0)); + } + } + } +}; + +template +class GemmConv3DTransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped, so it should not be constant pointer + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + + std::vector strides = context.Attr>("strides"); + // TODO(chengduo): Paddings can be added in future. // groups will alway be disabled in conv3dtranspose. diff --git a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py index ce5e4424170435f4cc56ab1910e6f20a470d5a06..53604c58b70a534dff6b0a668d380fb8f10f53f6 100644 --- a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py +++ b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py @@ -44,7 +44,7 @@ class TestConv2dTransposeOp(OpTest): input_ = np.random.random(self.input_size).astype("float32") filter_ = np.random.random(self.filter_size).astype("float32") output = conv2dtranspose_forward_naive( - input_, filter_, conv2dtranspose_param).astype("float32") + input_, filter_, conv2dtranspose_param).astype('float32') # print 'deconv output py', output, output.shape self.inputs = {'Input': input_, 'Filter': filter_}