/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the spopecific language governing permissions and limitations under the License. */ #include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" DECLARE_bool(cudnn_deterministic); DECLARE_uint64(conv_workspace_size_limit); DECLARE_bool(cudnn_exhaustive_search); namespace paddle { namespace operators { using Tensor = framework::Tensor; using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using DataLayout = platform::DataLayout; static inline bool IsVoltaOrLater(const platform::CUDADeviceContext& dev_ctx) { return dev_ctx.GetComputeCapability() >= 70; } template class CUDNNConvOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must use CUDAPlace."); const Tensor* input = ctx.Input("Input"); auto* filter = ctx.Input("Filter"); auto* output = ctx.Output("Output"); output->mutable_data(ctx.GetPlace()); const std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); bool exhaustive_search = FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); if (exhaustive_search && FLAGS_cudnn_deterministic) { PADDLE_THROW( "Cann't set exhaustive_search True and " "FLAGS_cudnn_deterministic True at same time."); } const std::string padding_algorithm = ctx.Attr("padding_algorithm"); const std::string data_format = ctx.Attr("data_format"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); auto dtype = platform::CudnnDataType::type; // Tensor Core introduced from Volta GPUs supports more faster conv op // with FP16 in NHWC data format. const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(dev_ctx); // We will only do data format conversion from NHWC to NCHW. // cudnn will convert NCHW to NHWC automatically on Tensor Core. auto compute_format = compute_in_nhwc && channel_last ? DataLayout::kNHWC : DataLayout::kNCHW; VLOG(3) << "Compute ConvOp with cuDNN:" << " data_format=" << data_format << " compute_format=" << (compute_format == DataLayout::kNHWC ? "NHWC" : "NCHW"); // ------------ transformed tensor ----------- Tensor transformed_input_channel(input->type()); Tensor transformed_output(output->type()); Tensor transformed_filter_channel(filter->type()); T* output_data = nullptr; if (channel_last && compute_format == DataLayout::kNCHW) { VLOG(3) << "Transform input tensor from NHWC to NCHW."; ResizeToChannelFirst( ctx, input, &transformed_input_channel); TransToChannelFirst( ctx, input, &transformed_input_channel); ResizeToChannelFirst(ctx, output, &transformed_output); } else { transformed_input_channel.ShareDataWith(*input); transformed_output.ShareDataWith(*output); } if (compute_format == DataLayout::kNHWC) { VLOG(3) << "Transform filter tensor from NCHW to NHWC."; ResizeToChannelLast( ctx, filter, &transformed_filter_channel); TransToChannelLast( ctx, filter, &transformed_filter_channel); } else { transformed_filter_channel.ShareDataWith(*filter); } output_data = transformed_output.data(); // update padding and dilation auto in_dims = transformed_input_channel.dims(); auto filter_dims = transformed_filter_channel.dims(); framework::DDim in_data_dims; framework::DDim filter_data_dims; if (compute_format == DataLayout::kNCHW) { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); } else { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); filter_data_dims = framework::slice_ddim(filter_dims, 1, filter_dims.size() - 1); } std::vector ksize = framework::vectorize(filter_data_dims); UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); Tensor transformed_input; std::vector padding_common(data_dim, 0); if (!is_sys_pad) { std::vector padding_diff(data_dim); std::vector new_input_shape_vec(data_dim + 2); new_input_shape_vec[0] = transformed_input_channel.dims()[0]; if (compute_format == DataLayout::kNCHW) { new_input_shape_vec[1] = transformed_input_channel.dims()[1]; } else { new_input_shape_vec[data_dim + 1] = transformed_input_channel.dims()[data_dim + 1]; } std::vector input_pad(transformed_input_channel.dims().size() * 2, 0); for (size_t i = 0; i < data_dim; ++i) { padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); if (compute_format == DataLayout::kNCHW) { new_input_shape_vec[i + 2] = transformed_input_channel.dims()[i + 2] + padding_diff[i]; } else { new_input_shape_vec[i + 1] = transformed_input_channel.dims()[i + 1] + padding_diff[i]; } if (compute_format == DataLayout::kNCHW) { input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } else { input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i]; } } framework::DDim new_input_shape( framework::make_ddim(new_input_shape_vec)); transformed_input.Resize(new_input_shape); auto& dev_ctx = ctx.template device_context(); transformed_input = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); const int rank = transformed_input_channel.dims().size(); T pad_value(0.0); switch (rank) { case 4: { math::PadFunction( ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; case 5: { math::PadFunction( ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; default: PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); } } else { transformed_input.ShareDataWith(transformed_input_channel); if (paddings.size() == data_dim) { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[i]; } } else { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[2 * i]; } } } const T* input_data = transformed_input.data(); const T* filter_data = transformed_filter_channel.data(); // ------------------- cudnn descriptors --------------------- ConvArgs args{&transformed_input, &transformed_filter_channel, &transformed_output, strides, padding_common, dilations}; auto handle = dev_ctx.cudnn_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; if (transformed_input.dims().size() == 5) { layout = compute_format == DataLayout::kNHWC ? DataLayout::kNDHWC : DataLayout::kNCDHW; } auto layout_format = GetCudnnTensorFormat(layout); args.handle = handle; args.cdesc.set(dtype, padding_common, strides, dilations); #if CUDNN_VERSION_MIN(7, 0, 1) // cudnn 7 can support groups, no need to do it manually // FIXME(typhoonzero): find a better way to disable groups // rather than setting it to 1. CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( args.cdesc.desc(), groups)); groups = 1; #endif args.idesc.set(transformed_input, layout_format); args.wdesc.set(transformed_filter_channel, layout_format, groups); args.odesc.set(transformed_output, layout_format); int i_n, i_c, i_d, i_h, i_w; int o_n, o_c, o_d, o_h, o_w; if (compute_format == DataLayout::kNHWC) { GetNCDHW(transformed_input.dims(), DataLayout::kNHWC, &i_n, &i_c, &i_d, &i_h, &i_w); GetNCDHW(transformed_output.dims(), DataLayout::kNHWC, &o_n, &o_c, &o_d, &o_h, &o_w); } else { GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); GetNCDHW(transformed_output.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); } int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = transformed_filter_channel.numel() / groups; // ------------------- cudnn conv workspace --------------------- size_t workspace_size = 0; // final workspace to allocate. // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo{}; using search = SearchAlgorithm; algo = search::Find(args, exhaustive_search, false, 0, ctx); workspace_size = search::GetWorkspaceSize(args, algo); // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { workspace_handle.RunFunc( [&](void* workspace_ptr) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, args.idesc.desc(), input_data + i * group_offset_in, args.wdesc.desc(), filter_data + i * group_offset_filter, args.cdesc.desc(), algo, workspace_ptr, workspace_size, &beta, args.odesc.desc(), output_data + i * group_offset_out)); }, workspace_size); } if (channel_last && compute_format == DataLayout::kNCHW) { TransToChannelLast( ctx, &transformed_output, output); } } }; template class CUDNNConvGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must use CUDAPlace."); auto input = ctx.Input("Input"); auto filter = ctx.Input("Filter"); auto output_grad = ctx.Input(framework::GradVarName("Output")); auto input_grad = ctx.Output(framework::GradVarName("Input")); auto filter_grad = ctx.Output(framework::GradVarName("Filter")); if (input_grad) { input_grad->mutable_data(ctx.GetPlace()); } if (filter_grad) { filter_grad->mutable_data(ctx.GetPlace()); } std::vector dilations = ctx.Attr>("dilations"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::string padding_algorithm = ctx.Attr("padding_algorithm"); int groups = ctx.Attr("groups"); bool exhaustive_search = FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); bool deterministic = FLAGS_cudnn_deterministic; if (exhaustive_search && deterministic) { PADDLE_THROW( "Can't set exhaustive_search True and " "FLAGS_cudnn_deterministic True at same time."); } const std::string data_format = ctx.Attr("data_format"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); auto dtype = platform::CudnnDataType::type; const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(dev_ctx); auto compute_format = compute_in_nhwc && channel_last ? DataLayout::kNHWC : DataLayout::kNCHW; VLOG(3) << "Compute ConvGradOp with cuDNN:" << " data_format=" << data_format << " compute_format=" << (compute_format == DataLayout::kNHWC ? "NHWC" : "NCHW"); // transform Tensor Tensor transformed_input_channel(input->type()); Tensor transformed_output_grad_channel(output_grad->type()); Tensor transformed_input_grad_channel(input->type()); Tensor transformed_filter_channel(filter->type()); Tensor transformed_filter_grad_channel(filter->type()); if (channel_last && compute_format == DataLayout::kNCHW) { VLOG(3) << "Transform input, output_grad, input_grad and tensor from " "NHWC to NCHW."; ResizeToChannelFirst( ctx, input, &transformed_input_channel); TransToChannelFirst( ctx, input, &transformed_input_channel); ResizeToChannelFirst( ctx, output_grad, &transformed_output_grad_channel); TransToChannelFirst( ctx, output_grad, &transformed_output_grad_channel); if (input_grad) { ResizeToChannelFirst( ctx, input_grad, &transformed_input_grad_channel); } } else { transformed_input_channel.ShareDataWith(*input); transformed_output_grad_channel.ShareDataWith(*output_grad); if (input_grad) { transformed_input_grad_channel.ShareDataWith(*input_grad); } } if (compute_format == DataLayout::kNHWC) { VLOG(3) << "Transform filter and filter_grad tensor from NCHW to NHWC."; ResizeToChannelLast( ctx, filter, &transformed_filter_channel); TransToChannelLast( ctx, filter, &transformed_filter_channel); if (filter_grad) { ResizeToChannelLast( ctx, filter_grad, &transformed_filter_grad_channel); } } else { transformed_filter_channel.ShareDataWith(*filter); if (filter_grad) { transformed_filter_grad_channel.ShareDataWith(*filter_grad); } } // update paddings auto in_dims = transformed_input_channel.dims(); auto filter_dims = transformed_filter_channel.dims(); framework::DDim in_data_dims; framework::DDim filter_data_dims; if (compute_format == DataLayout::kNCHW) { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); } else { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); filter_data_dims = framework::slice_ddim(filter_dims, 1, filter_dims.size() - 1); } std::vector ksize = framework::vectorize(filter_data_dims); UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); // cuDNN only supports padding the same amount on every dimension. // So we create a new padded input tensor. int data_dim = strides.size(); // 2d or 3d bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); Tensor transformed_input(input->type()); Tensor transformed_input_grad(input->type()); std::vector padding_common(data_dim, 0); std::vector input_pad(transformed_input_channel.dims().size() * 2, 0); if (!is_sys_pad) { // get pad std::vector padding_diff(data_dim); std::vector new_input_shape_vec(data_dim + 2); new_input_shape_vec[0] = transformed_input_channel.dims()[0]; if (compute_format == DataLayout::kNCHW) { new_input_shape_vec[1] = transformed_input_channel.dims()[1]; } else { new_input_shape_vec[data_dim + 1] = transformed_input_channel.dims()[data_dim + 1]; } for (size_t i = 0; i < data_dim; ++i) { padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); if (compute_format == DataLayout::kNCHW) { new_input_shape_vec[i + 2] = transformed_input_channel.dims()[i + 2] + padding_diff[i]; } else { new_input_shape_vec[i + 1] = transformed_input_channel.dims()[i + 1] + padding_diff[i]; } if (compute_format == DataLayout::kNCHW) { input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } else { input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i]; } } framework::DDim new_input_shape( framework::make_ddim(new_input_shape_vec)); transformed_input.Resize(new_input_shape); transformed_input_grad.Resize(new_input_shape); auto& dev_ctx = ctx.template device_context(); transformed_input = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); if (input_grad) { transformed_input_grad = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); } // pad for input const int rank = transformed_input_channel.dims().size(); T pad_value(0.0); switch (rank) { case 4: { math::PadFunction( ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; case 5: { math::PadFunction( ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; default: PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); } } else { transformed_input.ShareDataWith(transformed_input_channel); if (input_grad) { transformed_input_grad.ShareDataWith(transformed_input_grad_channel); } if (paddings.size() == data_dim) { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[i]; } } else { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[2 * i]; } } } const T* input_data = transformed_input.data(); const T* output_grad_data = transformed_output_grad_channel.data(); const T* filter_data = transformed_filter_channel.data(); T* filter_grad_data = nullptr; T* input_grad_data = nullptr; T* transformed_input_grad_data = nullptr; ConvArgs args1{&transformed_input_grad, &transformed_filter_channel, &transformed_output_grad_channel, strides, padding_common, dilations}; ConvArgs args2{&transformed_input, &transformed_filter_grad_channel, &transformed_output_grad_channel, strides, padding_common, dilations}; auto handle = dev_ctx.cudnn_handle(); DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; if (transformed_input.dims().size() == 5) { layout = compute_format == DataLayout::kNHWC ? DataLayout::kNDHWC : DataLayout::kNCDHW; } auto layout_tensor = GetCudnnTensorFormat(layout); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); int i_n, i_c, i_d, i_h, i_w; int o_n, o_c, o_d, o_h, o_w; if (compute_format == DataLayout::kNHWC) { GetNCDHW(transformed_input.dims(), DataLayout::kNHWC, &i_n, &i_c, &i_d, &i_h, &i_w); GetNCDHW(transformed_output_grad_channel.dims(), DataLayout::kNHWC, &o_n, &o_c, &o_d, &o_h, &o_w); } else { GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); GetNCDHW(transformed_output_grad_channel.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); } int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = transformed_filter_channel.numel() / groups; // ------------------- cudnn backward algorithm --------------------- cudnnConvolutionBwdDataAlgo_t data_algo = static_cast(0); cudnnConvolutionBwdFilterAlgo_t filter_algo = static_cast(0); size_t workspace_size = 0; int iwo_groups, c_groups; #if CUDNN_VERSION_MIN(7, 0, 1) iwo_groups = 1; c_groups = groups; groups = 1; #endif if (input_grad) { // ------------------- cudnn descriptors --------------------- input_grad_data = input_grad->data(); transformed_input_grad_data = transformed_input_grad.data(); args1.handle = handle; args1.idesc.set(transformed_input_grad, layout_tensor); args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups); args1.odesc.set(transformed_output_grad_channel, layout_tensor); args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups); using search1 = SearchAlgorithm; data_algo = search1::Find(args1, exhaustive_search, deterministic, 0, ctx); workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); } if (filter_grad) { // ------------------- cudnn descriptors --------------------- filter_grad_data = transformed_filter_grad_channel.data(); args2.handle = handle; args2.idesc.set(transformed_input, layout_tensor); args2.wdesc.set(transformed_filter_grad_channel, layout_tensor, iwo_groups); args2.odesc.set(transformed_output_grad_channel, layout_tensor); args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups); using search2 = SearchAlgorithm; filter_algo = search2::Find(args2, exhaustive_search, deterministic, 1, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo)); } // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { // Because beta is zero, it is unnecessary to reset input_grad. for (int i = 0; i < groups; i++) { workspace_handle.RunFunc( [&](void* cudnn_workspace_ptr) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( handle, &alpha, args1.wdesc.desc(), filter_data + i * group_offset_filter, args1.odesc.desc(), output_grad_data + i * group_offset_out, args1.cdesc.desc(), data_algo, cudnn_workspace_ptr, workspace_size, &beta, args1.idesc.desc(), transformed_input_grad_data + i * group_offset_in)); }, workspace_size); } if (!is_sys_pad) { std::vector starts(transformed_input_channel.dims().size(), 0); std::vector axes(transformed_input_channel.dims().size(), 0); for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) { starts[i] = input_pad[2 * i]; axes[i] = i; } transformed_input_grad_channel.mutable_data(ctx.GetPlace()); if (transformed_input_channel.dims().size() == 4) { RemovePaddingSlice( ctx, &transformed_input_grad, &transformed_input_grad_channel, starts, axes); } else { RemovePaddingSlice( ctx, &transformed_input_grad, &transformed_input_grad_channel, starts, axes); } } if (channel_last && compute_format == DataLayout::kNCHW) { TransToChannelLast( ctx, &transformed_input_grad_channel, input_grad); } } // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { // Because beta is zero, it is unnecessary to reset filter_grad. for (int i = 0; i < groups; i++) { workspace_handle.RunFunc( [&](void* cudnn_workspace_ptr) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, args2.idesc.desc(), input_data + i * group_offset_in, args2.odesc.desc(), output_grad_data + i * group_offset_out, args2.cdesc.desc(), filter_algo, cudnn_workspace_ptr, workspace_size, &beta, args2.wdesc.desc(), filter_grad_data + i * group_offset_filter)); }, workspace_size); } if (compute_format == DataLayout::kNHWC) { TransToChannelFirst( ctx, &transformed_filter_grad_channel, filter_grad); } } } }; /* * Inputs: I, W, dO, ddI, ddW * Outputs: ddO, dW, dI * ddo = conv(ddI, W) + conv(I, ddW) * dW = conv_bp_filter(ddI, dO) * dI = conv_bp_data(ddW, dO) */ template class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must use CUDAPlace."); auto X = ctx.Input("Input"); auto W = ctx.Input("Filter"); auto dO = ctx.Input("DOutput"); auto ddX = ctx.Input("DDInput"); auto ddW = ctx.Input("DDFilter"); auto ddO = ctx.Output("DDOutput"); auto dW = ctx.Output("DFilter"); auto dX = ctx.Output("DInput"); if (ddO) { ddO->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; set_zero(dev_ctx, ddO, static_cast(0)); } if (dW) { dW->mutable_data(ctx.GetPlace()); } if (dX) { dX->mutable_data(ctx.GetPlace()); } // const T* x = X->data(); const T* dy = dO->data(); const T* w = W->data(); const T* ddx = nullptr; const T* ddw = nullptr; T *dw, *dx, *ddy; dw = dx = ddy = nullptr; T* transformed_dx = nullptr; const std::vector& strides = ctx.Attr>("strides"); std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); bool exhaustive_search = FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); bool deterministic = FLAGS_cudnn_deterministic; if (exhaustive_search && deterministic) { PADDLE_THROW( "Can't set exhaustive_search True and " "FLAGS_cudnn_deterministic True at same time."); } std::vector paddings = ctx.Attr>("paddings"); std::string padding_algorithm = ctx.Attr("padding_algorithm"); const std::string data_format = ctx.Attr("data_format"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); // transform Tensors to channel first----------- Tensor transformed_X_channel(X->type()); Tensor transformed_dO_channel(dO->type()); Tensor transformed_ddX_channel(X->type()); Tensor transformed_ddO_channel(dO->type()); Tensor transformed_dX_channel(X->type()); if (channel_last) { ResizeToChannelFirst( ctx, X, &transformed_X_channel); TransToChannelFirst( ctx, X, &transformed_X_channel); ResizeToChannelFirst( ctx, dO, &transformed_dO_channel); TransToChannelFirst( ctx, dO, &transformed_dO_channel); if (ddX) { ResizeToChannelFirst( ctx, ddX, &transformed_ddX_channel); TransToChannelFirst( ctx, ddX, &transformed_ddX_channel); } if (ddO) { ResizeToChannelFirst( ctx, ddO, &transformed_ddO_channel); } if (dX) { ResizeToChannelFirst( ctx, dX, &transformed_dX_channel); transformed_dX_channel.mutable_data(ctx.GetPlace()); } } else { transformed_X_channel = *X; transformed_dO_channel = *dO; if (ddX) { transformed_ddX_channel = *ddX; } if (ddO) { transformed_ddO_channel.ShareDataWith(*ddO); } if (dX) { transformed_dX_channel.ShareDataWith(*dX); } } auto in_dims = transformed_X_channel.dims(); auto filter_dims = W->dims(); framework::DDim in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); framework::DDim filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::vector ksize = framework::vectorize(filter_data_dims); UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); Tensor transformed_X(X->type()); Tensor transformed_ddX(X->type()); Tensor transformed_dX(X->type()); std::vector padding_common(data_dim, 0); std::vector input_pad(X->dims().size() * 2, 0); if (!is_sys_pad) { // get pad std::vector padding_diff(data_dim); std::vector new_input_shape_vec(data_dim + 2); new_input_shape_vec[0] = transformed_X_channel.dims()[0]; new_input_shape_vec[1] = transformed_X_channel.dims()[1]; for (size_t i = 0; i < data_dim; ++i) { padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); new_input_shape_vec[i + 2] = transformed_X_channel.dims()[i + 2] + padding_diff[i]; input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } framework::DDim new_input_shape( framework::make_ddim(new_input_shape_vec)); transformed_X.Resize(new_input_shape); transformed_ddX.Resize(new_input_shape); transformed_dX.Resize(new_input_shape); transformed_X = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); if (ddX) { transformed_ddX = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); } if (dX) { transformed_dX = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); } // pad for input const int rank = X->dims().size(); T pad_value(0.0); switch (rank) { case 4: { math::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); if (ddX) { math::PadFunction( ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; case 5: { math::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); if (ddX) { math::PadFunction( ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; default: PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); } } else { transformed_X.ShareDataWith(transformed_X_channel); if (ddX) { transformed_ddX.ShareDataWith(transformed_ddX_channel); } if (dX) { transformed_dX.ShareDataWith(transformed_dX_channel); } if (paddings.size() == data_dim) { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[i]; } } else { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[2 * i]; } } } const T* x = transformed_X.data(); int iwo_group = groups; int c_group = 1; #if CUDNN_VERSION_MIN(7, 0, 1) iwo_group = 1; c_group = groups; #endif auto dtype = platform::CudnnDataType::type; auto handle = dev_ctx.cudnn_handle(); ConvArgs args1{&transformed_ddX, W, &transformed_ddO_channel, strides, padding_common, dilations}; ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides, padding_common, dilations}; ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides, padding_common, dilations}; ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides, padding_common, dilations}; cudnnConvolutionFwdAlgo_t fwd_algo1 = static_cast(0); cudnnConvolutionFwdAlgo_t fwd_algo2 = static_cast(0); cudnnConvolutionBwdDataAlgo_t data_algo = static_cast(0); cudnnConvolutionBwdFilterAlgo_t filter_algo = static_cast(0); auto layout = GetCudnnTensorFormat(DataLayout::kNCHW); // ddo = conv(ddI, W) + conv(I, ddW) size_t workspace_size = 0; T* transformed_ddy_channel = nullptr; if (ddO) { ddy = ddO->data(); transformed_ddy_channel = transformed_ddO_channel.data(); if (ddX) { args1.handle = handle; args1.idesc.set(transformed_ddX, iwo_group); args1.wdesc.set(*W, layout, iwo_group); args1.odesc.set(transformed_ddO_channel, iwo_group); args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search1 = SearchAlgorithm; fwd_algo1 = search1::Find(args1, exhaustive_search, false, 0, ctx); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); } if (ddW) { ddw = ddW->data(); args2.handle = handle; args2.idesc.set(transformed_X, iwo_group); args2.wdesc.set(*ddW, layout, iwo_group); args2.odesc.set(transformed_ddO_channel, iwo_group); args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search2 = SearchAlgorithm; fwd_algo2 = search2::Find(args2, exhaustive_search, false, 0, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2)); } } if (dW && ddX) { dw = dW->data(); args3.handle = handle; args3.idesc.set(transformed_ddX, iwo_group); args3.wdesc.set(*dW, layout, iwo_group); args3.odesc.set(transformed_dO_channel, iwo_group); args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search3 = SearchAlgorithm; filter_algo = search3::Find(args3, exhaustive_search, deterministic, 1, ctx); workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); } if (ddW && dX) { transformed_dx = transformed_dX.data(); args4.handle = handle; args4.idesc.set(transformed_dX, iwo_group); args4.wdesc.set(*ddW, layout, iwo_group); args4.odesc.set(transformed_dO_channel, iwo_group); args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search4 = SearchAlgorithm; data_algo = search4::Find(args4, exhaustive_search, deterministic, 2, ctx); workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); } int i_n, i_c, i_d, i_h, i_w; GetNCDHW(transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); int o_n, o_c, o_d, o_h, o_w; GetNCDHW(transformed_dO_channel.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = W->numel() / groups; ScalingParamType alpha = 1.0f, beta = 0.0f; auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); if (ddO) { if (ddX) { ddx = transformed_ddX.data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in, args1.wdesc.desc(), w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size, &beta, args1.odesc.desc(), transformed_ddy_channel + i * group_offset_out)); }, workspace_size); } } if (ddW) { for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, args2.idesc.desc(), x + i * group_offset_in, args2.wdesc.desc(), ddw + i * group_offset_filter, args2.cdesc.desc(), fwd_algo2, workspace_ptr, workspace_size, &alpha, args2.odesc.desc(), transformed_ddy_channel + i * group_offset_out)); }, workspace_size); } } if (channel_last) { TransToChannelLast( ctx, &transformed_ddO_channel, ddO); } } T* transformed_dy_channel = transformed_dO_channel.data(); if (dW && ddX) { ddx = transformed_ddX.data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in, args3.odesc.desc(), transformed_dy_channel + i * group_offset_out, args3.cdesc.desc(), filter_algo, workspace_ptr, workspace_size, &beta, args3.wdesc.desc(), dw + i * group_offset_filter)); }, workspace_size); } } if (dX && ddW) { ddw = ddW->data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( handle, &alpha, args4.wdesc.desc(), ddw + i * group_offset_filter, args4.odesc.desc(), transformed_dy_channel + i * group_offset_out, args4.cdesc.desc(), data_algo, workspace_ptr, workspace_size, &beta, args4.idesc.desc(), transformed_dx + i * group_offset_in)); }, workspace_size); } if (!is_sys_pad) { // reverse padded input std::vector starts(X->dims().size(), 0); std::vector axes(X->dims().size(), 0); for (size_t i = 0; i < X->dims().size(); ++i) { starts[i] = input_pad[2 * i]; axes[i] = i; } if (X->dims().size() == 4) { RemovePaddingSlice( ctx, &transformed_dX, &transformed_dX_channel, starts, axes); } else { RemovePaddingSlice( ctx, &transformed_dX, &transformed_dX_channel, starts, axes); } } if (channel_last) { TransToChannelLast( ctx, &transformed_dX_channel, dX); } } } }; } // namespace operators } // namespace paddle namespace plat = paddle::platform; REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, paddle::operators::CUDNNConvOpKernel, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL( conv2d_grad_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvDoubleGradOpKernel, paddle::operators::CUDNNConvDoubleGradOpKernel, paddle::operators::CUDNNConvDoubleGradOpKernel); REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, paddle::operators::CUDNNConvOpKernel, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL( conv3d_grad_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvDoubleGradOpKernel, paddle::operators::CUDNNConvDoubleGradOpKernel, paddle::operators::CUDNNConvDoubleGradOpKernel);