diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index c6570b89aedfaac1aef9b00e889b0b3ed21d8d65..264b998f50df016da0741d97d4b26f759ee90900 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -54,9 +54,9 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { public: MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of mul op"); - AddInput("Y", "The second input of mul op"); - AddOutput("Out", "The output of mul op"); + AddInput("X", "(Tensor), 2D tensor of size (M x K)"); + AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); + AddOutput("Out", "(Tensor), 2D tensor of size (M x N)"); AddComment(R"DOC( Two Element Mul Operator. The equation is: Out = X * Y @@ -72,7 +72,7 @@ The equation is: Out = X * Y 构造函数里通过`AddInput`添加输入参数,通过`AddOutput`添加输出参数,通过`AddComment`添加Op的注释。这些函数会将对应内容添加到`OpProto`中。 -上面的代码在`MulOp`中添加两个输入`X`和`Y`,添加了一个输出`Out`,并解释了各自含义,命名请遵守命名规范。 +上面的代码在`MulOp`中添加两个输入`X`和`Y`,添加了一个输出`Out`,并解释了各自含义,命名请遵守[命名规范](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md)。 再以[`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37)为例: diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index ed166935f76be9d25062b5e69536c7b7ac19045d..6d2c14f4c47afb755b1c74f6dc4dd10ab25ed191 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -130,15 +130,19 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE_LT(begin_idx, end_idx, "Begin index must be less than end index."); - PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1."); - size_t base = numel() / dims_[0]; - Tensor dst; - dst.holder_ = holder_; - DDim dst_dims = dims_; - dst_dims[0] = end_idx - begin_idx; - dst.Resize(dst_dims); - dst.offset_ = offset_ + begin_idx * base * sizeof(T); - return dst; + + if (dims_[0] == 1) { + return *this; + } else { + size_t base = numel() / dims_[0]; + Tensor dst; + dst.holder_ = holder_; + DDim dst_dims = dims_; + dst_dims[0] = end_idx - begin_idx; + dst.Resize(dst_dims); + dst.offset_ = offset_ + begin_idx * base * sizeof(T); + return dst; + } } inline Tensor& Tensor::Resize(const DDim& dims) { diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..12db65b5cbf224e95d91c7b4839afa552c084ee7 --- /dev/null +++ b/paddle/operators/conv2d_op.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gemm_conv2d_op.h" + +namespace paddle { +namespace operators { + +int outputSize(int input_size, int filter_size, int padding, int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class Conv2DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), + "Input(Input) of Conv2DOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"), + "Input(Filter) of Conv2DOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), + "Output(Output) of Conv2DOp should not be null."); + + auto in = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto out = ctx.Output("Output"); + std::vector strides = Attr>("strides"); + std::vector paddings = Attr>("paddings"); + int groups = Attr("groups"); + int input_channels = in->dims()[1]; + int output_channels = filter->dims()[0]; + + PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D."); + PADDLE_ENFORCE_EQ(filter->dims().size(), 4, + "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); + + auto output_height = + outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); + auto output_width = + outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); + out->Resize( + {in->dims()[0], filter->dims()[0], output_height, output_width}); + } +}; + +class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput( + "Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "H and W is height and width of filter. " + "If the groups attribute is greater than 1, C equal the number of " + "input image channels divided by the groups."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", "strides of convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of convolution operator.") + .SetDefault({0, 0}); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); + AddComment(R"DOC( +The convolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); + } +}; + +class Conv2DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto in = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto d_in = + ctx.Output(framework::GradVarName("Input")); + auto d_filter = + ctx.Output(framework::GradVarName("Filter")); + if (d_in) d_in->Resize(in->dims()); + if (d_filter) d_filter->Resize(filter->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, + ops::Conv2DOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2d, ops::GemmConv2DKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_grad, ops::GemmConvGrad2DKernel); diff --git a/paddle/operators/conv2d_op.cu b/paddle/operators/conv2d_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5df818ba0496a65502dde37fd1397ec56f8c1101 --- /dev/null +++ b/paddle/operators/conv2d_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gemm_conv2d_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + conv2d, ops::GemmConv2DKernel); +REGISTER_OP_GPU_KERNEL( + conv2d_grad, ops::GemmConvGrad2DKernel); diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h new file mode 100644 index 0000000000000000000000000000000000000000..08b7df1dfead72fe8de8e89fa633c7bfc7bdbf33 --- /dev/null +++ b/paddle/operators/gemm_conv2d_op.h @@ -0,0 +1,231 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GemmConv2DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output->dims()[1]; + int output_height = output->dims()[2]; + int output_width = output->dims()[3]; + + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + // use col_shape in the im2col calculation + framework::DDim col_shape = {input_channels / groups, filter_height, + filter_width, output_height, output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_height * filter_width, + output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3]}; + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + + framework::DDim output_matrix_shape = {output_channels, + output_height * output_width}; + + auto* device_context = + const_cast(context.device_context_); + + // convolution operator: im2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, false, col_matrix, false, T(1.0), + &out_slice, T(0.0), device_context); + } + } + } +}; + +template +class GemmConvGrad2DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output_grad->dims()[1]; + int output_height = output_grad->dims()[2]; + int output_width = output_grad->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + // use col_shape in the im2col and col2im calculation + framework::DDim col_shape = {input_channels / groups, filter_height, + filter_width, output_height, output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_height * filter_width, + output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3]}; + framework::DDim output_matrix_shape = { + output_grad->dims()[1], + output_grad->dims()[2] * output_grad->dims()[3]}; + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + + auto* device_context = + const_cast(context.device_context_); + + // convolution backward input operator: gemm + col2im + // convolution backward weight operator: im2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, true, out_grad_slice, false, + T(1.0), &col_matrix, T(0.0), device_context); + + // col2im + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + } + } + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(out_grad_slice, false, col_matrix, true, + T(1.0), &filter_grad_slice, T(1.0), + device_context); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3142a60a1ae7d1874d02b81a4bb90c1fc50d07b9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -0,0 +1,94 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestConv2dOp(OpTest): + def setUp(self): + self.init_groups() + self.op_type = "conv2d" + batch_size = 2 + input_channels = 3 + input_height = 5 + input_width = 5 + output_channels = 6 + filter_height = 3 + filter_width = 3 + stride = 1 + padding = 0 + output_height = (input_height - filter_height + 2 * padding + ) / stride + 1 + output_width = (input_width - filter_width + 2 * padding) / stride + 1 + input = np.random.random((batch_size, input_channels, input_height, + input_width)).astype("float32") + + filter = np.random.random( + (output_channels, input_channels / self.groups, filter_height, + filter_width)).astype("float32") + output = np.ndarray( + (batch_size, output_channels, output_height, output_width)) + + self.inputs = {'Input': input, 'Filter': filter} + self.attrs = { + 'strides': [1, 1], + 'paddings': [0, 0], + 'groups': self.groups + } + + output_group_channels = output_channels / self.groups + input_group_channels = input_channels / self.groups + for batchid in xrange(batch_size): + for group in xrange(self.groups): + for outchannelid in range(group * output_group_channels, + (group + 1) * output_group_channels): + for rowid in xrange(output_height): + for colid in xrange(output_width): + start_h = (rowid * stride) - padding + start_w = (colid * stride) - padding + output_value = 0.0 + for inchannelid in range( + group * input_group_channels, + (group + 1) * input_group_channels): + for frowid in xrange(filter_height): + for fcolid in xrange(filter_width): + input_value = 0.0 + inrowid = start_h + frowid + incolid = start_w + fcolid + if ((inrowid >= 0 and + inrowid < input_height) and + (incolid >= 0 and + incolid < input_width)): + input_value = input[batchid][ + inchannelid][inrowid][incolid] + filter_value = filter[outchannelid][ + inchannelid % input_group_channels][ + frowid][fcolid] + output_value += input_value * filter_value + output[batchid][outchannelid][rowid][ + colid] = output_value + + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(set(['Input', 'Filter']), 'Output') + + def test_check_grad_no_filter(self): + self.check_grad(['Input'], 'Output', no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad(['Filter'], 'Output', no_grad_set=set(['Input'])) + + def init_groups(self): + self.groups = 1 + + +class TestWithGroup(TestConv2dOp): + def init_groups(self): + self.groups = 3 + + +if __name__ == '__main__': + unittest.main()