From b76343c3b703918543604f5936dfac390a5bb8a6 Mon Sep 17 00:00:00 2001 From: lvmengsi Date: Tue, 17 Sep 2019 11:05:39 +0800 Subject: [PATCH] cpu Conv double grad (#19672) * cpu conv_grad_grad --- paddle/fluid/operators/conv_cudnn_op.cu.cc | 5 + paddle/fluid/operators/conv_op.cc | 55 ++++- paddle/fluid/operators/conv_op.h | 213 ++++++++++++++++++ .../tests/unittests/test_conv_nn_grad.py | 129 +++++++++++ .../fluid/tests/unittests/test_nn_grad.py | 24 -- 5 files changed, 399 insertions(+), 27 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_conv_nn_grad.py diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index ecedb7d70ff..1c20cf9cc20 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -510,3 +510,8 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel); +REGISTER_OP_KERNEL( + conv3d_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index cdecd816524..1cfdf7da86a 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -565,6 +565,40 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker { } }; +/* + * Inputs: I, W, dO, ddI, ddW + * Outputs: ddO, dW, dI + */ +class Conv3DDoubleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType(this->ForwardOpType() + "_grad"); + // I, W, dO, ddI, ddW + op->SetInput("Input", Input("Input")); + op->SetInput("Filter", Input("Filter")); + op->SetInput("DOutput", Input(framework::GradVarName("Output"))); + op->SetInput("DDInput", OutputGrad(framework::GradVarName("Input"))); + op->SetInput("DDFilter", OutputGrad(framework::GradVarName("Filter"))); + + auto ddx = OutputGrad(framework::GradVarName("Input")); + auto ddw = OutputGrad(framework::GradVarName("Filter")); + std::vector empty_str = {}; + + op->SetOutput( + "DDOutput", + ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output"))); + op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter")); + op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input")); + + op->SetAttrMap(Attrs()); + + return std::unique_ptr(op); + } +}; + void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { auto x_dims = ctx->GetInputDim("Input"); auto w_dims = ctx->GetInputDim("Filter"); @@ -592,8 +626,14 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; - } else { - PADDLE_THROW("Now ConvDoubleGrad only supports cuDNN."); + } +#endif +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + customized_type_value = kConvMKLDNNFP32; } #endif auto type = framework::OpKernelType(ctx.Input("Input")->type(), @@ -637,7 +677,8 @@ REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, ops::ConvOpInferVarType, ops::Conv3DGradMaker); -REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad); +REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad, ops::Conv3DDoubleGradMaker); +REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad); // depthwise conv kernel // TODO(xingzhaolong): neon kernel for mobile @@ -658,6 +699,10 @@ REGISTER_OP_CPU_KERNEL( conv2d_grad, ops::GemmConvGradKernel, ops::GemmConvGradKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_grad_grad, + ops::GemmConvDoubleGradKernel, + ops::GemmConvDoubleGradKernel); REGISTER_OP_CPU_KERNEL( conv3d, ops::GemmConvKernel, @@ -666,3 +711,7 @@ REGISTER_OP_CPU_KERNEL( conv3d_grad, ops::GemmConvGradKernel, ops::GemmConvGradKernel); +REGISTER_OP_CPU_KERNEL( + conv3d_grad_grad, + ops::GemmConvDoubleGradKernel, + ops::GemmConvDoubleGradKernel); diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 4df47ef261e..aa621529b52 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" @@ -393,6 +394,218 @@ class GemmConvGradKernel : public framework::OpKernel { } }; +template +class GemmConvDoubleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, + "It must use CPUPlace."); + const Tensor* X = ctx.Input("Input"); + const Tensor* dY = ctx.Input("DOutput"); + const Tensor* ddX = ctx.Input("DDInput"); + const Tensor* ddW_in = ctx.Input("DDFilter"); + + Tensor* ddY = ctx.Output("DDOutput"); + Tensor* dW = ctx.Output("DFilter"); + Tensor* dX = ctx.Output("DInput"); + Tensor W = detail::Ref(ctx.Input("Filter"), + "Cannot find input Filter(%s) in scope)", + ctx.Inputs("Filter")[0]); + + if (!ddY && !dW && !dX) return; + int groups = ctx.Attr("groups"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + const int batch_size = static_cast(X->dims()[0]); + std::vector filter_shape_vec(framework::vectorize(W.dims())); + std::vector output_shape_vec(framework::vectorize(dY->dims())); + + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + // col_shape [in_channel/group, kh, kw, oh, ow] + col_shape_vec[0] = X->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + data_dim + 1] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + // col_matrix_shape [in_channel/group * kh * kw, oh * ow] + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + // input_shape [Cin, H, W] + framework::DDim input_shape = + framework::slice_ddim(X->dims(), 1, X->dims().size()); + // filter_matrix_shape [Cout, Cin * kh * kw] + framework::DDim filter_matrix_shape = {W.dims()[0], + W.numel() / W.dims()[0]}; + + W.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + dY->dims()[1], dY->numel() / (dY->dims()[0] * dY->dims()[1])}; + int in_step = static_cast(X->dims()[1]) / groups; + int out_step = static_cast(dY->dims()[1]) / groups; + + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col = ctx.AllocateTmpTensor(col_shape, dev_ctx); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + math::SetConstant set_zero; + auto blas = math::GetBlas(dev_ctx); + + // dx convolution double grad: gemm + col2im(col2vol) + // dx = ddw * dy ==> dx(N, Cin, H, W), ddw(Cout, Cin, kh, kw), dy(N, Cout, + // oH, oW) + if (dX && ddW_in) { + Tensor ddW; + ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); + + dX->mutable_data(ctx.GetPlace()); + // if is_expand is false, the operation of set_zero is unnecessary + // because math::matmul will reset dx + if (is_expand) { + set_zero(dev_ctx, dX, static_cast(0)); + } + math::Col2VolFunctor col2vol; + math::Col2ImFunctor col2im; + + for (int i = 0; i < batch_size; i++) { + Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor dx_batch = dX->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step); + Tensor dx_slice = dx_batch.Slice(g * in_step, (g + 1) * in_step); + if (!is_expand) { + col_matrix.ShareDataWith(dx_slice); + col_matrix.Resize(col_matrix_shape); + } + blas.MatMul(ddw_slice, true, dy_slice, false, T(1.0), &col_matrix, + T(0.0)); + + if (is_expand && data_dim == 2U) { + col2im(dev_ctx, col, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &dx_slice); + } else if (is_expand && data_dim == 3U) { + col2vol(dev_ctx, col, dilations, strides, paddings, &dx_slice); + } + } + } + } + + // dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout, + // oH, oW) + // dw convolution double grad: im2col(vol2col) + gemm + if (dW) { + dW->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, dW, static_cast(0)); + Tensor dW_arr = *dW; + dW_arr.Resize(filter_matrix_shape); + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + for (int i = 0; i < batch_size; ++i) { + Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; ++g) { + // im2col + Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step); + if (!is_expand) { + col.ShareDataWith(ddx_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + im2col(dev_ctx, ddx_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); + } + + Tensor dw_slice = dW_arr.Slice(g * out_step, (g + 1) * out_step); + blas.MatMul(dy_slice, false, col_matrix, true, T(1.0), &dw_slice, + T(1.0)); + } + } + } + + // ddy = w * ddx + x * ddw ==> ddy(N, Cout, oH, oW), x/ddx(N, Cin, H, W), + // w/ddw(Cout, Cin, kh, kw) + // ddy convolution double grad: im2col(vol2col) + gemm + if (ddY) { + ddY->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, ddY, static_cast(0)); + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + for (int i = 0; i < batch_size; ++i) { + Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape); + Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape); + Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; ++g) { + Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step); + Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step); + if (!is_expand) { + col.ShareDataWith(ddx_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(dev_ctx, ddx_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); + blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, + T(0.0)); + + if (ddW_in) { + Tensor ddW; + ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); + + if (!is_expand) { + col.ShareDataWith(x_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(dev_ctx, x_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(dev_ctx, x_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step); + blas.MatMul(ddw_slice, false, col_matrix, false, T(1.0), &ddy_slice, + T(1.0)); + } + } + } + } + } +}; + template class DepthwiseConvKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/tests/unittests/test_conv_nn_grad.py b/python/paddle/fluid/tests/unittests/test_conv_nn_grad.py new file mode 100644 index 00000000000..81f902d529e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_conv_nn_grad.py @@ -0,0 +1,129 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +import gradient_checker + +from decorator_helper import prog_scope + + +class TestConvDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 4, 7, 8] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d(x, 4, 1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConvDoubleGradCheckTest1(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 3, 4, 5] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d(x, 4, 1, padding=1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv3DDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 4, 3, 4, 2] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv3d(x, 4, 1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv3DDoubleGradCheckTest1(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 4, 5, 3, 2] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv3d(x, 4, 1, padding=1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index ae1e85c483e..8bbd9443230 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -43,30 +43,6 @@ class TestMulGradCheck(unittest.TestCase): self.func(p) -class TestConvDoubleGradCheck(unittest.TestCase): - @prog_scope() - def func(self, place): - shape = [2, 4, 14, 16] - eps = 0.005 - dtype = np.float64 - x = layers.data('x', shape, False, dtype) - y = layers.conv2d(x, 4, 1, bias_attr=False) - x_arr = np.random.uniform(-1, 1, shape).astype(dtype) - - w = fluid.default_main_program().global_block().all_parameters() - w_arr = [] - for p in w: - w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) - gradient_checker.double_grad_check( - [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) - - def test_grad(self): - if core.is_compiled_with_cuda(): - places = [fluid.CUDAPlace(0)] - for p in places: - self.func(p) - - class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): -- GitLab