// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/conv_grad_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/fluid/framework/eigen.h" #ifdef PADDLE_WITH_HIP #include "paddle/fluid/operators/conv_miopen_helper.h" #else #include "paddle/fluid/operators/conv_cudnn_helper.h" #endif #include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/phi/kernels/funcs/padding.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/impl/conv_cudnn_impl.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { template void ConvCudnnGradGradKernel( const Context& ctx, paddle::optional input_grad_grad, paddle::optional filter_grad_grad, const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, const std::vector& strides, const std::vector& paddings_t, const std::string& padding_algorithm, int groups, const std::vector& dilations_t, const std::string& data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search_t, DenseTensor* out_grad_grad, DenseTensor* input_grad, DenseTensor* filter_grad) { auto X = &input; auto W = &filter; auto dO = &out_grad; auto ddX = input_grad_grad.get_ptr(); auto ddW = filter_grad_grad.get_ptr(); auto ddO = out_grad_grad; auto dW = filter_grad; auto dX = input_grad; if (ddO) { ctx.template Alloc(ddO); phi::funcs::SetConstant set_zero; set_zero(ctx, ddO, static_cast(0)); } if (dW) { ctx.template Alloc(dW); } if (dX) { ctx.template Alloc(dX); } // const T* x = X->data(); const T* dy = dO->data(); const T* w = W->data(); const T* ddx = nullptr; const T* ddw = nullptr; T *dw, *dx, *ddy; dw = dx = ddy = nullptr; T* transformed_dx = nullptr; std::vector dilations = dilations_t; bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t; bool deterministic = FLAGS_cudnn_deterministic; auto exhaustive_deterministic = exhaustive_search && deterministic; PADDLE_ENFORCE_EQ(exhaustive_deterministic, false, phi::errors::InvalidArgument( "Cann't set exhaustive_search True and " "FLAGS_cudnn_deterministic True at same time.")); std::vector paddings = paddings_t; const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); // transform Tensors to channel first----------- DenseTensor transformed_X_channel(X->type()); DenseTensor transformed_dO_channel(dO->type()); DenseTensor transformed_ddX_channel(X->type()); DenseTensor transformed_ddO_channel(dO->type()); DenseTensor transformed_dX_channel(X->type()); if (channel_last) { ResizeToChannelFirst(ctx, X, &transformed_X_channel); TransToChannelFirst(ctx, X, &transformed_X_channel); ResizeToChannelFirst(ctx, dO, &transformed_dO_channel); TransToChannelFirst(ctx, dO, &transformed_dO_channel); if (ddX) { ResizeToChannelFirst(ctx, ddX, &transformed_ddX_channel); TransToChannelFirst(ctx, ddX, &transformed_ddX_channel); } if (ddO) { ResizeToChannelFirst(ctx, ddO, &transformed_ddO_channel); } if (dX) { ResizeToChannelFirst(ctx, dX, &transformed_dX_channel); ctx.template Alloc(&transformed_dX_channel); } } else { transformed_X_channel = *X; transformed_dO_channel = *dO; if (ddX) { transformed_ddX_channel = *ddX; } if (ddO) { transformed_ddO_channel.ShareDataWith(*ddO); } if (dX) { transformed_dX_channel.ShareDataWith(*dX); } } auto in_dims = transformed_X_channel.dims(); auto filter_dims = W->dims(); DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size()); DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size()); std::vector ksize = vectorize(filter_data_dims); UpdatePaddingAndDilation( &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim); DenseTensor transformed_X(X->type()); DenseTensor transformed_ddX(X->type()); DenseTensor transformed_dX(X->type()); std::vector padding_common(data_dim, 0); std::vector input_pad(X->dims().size() * 2, 0); if (!is_sys_pad) { // get pad std::vector padding_diff(data_dim); std::vector new_input_shape_vec(data_dim + 2); new_input_shape_vec[0] = transformed_X_channel.dims()[0]; new_input_shape_vec[1] = transformed_X_channel.dims()[1]; for (size_t i = 0; i < data_dim; ++i) { padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); new_input_shape_vec[i + 2] = transformed_X_channel.dims()[i + 2] + padding_diff[i]; input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } DDim new_input_shape(make_ddim(new_input_shape_vec)); transformed_X.Resize(new_input_shape); transformed_ddX.Resize(new_input_shape); transformed_dX.Resize(new_input_shape); ctx.template Alloc(&transformed_X); if (ddX) { ctx.template Alloc(&transformed_ddX); } if (dX) { ctx.template Alloc(&transformed_dX); } // pad for input const int rank = X->dims().size(); T pad_value(0.0); switch (rank) { case 4: { funcs::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); if (ddX) { funcs::PadFunction(ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; case 5: { funcs::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); if (ddX) { funcs::PadFunction(ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; default: PADDLE_THROW(phi::errors::InvalidArgument( "ConvOp only support tensors with 4 or 5 dimensions.")); } } else { transformed_X.ShareDataWith(transformed_X_channel); if (ddX) { transformed_ddX.ShareDataWith(transformed_ddX_channel); } if (dX) { transformed_dX.ShareDataWith(transformed_dX_channel); } if (paddings.size() == data_dim) { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[i]; } } else { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[2 * i]; } } } const T* x = transformed_X.data(); int iwo_group = groups; int c_group = 1; #if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1) iwo_group = 1; c_group = groups; groups = 1; #endif auto dtype = paddle::platform::CudnnDataType::type; auto handle = ctx.cudnn_handle(); paddle::operators::ConvArgs args1{&transformed_ddX, W, &transformed_ddO_channel, strides, padding_common, dilations, dtype}; paddle::operators::ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides, padding_common, dilations, dtype}; paddle::operators::ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides, padding_common, dilations, dtype}; paddle::operators::ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides, padding_common, dilations, dtype}; #ifdef PADDLE_WITH_HIP miopenConvFwdAlgorithm_t fwd_algo1 = static_cast(0); miopenConvFwdAlgorithm_t fwd_algo2 = static_cast(0); miopenConvBwdDataAlgorithm_t data_algo = static_cast(0); miopenConvBwdWeightsAlgorithm_t filter_algo = static_cast(0); #else cudnnConvolutionFwdAlgo_t fwd_algo1 = static_cast(0); cudnnConvolutionFwdAlgo_t fwd_algo2 = static_cast(0); cudnnConvolutionBwdDataAlgo_t data_algo = static_cast(0); cudnnConvolutionBwdFilterAlgo_t filter_algo = static_cast(0); #endif auto layout = paddle::platform::GetCudnnTensorFormat( paddle::platform::DataLayout::kNCHW); // ddo = conv(ddI, W) + conv(I, ddW) size_t workspace_size = 0; T* transformed_ddy_channel = nullptr; if (ddO) { ddy = ddO->data(); transformed_ddy_channel = transformed_ddO_channel.data(); if (ddX) { args1.handle = handle; args1.idesc.set(transformed_ddX, iwo_group); args1.wdesc.set(*W, layout, iwo_group); args1.odesc.set(transformed_ddO_channel, iwo_group); args1.cdesc.set(dtype, padding_common, strides, dilations, paddle::platform::AllowTF32Cudnn(), c_group); #ifdef PADDLE_WITH_HIP using search1 = paddle::operators::SearchAlgorithm; workspace_size = search1::GetWorkspaceSize(args1); fwd_algo1 = search1::Find( args1, exhaustive_search, false, workspace_size, ctx); #else using search1 = paddle::operators::SearchAlgorithm; fwd_algo1 = search1::Find(args1, exhaustive_search, false, ctx); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); #endif } if (ddW) { ddw = ddW->data(); args2.handle = handle; args2.idesc.set(transformed_X, iwo_group); args2.wdesc.set(*ddW, layout, iwo_group); args2.odesc.set(transformed_ddO_channel, iwo_group); args2.cdesc.set(dtype, padding_common, strides, dilations, paddle::platform::AllowTF32Cudnn(), c_group); #ifdef PADDLE_WITH_HIP using search2 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2)); fwd_algo2 = search2::Find( args2, exhaustive_search, false, workspace_size, ctx); #else using search2 = paddle::operators::SearchAlgorithm; fwd_algo2 = search2::Find(args2, exhaustive_search, false, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2)); #endif } } if (dW && ddX) { dw = dW->data(); args3.handle = handle; args3.idesc.set(transformed_ddX, iwo_group); args3.wdesc.set(*dW, layout, iwo_group); args3.odesc.set(transformed_dO_channel, iwo_group); args3.cdesc.set(dtype, padding_common, strides, dilations, paddle::platform::AllowTF32Cudnn(), c_group); #ifdef PADDLE_WITH_HIP using search3 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3)); filter_algo = search3::Find( args3, exhaustive_search, deterministic, workspace_size, ctx); #else using search3 = paddle::operators::SearchAlgorithm; filter_algo = search3::Find(args3, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); #endif } if (ddW && dX) { transformed_dx = transformed_dX.data(); args4.handle = handle; args4.idesc.set(transformed_dX, iwo_group); args4.wdesc.set(*ddW, layout, iwo_group); args4.odesc.set(transformed_dO_channel, iwo_group); args4.cdesc.set(dtype, padding_common, strides, dilations, paddle::platform::AllowTF32Cudnn(), c_group); #ifdef PADDLE_WITH_HIP using search4 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4)); data_algo = search4::Find( args4, exhaustive_search, deterministic, workspace_size, ctx); #else using search4 = paddle::operators::SearchAlgorithm; data_algo = search4::Find(args4, exhaustive_search, deterministic, ctx); workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); #endif } int i_n, i_c, i_d, i_h, i_w; GetNCDHW( transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); int o_n, o_c, o_d, o_h, o_w; GetNCDHW(transformed_dO_channel.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = W->numel() / groups; paddle::operators::ScalingParamType alpha = 1.0f; paddle::operators::ScalingParamType beta = 0.0f; // NOTE(zhiqiu): inplace addto is not supportted in double grad yet. // ScalingParamType beta = ctx.Attr("use_addto") ? 1.0f : // 0.0f; // VLOG(4) << "Conv_grad_grad: use_addto = " << ctx.Attr("use_addto"); auto wkspace_handle = ctx.cudnn_workspace_handle(); if (ddO) { if (ddX) { ddx = transformed_ddX.data(); #ifdef PADDLE_WITH_HIP wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenConvolutionForward( handle, &alpha, args1.idesc.desc(), ddx, args1.wdesc.desc(), w, args1.cdesc.desc(), fwd_algo1, &beta, args1.odesc.desc(), transformed_ddy_channel, workspace_ptr, workspace_size)); }, workspace_size); #else for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnConvolutionForward( handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in, args1.wdesc.desc(), w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size, &beta, args1.odesc.desc(), transformed_ddy_channel + i * group_offset_out)); }, workspace_size); } #endif } if (ddW) { #ifdef PADDLE_WITH_HIP // MIOPEN ONLY support beta to be 0.0f wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenConvolutionForward( handle, &alpha, args2.idesc.desc(), x, args2.wdesc.desc(), ddw, args2.cdesc.desc(), fwd_algo2, &beta, args2.odesc.desc(), transformed_ddy_channel, workspace_ptr, workspace_size)); }, workspace_size); #else for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnConvolutionForward( handle, &alpha, args2.idesc.desc(), x + i * group_offset_in, args2.wdesc.desc(), ddw + i * group_offset_filter, args2.cdesc.desc(), fwd_algo2, workspace_ptr, workspace_size, &alpha, args2.odesc.desc(), transformed_ddy_channel + i * group_offset_out)); }, workspace_size); } #endif } if (channel_last) { TransToChannelLast(ctx, &transformed_ddO_channel, ddO); } } T* transformed_dy_channel = transformed_dO_channel.data(); if (dW && ddX) { ddx = transformed_ddX.data(); #ifdef PADDLE_WITH_HIP wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenConvolutionBackwardWeights( handle, &alpha, args3.odesc.desc(), transformed_dy_channel, args3.idesc.desc(), ddx, args3.cdesc.desc(), filter_algo, &beta, args3.wdesc.desc(), dw, workspace_ptr, workspace_size)); }, workspace_size); #else for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in, args3.odesc.desc(), transformed_dy_channel + i * group_offset_out, args3.cdesc.desc(), filter_algo, workspace_ptr, workspace_size, &beta, args3.wdesc.desc(), dw + i * group_offset_filter)); }, workspace_size); } #endif } if (dX && ddW) { ddw = ddW->data(); #ifdef PADDLE_WITH_HIP wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenConvolutionBackwardData( handle, &alpha, args4.odesc.desc(), transformed_dy_channel, args4.wdesc.desc(), ddw, args4.cdesc.desc(), data_algo, &beta, args4.idesc.desc(), transformed_dx, workspace_ptr, workspace_size)); }, workspace_size); #else for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnConvolutionBackwardData( handle, &alpha, args4.wdesc.desc(), ddw + i * group_offset_filter, args4.odesc.desc(), transformed_dy_channel + i * group_offset_out, args4.cdesc.desc(), data_algo, workspace_ptr, workspace_size, &beta, args4.idesc.desc(), transformed_dx + i * group_offset_in)); }, workspace_size); } #endif if (!is_sys_pad) { // reverse padded input std::vector starts(X->dims().size(), 0); std::vector axes(X->dims().size(), 0); for (size_t i = 0; i < X->dims().size(); ++i) { starts[i] = input_pad[2 * i]; axes[i] = i; } if (X->dims().size() == 4) { paddle::operators::RemovePaddingSlice( ctx, &transformed_dX, &transformed_dX_channel, starts, axes); } else { paddle::operators::RemovePaddingSlice( ctx, &transformed_dX, &transformed_dX_channel, starts, axes); } } if (channel_last) { TransToChannelLast(ctx, &transformed_dX_channel, dX); } } } template void DepthwiseConvCudnnGradGradKernel( const Context& ctx, paddle::optional input_grad_grad, paddle::optional filter_grad_grad, const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, const std::vector& strides, const std::vector& paddings_t, const std::string& padding_algorithm, int groups, const std::vector& dilations_t, const std::string& data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search_t, bool fuse_relu, DenseTensor* out_grad_grad, DenseTensor* input_grad, DenseTensor* filter_grad) { ConvCudnnGradGradKernel(ctx, input_grad_grad, filter_grad_grad, out_grad, input, filter, strides, paddings_t, padding_algorithm, groups, dilations_t, data_format, use_addto, workspace_size_MB, exhaustive_search_t, out_grad_grad, input_grad, filter_grad); } template void Conv3DCudnnGradGradKernel( const Context& ctx, paddle::optional input_grad_grad, paddle::optional filter_grad_grad, const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, const std::vector& strides, const std::vector& paddings_t, const std::string& padding_algorithm, int groups, const std::vector& dilations_t, const std::string& data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search_t, DenseTensor* out_grad_grad, DenseTensor* input_grad, DenseTensor* filter_grad) { ConvCudnnGradGradKernel(ctx, input_grad_grad, filter_grad_grad, out_grad, input, filter, strides, paddings_t, padding_algorithm, groups, dilations_t, data_format, use_addto, workspace_size_MB, exhaustive_search_t, out_grad_grad, input_grad, filter_grad); } } // namespace phi #ifdef PADDLE_WITH_HIP PD_REGISTER_KERNEL(conv2d_grad_grad, GPUDNN, ALL_LAYOUT, phi::ConvCudnnGradGradKernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL(conv3d_grad_grad, GPUDNN, ALL_LAYOUT, phi::Conv3DCudnnGradGradKernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad, GPU, ALL_LAYOUT, phi::DepthwiseConvCudnnGradGradKernel, float, phi::dtype::float16) {} #else #if CUDNN_VERSION_MIN(8, 1, 0) PD_REGISTER_KERNEL(conv2d_grad_grad, GPUDNN, ALL_LAYOUT, phi::ConvCudnnGradGradKernel, float, double, phi::dtype::float16, phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(conv3d_grad_grad, GPUDNN, ALL_LAYOUT, phi::Conv3DCudnnGradGradKernel, float, double, phi::dtype::float16, phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad, GPU, ALL_LAYOUT, phi::DepthwiseConvCudnnGradGradKernel, float, double, phi::dtype::float16, phi::dtype::bfloat16) {} #else PD_REGISTER_KERNEL(conv2d_grad_grad, GPUDNN, ALL_LAYOUT, phi::ConvCudnnGradGradKernel, float, double, phi::dtype::float16) {} PD_REGISTER_KERNEL(conv3d_grad_grad, GPUDNN, ALL_LAYOUT, phi::Conv3DCudnnGradGradKernel, float, double, phi::dtype::float16) {} PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad, GPU, ALL_LAYOUT, phi::DepthwiseConvCudnnGradGradKernel, float, double, phi::dtype::float16) {} #endif #endif