/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DataLayout = platform::DataLayout; template static void DataTranspose(const framework::ExecutionContext& ctx, const Tensor* input, Tensor* output, const std::vector& axis, int flag = 0) { auto& dev_ctx = ctx.template device_context(); math::Transpose transpose; auto in_dims = input->dims(); std::vector input_transpose_vec; for (size_t i = 0; i < axis.size(); ++i) { if (flag == 0) input_transpose_vec.push_back(in_dims[axis[i]]); else input_transpose_vec.push_back(in_dims[i]); } framework::DDim input_transpose_dims( framework::make_ddim(input_transpose_vec)); output->mutable_data(input_transpose_dims, ctx.GetPlace()); transpose(dev_ctx, *input, output, axis); } template class CUDNNConvTransposeOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()), true, paddle::platform::errors::PreconditionNotMet("It must use CUDAPlace.")); auto* input = ctx.Input("Input"); auto* filter = ctx.Input("Filter"); auto* output = ctx.Output("Output"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::string padding_algorithm = ctx.Attr("padding_algorithm"); // cudnn v5 does not support dilations std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); const T* filter_data = filter->data(); const std::string data_layout_str = ctx.Attr("data_format"); const paddle::operators::DataLayout data_layout = (data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC); // if channel_last, transpose to channel_first Tensor input_transpose; std::vector input_vec = framework::vectorize(input->dims()); std::vector output_vec = framework::vectorize(output->dims()); if (data_layout == DataLayout::kNHWC) { if (strides.size() == 2U) { std::vector axis = {0, 3, 1, 2}; for (size_t i = 0; i < axis.size(); ++i) { input_vec[i] = input->dims()[axis[i]]; output_vec[i] = output->dims()[axis[i]]; } DataTranspose(ctx, input, &input_transpose, axis); } else if (strides.size() == 3U) { std::vector axis = {0, 4, 1, 2, 3}; for (size_t i = 0; i < axis.size(); ++i) { input_vec[i] = input->dims()[axis[i]]; output_vec[i] = output->dims()[axis[i]]; } DataTranspose(ctx, input, &input_transpose, axis); } } else { input_transpose = *input; } // update padding and dilation auto in_dims = input_transpose.dims(); auto filter_dims = filter->dims(); framework::DDim in_data_dims; in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); framework::DDim filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::vector ksize = framework::vectorize(filter_data_dims); UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); std::vector input_pad(input_transpose.dims().size() * 2, 0); Tensor transformed_input; std::vector padding_common(data_dim, 0); if (!is_sys_pad) { std::vector padding_diff(data_dim); std::vector new_input_shape_vec(data_dim + 2); new_input_shape_vec[0] = input_transpose.dims()[0]; new_input_shape_vec[1] = input_transpose.dims()[1]; for (size_t i = 0; i < data_dim; ++i) { padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); new_input_shape_vec[i + 2] = input_transpose.dims()[i + 2] + padding_diff[i]; input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } framework::DDim new_input_shape( framework::make_ddim(new_input_shape_vec)); transformed_input.Resize(new_input_shape); auto& dev_ctx = ctx.template device_context(); transformed_input = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); const int rank = input_transpose.dims().size(); T pad_value(0.0); switch (rank) { case 4: { math::PadFunction( ctx, input_pad, input_transpose, pad_value, &transformed_input); } break; case 5: { math::PadFunction( ctx, input_pad, input_transpose, pad_value, &transformed_input); } break; default: PADDLE_THROW(platform::errors::InvalidArgument( "Op(ConvTranspose) only supports 4-D or 5-D input Tensor.")); } } else { transformed_input = input_transpose; if (paddings.size() == data_dim) { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[i]; } } else { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[2 * i]; } } } std::vector starts(data_dim, 0); std::vector ends(data_dim, 0); std::vector axes(data_dim, 0); for (size_t i = 0; i < data_dim; ++i) { starts[i] = input_pad[2 * i + 4] * (strides[i] + 1); ends[i] = starts[i] + output_vec[i + 2]; axes[i] = i + 2; } const T* input_data = transformed_input.data(); input_vec = framework::vectorize(transformed_input.dims()); std::vector transformed_output_vec = output_vec; for (size_t i = 0; i < data_dim; ++i) { transformed_output_vec[i + 2] = output_vec[i + 2] + (input_pad[2 * i + 4] + input_pad[2 * i + 5]) * strides[i] - 2 * padding_common[i] + paddings[2 * i] + paddings[2 * i + 1]; } Tensor transformed_output; if (!is_sys_pad) { DDim transformed_output_shape( framework::make_ddim(transformed_output_vec)); transformed_output.mutable_data(transformed_output_shape, ctx.GetPlace()); } else { output->mutable_data(ctx.GetPlace()); transformed_output.ShareDataWith(*output); transformed_output.Resize(framework::make_ddim(transformed_output_vec)); } T* transformed_output_data = transformed_output.data(); DataLayout layout; int iwo_groups = groups; int c_groups = 1; #if CUDNN_VERSION_MIN(7, 0, 1) iwo_groups = 1; c_groups = groups; groups = 1; #endif if (strides.size() == 2U) { layout = DataLayout::kNCHW; } else { layout = DataLayout::kNCDHW; } size_t workspace_size = 0; cudnnConvolutionBwdDataAlgo_t algo{}; // ------------------- cudnn conv algorithm --------------------- auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); auto layout_tensor = GetCudnnTensorFormat(layout); bool deterministic = FLAGS_cudnn_deterministic; auto dtype = platform::CudnnDataType::type; // ------------------- cudnn descriptors --------------------- ConvArgs args{&transformed_output, filter, &transformed_input, strides, padding_common, dilations, dtype}; args.handle = handle; args.idesc.set(transformed_output, iwo_groups); args.wdesc.set(*filter, layout_tensor, iwo_groups); args.odesc.set(transformed_input, iwo_groups); args.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); using search = SearchAlgorithm; algo = search::Find(args, false, deterministic, ctx); workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args, algo)); // ------------------- cudnn conv transpose forward --------------------- int input_offset = transformed_input.numel() / transformed_input.dims()[0] / groups; int output_offset = transformed_output.numel() / transformed_output.dims()[0] / groups; int filter_offset = filter->numel() / groups; ScalingParamType alpha = 1.0f; ScalingParamType beta = 0.0f; auto workspace_handle = dev_ctx.cudnn_workspace_handle(); for (int g = 0; g < groups; g++) { auto cudnn_func = [&](void* cudnn_workspace) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionBackwardData( handle, &alpha, args.wdesc.desc(), filter_data + filter_offset * g, args.odesc.desc(), input_data + input_offset * g, args.cdesc.desc(), algo, cudnn_workspace, workspace_size, &beta, args.idesc.desc(), transformed_output_data + output_offset * g)); }; workspace_handle.RunFunc(cudnn_func, workspace_size); } if (!is_sys_pad && strides.size() == 2U) { Slice( ctx, &transformed_output, output, starts, ends, axes); } else if (!is_sys_pad && strides.size() == 3U) { Slice( ctx, &transformed_output, output, starts, ends, axes); } if (data_layout == DataLayout::kNHWC) { Tensor output_transpose; Tensor output_nchw; output_nchw.ShareDataWith(*output); output_nchw.Resize(framework::make_ddim(output_vec)); if (strides.size() == 2U) { std::vector axis = {0, 2, 3, 1}; DataTranspose(ctx, &output_nchw, &output_transpose, axis); *output = output_transpose; } else if (strides.size() == 3U) { std::vector axis = {0, 2, 3, 4, 1}; DataTranspose(ctx, &output_nchw, &output_transpose, axis); *output = output_transpose; } } } }; template class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()), true, paddle::platform::errors::PreconditionNotMet("It must use CUDAPlace.")); auto input = ctx.Input("Input"); auto filter = ctx.Input("Filter"); auto output_grad = ctx.Input(framework::GradVarName("Output")); auto input_grad = ctx.Output(framework::GradVarName("Input")); auto filter_grad = ctx.Output(framework::GradVarName("Filter")); const T* filter_data = filter->data(); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); // cudnn v5 does not support dilations std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); std::string padding_algorithm = ctx.Attr("padding_algorithm"); int user_workspace_size = ctx.Attr("workspace_size_MB"); const std::string data_layout_str = ctx.Attr("data_format"); const paddle::operators::DataLayout data_layout = (data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC); // if channel_last, transpose to channel_first Tensor input_transpose; Tensor output_grad_transpose; std::vector input_vec = framework::vectorize(input->dims()); std::vector output_vec = framework::vectorize(output_grad->dims()); if (data_layout == DataLayout::kNHWC) { if (strides.size() == 2U) { std::vector axis = {0, 3, 1, 2}; for (size_t i = 0; i < axis.size(); ++i) { input_vec[i] = input->dims()[axis[i]]; output_vec[i] = output_grad->dims()[axis[i]]; } DataTranspose(ctx, input, &input_transpose, axis); DataTranspose(ctx, output_grad, &output_grad_transpose, axis); } else if (strides.size() == 3U) { std::vector axis = {0, 4, 1, 2, 3}; for (size_t i = 0; i < axis.size(); ++i) { input_vec[i] = input->dims()[axis[i]]; output_vec[i] = output_grad->dims()[axis[i]]; } DataTranspose(ctx, input, &input_transpose, axis); DataTranspose(ctx, output_grad, &output_grad_transpose, axis); } } else { input_transpose = *input; output_grad_transpose = *output_grad; } // update padding and dilation auto in_dims = input_transpose.dims(); auto filter_dims = filter->dims(); framework::DDim in_data_dims; in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); framework::DDim filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::vector ksize = framework::vectorize(filter_data_dims); UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); std::vector input_pad(input_transpose.dims().size() * 2, 0); Tensor transformed_output_grad; std::vector padding_common(data_dim, 0); if (!is_sys_pad) { std::vector padding_diff(data_dim); std::vector new_output_grad_shape_vec(data_dim + 2); new_output_grad_shape_vec[0] = output_grad_transpose.dims()[0]; new_output_grad_shape_vec[1] = output_grad_transpose.dims()[1]; for (size_t i = 0; i < data_dim; ++i) { padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); new_output_grad_shape_vec[i + 2] = output_grad_transpose.dims()[i + 2] + padding_diff[i]; input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } framework::DDim new_output_grad_shape( framework::make_ddim(new_output_grad_shape_vec)); transformed_output_grad.Resize(new_output_grad_shape); auto& dev_ctx = ctx.template device_context(); transformed_output_grad = ctx.AllocateTmpTensor( new_output_grad_shape, dev_ctx); const int rank = input_transpose.dims().size(); T pad_value(0.0); switch (rank) { case 4: { math::PadFunction( ctx, input_pad, output_grad_transpose, pad_value, &transformed_output_grad); } break; case 5: { math::PadFunction( ctx, input_pad, output_grad_transpose, pad_value, &transformed_output_grad); } break; default: PADDLE_THROW(platform::errors::InvalidArgument( "Op(ConvTranspose) only supports 4-D or 5-D input Tensor.")); } } else { transformed_output_grad = output_grad_transpose; if (paddings.size() == data_dim) { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[i]; } } else { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[2 * i]; } } } const T* input_data = input_transpose.data(); const T* output_grad_data = transformed_output_grad.data(); output_vec = framework::vectorize(transformed_output_grad.dims()); // ------------------- cudnn descriptors --------------------- DataLayout layout; if (strides.size() == 2U) { layout = DataLayout::kNCHW; } else { layout = DataLayout::kNCDHW; } int iwo_groups = groups; int c_groups = 1; #if CUDNN_VERSION_MIN(7, 0, 1) iwo_groups = 1; c_groups = groups; groups = 1; #endif auto dtype = platform::CudnnDataType::type; ConvArgs args1{&transformed_output_grad, filter, &input_transpose, strides, padding_common, dilations, dtype}; ConvArgs args2{&transformed_output_grad, filter, &input_transpose, strides, padding_common, dilations, dtype}; cudnnConvolutionFwdAlgo_t data_algo{}; cudnnConvolutionBwdFilterAlgo_t filter_algo{}; auto layout_tensor = GetCudnnTensorFormat(layout); size_t workspace_size = 0; auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); bool deterministic = FLAGS_cudnn_deterministic; T* input_grad_data = nullptr; T* filter_grad_data = nullptr; if (input_grad) input_grad_data = input_grad->mutable_data(ctx.GetPlace()); if (filter_grad) filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); if (input_grad) { input_grad_data = input_grad->mutable_data(ctx.GetPlace()); args1.handle = handle; args1.idesc.set(transformed_output_grad, iwo_groups); args1.wdesc.set(*filter, layout_tensor, iwo_groups); args1.odesc.set(input_transpose, iwo_groups); args1.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); using search1 = SearchAlgorithm; data_algo = search1::Find(args1, false, deterministic, ctx); workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); } if (filter_grad) { filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); args2.handle = handle; args2.idesc.set(transformed_output_grad, iwo_groups); args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups); args2.odesc.set(input_transpose, iwo_groups); args2.cdesc.set(dtype, padding_common, strides, dilations, platform::AllowTF32Cudnn(), c_groups); using search2 = SearchAlgorithm; filter_algo = search2::Find(args2, false, deterministic, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo)); } // ------------------- cudnn conv backward data --------------------- // FIXME(typhoonzero): template type T may not be the same as cudnn call. int input_offset = input->numel() / input->dims()[0] / groups; int output_grad_offset = transformed_output_grad.numel() / transformed_output_grad.dims()[0] / groups; int filter_offset = filter->numel() / groups; ScalingParamType alpha = 1.0f; ScalingParamType beta = 0.0f; auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { // Because beta is zero, it is unnecessary to reset input_grad. for (int g = 0; g < groups; g++) { auto cudnn_func = [&](void* cudnn_workspace) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionForward( handle, &alpha, args1.idesc.desc(), output_grad_data + output_grad_offset * g, args1.wdesc.desc(), filter_data + filter_offset * g, args1.cdesc.desc(), data_algo, cudnn_workspace, workspace_size, &beta, args1.odesc.desc(), input_grad_data + input_offset * g)); }; workspace_handle.RunFunc(cudnn_func, workspace_size); } if (data_layout == DataLayout::kNHWC) { Tensor input_grad_transpose; Tensor input_grad_nchw; input_grad_nchw.ShareDataWith(*input_grad); input_grad_nchw.Resize(framework::make_ddim(input_vec)); if (strides.size() == 2U) { std::vector axis = {0, 2, 3, 1}; DataTranspose(ctx, &input_grad_nchw, &input_grad_transpose, axis); *input_grad = input_grad_transpose; } else if (strides.size() == 3U) { std::vector axis = {0, 2, 3, 4, 1}; DataTranspose(ctx, &input_grad_nchw, &input_grad_transpose, axis); *input_grad = input_grad_transpose; } } } // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { // Because beta is zero, it is unnecessary to reset filter_grad. // Gradient with respect to the filter for (int g = 0; g < groups; g++) { auto cudnn_func = [&](void* cudnn_workspace) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, args2.idesc.desc(), output_grad_data + output_grad_offset * g, args2.odesc.desc(), input_data + input_offset * g, args2.cdesc.desc(), filter_algo, cudnn_workspace, workspace_size, &beta, args2.wdesc.desc(), filter_grad_data + filter_offset * g)); }; workspace_handle.RunFunc(cudnn_func, workspace_size); } } } }; /* * Inputs: I, W, dO, ddI, ddW * Outputs: ddO, dW, dI * ddo = conv_bp_data(W, ddI) + conv_bp_data(ddW, I) * dW = conv_bp_filter(dO, ddI) * dI = conv(dO, ddW) */ template class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()), true, paddle::platform::errors::PreconditionNotMet("It must use CUDAPlace.")); auto X = ctx.Input("Input"); auto W = ctx.Input("Filter"); auto dO = ctx.Input("DOutput"); auto ddX = ctx.Input("DDInput"); auto ddW = ctx.Input("DDFilter"); auto ddO = ctx.Output("DDOutput"); auto dW = ctx.Output("DFilter"); auto dX = ctx.Output("DInput"); if (ddO) { ddO->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; set_zero(dev_ctx, ddO, static_cast(0)); } if (dW) { dW->mutable_data(ctx.GetPlace()); } if (dX) { dX->mutable_data(ctx.GetPlace()); } const T* dy = dO->data(); const T* w = W->data(); const T* ddx = nullptr; const T* ddw = nullptr; T *dw, *dx, *ddy; dw = dx = ddy = nullptr; T* transformed_dx = nullptr; const std::vector& strides = ctx.Attr>("strides"); std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); bool deterministic = FLAGS_cudnn_deterministic; std::vector paddings = ctx.Attr>("paddings"); std::string padding_algorithm = ctx.Attr("padding_algorithm"); const std::string data_format = ctx.Attr("data_format"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); // transform Tensors to channel first----------- Tensor transformed_X_channel(X->type()); Tensor transformed_dO_channel(dO->type()); Tensor transformed_ddX_channel(X->type()); Tensor transformed_ddO_channel(dO->type()); Tensor transformed_dX_channel(X->type()); if (channel_last) { ResizeToChannelFirst( ctx, X, &transformed_X_channel); TransToChannelFirst( ctx, X, &transformed_X_channel); ResizeToChannelFirst( ctx, dO, &transformed_dO_channel); TransToChannelFirst( ctx, dO, &transformed_dO_channel); if (ddX) { ResizeToChannelFirst( ctx, ddX, &transformed_ddX_channel); TransToChannelFirst( ctx, ddX, &transformed_ddX_channel); } if (ddO) { ResizeToChannelFirst( ctx, ddO, &transformed_ddO_channel); } if (dX) { ResizeToChannelFirst( ctx, dX, &transformed_dX_channel); transformed_dX_channel.mutable_data(ctx.GetPlace()); } } else { transformed_X_channel = *X; transformed_dO_channel = *dO; if (ddX) { transformed_ddX_channel = *ddX; } if (dX) { transformed_dX_channel = *dX; } } std::vector output_vec = framework::vectorize(transformed_dO_channel.dims()); auto in_dims = transformed_X_channel.dims(); auto filter_dims = W->dims(); framework::DDim in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); framework::DDim filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::vector ksize = framework::vectorize(filter_data_dims); UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); Tensor transformed_X(X->type()); Tensor transformed_ddX(X->type()); Tensor transformed_dO(dO->type()); std::vector padding_common(data_dim, 0); std::vector input_pad(X->dims().size() * 2, 0); if (!is_sys_pad) { // get pad std::vector padding_diff(data_dim); std::vector new_input_shape_vec(data_dim + 2); std::vector new_output_grad_shape_vec(data_dim + 2); new_input_shape_vec[0] = transformed_X_channel.dims()[0]; new_input_shape_vec[1] = transformed_X_channel.dims()[1]; new_output_grad_shape_vec[0] = transformed_dO_channel.dims()[0]; new_output_grad_shape_vec[1] = transformed_dO_channel.dims()[1]; for (size_t i = 0; i < data_dim; ++i) { padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); new_input_shape_vec[i + 2] = transformed_X_channel.dims()[i + 2] + padding_diff[i]; new_output_grad_shape_vec[i + 2] = transformed_dO_channel.dims()[i + 2] + padding_diff[i]; input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } framework::DDim new_input_shape( framework::make_ddim(new_input_shape_vec)); transformed_X.Resize(new_input_shape); transformed_ddX.Resize(new_input_shape); framework::DDim new_output_grad_shape( framework::make_ddim(new_output_grad_shape_vec)); transformed_dO.Resize(new_output_grad_shape); transformed_dO = ctx.AllocateTmpTensor( new_output_grad_shape, dev_ctx); transformed_X = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); if (ddX) { transformed_ddX = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); } // pad for input const int rank = X->dims().size(); T pad_value(0.0); switch (rank) { case 4: { math::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); if (dO) { math::PadFunction( ctx, input_pad, transformed_dO_channel, pad_value, &transformed_dO); } if (ddX) { math::PadFunction( ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; case 5: { math::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); if (ddX) { math::PadFunction( ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; default: PADDLE_THROW(platform::errors::InvalidArgument( "ConvOp only support tensors with 4 or 5 dimensions.")); } } else { transformed_X = transformed_X_channel; transformed_dO = transformed_dO_channel; if (ddX) { transformed_ddX = transformed_ddX_channel; } if (paddings.size() == data_dim) { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[i]; } } else { for (size_t i = 0; i < data_dim; ++i) { padding_common[i] = paddings[2 * i]; } } } std::vector starts(data_dim, 0); std::vector ends(data_dim, 0); std::vector axes(data_dim, 0); for (size_t i = 0; i < data_dim; ++i) { starts[i] = input_pad[2 * i + 4] * (strides[i] + 1); ends[i] = starts[i] + output_vec[i + 2]; axes[i] = i + 2; } std::vector transformed_output_vec = output_vec; for (size_t i = 0; i < data_dim; ++i) { transformed_output_vec[i + 2] = output_vec[i + 2] + (input_pad[2 * i + 4] + input_pad[2 * i + 5]) * strides[i] - 2 * padding_common[i] + paddings[2 * i] + paddings[2 * i + 1]; } if (!is_sys_pad) { DDim transformed_output_shape( framework::make_ddim(transformed_output_vec)); transformed_ddO_channel.mutable_data(transformed_output_shape, ctx.GetPlace()); } else { ddO->mutable_data(ctx.GetPlace()); transformed_ddO_channel = *ddO; transformed_ddO_channel.Resize( framework::make_ddim(transformed_output_vec)); } const T* x = transformed_X.data(); int iwo_group = groups; int c_group = 1; #if CUDNN_VERSION_MIN(7, 0, 1) iwo_group = 1; c_group = groups; groups = 1; #endif auto dtype = platform::CudnnDataType::type; auto handle = dev_ctx.cudnn_handle(); ConvArgs args1{&transformed_ddO_channel, W, &transformed_ddX, strides, padding_common, dilations, dtype}; ConvArgs args2{&transformed_ddO_channel, ddW, &transformed_X, strides, padding_common, dilations, dtype}; ConvArgs args3{&transformed_dO, dW, &transformed_ddX_channel, strides, padding_common, dilations, dtype}; ConvArgs args4{ &transformed_dO, ddW, &transformed_dX_channel, strides, padding_common, dilations, dtype}; cudnnConvolutionBwdDataAlgo_t bwd_algo1 = static_cast(0); cudnnConvolutionBwdDataAlgo_t bwd_algo2 = static_cast(0); cudnnConvolutionFwdAlgo_t data_algo = static_cast(0); cudnnConvolutionBwdFilterAlgo_t filter_algo = static_cast(0); auto layout = GetCudnnTensorFormat(DataLayout::kNCHW); // ddo = conv(ddI, W) + conv(I, ddW) size_t workspace_size = 0; T* transformed_ddy_channel = nullptr; if (ddO) { ddy = ddO->data(); transformed_ddy_channel = transformed_ddO_channel.data(); if (ddX) { args1.handle = handle; args1.idesc.set(transformed_ddO_channel, iwo_group); args1.wdesc.set(*W, layout, iwo_group); args1.odesc.set(transformed_ddX, iwo_group); args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search1 = SearchAlgorithm; bwd_algo1 = search1::Find(args1, false, deterministic, ctx); workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1); } if (ddW) { ddw = ddW->data(); args2.handle = handle; args2.idesc.set(transformed_ddO_channel, iwo_group); args2.wdesc.set(*ddW, layout, iwo_group); args2.odesc.set(transformed_X, iwo_group); args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search2 = SearchAlgorithm; bwd_algo2 = search2::Find(args2, false, deterministic, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, bwd_algo2)); } } if (dW && ddX) { dw = dW->data(); args3.handle = handle; args3.idesc.set(transformed_dO, iwo_group); args3.wdesc.set(*dW, layout, iwo_group); args3.odesc.set(transformed_ddX_channel, iwo_group); args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search3 = SearchAlgorithm; filter_algo = search3::Find(args3, false, deterministic, ctx); workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); } if (ddW && dX) { transformed_dx = transformed_dX_channel.data(); args4.handle = handle; args4.idesc.set(transformed_dO, iwo_group); args4.wdesc.set(*ddW, layout, iwo_group); args4.odesc.set(transformed_dX_channel, iwo_group); args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search4 = SearchAlgorithm; data_algo = search4::Find(args4, false, deterministic, ctx); workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); } int i_n, i_c, i_d, i_h, i_w; GetNCDHW(transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); int o_n, o_c, o_d, o_h, o_w; GetNCDHW(transformed_dO.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); int group_offset_in = transformed_X.numel() / transformed_X.dims()[0] / groups; int group_offset_out = transformed_dO.numel() / transformed_dO.dims()[0] / groups; int group_offset_filter = W->numel() / groups; ScalingParamType alpha = 1.0f; ScalingParamType beta = 0.0f; auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); if (ddO) { if (ddX) { ddx = transformed_ddX.data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionBackwardData( handle, &alpha, args1.wdesc.desc(), w + i * group_offset_filter, args1.odesc.desc(), ddx + i * group_offset_in, args1.cdesc.desc(), bwd_algo1, workspace_ptr, workspace_size, &beta, args1.idesc.desc(), transformed_ddy_channel + i * group_offset_out)); }, workspace_size); } } if (ddW) { for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionBackwardData( handle, &alpha, args2.wdesc.desc(), ddw + i * group_offset_filter, args2.odesc.desc(), x + i * group_offset_in, args2.cdesc.desc(), bwd_algo2, workspace_ptr, workspace_size, &alpha, args2.idesc.desc(), transformed_ddy_channel + i * group_offset_out)); }, workspace_size); } } if ((!is_sys_pad) && (!channel_last)) { if (strides.size() == 2U) { Slice( ctx, &transformed_ddO_channel, ddO, starts, ends, axes); } else if (!is_sys_pad && strides.size() == 3U) { Slice( ctx, &transformed_ddO_channel, ddO, starts, ends, axes); } } else if ((!is_sys_pad) && (channel_last)) { if (strides.size() == 2U) { Slice( ctx, &transformed_ddO_channel, &transformed_ddO_channel, starts, ends, axes); } else if (!is_sys_pad && strides.size() == 3U) { Slice( ctx, &transformed_ddO_channel, &transformed_ddO_channel, starts, ends, axes); } TransToChannelLast( ctx, &transformed_ddO_channel, ddO); } } T* transformed_dy_channel = transformed_dO.data(); if (dW && ddX) { ddx = transformed_ddX_channel.data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, args3.idesc.desc(), transformed_dy_channel + i * group_offset_out, args3.odesc.desc(), ddx + i * group_offset_in, args3.cdesc.desc(), filter_algo, workspace_ptr, workspace_size, &beta, args3.wdesc.desc(), dw + i * group_offset_filter)); }, workspace_size); } } if (dX && ddW) { ddw = ddW->data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnConvolutionForward( handle, &alpha, args4.idesc.desc(), transformed_dy_channel + i * group_offset_out, args4.wdesc.desc(), ddw + i * group_offset_filter, args4.cdesc.desc(), data_algo, workspace_ptr, workspace_size, &beta, args4.odesc.desc(), transformed_dx + i * group_offset_in)); }, workspace_size); } if (channel_last) { TransToChannelLast( ctx, &transformed_dX_channel, dX); } } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeOpKernel, ops::CUDNNConvTransposeOpKernel, ops::CUDNNConvTransposeOpKernel); REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel); REGISTER_OP_KERNEL( conv2d_transpose_grad_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvTransposeDoubleGradOpKernel, paddle::operators::CUDNNConvTransposeDoubleGradOpKernel, paddle::operators::CUDNNConvTransposeDoubleGradOpKernel); REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeOpKernel, ops::CUDNNConvTransposeOpKernel, ops::CUDNNConvTransposeOpKernel); REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel);