From eafbbc11a0bb1f347f7917552d46c2944b5f3bb2 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 26 Oct 2017 10:21:05 +0800 Subject: [PATCH] write conv2d and conv3d together --- paddle/operators/CMakeLists.txt | 11 +- paddle/operators/conv2d_op.cc | 111 -------- paddle/operators/conv3d_op.cu | 22 -- paddle/operators/conv3d_op.h | 263 ------------------ paddle/operators/conv_cudnn_op.cc | 7 +- paddle/operators/conv_cudnn_op.cu | 2 +- paddle/operators/{conv3d_op.cc => conv_op.cc} | 100 +++++-- paddle/operators/{conv2d_op.cu => conv_op.cu} | 7 +- paddle/operators/{conv2d_op.h => conv_op.h} | 224 ++++++++++++++- 9 files changed, 315 insertions(+), 432 deletions(-) delete mode 100644 paddle/operators/conv2d_op.cc delete mode 100644 paddle/operators/conv3d_op.cu delete mode 100644 paddle/operators/conv3d_op.h rename paddle/operators/{conv3d_op.cc => conv_op.cc} (61%) rename paddle/operators/{conv2d_op.cu => conv_op.cu} (78%) rename paddle/operators/{conv2d_op.h => conv_op.h} (51%) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 4d1fb3b96e3..39250480db3 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -69,6 +69,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() + # conv_op contains several operators + if ("${TARGET}" STREQUAL "conv_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(conv2d);\n") + endif() + # save_restore_op contains several operators if ("${TARGET}" STREQUAL "save_restore_op") set(pybind_flag 1) @@ -123,7 +130,7 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op - conv3d_op + conv_op lstm_op) @@ -133,7 +140,7 @@ op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) -op_library(conv3d_op DEPS vol2col) +op_library(conv_op DEPS vol2col) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc deleted file mode 100644 index 1acb8415d06..00000000000 --- a/paddle/operators/conv2d_op.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/conv2d_op.h" - -namespace paddle { -namespace operators { - -void Conv2DOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv2DOp should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - int groups = ctx->Attrs().Get("groups"); - int input_channels = in_dims[1]; - int output_channels = filter_dims[0]; - - PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, - "The number of input channels should be equal to filter " - "channels * groups."); - PADDLE_ENFORCE_EQ( - output_channels % groups, 0, - "The number of output channels should be divided by groups."); - - auto output_height = - OutputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]); - auto output_width = - OutputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]); - ctx->SetOutputDim("Output", - {in_dims[0], filter_dims[0], output_height, output_width}); -} - -Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "The input tensor of convolution operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of image."); - AddInput("Filter", - "The filter tensor of convolution operator." - "The format of the filter tensor is MCHW, where M is the number of " - "output image channels, C is the number of input image channels, " - "H and W is height and width of filter. " - "If the groups attribute is greater than 1, C equal the number of " - "input image channels divided by the groups."); - AddOutput("Output", - "The output tensor of convolution operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") - .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") - .SetDefault({0, 0}); - AddAttr( - "groups", - "group size of convolution operator. " - "Refer to grouped convolution in Alex Krizhevsky's paper: " - "when group=2, the first half of the filters are only connected to the " - "first half of the input channels, and the second half only connected " - "to the second half.") - .SetDefault(1); - AddComment(R"DOC( -The convolution operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the -parameters is checked in the infer-shape. -)DOC"); -} - -void Conv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, - ops::Conv2DOpGrad); - -REGISTER_OP_CPU_KERNEL( - conv2d, ops::GemmConv2DKernel); -REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2DKernel); diff --git a/paddle/operators/conv3d_op.cu b/paddle/operators/conv3d_op.cu deleted file mode 100644 index ec6279f9bbb..00000000000 --- a/paddle/operators/conv3d_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/conv3d_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_GPU_KERNEL( - conv3d, ops::GemmConv3DKernel); -REGISTER_OP_GPU_KERNEL( - conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv3d_op.h b/paddle/operators/conv3d_op.h deleted file mode 100644 index c5aaf019f3b..00000000000 --- a/paddle/operators/conv3d_op.h +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once - -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" -#include "paddle/operators/math/vol2col.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -class Conv3DOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -class Conv3DOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Conv3DOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker); -}; - -template -class GemmConv3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); - - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); - - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output->dims()[1]; - int output_depth = output->dims()[2]; - int output_height = output->dims()[3]; - int output_width = output->dims()[4]; - - paddle::operators::math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); - - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width - filter.Resize(filter_matrix_shape); - - framework::DDim output_matrix_shape = { - output_channels, output_depth * output_height * output_width}; - - // convolution operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - // vol2col - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); - } - } - } -}; - -template -class GemmConvGrad3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = - context.Output(framework::GradVarName("Filter")); - - // The filter and filter_grad will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); - - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); - - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output_grad->dims()[1]; - int output_depth = output_grad->dims()[2]; - int output_height = output_grad->dims()[3]; - int output_width = output_grad->dims()[4]; - - paddle::operators::math::Col2VolFunctor col2vol; - paddle::operators::math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col and col2vol calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); - - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim output_matrix_shape = {output_grad->dims()[1], - output_grad->dims()[2] * - output_grad->dims()[3] * - output_grad->dims()[4]}; - - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width - filter.Resize(filter_matrix_shape); - - // convolution backward input operator: gemm + col2vol - // convolution backward weight operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - - // col2vol - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2vol(context.device_context(), in_grad_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); - } - } - } - - if (filter_grad) { - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - auto t = framework::EigenVector::Flatten(filter_grad_); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // vol2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index 4288f300dd5..37bba3a1a1e 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/conv_op.h" namespace paddle { namespace operators { @@ -38,8 +38,9 @@ class CudnnConvOpMaker : public Conv2DOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv_cudnn, ops::Conv2DOp, ops::CudnnConvOpMaker, conv_cudnn_grad, - ops::Conv2DOpGrad); +REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, + ops::ConvOpGrad); + REGISTER_OP_CPU_KERNEL( conv_cudnn, ops::GemmConv2DKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu index 366d0323b84..e34d5937407 100644 --- a/paddle/operators/conv_cudnn_op.cu +++ b/paddle/operators/conv_cudnn_op.cu @@ -15,7 +15,7 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/memory/memory.h" -#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/conv_op.h" #include "paddle/platform/assert.h" #include "paddle/platform/cudnn_helper.h" diff --git a/paddle/operators/conv3d_op.cc b/paddle/operators/conv_op.cc similarity index 61% rename from paddle/operators/conv3d_op.cc rename to paddle/operators/conv_op.cc index fb3f1265f3a..5e264d730c4 100644 --- a/paddle/operators/conv3d_op.cc +++ b/paddle/operators/conv_op.cc @@ -12,23 +12,18 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv3d_op.h" +#include "paddle/operators/conv_op.h" namespace paddle { namespace operators { -int OutputSizeConv3d(int input_size, int filter_size, int padding, int stride) { - int output_size = (input_size - filter_size + 2 * padding) / stride + 1; - return output_size; -} - -void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { +void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv3DOp should not be null."); + "Input(Input) of ConvOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv3DOp should not be null."); + "Input(Filter) of ConvOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv3DOp should not be null."); + "Output(Output) of ConvOp should not be null."); auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -38,33 +33,65 @@ void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { int input_channels = in_dims[1]; int output_channels = filter_dims[0]; - PADDLE_ENFORCE_EQ(in_dims.size(), 5, "Conv3DOp input should be 5-D tensor."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 5, - "Conv3DOp filter should be 5-D tensor."); + PADDLE_ENFORCE_EQ( + in_dims.size(), filter_dims.size(), + "Conv input dimension and filter dimension should be the same."); + PADDLE_ENFORCE( + in_dims.size() - strides.size() == 2U, + "Conv input dimension and strides dimension should be consistent."); + PADDLE_ENFORCE_EQ( + paddings.size(), strides.size(), + "Conv paddings dimension and Conv strides dimension should be the same."); PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " - "(channels * groups)."); + "channels * groups."); PADDLE_ENFORCE_EQ( output_channels % groups, 0, "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < paddings.size(); ++i) { - output_shape.push_back(OutputSizeConv3d(in_dims[i + 2], filter_dims[i + 2], - paddings[i], strides[i])); + output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], + paddings[i], strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } -void Conv3DOpGrad::InferShape(framework::InferShapeContext* ctx) const { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } +Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput("Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "H and W is height and width of filter. " + "If the groups attribute is greater than 1, C equal the number of " + "input image channels divided by the groups."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", "strides of convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of convolution operator.") + .SetDefault({0, 0}); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); + AddComment(R"DOC( +The convolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); } Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, @@ -125,12 +152,31 @@ Example: )DOC"); } +void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv3d, ops::Conv3DOp, ops::Conv3DOpMaker, conv3d_grad, - ops::Conv3DOpGrad); +REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, + ops::ConvOpGrad); +namespace ops = paddle::operators; +REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, + ops::ConvOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2d, ops::GemmConv2DKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_grad, ops::GemmConvGrad2DKernel); REGISTER_OP_CPU_KERNEL( conv3d, ops::GemmConv3DKernel); diff --git a/paddle/operators/conv2d_op.cu b/paddle/operators/conv_op.cu similarity index 78% rename from paddle/operators/conv2d_op.cu rename to paddle/operators/conv_op.cu index c697c9466d3..d8c0bd9326b 100644 --- a/paddle/operators/conv2d_op.cu +++ b/paddle/operators/conv_op.cu @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/conv_op.h" namespace ops = paddle::operators; @@ -20,3 +20,8 @@ REGISTER_OP_GPU_KERNEL( conv2d, ops::GemmConv2DKernel); REGISTER_OP_GPU_KERNEL( conv2d_grad, ops::GemmConvGrad2DKernel); + +REGISTER_OP_GPU_KERNEL( + conv3d, ops::GemmConv3DKernel); +REGISTER_OP_GPU_KERNEL( + conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv2d_op.h b/paddle/operators/conv_op.h similarity index 51% rename from paddle/operators/conv2d_op.h rename to paddle/operators/conv_op.h index 0621389a79e..e39b1ffeb6d 100644 --- a/paddle/operators/conv2d_op.h +++ b/paddle/operators/conv_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/vol2col.h" namespace paddle { namespace operators { @@ -40,14 +41,20 @@ class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker); }; -class Conv2DOp : public framework::OperatorWithKernel { +class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv3DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class ConvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override; }; -class Conv2DOpGrad : public framework::OperatorWithKernel { +class ConvOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -251,5 +258,218 @@ class GemmConvGrad2DKernel : public framework::OpKernel { } }; +template +class GemmConv3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output->dims()[1]; + int output_depth = output->dims()[2]; + int output_height = output->dims()[3]; + int output_width = output->dims()[4]; + + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = { + input->dims()[1], input->dims()[2], input->dims()[3], + input->dims()[4]}; // channel, depth, height, width + framework::DDim filter_matrix_shape = { + filter.dims()[0], + filter.numel() / filter.dims()[0]}; // filter_out_channel, + // filter_in_channel*filter_depth*filter_height*filter_width + filter.Resize(filter_matrix_shape); + + framework::DDim output_matrix_shape = { + output_channels, output_depth * output_height * output_width}; + + // convolution operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); + } + } + } +}; + +template +class GemmConvGrad3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output_grad->dims()[1]; + int output_depth = output_grad->dims()[2]; + int output_height = output_grad->dims()[3]; + int output_width = output_grad->dims()[4]; + + paddle::operators::math::Col2VolFunctor col2vol; + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col and col2vol calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = { + input->dims()[1], input->dims()[2], input->dims()[3], + input->dims()[4]}; // channel, depth, height, width + framework::DDim output_matrix_shape = {output_grad->dims()[1], + output_grad->dims()[2] * + output_grad->dims()[3] * + output_grad->dims()[4]}; + + framework::DDim filter_matrix_shape = { + filter.dims()[0], + filter.numel() / filter.dims()[0]}; // filter_out_channel, + // filter_in_channel*filter_depth*filter_height*filter_width + filter.Resize(filter_matrix_shape); + + // convolution backward input operator: gemm + col2vol + // convolution backward weight operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + + // col2vol + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + } + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); + } + } + } + } +}; + } // namespace operators } // namespace paddle -- GitLab