diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu b/paddle/fluid/operators/conv_transpose_cudnn_op.cu index 5249264b1c9bcf13c5ee8227828087659de5254b..94148109c7369fa15572e3e9d27912c82cdb150e 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu @@ -551,6 +551,487 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { } }; +/* + * Inputs: I, W, dO, ddI, ddW + * Outputs: ddO, dW, dI + * ddo = conv_bp_data(W, ddI) + conv_bp_data(ddW, I) + * dW = conv_bp_filter(dO, ddI) + * dI = conv(dO, ddW) + */ +template +class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + paddle::platform::errors::PreconditionNotMet("It must use CUDAPlace.")); + auto X = ctx.Input("Input"); + auto W = ctx.Input("Filter"); + auto dO = ctx.Input("DOutput"); + auto ddX = ctx.Input("DDInput"); + auto ddW = ctx.Input("DDFilter"); + + auto ddO = ctx.Output("DDOutput"); + auto dW = ctx.Output("DFilter"); + auto dX = ctx.Output("DInput"); + + if (ddO) { + ddO->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(dev_ctx, ddO, static_cast(0)); + } + if (dW) { + dW->mutable_data(ctx.GetPlace()); + } + if (dX) { + dX->mutable_data(ctx.GetPlace()); + } + + const T* dy = dO->data(); + const T* w = W->data(); + + const T* ddx = nullptr; + const T* ddw = nullptr; + T *dw, *dx, *ddy; + dw = dx = ddy = nullptr; + T* transformed_dx = nullptr; + const std::vector& strides = ctx.Attr>("strides"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + + bool deterministic = FLAGS_cudnn_deterministic; + + std::vector paddings = ctx.Attr>("paddings"); + + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + // transform Tensors to channel first----------- + Tensor transformed_X_channel(X->type()); + Tensor transformed_dO_channel(dO->type()); + Tensor transformed_ddX_channel(X->type()); + + Tensor transformed_ddO_channel(dO->type()); + Tensor transformed_dX_channel(X->type()); + + if (channel_last) { + ResizeToChannelFirst( + ctx, X, &transformed_X_channel); + TransToChannelFirst( + ctx, X, &transformed_X_channel); + + ResizeToChannelFirst( + ctx, dO, &transformed_dO_channel); + TransToChannelFirst( + ctx, dO, &transformed_dO_channel); + + if (ddX) { + ResizeToChannelFirst( + ctx, ddX, &transformed_ddX_channel); + TransToChannelFirst( + ctx, ddX, &transformed_ddX_channel); + } + + if (ddO) { + ResizeToChannelFirst( + ctx, ddO, &transformed_ddO_channel); + } + if (dX) { + ResizeToChannelFirst( + ctx, dX, &transformed_dX_channel); + transformed_dX_channel.mutable_data(ctx.GetPlace()); + } + + } else { + transformed_X_channel = *X; + transformed_dO_channel = *dO; + if (ddX) { + transformed_ddX_channel = *ddX; + } + if (dX) { + transformed_dX_channel = *dX; + } + } + std::vector output_vec = + framework::vectorize(transformed_dO_channel.dims()); + + auto in_dims = transformed_X_channel.dims(); + auto filter_dims = W->dims(); + framework::DDim in_data_dims = + framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + Tensor transformed_X(X->type()); + Tensor transformed_ddX(X->type()); + + Tensor transformed_dO(dO->type()); + + std::vector padding_common(data_dim, 0); + std::vector input_pad(X->dims().size() * 2, 0); + + if (!is_sys_pad) { + // get pad + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + std::vector new_output_grad_shape_vec(data_dim + 2); + + new_input_shape_vec[0] = transformed_X_channel.dims()[0]; + new_input_shape_vec[1] = transformed_X_channel.dims()[1]; + + new_output_grad_shape_vec[0] = transformed_dO_channel.dims()[0]; + new_output_grad_shape_vec[1] = transformed_dO_channel.dims()[1]; + + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + new_input_shape_vec[i + 2] = + transformed_X_channel.dims()[i + 2] + padding_diff[i]; + + new_output_grad_shape_vec[i + 2] = + transformed_dO_channel.dims()[i + 2] + padding_diff[i]; + + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + framework::DDim new_input_shape( + framework::make_ddim(new_input_shape_vec)); + transformed_X.Resize(new_input_shape); + transformed_ddX.Resize(new_input_shape); + + framework::DDim new_output_grad_shape( + framework::make_ddim(new_output_grad_shape_vec)); + transformed_dO.Resize(new_output_grad_shape); + + transformed_dO = + ctx.AllocateTmpTensor( + new_output_grad_shape, dev_ctx); + + transformed_X = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + if (ddX) { + transformed_ddX = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + } + + // pad for input + const int rank = X->dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + math::PadFunction( + ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + if (dO) { + math::PadFunction( + ctx, input_pad, transformed_dO_channel, pad_value, + &transformed_dO); + } + + if (ddX) { + math::PadFunction( + ctx, input_pad, transformed_ddX_channel, pad_value, + &transformed_ddX); + } + } break; + case 5: { + math::PadFunction( + ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + if (ddX) { + math::PadFunction( + ctx, input_pad, transformed_ddX_channel, pad_value, + &transformed_ddX); + } + } break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "ConvOp only support tensors with 4 or 5 dimensions.")); + } + + } else { + transformed_X = transformed_X_channel; + transformed_dO = transformed_dO_channel; + if (ddX) { + transformed_ddX = transformed_ddX_channel; + } + + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + std::vector starts(data_dim, 0); + std::vector ends(data_dim, 0); + std::vector axes(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + starts[i] = input_pad[2 * i + 4] * (strides[i] + 1); + ends[i] = starts[i] + output_vec[i + 2]; + axes[i] = i + 2; + } + + std::vector transformed_output_vec = output_vec; + for (size_t i = 0; i < data_dim; ++i) { + transformed_output_vec[i + 2] = + output_vec[i + 2] + + (input_pad[2 * i + 4] + input_pad[2 * i + 5]) * strides[i] - + 2 * padding_common[i] + paddings[2 * i] + paddings[2 * i + 1]; + } + + if (!is_sys_pad) { + DDim transformed_output_shape( + framework::make_ddim(transformed_output_vec)); + transformed_ddO_channel.mutable_data(transformed_output_shape, + ctx.GetPlace()); + } else { + ddO->mutable_data(ctx.GetPlace()); + transformed_ddO_channel = *ddO; + transformed_ddO_channel.Resize( + framework::make_ddim(transformed_output_vec)); + } + + const T* x = transformed_X.data(); + + int iwo_group = groups; + int c_group = 1; +#if CUDNN_VERSION_MIN(7, 0, 1) + iwo_group = 1; + c_group = groups; + groups = 1; +#endif + auto dtype = platform::CudnnDataType::type; + + auto handle = dev_ctx.cudnn_handle(); + + ConvArgs args1{&transformed_ddO_channel, + W, + &transformed_ddX, + strides, + padding_common, + dilations, + dtype}; + ConvArgs args2{&transformed_ddO_channel, ddW, &transformed_X, strides, + padding_common, dilations, dtype}; + + ConvArgs args3{&transformed_dO, + dW, + &transformed_ddX_channel, + strides, + padding_common, + dilations, + dtype}; + ConvArgs args4{ + &transformed_dO, ddW, &transformed_dX_channel, strides, padding_common, + dilations, dtype}; + + cudnnConvolutionBwdDataAlgo_t bwd_algo1 = + static_cast(0); + cudnnConvolutionBwdDataAlgo_t bwd_algo2 = + static_cast(0); + cudnnConvolutionFwdAlgo_t data_algo = + static_cast(0); + cudnnConvolutionBwdFilterAlgo_t filter_algo = + static_cast(0); + + auto layout = GetCudnnTensorFormat(DataLayout::kNCHW); + + // ddo = conv(ddI, W) + conv(I, ddW) + size_t workspace_size = 0; + + T* transformed_ddy_channel = nullptr; + + if (ddO) { + ddy = ddO->data(); + transformed_ddy_channel = transformed_ddO_channel.data(); + if (ddX) { + args1.handle = handle; + args1.idesc.set(transformed_ddO_channel, iwo_group); + args1.wdesc.set(*W, layout, iwo_group); + args1.odesc.set(transformed_ddX, iwo_group); + args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); + using search1 = SearchAlgorithm; + bwd_algo1 = search1::Find(args1, false, deterministic, ctx); + workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1); + } + + if (ddW) { + ddw = ddW->data(); + args2.handle = handle; + args2.idesc.set(transformed_ddO_channel, iwo_group); + args2.wdesc.set(*ddW, layout, iwo_group); + args2.odesc.set(transformed_X, iwo_group); + args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); + using search2 = SearchAlgorithm; + bwd_algo2 = search2::Find(args2, false, deterministic, ctx); + workspace_size = std::max(workspace_size, + search2::GetWorkspaceSize(args2, bwd_algo2)); + } + } + + if (dW && ddX) { + dw = dW->data(); + args3.handle = handle; + args3.idesc.set(transformed_dO, iwo_group); + args3.wdesc.set(*dW, layout, iwo_group); + + args3.odesc.set(transformed_ddX_channel, iwo_group); + + args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); + + using search3 = SearchAlgorithm; + filter_algo = search3::Find(args3, false, deterministic, ctx); + workspace_size = std::max(workspace_size, + search3::GetWorkspaceSize(args3, filter_algo)); + } + + if (ddW && dX) { + transformed_dx = transformed_dX_channel.data(); + + args4.handle = handle; + args4.idesc.set(transformed_dO, iwo_group); + args4.wdesc.set(*ddW, layout, iwo_group); + args4.odesc.set(transformed_dX_channel, iwo_group); + args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); + + using search4 = SearchAlgorithm; + data_algo = search4::Find(args4, false, deterministic, ctx); + workspace_size = + std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); + } + + int i_n, i_c, i_d, i_h, i_w; + GetNCDHW(transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, + &i_w); + + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(transformed_dO.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, + &o_w); + + int group_offset_in = + transformed_X.numel() / transformed_X.dims()[0] / groups; + int group_offset_out = + transformed_dO.numel() / transformed_dO.dims()[0] / groups; + int group_offset_filter = W->numel() / groups; + + ScalingParamType alpha = 1.0f; + ScalingParamType beta = 0.0f; + + auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); + + if (ddO) { + if (ddX) { + ddx = transformed_ddX.data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, args1.wdesc.desc(), + w + i * group_offset_filter, args1.odesc.desc(), + ddx + i * group_offset_in, args1.cdesc.desc(), + bwd_algo1, workspace_ptr, workspace_size, &beta, + args1.idesc.desc(), + transformed_ddy_channel + i * group_offset_out)); + }, + workspace_size); + } + } + if (ddW) { + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, args2.wdesc.desc(), + ddw + i * group_offset_filter, args2.odesc.desc(), + x + i * group_offset_in, args2.cdesc.desc(), bwd_algo2, + workspace_ptr, workspace_size, &alpha, + args2.idesc.desc(), + transformed_ddy_channel + i * group_offset_out)); + }, + workspace_size); + } + } + if ((!is_sys_pad) && (!channel_last)) { + if (strides.size() == 2U) { + Slice( + ctx, &transformed_ddO_channel, ddO, starts, ends, axes); + } else if (!is_sys_pad && strides.size() == 3U) { + Slice( + ctx, &transformed_ddO_channel, ddO, starts, ends, axes); + } + } else if ((!is_sys_pad) && (channel_last)) { + if (strides.size() == 2U) { + Slice( + ctx, &transformed_ddO_channel, &transformed_ddO_channel, starts, + ends, axes); + } else if (!is_sys_pad && strides.size() == 3U) { + Slice( + ctx, &transformed_ddO_channel, &transformed_ddO_channel, starts, + ends, axes); + } + + TransToChannelLast( + ctx, &transformed_ddO_channel, ddO); + } + } + + T* transformed_dy_channel = transformed_dO.data(); + if (dW && ddX) { + ddx = transformed_ddX_channel.data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, args3.idesc.desc(), + transformed_dy_channel + i * group_offset_out, + args3.odesc.desc(), ddx + i * group_offset_in, + args3.cdesc.desc(), filter_algo, workspace_ptr, + workspace_size, &beta, args3.wdesc.desc(), + dw + i * group_offset_filter)); + }, + workspace_size); + } + } + + if (dX && ddW) { + ddw = ddW->data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnConvolutionForward( + handle, &alpha, args4.idesc.desc(), + transformed_dy_channel + i * group_offset_out, + args4.wdesc.desc(), ddw + i * group_offset_filter, + args4.cdesc.desc(), data_algo, workspace_ptr, + workspace_size, &beta, args4.odesc.desc(), + transformed_dx + i * group_offset_in)); + }, + workspace_size); + } + if (channel_last) { + TransToChannelLast( + ctx, &transformed_dX_channel, dX); + } + } + } +}; + } // namespace operators } // namespace paddle @@ -565,6 +1046,11 @@ REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel, ops::CUDNNConvTransposeGradOpKernel); +REGISTER_OP_KERNEL( + conv2d_transpose_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvTransposeDoubleGradOpKernel, + paddle::operators::CUDNNConvTransposeDoubleGradOpKernel, + paddle::operators::CUDNNConvTransposeDoubleGradOpKernel); REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace, ops::CUDNNConvTransposeOpKernel, diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 6c4844855591911c025230822768d091826cb794..a4f00f6cd809b8475ac8e39daae97b99a7cc87b3 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -513,6 +513,85 @@ class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker { } }; +/* + * Inputs: I, W, dO, ddI, ddW + * Outputs: ddO, dW, dI + */ +template +class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + // I, W, dO, ddI, ddW + op->SetInput("Input", this->Input("Input")); + op->SetInput("Filter", this->Input("Filter")); + op->SetInput("DOutput", this->Input(framework::GradVarName("Output"))); + op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input"))); + op->SetInput("DDFilter", + this->OutputGrad(framework::GradVarName("Filter"))); + + // ddO, dI, dW + // Unlike grad op, double grad op does not use name@GRAD@GRAD + // as key of ops' inputs and outputs. + auto ddx = this->OutputGrad(framework::GradVarName("Input")); + auto ddw = this->OutputGrad(framework::GradVarName("Filter")); + + op->SetOutput("DDOutput", + ddx.empty() + ? this->EmptyInputGrad() + : this->InputGrad(framework::GradVarName("Output"))); + op->SetOutput("DFilter", ddx.empty() ? this->EmptyInputGrad() + : this->InputGrad("Filter")); + op->SetOutput("DInput", ddw.empty() ? this->EmptyInputGrad() + : this->InputGrad("Input")); + + op->SetAttrMap(this->Attrs()); + } +}; + +void ConvTransposeOpDoubleGrad::InferShape( + framework::InferShapeContext* ctx) const { + auto x_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("Filter"); + auto do_dims = ctx->GetInputDim("DOutput"); + + if (ctx->HasOutput("DDOutput") && + (ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) { + ctx->SetOutputDim("DDOutput", do_dims); + } + if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) { + ctx->SetOutputDim("DFilter", w_dims); + } + if (ctx->HasOutput("DInput") && ctx->HasInput("DDFilter")) { + ctx->SetOutputDim("DInput", x_dims); + } +} + +framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = ctx.template device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + framework::LibraryType library_; + if (use_cudnn) { + library_ = framework::LibraryType::kCUDNN; + } else { + library_ = framework::LibraryType::kPlain; + } + + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); +} + } // namespace operators } // namespace paddle @@ -523,7 +602,11 @@ REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, ops::ConvTransposeGradOpMaker, ops::ConvTransposeGradOpMaker); -REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad); +REGISTER_OPERATOR( + conv2d_transpose_grad, ops::ConvTransposeOpGrad, + ops::ConvTransposeDoubleGradMaker, + ops::ConvTransposeDoubleGradMaker); +REGISTER_OPERATOR(conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad); REGISTER_OP_CPU_KERNEL( conv2d_transpose, diff --git a/paddle/fluid/operators/conv_transpose_op.cu b/paddle/fluid/operators/conv_transpose_op.cu index a6d5665df83ae5c89d42840e91a6abd853fedd12..b2a4910222f1178d23e94eade9580248bb103c88 100644 --- a/paddle/fluid/operators/conv_transpose_op.cu +++ b/paddle/fluid/operators/conv_transpose_op.cu @@ -24,6 +24,9 @@ REGISTER_OP_CUDA_KERNEL(conv2d_transpose, REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad, ops::GemmConvTransposeGradKernel, ops::GemmConvTransposeGradKernel); +REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad_grad, + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); // conv3d REGISTER_OP_CUDA_KERNEL(conv3d_transpose, diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 1ea869e002af3ac8157321c66616b82517e4fabc..651719f1052806ad356f2bc8fd4c2f3a0abe210b 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -114,6 +114,16 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override; }; +class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + template class GemmConvTransposeKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/tests/unittests/test_conv_transpose_nn_grad.py b/python/paddle/fluid/tests/unittests/test_conv_transpose_nn_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..110cfc47cae4126dddbb6c3c68c0fe2e3bb42def --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_conv_transpose_nn_grad.py @@ -0,0 +1,159 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +import gradient_checker + +from decorator_helper import prog_scope + + +class TestConvTransposeDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 4, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d_transpose( + x, 2, filter_size=1, groups=1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [] + + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConvTranspose2DoubleGradCheck_AsyPadding( + TestConvTransposeDoubleGradCheck): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d_transpose( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 0, 0, 1], + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + +class TestConvTranspose2DoubleGradCheck_PaddingSAME( + TestConvTransposeDoubleGradCheck): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d_transpose( + input=x, + num_filters=2, + filter_size=1, + padding="SAME", + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + +class TestConvTranspose2DoubleGradCheck_PaddingVALID( + TestConvTransposeDoubleGradCheck): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d_transpose( + input=x, + num_filters=2, + filter_size=1, + padding="VALID", + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + +class TestConvTranspose2DoubleGradCheck_ChannelLast( + TestConvTransposeDoubleGradCheck): + @prog_scope() + def func(self, place): + shape = [2, 3, 3, 2] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d_transpose( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 1], + bias_attr=False, + use_cudnn=True, + groups=1, + data_format="NHWC") + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 05dfc9c621ee1ea37e437932b7b87884509da2e4..7d9f44f90503511083a16bcc77c850dac3cd002a 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -108,6 +108,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_conv3d_transpose_layer', 'test_conv3d_transpose_part2_op', 'test_conv_nn_grad', + 'test_conv_transpose_nn_grad', 'test_conv_shift_op', 'test_cos_sim_op', 'test_create_global_var',