From 96b4035dd132d419f463bd0341baa2c4a773b8b6 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 16:08:23 +0800 Subject: [PATCH] Add conv3d_gemm_op --- paddle/operators/CMakeLists.txt | 5 +- paddle/operators/conv3d_op.cc | 117 +++++++++++++++ paddle/operators/conv3d_op.cu | 22 +++ paddle/operators/conv3d_op.h | 259 ++++++++++++++++++++++++++++++++ 4 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/conv3d_op.cc create mode 100644 paddle/operators/conv3d_op.cu create mode 100644 paddle/operators/conv3d_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7dae8fe2f9..576cd2530d 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -112,7 +112,8 @@ set(DEPS_OPS cond_op cross_entropy_op softmax_with_cross_entropy_op - sum_op) + sum_op + conv3d_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -121,6 +122,8 @@ op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) +op_library(conv3d_op DEPS vol2col) + list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/conv3d_op.cc b/paddle/operators/conv3d_op.cc new file mode 100644 index 0000000000..2b34a2671d --- /dev/null +++ b/paddle/operators/conv3d_op.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/conv3d_op.h" + +namespace paddle { +namespace operators { + +int OutputSizeConv3d(int input_size, int filter_size, int padding, int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv3DOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv3DOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv3DOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + int groups = ctx->Attrs().Get("groups"); + int input_channels = in_dims[1]; + int output_channels = filter_dims[0]; + + PADDLE_ENFORCE_EQ(in_dims.size(), 5, "Conv3DOp input should be 5-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 5, "Conv3DOp filter should be 5-D."); + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < paddings.size(); ++i) { + output_shape.push_back(OutputSizeConv3d(in_dims[i + 2], filter_dims[i], + paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); +} + +void Conv3DOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + +Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is the " + "number of channels, D, H and W is the depth, height and width of " + "image."); + AddInput("Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCDHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "D, H and W is depth, height and width of filter. " + "If the groups attribute is greater than 1, C equal the number of " + "input image channels divided by the groups."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCDHW."); + AddAttr>("strides", "strides of convolution operator.") + .SetDefault({1, 1, 1}); + AddAttr>("paddings", "paddings of convolution operator.") + .SetDefault({0, 0, 0}); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); + AddComment(R"DOC( +The convolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv3d, ops::Conv3DOp, ops::Conv3DOpMaker, conv3d_grad, + ops::Conv3DOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv3d, ops::GemmConv3DKernel); +REGISTER_OP_CPU_KERNEL( + conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv3d_op.cu b/paddle/operators/conv3d_op.cu new file mode 100644 index 0000000000..ec6121d5d5 --- /dev/null +++ b/paddle/operators/conv3d_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/conv3d_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + conv3d, ops::GemmConv3DKernel); +REGISTER_OP_GPU_KERNEL( + conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv3d_op.h b/paddle/operators/conv3d_op.h new file mode 100644 index 0000000000..a22cb34f67 --- /dev/null +++ b/paddle/operators/conv3d_op.h @@ -0,0 +1,259 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class Conv3DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv3DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv3DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +template +class GemmConv3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output->dims()[1]; + int output_depth = output->dims()[2]; + int output_height = output->dims()[3]; + int output_width = output->dims()[4]; + + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3], input->dims()[4]}; + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + + framework::DDim output_matrix_shape = { + output_channels, output_depth * output_height * output_width}; + + // convolution operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); + } + } + } +}; + +template +class GemmConvGrad3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output_grad->dims()[1]; + int output_depth = output_grad->dims()[2]; + int output_height = output_grad->dims()[3]; + int output_width = output_grad->dims()[4]; + + paddle::operators::math::Col2VolFunctor col2vol; + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col and col2vol calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3], input->dims()[4]}; + framework::DDim output_matrix_shape = {output_grad->dims()[1], + output_grad->dims()[2] * + output_grad->dims()[3] * + output_grad->dims()[4]}; + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + + // convolution backward input operator: gemm + col2vol + // convolution backward weight operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + + // col2vol + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + } + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle -- GitLab