diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index df16738130ae71fb1eea17ca7e114da01def97ee..c5ecccf9023d7fba7d0edec4c687c9515da034e6 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -69,6 +69,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() + # conv_transpose_op contains several operators + if ("${TARGET}" STREQUAL "conv_transpose_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n") + endif() + # pool_cudnn_op contains several operators if ("${TARGET}" STREQUAL "pool_cudnn_op") set(pybind_flag 1) @@ -139,6 +146,8 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op + lstm_op + conv_transpose_op nccl_op sequence_conv_op sequence_pool_op @@ -159,10 +168,12 @@ endif() op_library(sequence_conv_op DEPS context_project) op_library(sequence_pool_op DEPS sequence_pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) +op_library(conv_transpose_op DEPS vol2col) op_library(gru_op DEPS sequence2batch gru_compute) op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc DEPS net_op tensor_array) op_library(recurrent_op SRCS recurrent_op.cc DEPS executor) + list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cc b/paddle/operators/conv2d_transpose_cudnn_op.cc index 8ce94e0f04f14e1eae7e7d01280601cc72dea8c4..fce1357ce5af5f11ccc5941690431393301e6725 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cc +++ b/paddle/operators/conv2d_transpose_cudnn_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2d_transpose_op.h" +#include "paddle/operators/conv_transpose_op.h" namespace paddle { namespace operators { @@ -38,13 +38,13 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv2d_transpose_cudnn, ops::Conv2DTransposeOp, +REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp, ops::CudnnConv2DTransposeOpMaker, conv2d_transpose_cudnn_grad, - ops::Conv2DTransposeOpGrad); + ops::ConvTransposeOpGrad); REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn, - ops::GemmConv2DTransposeKernel); + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn_grad, - ops::GemmConv2DTransposeGradKernel); + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu index 61fcfb3bd8fa57f2c45fbf3a980dbe41041cff18..1aa8d110759a7d99c26cf7baaf6d4ce4b92975b9 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu @@ -15,7 +15,7 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/memory/memory.h" -#include "paddle/operators/conv2d_transpose_op.h" +#include "paddle/operators/conv_transpose_op.h" #include "paddle/platform/assert.h" #include "paddle/platform/cudnn_helper.h" diff --git a/paddle/operators/conv2d_transpose_op.cc b/paddle/operators/conv2d_transpose_op.cc deleted file mode 100644 index 8f5d18cddf45d1129040454adbc95a511ccf0583..0000000000000000000000000000000000000000 --- a/paddle/operators/conv2d_transpose_op.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/conv2d_transpose_op.h" - -namespace paddle { -namespace operators { - -void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv2DTransposeOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv2DTransposeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv2DTransposeOp should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - - for (size_t i = 0; i < paddings.size(); ++i) { - PADDLE_ENFORCE_EQ(paddings[i], 0, - "No Padding allowed in conv transpose op."); - } - - PADDLE_ENFORCE_EQ(in_dims.size(), 4, - "Conv2DTransposeOp input should be 4-D tensor."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 4, - "Conv2DTransposeOp filter should be 4-D tensor."); - PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], - "input and kernel input dimension should be equal."); - - auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; - auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; - ctx->SetOutputDim("Output", - {in_dims[0], filter_dims[1], output_height, output_width}); -} - -Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( - framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "(Tensor) The input tensor of convolution transpose operator. " - "The format of input tensor is NCHW, where N is batch size, C is the " - "number of input channels, H is the height of the image, and " - "W is the width of the image."); - AddInput("Filter", - "(Tensor) The filter tensor of convolution transpose operator." - "The format of the filter tensor is CMHW, where C is the number of " - "output image channels, M is the number of input image channels, " - "H is the height of the filter, and W is the width of the filter. " - "We enforce groups number == 1 and padding == 0 in " - "the convolution transpose scenario."); - AddOutput("Output", - "(Tensor) The output tensor of convolution transpose operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", - "strides of convolution transpose operator.") - .SetDefault({1, 1}); - AddAttr>("paddings", - "paddings of convolution transpose operator.") - .SetDefault({0, 0}); - AddComment(R"DOC( -Convolution Transpose Operator. - -The convolution transpose operation calculates the output based on the input, -filter, strides, paddings, and groups parameters. The size of each dimension -of the parameters is checked in the infer-shape method. - -)DOC"); -} - -void Conv2DTransposeOpGrad::InferShape( - framework::InferShapeContext* ctx) const { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(conv2d_transpose, ops::Conv2DTransposeOp, - ops::Conv2DTransposeOpMaker, conv2d_transpose_grad, - ops::Conv2DTransposeOpGrad); - -REGISTER_OP_CPU_KERNEL( - conv2d_transpose, - ops::GemmConv2DTransposeKernel); -REGISTER_OP_CPU_KERNEL( - conv2d_transpose_grad, - ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2d_transpose_op.h b/paddle/operators/conv2d_transpose_op.h deleted file mode 100644 index cab7788227690621a0e5b744197b86c515bbef72..0000000000000000000000000000000000000000 --- a/paddle/operators/conv2d_transpose_op.h +++ /dev/null @@ -1,254 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/im2col.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -// Define Op classes in .h file so that other conv transpose -// operator implementations can reuse the code. -class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Conv2DTransposeOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker); -}; - -class Conv2DTransposeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -class Conv2DTransposeOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -template -class GemmConv2DTransposeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped, so it should not be constant pointer - Tensor filter = *context.Input("Filter"); - - Tensor* output = context.Output("Output"); - - std::vector strides = context.Attr>("strides"); - - // TODO(Zhuoyuan): Paddings can be added in future. - // groups will alway be disabled in conv2d_transpose. - - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - const int k_h = filter.dims()[2]; - const int k_w = filter.dims()[3]; - - const int c = output->dims()[1]; // output channels - const int o_h = output->dims()[2]; - const int o_w = output->dims()[3]; - - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - col2im; - - // use col_shape in the im2col and col2im calculation - DDim col_shape = {c, k_h, k_w, h, w}; - - // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape = {c * k_h * k_w, h * w}; - - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix; - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - - DDim output_shape = {c, o_h, o_w}; - DDim input_matrix_shape = {m, h * w}; - - DDim filter_matrix_shape = {m, c * k_h * k_w}; - filter.Resize(filter_matrix_shape); - - // convolution transpose: gemm + col2im (similar to conv-backward on input) - - output->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*output); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - // batch with size (M, h * w) - Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // filter size: (M, c * k_h * k_w) - - // output size: (c, o_h, o_w) - Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); - - // col_matrix = filter * input_batch - // of shape (c * k_h * k_w, h * w) - math::matmul(context.device_context(), filter, true, - input_batch, false, T(1.0), &col_matrix, T(0.0)); - col2im(context.device_context(), output_batch, col, strides[0], - strides[1], 0, 0, 0, 0); - } - } -}; - -template -class GemmConv2DTransposeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - - // For filter, we do not use const pointer b/c we will do reshape, - // but we should avoid modifying its value. - Tensor filter = *context.Input("Filter"); - - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = - context.Output(framework::GradVarName("Filter")); - - std::vector strides = context.Attr>("strides"); - // Actually, no paddings and groups allowed in conv transpose. - std::vector paddings = context.Attr>("paddings"); - - const int batch_size = input->dims()[0]; - const int m = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - const int k_h = filter.dims()[2]; - const int k_w = filter.dims()[3]; - - const int c = output_grad->dims()[1]; // output channels - const int o_h = output_grad->dims()[2]; - const int o_w = output_grad->dims()[3]; - - // Only im2col functor required for bp to get to the right shape - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - im2col; - - // use col_shape in the im2col and col2im calculation - DDim col_shape = {c, k_h, k_w, h, w}; - - // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; - - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - - DDim output_shape = {c, o_h, o_w}; - DDim input_matrix_shape = {m, h * w}; - - DDim filter_matrix_shape = {m, c * k_h * k_w}; - filter.Resize(filter_matrix_shape); - - // convolution transpose grad on input: - // im2col + gemm (similar to conv-forward) - // input need to compute gradient - if (input_grad) { - Tensor col_matrix; - col_matrix.ShareDataWith(col); - DDim col_matrix_shape = {c * k_h * k_w, h * w}; - col_matrix.Resize(col_matrix_shape); - - input_grad->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - // batch with size (c, o_h * o_w) - Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - // filter of size (m, c * k_h * k_w) - - // batch with size (m, h, w) - Tensor input_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_matrix_shape); - - // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); - - // gemm: dx = filter * dy - // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) - math::matmul(context.device_context(), filter, false, - col_matrix, false, T(1.0), &input_grad_batch, - T(0.0)); - } - } - - // filter gradient required - if (filter_grad) { - Tensor col_matrix_f; - col_matrix_f.ShareDataWith(col); - DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; - col_matrix_f.Resize(col_matrix_shape_f); - - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - auto t = framework::EigenVector::Flatten(filter_grad_); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; ++i) { - // batch with size (c, o_h, o_w) - Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - // input batch - Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - - // im2col: (c * h * w, k_h * k_w) - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); - - // gemm: d_filter = x * y_grad^T - // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) - math::matmul(context.device_context(), in_batch, false, - col_matrix_f, true, T(1.0), &filter_grad_, - T(1.0)); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..50081779a5ea3c81884007d4e4b7832dc4ea2bdd --- /dev/null +++ b/paddle/operators/conv_transpose_op.cc @@ -0,0 +1,203 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_transpose_op.h" + +namespace paddle { +namespace operators { + +void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of ConvTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of ConvTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of ConvTransposeOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + for (size_t i = 0; i < paddings.size(); ++i) { + PADDLE_ENFORCE_EQ(paddings[i], 0, + "No Padding allowed in conv transpose op."); + } + + PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, + "ConvTransposeOp intput should be 4-D or 5-D tensor."); + PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(), + "ConvTransposeOp input dimension and filter dimension " + "should be the same."); + PADDLE_ENFORCE(in_dims.size() - strides.size() == 2U, + "ConvTransposeOp input dimension and strides dimension should " + "be consistent."); + PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), + "ConvTransposeOp paddings dimension and Conv strides " + "dimension should be the same."); + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], + "In ConvTransposeOp, The input channel should be the same " + "as the number of filters."); + + std::vector output_shape({in_dims[0], filter_dims[1]}); + for (size_t i = 0; i < paddings.size(); ++i) { + output_shape.push_back((in_dims[i + 2] - 1) * strides[i] + + filter_dims[i + 2]); + } + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); +} + +Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( + framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "(Tensor) The input tensor of convolution transpose operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of input channels, H is the height of the feature, and " + "W is the width of the feature."); + AddInput("Filter", + "(Tensor) The filter tensor of convolution transpose operator. " + "The format of the filter tensor is CMHW, where C is the number of " + "output image channels, M is the number of input image channels, " + "H is the height of the filter, and W is the width of the filter. " + "We enforce groups number == 1 and padding == 0 in " + "the convolution transpose scenario."); + AddOutput("Output", + "(Tensor) The output tensor of convolution transpose operator. " + "The format of output tensor is also NCHW."); + AddAttr>( + "strides", + "(vector defalut:{1, 1}), strides of convolution transpose operator.") + .SetDefault({1, 1}); + AddAttr>( + "paddings", + "(vector defalut:{0, 0}), paddings of convolution transpose operator.") + .SetDefault({0, 0}); + AddComment(R"DOC( +Convolution2D Transpose Operator. + +The convolution transpose operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. + +Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch +size, C is the number of channels, H is the height of the feature, and +W is the width of the feature. Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. +The input(X) size and output(Out) size may be different. +Example: + Input: + Input shape: (N, C_in, H_in, W_in) + Filter shape: (C_in, C_out, H_f, W_f) + Output: + Output shape: (N, C_out, H_out, W_out) + where + H_out = (H_in - 1) * strides[0] - 2 * paddings[0] + filter_size[0]; + W_out = (W_in - 1) * strides[1] - 2 * paddings[1] + filter_size[1]; +)DOC"); +} + +Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( + framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(Tensor) The input tensor of convolution transpose operator." + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D is the depth of the feature, H is the " + "height of the feature, and " + "W is the width of the feature."); + AddInput("Filter", + "(Tensor) The filter tensor of convolution transpose operator." + "The format of the filter tensor is CMDHW, where C is the number of " + "output image channels, M is the number of input image channels, D " + "is the depth of the filter, H is the height of the filter, and " + "W is the width of the filter." + "We enforce groups number == 1 and padding == 0 in " + "the convolution3d transpose scenario."); + AddOutput("Output", + "(Tensor) The output tensor of convolution transpose operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D is the depth of the feature, H is the " + "height of the feature, and W is the width of the feature."); + AddAttr>( + "strides", + "(vector defalut:{1, 1, 1}), strides of convolution transpose operator.") + .SetDefault({1, 1, 1}); + AddAttr>( + "paddings", + "(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.") + .SetDefault({0, 0, 0}); + AddComment(R"DOC( +Convolution3D Transpose Operator. + +The convolution transpose operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. + +Input(Input, Filter) and output(Output) are in NCDHW format. Where N is batch +size, C is the number of channels, D is the depth of the feature, +H is the height of the feature, and W is the width of the feature. +Parameters(ksize, strides, paddings) are three elements. +These three elements represent depth, height and width, respectively. +The input(X) size and output(Out) size may be different. +Example: + Input: + Input shape: (N, C_in, D_in, H_in, W_in) + Filter shape: (C_in, C_out, D_f, H_f, W_f) + Output: + Output shape: (N, C_out, D_out, H_out, W_out) + where + D_out = (D_in - 1) * strides[0] - 2 * paddings[0] + filter_size[0]; + H_out = (H_in - 1) * strides[1] - 2 * paddings[1] + filter_size[1]; + W_out = (W_in - 1) * strides[2] - 2 * paddings[2] + filter_size[2]; +)DOC"); +} + +void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, + conv2d_transpose_grad, ops::ConvTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2d_transpose, + ops::GemmConvTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_transpose_grad, + ops::GemmConvTransposeGradKernel); + +REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, + conv3d_transpose_grad, ops::ConvTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv3d_transpose, + ops::GemmConvTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv3d_transpose_grad, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv2d_transpose_op.cu b/paddle/operators/conv_transpose_op.cu similarity index 63% rename from paddle/operators/conv2d_transpose_op.cu rename to paddle/operators/conv_transpose_op.cu index 931ac9eed294c4fe7c726d8cc2c4d9a39ec12828..401cddb379ced134b800d2a078fe130a2850fbb2 100644 --- a/paddle/operators/conv2d_transpose_op.cu +++ b/paddle/operators/conv_transpose_op.cu @@ -12,13 +12,20 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2d_transpose_op.h" +#include "paddle/operators/conv_transpose_op.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( conv2d_transpose, - ops::GemmConv2DTransposeKernel); + ops::GemmConvTransposeKernel); REGISTER_OP_GPU_KERNEL( conv2d_transpose_grad, - ops::GemmConv2DTransposeGradKernel); + ops::GemmConvTransposeGradKernel); + +REGISTER_OP_GPU_KERNEL( + conv3d_transpose, + ops::GemmConvTransposeKernel); +REGISTER_OP_GPU_KERNEL( + conv3d_transpose_grad, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6c1a6220d784abf89ec789f94d9cff9e5414db04 --- /dev/null +++ b/paddle/operators/conv_transpose_op.h @@ -0,0 +1,293 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +// Define Op classes in .h file so that other conv transpose +// operator implementations can reuse the code. +class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv3DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class ConvTransposeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class ConvTransposeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +template +class GemmConvTransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped, so it should not be constant pointer + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + + std::vector strides = context.Attr>("strides"); + // TODO(Zhuoyuan): Paddings can be added in future. + // groups will alway be disabled in conv2dtranspose. + + const int batch_size = static_cast(input->dims()[0]); + + // input_shape_vec: {h, w} or {d, h, w} + std::vector input_shape_vec = framework::vectorize(input->dims()); + input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); + + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec = framework::vectorize(filter.dims()); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); + + // use col_shape in the im2col and col2im (or vol2col and col2vol) + // calculation + // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} + std::vector col_shape_vec; + col_shape_vec.push_back(output->dims()[1]); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), + input_shape_vec.end()); + DDim col_shape(framework::make_ddim(col_shape_vec)); + + // use col_matrix_shape in the gemm calculation + // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) + DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + DDim output_shape = + framework::slice_ddim(output->dims(), 1, output->dims().size()); + + // input matrix size: (m, h * w) or (m, d * h * w) + DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; + + // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) + DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; + filter.Resize(filter_matrix_shape); + + output->mutable_data(context.GetPlace()); + math::SetConstant set_zero; + set_zero(context.device_context(), output, static_cast(0)); + + // convolution transpose: gemm + col2im or col2vol (similar to conv-backward + // on input) + for (int i = 0; i < batch_size; i++) { + // batch with size (m, h * w) or (m, d * h * w) + Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + + // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + + // col_matrix = filter * input_batch + // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) + math::matmul(context.device_context(), filter, true, + input_batch, false, static_cast(1.0), + &col_matrix, static_cast(0.0)); + + if (filter_shape_vec.size() == 2) { + // col2im: col_matrix -> dy + // from (c * k_h * k_w, h * w) to (c, o_h, o_w) + math::Col2ImFunctor col2im; + + col2im(context.device_context(), output_batch, col, strides[0], + strides[1], 0, 0, 0, 0); + } else if (filter_shape_vec.size() == 3) { + // col2vol: col_matrix -> dy + // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) + math::Col2VolFunctor col2vol; + col2vol(context.device_context(), output_batch, col, strides[0], + strides[1], strides[2], 0, 0, 0); + } + } + } +}; + +template +class GemmConvTransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + // For filter, we do not use const pointer b/c we will do reshape, + // but we should avoid modifying its value. + Tensor filter = *context.Input("Filter"); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + if ((!input_grad) && (!filter_grad)) return; + + std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); + + const int batch_size = static_cast(input->dims()[0]); + + // input_shape_vec: {h, w} or {d, h, w} + std::vector input_shape_vec = framework::vectorize(input->dims()); + input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); + + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec = framework::vectorize(filter.dims()); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); + + // use col_shape in the im2col and col2im (or vol2col and col2vol) + // calculation + // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} + std::vector col_shape_vec; + col_shape_vec.push_back(output_grad->dims()[1]); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), + input_shape_vec.end()); + DDim col_shape(framework::make_ddim(col_shape_vec)); + + // use col_matrix_shape in the gemm calculation + // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) + DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + + // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + DDim output_shape = framework::slice_ddim(output_grad->dims(), 1, + output_grad->dims().size()); + + // input matrix size: (m, h * w) or (m, d * h * w) + DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; + + // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) + DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; + filter.Resize(filter_matrix_shape); + + // convolution transpose grad on input: + // im2col + gemm (similar to conv-forward) + // input need to compute gradient + if (input_grad || filter_grad) { + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + Tensor filter_grad_; + math::SetConstant set_zero; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + set_zero(context.device_context(), input_grad, static_cast(0)); + } + if (filter_grad) { // filter size (m, c, k_h, k_w) + filter_grad->mutable_data(context.GetPlace()); + set_zero(context.device_context(), filter_grad, static_cast(0)); + filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + } + + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_h * o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + + if (filter_shape_vec.size() == 2) { + // im2col: dy -> col matrix + // from (c, o_h, o_w) to (c * k_h * k_w, h * w) + math::Im2ColFunctor im2col; + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + } else if (filter_shape_vec.size() == 3) { + // vol2col: dy -> col_matrix + // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + + if (input_grad) { + // batch with size (m, h, w) + Tensor input_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + // gemm: dx = filter * dy + // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w) + // or + // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, + // d, h, w) + math::matmul(context.device_context(), filter, false, + col_matrix, false, static_cast(1.0), + &input_grad_batch, static_cast(0.0)); + } + if (filter_grad) { + // input batch + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // gemm: d_filter = x * dy^T + // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w) + // or + // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * + // k_h * k_w) + math::matmul(context.device_context(), in_batch, false, + col_matrix, true, static_cast(1.0), + &filter_grad_, static_cast(1.0)); + } + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py b/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py index 999a0bdc629010d96a8db31b317ba7a65bf35773..54349c018c4a53b8767d6cd4f94d99c719dc0237 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py @@ -58,36 +58,37 @@ class TestConv2dTransposeOp(OpTest): print 'check output here for', self.op_type self.check_output() - def init_test_case(self): - self.pad = [0, 0] - self.stride = [1, 1] - self.dilations = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - f_c = self.input_size[1] - self.filter_size = [f_c, 6, 3, 3] - - def init_op_type(self): - self.op_type = "conv2d_transpose" - def test_check_grad_no_input(self): self.check_grad( ['Filter'], 'Output', - max_relative_error=0.05, + max_relative_error=0.02, no_grad_set=set(['Input'])) def test_check_grad_no_filter(self): self.check_grad( ['Input'], 'Output', - max_relative_error=0.05, + max_relative_error=0.02, no_grad_set=set(['Filter'])) def test_check_grad(self): self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + set(['Input', 'Filter']), 'Output', max_relative_error=0.02) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2d_transpose" +# ------------ test_cudnn ------------ class TestCudnn(TestConv2dTransposeOp): def init_op_type(self): self.op_type = "conv2d_transpose_cudnn" diff --git a/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py b/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py new file mode 100644 index 0000000000000000000000000000000000000000..132fe7931438a30cf02e4ad2894c0838e48ffc9f --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py @@ -0,0 +1,97 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def conv3dtranspose_forward_naive(input_, filter_, conv3dtranspose_param): + # [2, 3, 5, 5, 5] + in_n, in_c, in_d, in_h, in_w = input_.shape + # [3, 6, 3, 3, 3] + f_c, out_c, f_d, f_h, f_w = filter_.shape + assert in_c == f_c + + stride, pad = conv3dtranspose_param['stride'], conv3dtranspose_param['pad'] + out_d = (in_d - 1) * stride[0] + f_d + out_h = (in_h - 1) * stride[1] + f_h + out_w = (in_w - 1) * stride[2] + f_w + + out = np.zeros((in_n, out_c, out_d, out_h, out_w)) + + for n in range(in_n): + for d in range(in_d): + for i in range(in_h): + for j in range(in_w): + input_masked = input_[n, :, d, i, j] # (c) + input_masked = np.reshape(input_masked, (in_c, 1, 1, 1)) + input_masked = np.tile(input_masked, (1, f_d, f_h, f_w)) + + for k in range(out_c): + tmp_out = np.sum(input_masked * filter_[:, k, :, :, :], + axis=0) + d1, d2 = d * stride[0], d * stride[0] + f_d + i1, i2 = i * stride[1], i * stride[1] + f_h + j1, j2 = j * stride[2], j * stride[2] + f_w + out[n, k, d1:d2, i1:i2, j1:j2] += tmp_out + + return out + + +class TestConv3dTransposeOp(OpTest): + def setUp(self): + # init as conv transpose + self.init_op_type() + + # [2, 3, 5, 5, 5] -> kernel [3, 6, 3, 3, 3] -> output [2, 6, 7, 7, 7] + self.init_test_case() + + conv3dtranspose_param = {'stride': self.stride, 'pad': self.pad} + input_ = np.random.random(self.input_size).astype("float32") + filter_ = np.random.random(self.filter_size).astype("float32") + output = conv3dtranspose_forward_naive( + input_, filter_, conv3dtranspose_param).astype("float32") + # print 'deconv output py', output, output.shape + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + # 'dilations': self.dilations + } + self.outputs = {'Output': output} + + def test_check_output(self): + print 'check output here' + self.check_output() + + def test_check_grad(self): + self.check_grad( + set(['Input', 'Filter']), 'Output', max_relative_error=0.02) + + def test_check_grad_no_filter(self): + self.check_grad( + ['Input'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad( + ['Filter'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.input_size = [2, 3, 5, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + + def init_op_type(self): + self.op_type = "conv3d_transpose" + + +if __name__ == '__main__': + unittest.main()