From c9d8cb4e90597409257da63c3d788ad067382772 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 11 Sep 2017 21:25:30 +0800 Subject: [PATCH] Convolution op and forward calculation. --- paddle/operators/conv_op.cc | 96 ++++++++++++++++ paddle/operators/conv_op.cu | 22 ++++ paddle/operators/gemm_conv_op.h | 103 ++++++++++++++++++ paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/CMakeLists.txt | 1 + .../v2/framework/tests/test_conv2d_op.py | 62 +++++++++++ 6 files changed, 285 insertions(+) create mode 100644 paddle/operators/conv_op.cc create mode 100644 paddle/operators/conv_op.cu create mode 100644 paddle/operators/gemm_conv_op.h create mode 100644 python/paddle/v2/framework/tests/test_conv2d_op.py diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc new file mode 100644 index 00000000000..873366394da --- /dev/null +++ b/paddle/operators/conv_op.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gemm_conv_op.h" + +namespace paddle { +namespace operators { + +int outputSize(int input_size, int filter_size, int padding, int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class Conv2DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto *in = ctx.Input("Input"); + auto *filter = ctx.Input("Filter"); + auto *out = ctx.Output("Output"); + PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); + PADDLE_ENFORCE_EQ(filter->dims().size(), 4, + "Conv2DOp filter should be 4-D."); + + std::vector strides = Attr>("strides"); + std::vector paddings = Attr>("paddings"); + auto output_height = + outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); + auto output_width = + outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); + out->Resize( + {in->dims()[0], filter->dims()[0], output_height, output_width}); + } +}; + +class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DOppMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput( + "Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output " + "image channels, C is the number of input image channels, H and W is " + " height and width of filter."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCHW."); + AddComment(R"DOC( +The convolution operation calculates the output based on +the input, filter and strides, paddings parameters. +)DOC"); + AddAttr>("strides", "strides of convolution operator."); + AddAttr>("paddings", "paddings of convolution operator."); + } +}; + +class Conv2DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOppMaker, conv2d_grad, + ops::Conv2DOpGrad); + +REGISTER_OP_CPU_KERNEL(conv2d, + ops::GemmConvKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cu b/paddle/operators/conv_op.cu new file mode 100644 index 00000000000..a15adecda46 --- /dev/null +++ b/paddle/operators/conv_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gemm_conv_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(conv2d, + ops::GemmConvKernel); +REGISTER_OP_GPU_KERNEL( + conv2d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h new file mode 100644 index 00000000000..16ea5ff74c5 --- /dev/null +++ b/paddle/operators/gemm_conv_op.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GemmConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + Tensor* filter = const_cast(context.Input("Filter")); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + paddle::framework::Tensor col; + paddle::framework::Tensor in_slice; + paddle::framework::Tensor out_slice; + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_height = filter->dims()[filter->dims().size() - 2]; + int filter_width = filter->dims()[filter->dims().size() - 1]; + int output_height = output->dims()[2]; + int output_width = output->dims()[3]; + + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + framework::DDim col_shape = {input_channels, filter_height, filter_width, + output_height, output_width}; + col.mutable_data(col_shape, context.GetPlace()); + + auto* device_context = + const_cast(context.device_context_); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3]}; + framework::DDim filter_matrix_shape = { + filter->dims()[0], + filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; + framework::DDim output_matrix_shape = { + output->dims()[1], output->dims()[2] * output->dims()[3]}; + filter->Resize(filter_matrix_shape); + + // convolution opperator: im2col + gemm + for (int i = 0; i < batch_size; i++) { + // im2col + in_slice = input->Slice(i, i + 1); + in_slice.Resize(input_shape); + col.Resize(col_shape); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + out_slice = output->Slice(i, i + 1); + out_slice.Resize(output_matrix_shape); + col.Resize(col_matrix_shape); + math::matmul(*filter, false, col, false, T(1.0), &out_slice, + T(0.0), device_context); + } + } +}; + +template +class GemmConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { +#if 0 + auto input = context.Input("Input"); + auto filter = context.Input("Filter"); + auto output = context.Output("Output"); + output->mutable_data(context.GetPlace()); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 53985933ed1..ef72c86cbdb 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -51,6 +51,7 @@ USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); USE_OP(top_k); USE_OP(squared_l2_distance); +USE_OP(conv2d); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ef910f939be..11290e042df 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -35,3 +35,4 @@ py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) py_test(mnist SRCS mnist.py) py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) +py_test(test_conv2d SRCS test_conv2d_op.py) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py new file mode 100644 index 00000000000..d2015d0ce51 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -0,0 +1,62 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta + + +class TestConv2dOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "conv2d" + batch_size = 2 + input_channels = 3 + input_height = 5 + input_width = 5 + output_channels = 6 + filter_height = 3 + filter_width = 3 + stride = 1 + padding = 0 + output_height = (input_height - filter_height + 2 * padding + ) / stride + 1 + output_width = (input_width - filter_width + 2 * padding) / stride + 1 + input = np.random.random((batch_size, input_channels, input_height, + input_width)).astype("float32") + filter = np.random.random( + (output_channels, input_channels, filter_height, + filter_width)).astype("float32") + output = np.ndarray( + (batch_size, output_channels, output_height, output_width)) + + for batchid in xrange(batch_size): + for channelid in xrange(output_channels): + for rowid in xrange(output_height): + for colid in xrange(output_width): + start_h = (rowid * stride) - padding + start_w = (colid * stride) - padding + output_value = 0.0 + for inchannelid in xrange(input_channels): + for frowid in xrange(filter_height): + for fcolid in xrange(filter_width): + input_value = 0.0 + inrowid = start_h + frowid + incolid = start_w + fcolid + if ((inrowid >= 0 and + inrowid < input_height) and + (incolid >= 0 and + incolid < input_width)): + input_value = input[batchid][ + inchannelid][inrowid][incolid] + filter_value = filter[channelid][ + inchannelid][frowid][fcolid] + output_value += input_value * filter_value + output[batchid][channelid][rowid][colid] = output_value + + self.inputs = {'Input': input, 'Filter': filter} + self.outputs = {'Output': output} + self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} + + +if __name__ == '__main__': + unittest.main() -- GitLab