diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 1ca4ba29d7f1b5e4aeecf7d352f68c1717f288a4..91028877b61418d3a851b7f78492b7cc2f940d14 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -123,7 +123,8 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op - lstm_op) + lstm_op + conv3dtranspose_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -135,6 +136,7 @@ op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) +op_library(conv3dtranspose_op DEPS vol2col) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/conv3dtranspose_op.cc b/paddle/operators/conv3dtranspose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f830e98f1bcabf72899507a02384c4b6edbef4b0 --- /dev/null +++ b/paddle/operators/conv3dtranspose_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv3dtranspose_op.h" + +namespace paddle { +namespace operators { + +void Conv3DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv3DTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv3DTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv3DTransposeOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + for (size_t i = 0; i < paddings.size(); ++i) { + PADDLE_ENFORCE_EQ(paddings[i], 0, + "No Padding allowed in conv transpose op."); + } + + PADDLE_ENFORCE_EQ(in_dims.size(), 5, + "Conv3DTransposeOp input should be 5-D tensor."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 5, + "Conv3DTransposeOp filter should be 5-D tensor."); + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], + "input and kernel input dimension should be equal."); + + std::vector output_shape({in_dims[0], in_dims[1]}); + for (size_t i = 0; i < filter_dims.size(); ++i) { + output_shape.push_back((in_dims[i + 2] - 1) * strides[i] + + filter_dims[i + 2]); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); +} + +Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( + framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "(Tensor) The input tensor of convolution transpose operator." + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and width of " + "feature."); + AddInput("Filter", + "(Tensor) The filter tensor of convolution transpose operator." + "The format of the filter tensor is CMDHW, where C is the number of " + "output image channels, M is the number of input image channels, " + "D, H and W is depth, height and width of filter. " + "We enforce groups number == 1 and padding == 0 in " + "convolution transpose Scenario."); + AddOutput("Output", + "(Tensor) The output tensor of convolution transpose operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of feature."); + AddAttr>("strides", + "strides of convolution transpose operator.") + .SetDefault({1, 1, 1}); + AddAttr>("paddings", + "paddings of convolution transpose operator.") + .SetDefault({0, 0, 0}); + AddComment(R"DOC( +The convolution transpose operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); +} + +void Conv3DTransposeOpGrad::InferShape( + framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv3dtranspose, ops::Conv3DTransposeOp, + ops::Conv3DTransposeOpMaker, conv3dtranspose_grad, + ops::Conv3DTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv3dtranspose, + ops::GemmConv3DTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv3dtranspose_grad, + ops::GemmConv3DTransposeGradKernel); diff --git a/paddle/operators/conv3dtranspose_op.cu b/paddle/operators/conv3dtranspose_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..447646fd756d3f6d3b8429c91dbee2643a42c87b --- /dev/null +++ b/paddle/operators/conv3dtranspose_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv3dtranspose_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + conv3dtranspose, + ops::GemmConv3DTransposeKernel); +REGISTER_OP_GPU_KERNEL( + conv3dtranspose_grad, + ops::GemmConv3DTransposeGradKernel); diff --git a/paddle/operators/conv3dtranspose_op.h b/paddle/operators/conv3dtranspose_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fbab12731420b3756485c1f5d663a85604382ed9 --- /dev/null +++ b/paddle/operators/conv3dtranspose_op.h @@ -0,0 +1,259 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +// Define Op classes in .h file so that other conv transpose +// operator implementations can reuse the code. +class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv3DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class Conv3DTransposeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv3DTransposeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +template +class GemmConv3DTransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped, so it should not be constant pointer + Tensor filter = *context.Input("Filter"); + + Tensor* output = context.Output("Output"); + + std::vector strides = context.Attr>("strides"); + + // TODO(chengduo): Paddings can be added in future. + // groups will alway be disabled in conv3dtranspose. + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int d = input->dims()[2]; + const int h = input->dims()[3]; + const int w = input->dims()[4]; + + const int k_d = filter.dims()[2]; + const int k_h = filter.dims()[3]; + const int k_w = filter.dims()[4]; + + const int c = output->dims()[1]; // output channels + const int o_d = output->dims()[2]; + const int o_h = output->dims()[3]; + const int o_w = output->dims()[4]; + + paddle::operators::math::Col2VolFunctor col2vol; + + // use col_shape in the vol2col and col2vol calculation + DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + DDim output_shape = {c, o_d, o_h, o_w}; + DDim input_matrix_shape = {m, d * h * w}; + + DDim filter_matrix_shape = {m, c * k_d * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose: gemm + col2vol (similar to conv-backward on input) + + output->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*output); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (M, d * h * w) + Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // filter size: (M, c * k_d * k_h * k_w) + + // output size: (c, o_d, o_h, o_w) + Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + + // col_matrix = filter * input_batch + // of shape (c * k_d * k_h * k_w, d * h * w) + math::matmul(context.device_context(), filter, true, + input_batch, false, T(1.0), &col_matrix, T(0.0)); + col2vol(context.device_context(), output_batch, col, strides[0], + strides[1], strides[2], 0, 0, 0); + } + } +}; + +template +class GemmConv3DTransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + + // For filter, we do not use const pointer b/c we will do reshape, + // but we should avoid modifying its value. + Tensor filter = *context.Input("Filter"); + + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int d = input->dims()[2]; + const int h = input->dims()[3]; + const int w = input->dims()[4]; + + const int k_d = filter.dims()[2]; + const int k_h = filter.dims()[3]; + const int k_w = filter.dims()[4]; + + const int c = output_grad->dims()[1]; // output channels + const int o_d = output_grad->dims()[2]; + const int o_h = output_grad->dims()[3]; + const int o_w = output_grad->dims()[4]; + + // Only vol2col functor required for bp to get to the right shape + paddle::operators::math::Vol2ColFunctor vol2col; + + // use col_shape in the vol2col and col2vol calculation + DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + + DDim output_shape = {c, o_d, o_h, o_w}; + DDim input_matrix_shape = {m, d * h * w}; + + DDim filter_matrix_shape = {m, c * k_d * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose grad on input: + // vol2col + gemm (similar to conv-forward) + // input need to compute gradient + if (input_grad) { + Tensor col_matrix; + col_matrix.ShareDataWith(col); + DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; + col_matrix.Resize(col_matrix_shape); + + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_d * o_h * o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // filter of size (m, c * k_d * k_h * k_w) + + // batch with size (m, d, h, w) + Tensor input_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + + // vol2col: dy from (c, o_d, o_h, o_w) -> (c * k_d * k_h * k_w, d * h * + // w) + vol2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], paddings[2]); + + // gemm: dx = filter * dy + // (m, c *k_d * k_h * k_w) * (c * k_d * k_h * k_w, d* h * w) -> (m, c, + // d, h, w) + math::matmul(context.device_context(), filter, false, + col_matrix, false, T(1.0), &input_grad_batch, + T(0.0)); + } + } + + // filter gradient required + if (filter_grad) { + Tensor col_matrix_f; + col_matrix_f.ShareDataWith(col); + DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w}; + col_matrix_f.Resize(col_matrix_shape_f); + + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; ++i) { + // batch with size (c, o_d, o_h, o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // input batch + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + + // vol2col: (c * d * h * w, k_d * k_h * k_w) + vol2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], paddings[2]); + + // gemm: d_filter = x * y_grad^T + // (m, c * d * h * w) * (k_d * k_h * k_w, c * d * h * w) -> (m, c, d, h, + // w) + math::matmul(context.device_context(), in_batch, false, + col_matrix_f, true, T(1.0), &filter_grad_, + T(1.0)); + } + } + } +}; + +} // namespace operators +} // namespace paddle