From 40fe0a8c47cb9613f3e2db462ca74886754f41fe Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 12 Sep 2017 18:08:32 +0800 Subject: [PATCH] Add backward of convolution. --- paddle/operators/conv_op.cc | 24 ++-- paddle/operators/gemm_conv_op.h | 105 ++++++++++++++++-- .../v2/framework/tests/test_conv2d_op.py | 38 +++++++ 3 files changed, 146 insertions(+), 21 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 873366394..107682848 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -28,9 +28,9 @@ class Conv2DOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto *in = ctx.Input("Input"); - auto *filter = ctx.Input("Filter"); - auto *out = ctx.Output("Output"); + auto in = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto out = ctx.Output("Output"); PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); PADDLE_ENFORCE_EQ(filter->dims().size(), 4, "Conv2DOp filter should be 4-D."); @@ -46,10 +46,9 @@ class Conv2DOp : public framework::OperatorWithKernel { } }; -class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker { +class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { public: - Conv2DOppMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", @@ -62,7 +61,7 @@ class Conv2DOppMaker : public framework::OpProtoAndCheckerMaker { "The format of the filter tensor is MCHW, where M is the number of " "output " "image channels, C is the number of input image channels, H and W is " - " height and width of filter."); + "height and width of filter."); AddOutput("Output", "The output tensor of convolution operator." "The format of output tensor is also NCHW."); @@ -80,14 +79,21 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(const framework::InferShapeContext &ctx) const override { + auto in = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto d_in = ctx.Output(framework::GradVarName("Input")); + auto d_filter = ctx.Output(framework::GradVarName("Filter")); + d_in->Resize(in->dims()); + d_filter->Resize(filter->dims()); + } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOppMaker, conv2d_grad, +REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, ops::Conv2DOpGrad); REGISTER_OP_CPU_KERNEL(conv2d, diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 16ea5ff74..6c7236219 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" @@ -31,12 +32,10 @@ class GemmConvKernel : public framework::OpKernel { Tensor* filter = const_cast(context.Input("Filter")); Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); - paddle::framework::Tensor col; - paddle::framework::Tensor in_slice; - paddle::framework::Tensor out_slice; std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + auto filter_dims = filter->dims(); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; @@ -50,6 +49,7 @@ class GemmConvKernel : public framework::OpKernel { im2col; framework::DDim col_shape = {input_channels, filter_height, filter_width, output_height, output_width}; + Tensor col; col.mutable_data(col_shape, context.GetPlace()); auto* device_context = @@ -67,22 +67,23 @@ class GemmConvKernel : public framework::OpKernel { output->dims()[1], output->dims()[2] * output->dims()[3]}; filter->Resize(filter_matrix_shape); - // convolution opperator: im2col + gemm + // convolution operator: im2col + gemm for (int i = 0; i < batch_size; i++) { // im2col - in_slice = input->Slice(i, i + 1); + Tensor in_slice = input->Slice(i, i + 1); in_slice.Resize(input_shape); col.Resize(col_shape); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], device_context); // gemm - out_slice = output->Slice(i, i + 1); + Tensor out_slice = output->Slice(i, i + 1); out_slice.Resize(output_matrix_shape); col.Resize(col_matrix_shape); math::matmul(*filter, false, col, false, T(1.0), &out_slice, T(0.0), device_context); } + filter->Resize(filter_dims); } }; @@ -90,12 +91,92 @@ template class GemmConvGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { -#if 0 - auto input = context.Input("Input"); - auto filter = context.Input("Filter"); - auto output = context.Output("Output"); - output->mutable_data(context.GetPlace()); -#endif + const Tensor* input = context.Input("Input"); + Tensor* filter = const_cast(context.Input("Filter")); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + input_grad->mutable_data(context.GetPlace()); + filter_grad->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + auto filter_dims = filter->dims(); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_height = filter->dims()[filter->dims().size() - 2]; + int filter_width = filter->dims()[filter->dims().size() - 1]; + int output_height = output_grad->dims()[2]; + int output_width = output_grad->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + Tensor col; + framework::DDim col_shape = {input_channels, filter_height, filter_width, + output_height, output_width}; + col.mutable_data(col_shape, context.GetPlace()); + + auto* device_context = + const_cast(context.device_context_); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3]}; + framework::DDim filter_matrix_shape = { + filter->dims()[0], + filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; + framework::DDim col_matrix_shape = { + input_channels * filter_height * filter_width, + output_height * output_width}; + framework::DDim output_matrix_shape = { + output_grad->dims()[1], + output_grad->dims()[2] * output_grad->dims()[3]}; + filter->Resize(filter_matrix_shape); + filter_grad->Resize(filter_matrix_shape); + + auto t1 = framework::EigenVector::Flatten(*filter_grad); + t1.device(context.GetEigenDevice()) = t1.constant(static_cast(0)); + auto t2 = framework::EigenVector::Flatten(*input_grad); + t2.device(context.GetEigenDevice()) = t2.constant(static_cast(0)); + + // convolution backward input operator: gemm + col2im + // convolution backward weight operator: im2col + gemm + for (int i = 0; i < batch_size; i++) { + // gemm + Tensor out_slice = output_grad->Slice(i, i + 1); + out_slice.Resize(output_matrix_shape); + col.Resize(col_matrix_shape); + math::matmul(*filter, true, out_slice, false, T(1.0), &col, + T(0.0), device_context); + + // col2im + Tensor in_grad_slice = input_grad->Slice(i, i + 1); + in_grad_slice.Resize(input_shape); + col.Resize(col_shape); + col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], + paddings[1], device_context); + + // im2col + Tensor in_slice = input->Slice(i, i + 1); + in_slice.Resize(input_shape); + col.Resize(col_shape); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + col.Resize(col_matrix_shape); + math::matmul(out_slice, false, col, true, T(1.0), filter_grad, + T(1.0), device_context); + } + filter->Resize(filter_dims); + filter_grad->Resize(filter_dims); } }; diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index d2015d0ce..43f328ca0 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -2,6 +2,7 @@ import unittest import numpy as np from gradient_checker import GradientChecker, create_op from op_test_util import OpTestMeta +from paddle.v2.framework.op import Operator class TestConv2dOp(unittest.TestCase): @@ -58,5 +59,42 @@ class TestConv2dOp(unittest.TestCase): self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} +class TestConv2dGradOp(GradientChecker): + def setUp(self): + batch_size = 2 + input_channels = 3 + input_height = 5 + input_width = 5 + output_channels = 6 + filter_height = 3 + filter_width = 3 + stride = 1 + padding = 0 + output_height = (input_height - filter_height + 2 * padding + ) / stride + 1 + output_width = (input_width - filter_width + 2 * padding) / stride + 1 + input = np.random.random((batch_size, input_channels, input_height, + input_width)).astype("float32") + filter = np.random.random( + (output_channels, input_channels, filter_height, + filter_width)).astype("float32") + + self.inputs = {'Input': input, 'Filter': filter} + self.op = Operator( + "conv2d", + Input='Input', + Filter='Filter', + Output='Output', + strides=[1, 1], + paddings=[0, 0]) + + def test_compare_grad(self): + self.compare_grad(self.op, self.inputs) + + def test_check_grad(self): + self.check_grad(self.op, self.inputs, + set(['Input', 'Filter']), 'Output') + + if __name__ == '__main__': unittest.main() -- GitLab