From e32c9888f5c0160f19ef2faa4b6bdddb16bde303 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Fri, 10 May 2019 10:29:41 +0800 Subject: [PATCH] Double backward of conv2d. (#17211) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add conv2d_grad_grad_op * Extracte the cuDNN conv algo searching code in conv_cudnn_helper.h. - Now use it in conv2d_grad_grad. - Will simply the searching code in conv2d and conv2d_grad in next PR. * Enhance and fix bug in unit testing of gradient_checker. * Support to fetch empty variables,return None in Python. --- paddle/fluid/framework/operator.h | 7 +- paddle/fluid/operators/activation_op.cc | 6 +- .../fluid/operators/controlflow/fetch_op.cc | 8 +- paddle/fluid/operators/conv_cudnn_helper.h | 271 ++++++++++++++++ paddle/fluid/operators/conv_cudnn_op.cu.cc | 300 ++++++++++++++---- paddle/fluid/operators/conv_op.cc | 89 +++++- paddle/fluid/operators/conv_op.h | 11 + paddle/fluid/platform/cudnn_desc.h | 93 +++++- paddle/fluid/pybind/tensor_py.h | 3 + python/paddle/fluid/executor.py | 5 +- .../fluid/tests/unittests/gradient_checker.py | 46 ++- .../fluid/tests/unittests/test_nn_grad.py | 29 +- 12 files changed, 791 insertions(+), 77 deletions(-) create mode 100644 paddle/fluid/operators/conv_cudnn_helper.h diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d94326563fa..4bc94b4c5cd 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -386,9 +386,10 @@ class ExecutionContext { template T& GetKernelConfig(int idx) const { - PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx, - "%s selected kernel doesn't have kernel config %lu <= %d", - op_.Type().c_str(), kernel_configs_->size(), idx); + PADDLE_ENFORCE( + kernel_configs_ && kernel_configs_->size() > static_cast(idx), + "%s selected kernel doesn't have kernel config %lu <= %d", + op_.Type().c_str(), kernel_configs_->size(), idx); return *boost::get>(kernel_configs_->at(idx)); } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 2100264823b..f93474a122f 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -644,6 +644,7 @@ class LeakyReluDoubleGrad : public framework::OperatorWithKernel { // // ReluGrad: dx = dy if y >= 0 else 0 // ReluGradGrad: ddy = ddx if y >= 0 else 0 +// dy = 0 // class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker { public: @@ -655,11 +656,12 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker { op->SetType("relu_grad_grad"); // input1: Out op->SetInput("Out", Input("Out")); - // X@GRAD@GRAD: ddx + // input2: ddx op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); op->SetAttrMap(Attrs()); - // Out@GRAD@GRAD: ddy + // output1: ddy op->SetOutput("DOut", InputGrad("Out")); + // output2: ddy op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); return std::unique_ptr<::paddle::framework::OpDesc>(op); } diff --git a/paddle/fluid/operators/controlflow/fetch_op.cc b/paddle/fluid/operators/controlflow/fetch_op.cc index c197b45e819..85d36c5c3af 100644 --- a/paddle/fluid/operators/controlflow/fetch_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_op.cc @@ -54,7 +54,13 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? - TensorCopySync(src_item, platform::CPUPlace(), &dst_item); + if (src_item.IsInitialized() && src_item.numel() > 0) { + TensorCopySync(src_item, platform::CPUPlace(), &dst_item); + } else { + // Not copy, if the src tensor is empty. + dst_item.clear(); + dst_item.Resize({0}); + } dst_item.set_lod(src_item.lod()); VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name; diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h new file mode 100644 index 00000000000..c2ad468fa60 --- /dev/null +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -0,0 +1,271 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/operator_kernel_configs.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" +#include "paddle/fluid/platform/cudnn_desc.h" + +namespace paddle { +namespace operators { + +using framework::AlgorithmsCache; + +struct ConvArgs { + cudnnHandle_t handle; + platform::TensorDescriptor idesc, odesc; + platform::FilterDescriptor wdesc; + platform::ConvolutionDescriptor cdesc; + const framework::Tensor *x, *w, *o; + + // strides + std::vector s; + // paddings + std::vector p; + // dilations + std::vector d; + + ConvArgs(const framework::Tensor* x, const framework::Tensor* w, + const framework::Tensor* o, const std::vector s, + const std::vector p, const std::vector d) + : x(x), w(w), o(o), s(s), p(p), d(d) {} +}; + +template +struct SearchAlgorithm {}; + +template <> +struct SearchAlgorithm { + using perf_t = cudnnConvolutionFwdAlgoPerf_t; + using algo_t = cudnnConvolutionFwdAlgo_t; + + template + static algo_t Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, int algo_cache_id, + const framework::ExecutionContext& ctx) { + auto dtype = platform::CudnnDataType::type; + bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); + + size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; + + algo_t algo; + if (!exhaustive) { + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), + args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + VLOG(3) << "choose algo " << algo; + } else { + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(algo_cache_id); + auto& dev_ctx = + ctx.template device_context(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { + int returned_algo_count; + std::array perf_stat; + + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + CUDNN_ENFORCE( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + args.handle, args.idesc.desc(), args.x->data(), + args.wdesc.desc(), args.w->data(), args.cdesc.desc(), + args.odesc.desc(), const_cast(args.o->data()), + kNUM_CUDNN_FWD_ALGS, &returned_algo_count, + perf_stat.data(), cudnn_workspace_ptr, + workspace_size_limit)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); + + VLOG(3) << "FwdAlgo Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + return perf_stat[0].algo; + }); + } + VLOG(3) << "choose algo " << algo; + return algo; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + size_t workspace_size = 0; + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), + args.odesc.desc(), algo, &workspace_size)); + return workspace_size; + } +}; + +template <> +struct SearchAlgorithm { + using perf_t = cudnnConvolutionBwdDataAlgoPerf_t; + using algo_t = cudnnConvolutionBwdDataAlgo_t; + + template + static algo_t Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, int algo_cache_id, + const framework::ExecutionContext& ctx) { + auto dtype = platform::CudnnDataType::type; + bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); + + size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; + + algo_t algo; + if (!exhaustive && !deterministic) { + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + args.handle, args.wdesc.desc(), args.idesc.desc(), args.cdesc.desc(), + args.odesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + } else if (deterministic) { + return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } else { + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(algo_cache_id); + auto& dev_ctx = + ctx.template device_context(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { + int returned_algo_count; + std::array perf_stat; + + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, args.wdesc.desc(), args.w->data(), + args.odesc.desc(), args.o->data(), + args.cdesc.desc(), args.idesc.desc(), + const_cast(args.x->data()), + kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, + perf_stat.data(), cudnn_workspace_ptr, + workspace_size_limit)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); + + VLOG(3) << "BwdDataAlgo Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + + return perf_stat[0].algo; + }); + } + VLOG(3) << "choose algo " << algo; + return algo; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + size_t workspace_size = 0; + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + args.handle, args.wdesc.desc(), args.idesc.desc(), + args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size)); + return workspace_size; + } +}; + +template <> +struct SearchAlgorithm { + using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; + using algo_t = cudnnConvolutionBwdFilterAlgo_t; + + template + static algo_t Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, int algo_cache_id, + const framework::ExecutionContext& ctx) { + auto dtype = platform::CudnnDataType::type; + bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); + + size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; + + algo_t algo; + if (!exhaustive && !deterministic) { + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + } else if (deterministic) { + return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } else { + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(algo_cache_id); + auto& dev_ctx = + ctx.template device_context(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + auto x_dims = framework::vectorize(args.x->dims()); + auto w_dims = framework::vectorize(args.w->dims()); + + algo = algo_cache.GetAlgorithm( + x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { + int returned_algo_count; + std::array perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, args.idesc.desc(), args.x->data(), + args.odesc.desc(), args.o->data(), + args.cdesc.desc(), args.wdesc.desc(), + const_cast(args.w->data()), + kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, + perf_stat.data(), cudnn_workspace_ptr, + workspace_size_limit)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); + + VLOG(3) << "BwdFilterAlgo Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + return perf_stat[0].algo; + }); + } + VLOG(3) << "choose algo " << algo; + return algo; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { + size_t workspace_size = 0; + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size)); + return workspace_size; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 9a545160a10..158d6ced274 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/assert.h" @@ -46,6 +47,23 @@ template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; using framework::AlgorithmsCache; +static inline void GetNCDHW(const framework::DDim& dims, + const DataLayout& layout, int* N, int* C, int* D, + int* H, int* W) { + *N = dims[0]; + *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; + int i = layout == DataLayout::kNCHW ? 0 : 1; + if (dims.size() == 5) { + *D = dims[2 - i]; + *H = dims[3 - i]; + *W = dims[4 - i]; + } else { + *D = 1; + *H = dims[2 - i]; + *W = dims[3 - i]; + } +} + template class CUDNNConvOpKernel : public framework::OpKernel { public: @@ -99,33 +117,13 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( layout, framework::vectorize2int(filter->dims()), groups); - int input_channels = input->dims()[1]; - int input_height, input_width, input_depth; - if (input->dims().size() == 5) { - input_depth = input->dims()[2]; - input_height = input->dims()[3]; - input_width = input->dims()[4]; - } else { // dim size is enforced in InferShape - input_depth = 1; - input_height = input->dims()[2]; - input_width = input->dims()[3]; - } - int output_channels = filter->dims()[0]; - int output_height, output_width, output_depth; - if (output->dims().size() == 5) { - output_depth = output->dims()[2]; - output_height = output->dims()[3]; - output_width = output->dims()[4]; - } else { - output_depth = 1; - output_height = output->dims()[2]; - output_width = output->dims()[3]; - } + int i_n, i_c, i_d, i_h, i_w; + GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(output->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); - int group_offset_in = - input_channels / groups * input_height * input_width * input_depth; - int group_offset_out = - output_channels / groups * output_height * output_width * output_depth; + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = filter->numel() / groups; // ------------------- cudnn conv workspace --------------------- size_t workspace_size_in_bytes; // final workspace to allocate. @@ -164,6 +162,9 @@ class CUDNNConvOpKernel : public framework::OpKernel { auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto x_dims = framework::vectorize(input->dims()); auto f_dims = framework::vectorize(filter->dims()); + + // TODO(dangqingqing) simplify the following code by SearchAlgorithm in + // conv_cudnn_helper.h if ((!exhaustive_search) && (!half_float)) { CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, @@ -315,34 +316,14 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } #endif - int input_channels = input->dims()[1]; - int input_height, input_width, input_depth; - if (input->dims().size() == 5) { - input_depth = input->dims()[2]; - input_height = input->dims()[3]; - input_width = input->dims()[4]; - } else { // dim size is enforced in InferShape - input_depth = 1; - input_height = input->dims()[2]; - input_width = input->dims()[3]; - } + int i_n, i_c, i_d, i_h, i_w; + GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(output_grad->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, + &o_w); - int output_grad_channels = filter->dims()[0]; - int output_grad_height, output_grad_width, output_grad_depth; - if (input->dims().size() == 5) { - output_grad_depth = output_grad->dims()[2]; - output_grad_height = output_grad->dims()[3]; - output_grad_width = output_grad->dims()[4]; - } else { - output_grad_depth = 1; - output_grad_height = output_grad->dims()[2]; - output_grad_width = output_grad->dims()[3]; - } - - int group_offset_in = - input_channels / groups * input_height * input_width * input_depth; - int group_offset_out = output_grad_channels / groups * output_grad_height * - output_grad_width * output_grad_depth; + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = filter->numel() / groups; // ------------------- cudnn backward algorithm --------------------- cudnnConvolutionBwdDataAlgo_t data_algo; @@ -367,6 +348,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); } + // TODO(dangqingqing) simplify the following code by SearchAlgorithm in + // conv_cudnn_helper.h auto x_dims = framework::vectorize(input->dims()); auto f_dims = framework::vectorize(filter->dims()); auto handle = dev_ctx.cudnn_handle(); @@ -512,6 +495,212 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } }; +/* + * Inputs: I, W, dO, ddI, ddW + * Outputs: ddO, dW, dI + * ddo = conv(ddI, W) + conv(I, ddW) + * dW = conv_bp_filter(ddI, dO) + * dI = conv_bp_data(ddW, dO) + */ +template +class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto X = ctx.Input("Input"); + auto W = ctx.Input("Filter"); + auto dO = ctx.Input("DOutput"); + auto ddX = ctx.Input("DDInput"); + auto ddW = ctx.Input("DDFilter"); + + auto ddO = ctx.Output("DDOutput"); + auto dW = ctx.Output("DFilter"); + auto dX = ctx.Output("DInput"); + + const T* x = X->data(); + const T* dy = dO->data(); + const T* w = W->data(); + + const T* ddx = nullptr; + const T* ddw = nullptr; + T *dw, *dx, *ddy; + dw = dx = ddy = nullptr; + + const std::vector& strides = ctx.Attr>("strides"); + const std::vector& paddings = ctx.Attr>("paddings"); + const std::vector& dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + bool deterministic = FLAGS_cudnn_deterministic; + if (exhaustive_search && deterministic) { + PADDLE_THROW( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time."); + } + + int iwo_group = groups; + int c_group = 1; +#if CUDNN_VERSION_MIN(7, 0, 1) + iwo_group = 1; + c_group = groups; +#endif + auto dtype = platform::CudnnDataType::type; + + auto handle = dev_ctx.cudnn_handle(); + + ConvArgs args1{ddX, W, ddO, strides, paddings, dilations}; + ConvArgs args2{X, ddW, ddO, strides, paddings, dilations}; + ConvArgs args3{ddX, dW, dO, strides, paddings, dilations}; + ConvArgs args4{dX, ddW, dO, strides, paddings, dilations}; + + cudnnConvolutionFwdAlgo_t fwd_algo1 = + static_cast(0); + cudnnConvolutionFwdAlgo_t fwd_algo2 = + static_cast(0); + cudnnConvolutionBwdDataAlgo_t data_algo = + static_cast(0); + cudnnConvolutionBwdFilterAlgo_t filter_algo = + static_cast(0); + + auto layout = GetCudnnTensorFormat(DataLayout::kNCHW); + + // ddo = conv(ddI, W) + conv(I, ddW) + size_t workspace_size = 0; + if (ddO) { + ddy = ddO->mutable_data(ctx.GetPlace()); + args1.handle = handle; + args1.idesc.set(*ddX, iwo_group); + args1.wdesc.set(*W, layout, iwo_group); + args1.odesc.set(*ddO, iwo_group); + args1.cdesc.set(dtype, paddings, strides, dilations, c_group); + + using search1 = SearchAlgorithm; + fwd_algo1 = search1::Find(args1, exhaustive_search, false, 0, ctx); + workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); + + if (ddW) { + ddw = ddW->data(); + args2.handle = handle; + args2.idesc.set(*X, iwo_group); + args2.wdesc.set(*ddW, layout, iwo_group); + args2.odesc.set(*ddO, iwo_group); + args2.cdesc.set(dtype, paddings, strides, dilations, c_group); + + using search2 = SearchAlgorithm; + fwd_algo2 = search2::Find(args2, exhaustive_search, false, 0, ctx); + workspace_size = std::max(workspace_size, + search2::GetWorkspaceSize(args2, fwd_algo2)); + } + } + + if (dW) { + dw = dW->mutable_data(ctx.GetPlace()); + args3.handle = handle; + args3.idesc.set(*ddX, iwo_group); + args3.wdesc.set(*dW, layout, iwo_group); + args3.odesc.set(*dO, iwo_group); + args3.cdesc.set(dtype, paddings, strides, dilations, c_group); + + using search3 = SearchAlgorithm; + filter_algo = + search3::Find(args3, exhaustive_search, deterministic, 1, ctx); + workspace_size = std::max(workspace_size, + search3::GetWorkspaceSize(args3, filter_algo)); + } + + if (ddW && dX) { + dx = dX->mutable_data(ctx.GetPlace()); + args4.handle = handle; + args4.idesc.set(*dX, iwo_group); + args4.wdesc.set(*ddW, layout, iwo_group); + args4.odesc.set(*dO, iwo_group); + args4.cdesc.set(dtype, paddings, strides, dilations, c_group); + + using search4 = SearchAlgorithm; + data_algo = + search4::Find(args4, exhaustive_search, deterministic, 2, ctx); + workspace_size = + std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); + } + + int i_n, i_c, i_d, i_h, i_w; + GetNCDHW(X->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(dO->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); + + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; + int group_offset_filter = W->numel() / groups; + + ScalingParamType alpha = 1.0f, beta = 0.0f; + auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); + + if (ddO) { + ddx = ddX->data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in, + args1.wdesc.desc(), w + i * group_offset_filter, + args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size, + &beta, args1.odesc.desc(), ddy + i * group_offset_out)); + }, + workspace_size); + } + if (ddW) { + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, args2.idesc.desc(), x + i * group_offset_in, + args2.wdesc.desc(), ddw + i * group_offset_filter, + args2.cdesc.desc(), fwd_algo2, workspace_ptr, + workspace_size, &alpha, args2.odesc.desc(), + ddy + i * group_offset_out)); + }, + workspace_size); + } + } + } + + if (dW) { + ddx = ddX->data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in, + args3.odesc.desc(), dy + i * group_offset_out, + args3.cdesc.desc(), filter_algo, workspace_ptr, + workspace_size, &beta, args3.wdesc.desc(), + dw + i * group_offset_filter)); + }, + workspace_size); + } + } + + if (dX && ddW) { + ddw = ddW->data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, args4.wdesc.desc(), + ddw + i * group_offset_filter, args4.odesc.desc(), + dy + i * group_offset_out, args4.cdesc.desc(), data_algo, + workspace_ptr, workspace_size, &beta, args4.idesc.desc(), + dx + i * group_offset_in)); + }, + workspace_size); + } + } + } +}; + } // namespace operators } // namespace paddle @@ -524,6 +713,11 @@ REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel); +REGISTER_OP_KERNEL( + conv2d_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel); REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 1bacc54b61d..5b923f8a5eb 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -506,13 +506,100 @@ class Conv3DGradMaker : public framework::SingleGradOpDescMaker { } }; +/* + * Inputs: I, W, dO, ddI, ddW + * Outputs: ddO, dW, dI + */ +class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType(this->ForwardOpType() + "_grad"); + // I, W, dO, ddI, ddW + op->SetInput("Input", Input("Input")); + op->SetInput("Filter", Input("Filter")); + op->SetInput("DOutput", Input(framework::GradVarName("Output"))); + op->SetInput("DDInput", OutputGrad(framework::GradVarName("Input"))); + op->SetInput("DDFilter", OutputGrad(framework::GradVarName("Filter"))); + + // ddO, dI, dW + // Unlike grad op, double grad op does not use name@GRAD@GRAD + // as key of ops' inputs and outputs. + op->SetOutput("DDOutput", InputGrad(framework::GradVarName("Output"))); + op->SetOutput("DFilter", InputGrad("Filter")); + op->SetOutput("DInput", InputGrad("Input")); + op->SetAttrMap(Attrs()); + + return std::unique_ptr(op); + } +}; + +void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { + auto x_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("Filter"); + auto do_dims = ctx->GetInputDim("DOutput"); + + if (ctx->HasOutput("DDOutput")) { + ctx->SetOutputDim("DDOutput", do_dims); + } + if (ctx->HasOutput("DFilter")) { + ctx->SetOutputDim("DFilter", w_dims); + } + if (ctx->HasOutput("DInput")) { + ctx->SetOutputDim("DInput", x_dims); + } +} + +framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + +#ifdef PADDLE_WITH_CUDA + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; + } else { + PADDLE_THROW("Now ConvDoubleGrad only supports cuDNN."); + } +#endif + auto type = framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_, + customized_type_value); +#ifdef PADDLE_WITH_CUDA + if (library_ == framework::LibraryType::kCUDNN) { + std::vector& configs = kernel_configs_map_[type]; + if (configs.empty()) { + std::shared_ptr> p0( + new framework::AlgorithmsCache()); + configs.push_back(p0); + + std::shared_ptr< + framework::AlgorithmsCache> + p1(new framework::AlgorithmsCache()); + configs.push_back(p1); + + std::shared_ptr> + p2(new framework::AlgorithmsCache()); + configs.push_back(p2); + } + } +#endif + return type; +} + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, ops::ConvOpInferVarType, ops::Conv2DGradMaker); -REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad); +REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad, ops::Conv2DDoubleGradMaker); +REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad); // depthwise convolution op REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 797c6651659..4df47ef261e 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -107,6 +108,16 @@ class ConvOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override; }; +class ConvOpDoubleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + template class GemmConvKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/platform/cudnn_desc.h b/paddle/fluid/platform/cudnn_desc.h index 1062b403f28..4ed51acb587 100644 --- a/paddle/fluid/platform/cudnn_desc.h +++ b/paddle/fluid/platform/cudnn_desc.h @@ -29,13 +29,14 @@ namespace platform { using framework::Tensor; template -cudnnDataType_t ToCudnnDataType(const T& t) { +inline cudnnDataType_t ToCudnnDataType(const T& t) { auto type = framework::ToDataType(t); return ToCudnnDataType(type); } template <> -cudnnDataType_t ToCudnnDataType(const framework::proto::VarType::Type& t) { +inline cudnnDataType_t ToCudnnDataType( + const framework::proto::VarType::Type& t) { cudnnDataType_t type = CUDNN_DATA_FLOAT; switch (t) { case framework::proto::VarType::FP16: @@ -59,14 +60,14 @@ class ActivationDescriptor { struct Deleter { void operator()(T* t) { if (t != nullptr) { - PADDLE_ENFORCE(dynload::cudnnDestroyActivationDescriptor(t)); + CUDNN_ENFORCE(dynload::cudnnDestroyActivationDescriptor(t)); t = nullptr; } } }; ActivationDescriptor() { T* raw_ptr; - PADDLE_ENFORCE(dynload::cudnnCreateActivationDescriptor(&raw_ptr)); + CUDNN_ENFORCE(dynload::cudnnCreateActivationDescriptor(&raw_ptr)); desc_.reset(raw_ptr); } template @@ -88,14 +89,14 @@ class TensorDescriptor { struct Deleter { void operator()(T* t) { if (t != nullptr) { - PADDLE_ENFORCE(dynload::cudnnDestroyTensorDescriptor(t)); + CUDNN_ENFORCE(dynload::cudnnDestroyTensorDescriptor(t)); t = nullptr; } } }; TensorDescriptor() { T* raw_ptr; - PADDLE_ENFORCE(dynload::cudnnCreateTensorDescriptor(&raw_ptr)); + CUDNN_ENFORCE(dynload::cudnnCreateTensorDescriptor(&raw_ptr)); desc_.reset(raw_ptr); } T* desc() { return desc_.get(); } @@ -111,7 +112,7 @@ class TensorDescriptor { if (groups > 1) { dims_with_group[1] = dims_with_group[1] / groups; } - PADDLE_ENFORCE(dynload::cudnnSetTensorNdDescriptor( + CUDNN_ENFORCE(dynload::cudnnSetTensorNdDescriptor( desc_.get(), ToCudnnDataType(tensor.type()), dims_with_group.size(), dims_with_group.data(), strides.data())); } @@ -120,5 +121,83 @@ class TensorDescriptor { std::unique_ptr desc_; }; +class FilterDescriptor { + public: + using T = cudnnFilterStruct; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + CUDNN_ENFORCE(dynload::cudnnDestroyFilterDescriptor(t)); + t = nullptr; + } + } + }; + FilterDescriptor() { + T* raw_ptr; + CUDNN_ENFORCE(dynload::cudnnCreateFilterDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); + } + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } + + void set(const Tensor& tensor, const cudnnTensorFormat_t format, + const int groups = 1) { + auto dims = framework::vectorize2int(tensor.dims()); + if (groups > 1) { + dims[1] = dims[1] / groups; + } + CUDNN_ENFORCE(dynload::cudnnSetFilterNdDescriptor( + desc_.get(), ToCudnnDataType(tensor.type()), format, dims.size(), + dims.data())); + } + + private: + std::unique_ptr desc_; +}; + +class ConvolutionDescriptor { + public: + using T = cudnnConvolutionStruct; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + CUDNN_ENFORCE(dynload::cudnnDestroyConvolutionDescriptor(t)); + t = nullptr; + } + } + }; + ConvolutionDescriptor() { + T* raw_ptr; + CUDNN_ENFORCE(dynload::cudnnCreateConvolutionDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); + } + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } + + void set(cudnnDataType_t dtype, const std::vector& pads, + const std::vector& strides, const std::vector& dilations, + const int groups = 1) { + cudnnDataType_t compute_type = + (dtype == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT; + T* desc = desc_.get(); + CUDNN_ENFORCE(dynload::cudnnSetConvolutionNdDescriptor( + desc, pads.size(), pads.data(), strides.data(), dilations.data(), + CUDNN_CROSS_CORRELATION, compute_type)); + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + desc, CUDNN_DEFAULT_MATH)); +#if CUDNN_VERSION_MIN(7, 0, 1) + CUDNN_ENFORCE( + platform::dynload::cudnnSetConvolutionGroupCount(desc, groups)); + if (dtype == CUDNN_DATA_HALF) { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + desc, CUDNN_TENSOR_OP_MATH)); + } +#endif + } + + private: + std::unique_ptr desc_; +}; + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index cec21f40073..08e43bf24ce 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -472,6 +472,9 @@ inline std::string TensorDTypeToPyDTypeStr( } // namespace details inline py::array TensorToPyArray(const framework::Tensor &tensor) { + if (!tensor.IsInitialized()) { + return py::array(); + } bool is_gpu_tensor = platform::is_gpu_place(tensor.place()); const auto &tensor_dims = tensor.dims(); auto tensor_dtype = tensor.type(); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index f3988edf08f..063b65e8eef 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -119,7 +119,10 @@ def as_numpy(tensor): They can not be completely cast to Python ndarray. \ Please set the parameter 'return_numpy' as 'False' to \ return LoDTensor itself directly.") - return np.array(tensor) + if tensor._is_initialized(): + return np.array(tensor) + else: + return None def has_feed_operators(block, feed_targets, feed_holder_name): diff --git a/python/paddle/fluid/tests/unittests/gradient_checker.py b/python/paddle/fluid/tests/unittests/gradient_checker.py index 14a828f28ee..87c917873cd 100644 --- a/python/paddle/fluid/tests/unittests/gradient_checker.py +++ b/python/paddle/fluid/tests/unittests/gradient_checker.py @@ -82,6 +82,10 @@ def set_var_in_scope(scope, place, name, value, recursive_seq_len=None): return t +def var_to_np_array_in_scope(scope, place, name): + return np.array(scope.var(name).get_tensor()) + + def make_jacobian(x, y_size, np_dtype): if isinstance(x, fluid.framework.Variable): return np.zeros((_product(x.shape), y_size), dtype=np_dtype) @@ -192,14 +196,18 @@ def _compute_analytical_jacobian(program, x, y, place, scope): x = _as_list(x) jacobian = make_jacobian(x, y_size, np_type) - dx = _as_list(dx) for i in six.moves.xrange(y_size): _set_item(dy_t, i, 1, np_type) dx_res = exe.run(program, scope=scope, fetch_list=dx) for j in six.moves.xrange(len(x)): - jacobian[j][:, i] = dx_res[j].flatten() + if dx_res[j] is not None: + jacobian[j][:, i] = dx_res[j].flatten() + else: + jacobian[j][:, i] = np.zeros( + dx[j].shape, dtype=np_type).flatten() + _set_item(dy_t, i, 0, np_type) return jacobian @@ -242,6 +250,7 @@ def grad_check(x, # check input arguments x = _as_list(x) y = _as_list(y) + for v in x: v.stop_gradient = False v.persistable = True @@ -274,9 +283,24 @@ def grad_check(x, ] # [y_idx, x_idx] - analytical = [ - _compute_analytical_jacobian(program, x, yi, place, scope) for yi in y - ] + analytical = [] + for yi in y: + prog = program.clone() + + clone_x = [] + clone_y = None + for b in prog.blocks: + if b.has_var(yi.name): + clone_y = b.var(yi.name) + break + for xi in x: + for b in prog.blocks: + if b.has_var(xi.name): + clone_x.append(b.var(xi.name)) + break + + analytical.append( + _compute_analytical_jacobian(prog, clone_x, clone_y, place, scope)) for i, (x_idx, y_idx) in enumerate(product(*[range(len(x)), range(len(y))])): @@ -334,6 +358,7 @@ def double_grad_check(x, if y_grads is None: scope = fluid.executor.global_scope() y_grads = [] + y_grads_init = [] for yi in y: dyi_name = _append_grad_suffix_(yi.name) np_type = dtype_to_np_dtype(yi.dtype) @@ -343,9 +368,20 @@ def double_grad_check(x, v = np.random.random(size=yi.shape).astype(np_type) set_var_in_scope(scope, place, dyi_name, v) y_grads.append(dy) + y_grads_init.append(v) else: y_grads = _as_list(y_grads) + y_grads_init = [ + var_to_np_array_in_scope(scope, place, v.name) for v in y_grads + ] # append first order grads target_grads = calc_gradient(y, x, y_grads) + + # y_grads are the input of first-order backward, + # so, they are also the input of second-order backward. + x += y_grads + x_init = _as_list(x_init) + x_init += y_grads_init + grad_check(x, target_grads, x_init, place, program, eps, atol, rtol) diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index e2d540fea55..df0d8e0345c 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -46,7 +46,6 @@ class TestMulGradCheck(unittest.TestCase): class TestReluDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): - # the shape of input variable shoule be clearly specified, not inlcude -1. shape = [2, 8] eps = 0.005 dtype = np.float64 @@ -71,7 +70,6 @@ class TestReluDoubleGradCheck(unittest.TestCase): class TestLeakyReluDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): - # the shape of input variable shoule be clearly specified, not inlcude -1. shape = [3, 7] eps = 0.005 alpha = 0.2 @@ -79,6 +77,7 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase): x = layers.data('x', shape, False, dtype) x.persistable = True + y = layers.leaky_relu(x, alpha=alpha) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) x_arr[np.abs(x_arr) < 0.005] = 0.02 @@ -90,8 +89,30 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) - for p in places: - self.func(p) + + +class TestConvDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 4, 14, 16] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d(x, 4, 1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + if core.is_compiled_with_cuda(): + places = [fluid.CUDAPlace(0)] + for p in places: + self.func(p) if __name__ == "__main__": -- GitLab