diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 40fc14231c0526eeaf9222d7ccca9aadfb218930..37bbcbf11a9a94bbffc4cbb822f20d8d4a4f66dd 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -153,8 +153,8 @@ paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'moment paddle.fluid.layers.instance_norm (ArgSpec(args=['input', 'epsilon', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None)), ('document', '02972097e089629efdb0ed9404fd36ae')) paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787')) paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '83e08f21af41ac8bac37aeab1f86fdd0')) -paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'ab58296b567bf0c686084add7f3280a4')) -paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'fe15dbfb17d97d3d29b2fa7ee6390ee6')) +paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCHW')), ('document', '9391d75358b6cba0cc5d22a01a223420')) +paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCDHW')), ('document', '74bce3cd4224e6ff133d54508dc7f150')) paddle.fluid.layers.sequence_expand (ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '10e122eb755c2bd1f78ef2332b28f1a0')) paddle.fluid.layers.sequence_expand_as (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '858c432e7cbd8bb952cc2eb555457d50')) paddle.fluid.layers.sequence_pad (ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'df08b9c499ab3a90f95d08ab5b6c6c62')) diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu b/paddle/fluid/operators/conv_transpose_cudnn_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..0dc80d8f29c32e3dd087100df5d3eee09b5e65b8 --- /dev/null +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu @@ -0,0 +1,586 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/conv_transpose_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/padding.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using DataLayout = platform::DataLayout; + +static constexpr size_t kConvCUDNNWorkspaceLimitBytes = 1024 * 1024 * 1024; + +template +static void DataTranspose(const framework::ExecutionContext& ctx, + const Tensor* input, Tensor* output, + const std::vector& axis, int flag = 0) { + auto& dev_ctx = ctx.template device_context(); + math::Transpose transpose; + auto in_dims = input->dims(); + std::vector input_transpose_vec; + for (size_t i = 0; i < axis.size(); ++i) { + if (flag == 0) + input_transpose_vec.push_back(in_dims[axis[i]]); + else + input_transpose_vec.push_back(in_dims[i]); + } + framework::DDim input_transpose_dims( + framework::make_ddim(input_transpose_vec)); + output->mutable_data(input_transpose_dims, ctx.GetPlace()); + transpose(dev_ctx, *input, output, axis); +} + +static inline bool IsSymmetricPadding(const std::vector& paddings, + const int data_dim) { + bool is_sys_pad = true; + if (paddings.size() == data_dim * 2) { + for (size_t i = 0; i < data_dim; ++i) { + if (paddings[2 * i] != paddings[2 * i + 1]) { + is_sys_pad = false; + return is_sys_pad; + } + } + } + return is_sys_pad; +} + +template +class CUDNNConvTransposeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "It must use CUDAPlace."); + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + + // cudnn v5 does not support dilations + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + const T* filter_data = filter->data(); + const std::string data_layout_str = ctx.Attr("data_format"); + const paddle::operators::DataLayout data_layout = + (data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC); + + // if channel_last, transpose to channel_first + Tensor input_transpose; + std::vector input_vec = framework::vectorize(input->dims()); + std::vector output_vec = framework::vectorize(output->dims()); + if (data_layout == DataLayout::kNHWC) { + if (strides.size() == 2U) { + std::vector axis = {0, 3, 1, 2}; + for (size_t i = 0; i < axis.size(); ++i) { + input_vec[i] = input->dims()[axis[i]]; + output_vec[i] = output->dims()[axis[i]]; + } + DataTranspose(ctx, input, &input_transpose, axis); + } else if (strides.size() == 3U) { + std::vector axis = {0, 4, 1, 2, 3}; + for (size_t i = 0; i < axis.size(); ++i) { + input_vec[i] = input->dims()[axis[i]]; + output_vec[i] = output->dims()[axis[i]]; + } + DataTranspose(ctx, input, &input_transpose, axis); + } + } else { + input_transpose = *input; + } + + // update padding and dilation + auto in_dims = input_transpose.dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = IsSymmetricPadding(paddings, data_dim); + + std::vector input_pad(input_transpose.dims().size() * 2, 0); + Tensor transformed_input; + std::vector padding_common(data_dim, 0); + if (!is_sys_pad) { + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + new_input_shape_vec[0] = input_transpose.dims()[0]; + new_input_shape_vec[1] = input_transpose.dims()[1]; + + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + new_input_shape_vec[i + 2] = + input_transpose.dims()[i + 2] + padding_diff[i]; + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + framework::DDim new_input_shape( + framework::make_ddim(new_input_shape_vec)); + transformed_input.Resize(new_input_shape); + auto& dev_ctx = + ctx.template device_context(); + + transformed_input = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + const int rank = input_transpose.dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + math::PadFunction( + ctx, input_pad, input_transpose, pad_value, &transformed_input); + } break; + case 5: { + math::PadFunction( + ctx, input_pad, input_transpose, pad_value, &transformed_input); + } break; + default: + PADDLE_ENFORCE_EQ( + rank == 4 || rank == 5, true, + "Op(ConvTranspose) only supports 4-D or 5-D input Tensor."); + } + } else { + transformed_input = input_transpose; + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + std::vector starts(data_dim, 0); + std::vector ends(data_dim, 0); + std::vector axes(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + starts[i] = input_pad[2 * i + 4] * (strides[i] + 1); + ends[i] = starts[i] + output_vec[i + 2]; + axes[i] = i + 2; + } + + const T* input_data = transformed_input.data(); + input_vec = framework::vectorize(transformed_input.dims()); + + std::vector transformed_output_vec = output_vec; + for (size_t i = 0; i < data_dim; ++i) { + transformed_output_vec[i + 2] = + output_vec[i + 2] + + (input_pad[2 * i + 4] + input_pad[2 * i + 5]) * strides[i] - + 2 * padding_common[i] + paddings[2 * i] + paddings[2 * i + 1]; + } + + Tensor transformed_output; + if (!is_sys_pad) { + DDim transformed_output_shape( + framework::make_ddim(transformed_output_vec)); + transformed_output.mutable_data(transformed_output_shape, + ctx.GetPlace()); + } else { + output->mutable_data(ctx.GetPlace()); + transformed_output.ShareDataWith(*output); + transformed_output.Resize(framework::make_ddim(transformed_output_vec)); + } + T* transformed_output_data = transformed_output.data(); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } + + // (N, M, H, W) or (N, M, D, H, W) + cudnnTensorDescriptor_t cudnn_input_desc = + input_desc.descriptor(layout, input_vec, groups); + // (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) + cudnnTensorDescriptor_t cudnn_output_desc = + output_desc.descriptor(layout, transformed_output_vec, groups); + // (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w) + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize(filter->dims()), groups); + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(padding_common, strides, dilations); + + // ------------------- cudnn conv workspace --------------------- + size_t workspace_size_in_bytes; // final workspace to allocate. + size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + // ------------------- cudnn conv algorithm --------------------- + cudnnConvolutionBwdDataAlgo_t algo; + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + // Get the algorithm + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + // dxDesc: Handle to the previously initialized output tensor + // descriptor. + cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + + // get workspace size able to allocate + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_output_desc, algo, &workspace_size_in_bytes)); + + // ------------------- cudnn conv transpose forward --------------------- + int input_offset = + transformed_input.numel() / transformed_input.dims()[0] / groups; + int output_offset = + transformed_output.numel() / transformed_output.dims()[0] / groups; + int filter_offset = filter->numel() / groups; + T alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + for (int g = 0; g < groups; g++) { + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, + cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, + algo, cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_output_desc, transformed_output_data + output_offset * g)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); + } + + if (!is_sys_pad && strides.size() == 2U) { + Slice( + ctx, &transformed_output, output, starts, ends, axes); + } else if (!is_sys_pad && strides.size() == 3U) { + Slice( + ctx, &transformed_output, output, starts, ends, axes); + } + + if (data_layout == DataLayout::kNHWC) { + Tensor output_transpose; + Tensor output_nchw; + output_nchw.ShareDataWith(*output); + output_nchw.Resize(framework::make_ddim(output_vec)); + if (strides.size() == 2U) { + std::vector axis = {0, 2, 3, 1}; + DataTranspose(ctx, &output_nchw, &output_transpose, axis); + *output = output_transpose; + } else if (strides.size() == 3U) { + std::vector axis = {0, 2, 3, 4, 1}; + DataTranspose(ctx, &output_nchw, &output_transpose, axis); + *output = output_transpose; + } + } + } +}; + +template +class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + const T* filter_data = filter->data(); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + // cudnn v5 does not support dilations + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + const std::string data_layout_str = ctx.Attr("data_format"); + const paddle::operators::DataLayout data_layout = + (data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC); + + // if channel_last, transpose to channel_first + Tensor input_transpose; + Tensor output_grad_transpose; + std::vector input_vec = framework::vectorize(input->dims()); + std::vector output_vec = + framework::vectorize(output_grad->dims()); + if (data_layout == DataLayout::kNHWC) { + if (strides.size() == 2U) { + std::vector axis = {0, 3, 1, 2}; + for (size_t i = 0; i < axis.size(); ++i) { + input_vec[i] = input->dims()[axis[i]]; + output_vec[i] = output_grad->dims()[axis[i]]; + } + DataTranspose(ctx, input, &input_transpose, axis); + DataTranspose(ctx, output_grad, &output_grad_transpose, axis); + } else if (strides.size() == 3U) { + std::vector axis = {0, 4, 1, 2, 3}; + for (size_t i = 0; i < axis.size(); ++i) { + input_vec[i] = input->dims()[axis[i]]; + output_vec[i] = output_grad->dims()[axis[i]]; + } + DataTranspose(ctx, input, &input_transpose, axis); + DataTranspose(ctx, output_grad, &output_grad_transpose, axis); + } + } else { + input_transpose = *input; + output_grad_transpose = *output_grad; + } + + // update padding and dilation + auto in_dims = input_transpose.dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = IsSymmetricPadding(paddings, data_dim); + + std::vector input_pad(input_transpose.dims().size() * 2, 0); + Tensor transformed_output_grad; + std::vector padding_common(data_dim, 0); + if (!is_sys_pad) { + std::vector padding_diff(data_dim); + std::vector new_output_grad_shape_vec(data_dim + 2); + new_output_grad_shape_vec[0] = output_grad_transpose.dims()[0]; + new_output_grad_shape_vec[1] = output_grad_transpose.dims()[1]; + + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + new_output_grad_shape_vec[i + 2] = + output_grad_transpose.dims()[i + 2] + padding_diff[i]; + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + framework::DDim new_output_grad_shape( + framework::make_ddim(new_output_grad_shape_vec)); + transformed_output_grad.Resize(new_output_grad_shape); + auto& dev_ctx = + ctx.template device_context(); + + transformed_output_grad = + ctx.AllocateTmpTensor( + new_output_grad_shape, dev_ctx); + const int rank = input_transpose.dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + math::PadFunction( + ctx, input_pad, output_grad_transpose, pad_value, + &transformed_output_grad); + } break; + case 5: { + math::PadFunction( + ctx, input_pad, output_grad_transpose, pad_value, + &transformed_output_grad); + } break; + default: + PADDLE_ENFORCE_EQ( + rank == 4 || rank == 5, true, + "Op(ConvTranspose) only supports 4-D or 5-D input Tensor."); + } + } else { + transformed_output_grad = output_grad_transpose; + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + const T* input_data = input_transpose.data(); + const T* output_grad_data = transformed_output_grad.data(); + output_vec = framework::vectorize(transformed_output_grad.dims()); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } + + // Input: (N, M, H, W) or (N, M, D, H, W) + cudnnTensorDescriptor_t cudnn_input_desc = + input_desc.descriptor(layout, input_vec, groups); + // Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) + cudnnTensorDescriptor_t cudnn_output_desc = + output_desc.descriptor(layout, output_vec, groups); + // Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w) + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize(filter->dims()), groups); + + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(padding_common, strides, dilations); + + // ------------------- cudnn backward algorithm --------------------- + cudnnConvolutionFwdAlgo_t data_algo; + cudnnConvolutionBwdFilterAlgo_t filter_algo; + size_t bwd_filter_ws_size, fwd_ws_size; + size_t workspace_size_in_bytes = 0; + size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + if (input_grad) { + // choose backward algorithm for data + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &data_algo)); + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_input_desc, data_algo, &fwd_ws_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size); + } + + if (filter_grad) { + // choose backward algorithm for filter + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &filter_algo)); + + // get workspace for backwards filter algorithm + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_filter_desc, filter_algo, &bwd_filter_ws_size)); + workspace_size_in_bytes = + std::max(workspace_size_in_bytes, bwd_filter_ws_size); + } + + // ------------------- cudnn conv backward data --------------------- + // FIXME(typhoonzero): template type T may not be the same as cudnn call. + int input_offset = input->numel() / input->dims()[0] / groups; + int output_grad_offset = transformed_output_grad.numel() / + transformed_output_grad.dims()[0] / groups; + int filter_offset = filter->numel() / groups; + T alpha = 1.0f, beta = 0.0f; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + if (input_grad) { + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + // Because beta is zero, it is unnecessary to reset input_grad. + for (int g = 0; g < groups; g++) { + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_filter_desc, + filter_data + filter_offset * g, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data + input_offset * g)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); + } + + if (data_layout == DataLayout::kNHWC) { + Tensor input_grad_transpose; + Tensor input_grad_nchw; + input_grad_nchw.ShareDataWith(*input_grad); + input_grad_nchw.Resize(framework::make_ddim(input_vec)); + if (strides.size() == 2U) { + std::vector axis = {0, 2, 3, 1}; + DataTranspose(ctx, &input_grad_nchw, &input_grad_transpose, + axis); + *input_grad = input_grad_transpose; + } else if (strides.size() == 3U) { + std::vector axis = {0, 2, 3, 4, 1}; + DataTranspose(ctx, &input_grad_nchw, &input_grad_transpose, + axis); + *input_grad = input_grad_transpose; + } + } + } + + // ------------------- cudnn conv backward filter --------------------- + if (filter_grad) { + T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + // Because beta is zero, it is unnecessary to reset filter_grad. + // Gradient with respect to the filter + for (int g = 0; g < groups; g++) { + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_input_desc, + input_data + input_offset * g, cudnn_conv_desc, filter_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_filter_desc, filter_grad_data + filter_offset * g)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeOpKernel, + ops::CUDNNConvTransposeOpKernel); +REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeGradOpKernel, + ops::CUDNNConvTransposeGradOpKernel); + +REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeOpKernel, + ops::CUDNNConvTransposeOpKernel); +REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::CUDNNConvTransposeGradOpKernel, + ops::CUDNNConvTransposeGradOpKernel); diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc deleted file mode 100644 index bab6fe24e42f15e2703a977d1500bc63f343e79c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ /dev/null @@ -1,265 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/conv_transpose_op.h" -#include "paddle/fluid/platform/cudnn_helper.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; -using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; -using DataLayout = platform::DataLayout; - -static constexpr size_t kConvCUDNNWorkspaceLimitBytes = 1024 * 1024 * 1024; - -template -class CUDNNConvTransposeOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto* input = ctx.Input("Input"); - auto* filter = ctx.Input("Filter"); - auto* output = ctx.Output("Output"); - - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - // cudnn v5 does not support dilations - std::vector dilations = ctx.Attr>("dilations"); - int groups = ctx.Attr("groups"); - int user_workspace_size = ctx.Attr("workspace_size_MB"); - - const T* input_data = input->data(); - const T* filter_data = filter->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); - // ------------------- cudnn descriptors --------------------- - ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor output_desc; - ScopedFilterDescriptor filter_desc; - ScopedConvolutionDescriptor conv_desc; - DataLayout layout; - - if (strides.size() == 2U) { - layout = DataLayout::kNCHW; - } else { - layout = DataLayout::kNCDHW; - } - - // (N, M, H, W) or (N, M, D, H, W) - cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, framework::vectorize(input->dims()), groups); - // (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) - cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - layout, framework::vectorize(output->dims()), groups); - // (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w) - cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( - layout, framework::vectorize(filter->dims()), groups); - cudnnConvolutionDescriptor_t cudnn_conv_desc = - conv_desc.descriptor(paddings, strides, dilations); - - // ------------------- cudnn conv workspace --------------------- - size_t workspace_size_in_bytes; // final workspace to allocate. - size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes; - if (user_workspace_size > 0) { - workspace_size_limit = user_workspace_size * 1024 * 1024; - } - // ------------------- cudnn conv algorithm --------------------- - cudnnConvolutionBwdDataAlgo_t algo; - auto& dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - // Get the algorithm - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( - handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, - // dxDesc: Handle to the previously initialized output tensor - // descriptor. - cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); - - // get workspace size able to allocate - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( - handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, - cudnn_output_desc, algo, &workspace_size_in_bytes)); - - // ------------------- cudnn conv transpose forward --------------------- - int input_offset = input->numel() / input->dims()[0] / groups; - int output_offset = output->numel() / output->dims()[0] / groups; - int filter_offset = filter->numel() / groups; - T alpha = 1.0f, beta = 0.0f; - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, - cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, - algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_output_desc, output_data + output_offset * g)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); - } - } -}; - -template -class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto input = ctx.Input("Input"); - auto filter = ctx.Input("Filter"); - auto output_grad = ctx.Input(framework::GradVarName("Output")); - auto input_grad = ctx.Output(framework::GradVarName("Input")); - auto filter_grad = ctx.Output(framework::GradVarName("Filter")); - const T* input_data = input->data(); - const T* output_grad_data = output_grad->data(); - const T* filter_data = filter->data(); - - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - // cudnn v5 does not support dilations - std::vector dilations = ctx.Attr>("dilations"); - int groups = ctx.Attr("groups"); - int user_workspace_size = ctx.Attr("workspace_size_MB"); - - // ------------------- cudnn descriptors --------------------- - ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor output_desc; - ScopedFilterDescriptor filter_desc; - ScopedConvolutionDescriptor conv_desc; - DataLayout layout = DataLayout::kNCHW; - - // Input: (N, M, H, W) or (N, M, D, H, W) - cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, framework::vectorize(input->dims()), groups); - // Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) - cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - layout, framework::vectorize(output_grad->dims()), groups); - // Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w) - cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( - layout, framework::vectorize(filter->dims()), groups); - - cudnnConvolutionDescriptor_t cudnn_conv_desc = - conv_desc.descriptor(paddings, strides, dilations); - - // ------------------- cudnn backward algorithm --------------------- - cudnnConvolutionFwdAlgo_t data_algo; - cudnnConvolutionBwdFilterAlgo_t filter_algo; - size_t bwd_filter_ws_size, fwd_ws_size; - size_t workspace_size_in_bytes = 0; - size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes; - if (user_workspace_size > 0) { - workspace_size_limit = user_workspace_size * 1024 * 1024; - } - - auto& dev_ctx = ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - if (input_grad) { - // choose backward algorithm for data - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &data_algo)); - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( - handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_input_desc, data_algo, &fwd_ws_size)); - workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size); - } - - if (filter_grad) { - // choose backward algorithm for filter - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( - handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, - cudnn_filter_desc, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &filter_algo)); - - // get workspace for backwards filter algorithm - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( - handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, - cudnn_filter_desc, filter_algo, &bwd_filter_ws_size)); - workspace_size_in_bytes = - std::max(workspace_size_in_bytes, bwd_filter_ws_size); - } - - // ------------------- cudnn conv backward data --------------------- - // FIXME(typhoonzero): template type T may not be the same as cudnn call. - int input_offset = input->numel() / input->dims()[0] / groups; - int output_grad_offset = - output_grad->numel() / output_grad->dims()[0] / groups; - int filter_offset = filter->numel() / groups; - T alpha = 1.0f, beta = 0.0f; - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - if (input_grad) { - T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - // Because beta is zero, it is unnecessary to reset input_grad. - for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_output_desc, - output_grad_data + output_grad_offset * g, cudnn_filter_desc, - filter_data + filter_offset * g, cudnn_conv_desc, data_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, - input_grad_data + input_offset * g)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); - } - } - - // ------------------- cudnn conv backward filter --------------------- - if (filter_grad) { - T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - // Because beta is zero, it is unnecessary to reset filter_grad. - // Gradient with respect to the filter - for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_output_desc, - output_grad_data + output_grad_offset * g, cudnn_input_desc, - input_data + input_offset * g, cudnn_conv_desc, filter_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_filter_desc, filter_grad_data + filter_offset * g)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace, - ops::CUDNNConvTransposeOpKernel, - ops::CUDNNConvTransposeOpKernel); -REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, - ops::CUDNNConvTransposeGradOpKernel, - ops::CUDNNConvTransposeGradOpKernel); - -REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace, - ops::CUDNNConvTransposeOpKernel, - ops::CUDNNConvTransposeOpKernel); -REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace, - ops::CUDNNConvTransposeGradOpKernel, - ops::CUDNNConvTransposeGradOpKernel); diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index e76c57abc6300d845908a9c6db939747d17ca289..3758a3c0798d19948576dab1ee140dcb3a79e17f 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h" #ifdef PADDLE_WITH_MKLDNN @@ -25,13 +26,15 @@ limitations under the License. */ namespace paddle { namespace operators { +using DataLayout = framework::DataLayout; + void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of ConvTransposeOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of ConvTransposeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of ConvTransposeOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, + "Input(Input) of ConvTransposeOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, + "Input(Filter) of ConvTransposeOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true, + "Output(Output) of ConvTransposeOp should not be null."); auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -41,52 +44,75 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { std::vector paddings = ctx->Attrs().Get>("paddings"); std::vector dilations = ctx->Attrs().Get>("dilations"); int groups = ctx->Attrs().Get("groups"); + std::string padding_algorithm = + ctx->Attrs().Get("padding_algorithm"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_format")); - PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, - "ConvTransposeOp intput should be 4-D or 5-D tensor."); + PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true, + "ConvTransposeOp intput should be 4-D or 5-D tensor."); PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(), "ConvTransposeOp input dimension and filter dimension " "should be the same."); - PADDLE_ENFORCE(in_dims.size() - strides.size() == 2U, - "ConvTransposeOp input dimension and strides dimension should " - "be consistent."); + PADDLE_ENFORCE_EQ( + in_dims.size() - strides.size(), 2U, + "ConvTransposeOp input dimension and strides dimension should " + "be consistent."); if (output_size.size()) PADDLE_ENFORCE_EQ(output_size.size(), strides.size(), "ConvTransposeOp output_size dimension and strides " "dimension should be the same."); - PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), - "ConvTransposeOp paddings dimension and strides " - "dimension should be the same."); - PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(), - "ConvTransposeOp paddings dimension and dilations " - "dimension should be the same."); - PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], - "In ConvTransposeOp, The number of input channels should " - "be equal to the number of filter's channels."); - - std::vector output_shape({in_dims[0], filter_dims[1] * groups}); + + const int64_t C = + (data_layout == DataLayout::kNCHW ? in_dims[1] + : in_dims[in_dims.size() - 1]); + PADDLE_ENFORCE_EQ( + C, filter_dims[0], + "The number of input channels of Op(ConvTransposeOp) should " + "be equal to the number of filter's channels."); + + framework::DDim in_data_dims; + if (data_layout == DataLayout::kNCHW) { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } else { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + std::vector output_shape({in_dims[0]}); + if (data_layout == DataLayout::kNCHW) { + output_shape.push_back(filter_dims[1] * groups); + } + const int offset = (data_layout == DataLayout::kNCHW ? 2 : 1); for (size_t i = 0; i < strides.size(); ++i) { auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; - auto infer_shape = - (in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] + filter_extent; + auto infer_shape = (in_dims[i + offset] - 1) * strides[i] - + paddings[2 * i] - paddings[2 * i + 1] + filter_extent; if (output_size.size()) { - PADDLE_ENFORCE((output_size[i] >= infer_shape && - output_size[i] < infer_shape + strides[i]), - "ConvTransposeOp output_size should be " - "in appropriate range."); + PADDLE_ENFORCE_EQ((output_size[i] >= infer_shape && + output_size[i] < infer_shape + strides[i]), + true, + "output_size of Op(ConvTransposeOp) should be " + "in appropriate range."); output_shape.push_back(output_size[i]); } else { output_shape.push_back(infer_shape); } } + if (data_layout == DataLayout::kNHWC) { + output_shape.push_back(filter_dims[1] * groups); + } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); #ifdef PADDLE_WITH_CUDA @@ -115,12 +141,11 @@ void Conv2DTransposeOpMaker::Make() { "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddInput( - "Input", - "(Tensor) The input tensor of convolution transpose operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of input channels, H is the height of the feature, and " - "W is the width of the feature."); + AddInput("Input", + "(Tensor) The input tensor of convolution transpose operator. " + "The format of input tensor is NCHW or NHWC. Where N is batch size, " + "C is the number of input channels, H is the height of the feature, " + "and W is the width of the feature."); AddInput( "Filter", "(Tensor) The filter tensor of convolution transpose operator. " @@ -137,7 +162,7 @@ void Conv2DTransposeOpMaker::Make() { AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator. " - "The format of output tensor is also NCHW."); + "The format of output tensor is the same as input tensor."); AddAttr>("output_size", "(vector default: []), the " "size of the output tensor") @@ -182,10 +207,15 @@ void Conv2DTransposeOpMaker::Make() { "data_format", "(string, default NCHW) Only used in " "An optional string from: \"NHWC\", \"NCHW\". " - "Defaults to \"NHWC\". Specify the data format of the output data, " - "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); - // TODO(dzhwinter): need to registered layout transform function + "Specify that the data format of the input and output data is " + "channel_first or channel_last.") + .SetDefault("NCHW"); + AddAttr( + "padding_algorithm", + "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\"," + "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. " + "Set to \"SAME\" or \"VALID\" for algorithm of padding. ") + .SetDefault("EXPLICIT"); AddAttr("workspace_size_MB", "Used in cudnn kernel only. workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " @@ -199,7 +229,7 @@ Convolution2D Transpose Operator. The convolution transpose operation calculates the output based on the input, filter and dilations, strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. -Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the +Input(Input) and output(Output) are in NCHW or NHWC format. Where N is batchsize, C is the number of channels, H is the height of the feature, and W is the width of the feature. Filter(Input) is in MCHW format. Where M is the number of input feature channels, C is the number of output feature channels, H is the height of the filter, @@ -216,19 +246,19 @@ For an example: Output shape: $(N, C_{out}, H_{out}, W_{out})$ Where $$ - H_{out} = (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\ - W_{out} = (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 + H_{out} = (H_{in} - 1) * strides[0] - pad_height_top - pad_height_bottom + dilations[0] * (H_f - 1) + 1 \\ + W_{out} = (W_{in} - 1) * strides[1] - pad_width_left - pad_width_right + dilations[1] * (W_f - 1) + 1 $$ )DOC"); } void Conv3DTransposeOpMaker::Make() { - AddInput("Input", - "(Tensor) The input tensor of convolution transpose operator." - "The format of input tensor is NCDHW. Where N is batch size, C is " - "the number of channels, D is the depth of the feature, H is the " - "height of the feature, and " - "W is the width of the feature."); + AddInput( + "Input", + "(Tensor) The input tensor of convolution transpose operator." + "The format of input tensor is NCDHW or NDHWC. Where N is batch " + "size, C is the number of channels, D is the depth of the feature, " + "H is the height of the feature, and W is the width of the feature."); AddInput("Filter", "(Tensor) The filter tensor of convolution transpose operator." "The format of the filter tensor is MCDHW, where M is the number of " @@ -240,7 +270,7 @@ void Conv3DTransposeOpMaker::Make() { "the convolution3d transpose scenario."); AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator." - "The format of output tensor is also NCDHW." + "The format of output tensor is the same as input tensor." "Where N is batch size, C is " "the number of channels, D is the depth of the feature, H is the " "height of the feature, and W is the width of the feature."); @@ -278,10 +308,15 @@ void Conv3DTransposeOpMaker::Make() { "data_format", "(string, default NCHW) Only used in " "An optional string from: \"NHWC\", \"NCHW\". " - "Defaults to \"NHWC\". Specify the data format of the output data, " - "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); - // TODO(dzhwinter): need to registered layout transform function + "Specify that the data format of the input and output data is " + "channel_first or channel_last.") + .SetDefault("NCHW"); + AddAttr( + "padding_algorithm", + "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\"," + "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. " + "Set to \"SAME\" or \"VALID\" for algorithm of padding. ") + .SetDefault("EXPLICIT"); AddAttr("workspace_size_MB", "Used in cudnn kernel only. workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " @@ -295,7 +330,7 @@ Convolution3D Transpose Operator. The convolution transpose operation calculates the output based on the input, filter and dilations, strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. -Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the +Input(Input) and output(Output) are in NCDHW or NDHWC format. Where N is batch size, C is the number of channels, D is the depth of the feature, H is the height of the feature, and W is the width of the feature. Filter(Input) is in MCDHW format. Where M is the number of input feature channels, @@ -313,9 +348,9 @@ Example: Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$ Where $$ - D_{out} = (D_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (D_f - 1) + 1 \\ - H_{out} = (H_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (H_f - 1) + 1 \\ - W_{out} = (W_{in} - 1) * strides[2] - 2 * paddings[2] + dilations[2] * (W_f - 1) + 1 + D_{out} = (D_{in} - 1) * strides[0] - pad_depth_front - pad_depth_back + dilations[0] * (D_f - 1) + 1 \\ + H_{out} = (H_{in} - 1) * strides[1] - pad_height_top - pad_height_bottom + dilations[1] * (H_f - 1) + 1 \\ + W_{out} = (W_{in} - 1) * strides[2] - pad_width_left - pad_width_right + dilations[2] * (W_f - 1) + 1 $$ )DOC"); } @@ -348,8 +383,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( library_ = framework::LibraryType::kPlain; } - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; return framework::OpKernelType(ctx.Input("Input")->type(), ctx.GetPlace(), layout_, library_); } diff --git a/paddle/fluid/operators/conv_transpose_op.cu.cc b/paddle/fluid/operators/conv_transpose_op.cu similarity index 100% rename from paddle/fluid/operators/conv_transpose_op.cu.cc rename to paddle/fluid/operators/conv_transpose_op.cu diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 88c578b1410558b9adcd55f1cd6b53fb9cb124e2..56cfa8618f2a1cb1616f245c768d6222c3d08ea0 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/vol2col.h" @@ -27,6 +30,94 @@ namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; +template +static void Slice(const framework::ExecutionContext& context, + const Tensor* input, Tensor* out, + const std::vector& begin_vec, + const std::vector& end_vec, + const std::vector& axes_vec) { + auto& place = + *context.template device_context().eigen_device(); + auto in_dims = input->dims(); + auto offsets = Eigen::array(); + auto extents = Eigen::array(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = in_dims[i]; + } + + std::vector out_shape_vec = framework::vectorize(in_dims); + for (size_t i = 0; i < axes_vec.size(); ++i) { + offsets[axes_vec[i]] = begin_vec[i]; + extents[axes_vec[i]] = end_vec[i] - begin_vec[i]; + out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i]; + } + + framework::DDim out_dims(framework::make_ddim(out_shape_vec)); + out->mutable_data(out_dims, context.GetPlace()); + + auto in_t = + framework::EigenTensor::From( + *input); + auto out_t = + framework::EigenTensor::From( + *out, out_dims); + + out_t.device(place) = in_t.slice(offsets, extents); + out->Resize(out_dims); +} + +template +static void Slice(const framework::ExecutionContext& context, + const Tensor* input, Tensor* out, int64_t begin_idx, + int64_t end_idx, int64_t axes) { + std::vector begin_vec = {begin_idx}; + std::vector end_vec = {end_idx}; + std::vector axes_vec = {axes}; + Slice(context, input, out, begin_vec, end_vec, axes_vec); +} + +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilation, + const std::string padding_algorithm, + const framework::DDim data_dims, + const std::vector& strides, + const std::vector& ksize) { + // set padding size == data_dims.size() * 2 + auto data_shape = framework::vectorize(data_dims); + if (paddings->size() == data_dims.size()) { + for (size_t i = 0; i < data_dims.size(); ++i) { + int copy_pad = *(paddings->begin() + 2 * i); + paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); + } + } else { + PADDLE_ENFORCE_EQ( + data_dims.size() * 2, paddings->size(), + "Paddings size should be the same or twice as the input data size."); + } + + // when padding_algorithm is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (size_t i = 0; i < data_dims.size(); ++i) { + int out_size = (data_dims[i] + strides[i] - 1) / strides[0]; + int pad_sum = + std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0); + int pad_0 = pad_sum / 2; + int pad_1 = pad_sum - pad_0; + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; + + // dilation + *(dilation->begin() + i) = 1; + } + + } else if (padding_algorithm == "VALID") { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = 0; + } + } +} + // Define Op classes in .h file so that other conv transpose // operator implementations can reuse the code. class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { @@ -63,6 +154,10 @@ template class GemmConvTransposeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + const std::string data_layout_str = + context.Attr("data_format"); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const Tensor* input = context.Input("Input"); // The filter will be reshaped, so it should not be constant pointer Tensor filter = *context.Input("Filter"); @@ -72,28 +167,54 @@ class GemmConvTransposeKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); int groups = context.Attr("groups"); + std::string padding_algorithm = + context.Attr("padding_algorithm"); + auto in_dims = input->dims(); + auto filter_dims = filter.dims(); + auto out_dims = output->dims(); const int batch_size = static_cast(input->dims()[0]); - // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} + framework::DDim in_data_dims; + if (data_layout == framework::DataLayout::kNCHW) { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } else { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first + // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last std::vector input_shape_vec = framework::vectorize(input->dims()); - // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} std::vector filter_shape_vec = framework::vectorize(filter.dims()); // use col_shape in the im2col and col2im (or vol2col and col2vol) // calculation - // col_shape_vec: {c/g, k_h, k_w, h, w} or {c/g, k_d, k_h, k_w, d, h, w} + // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w} size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = output->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; + if (data_layout == framework::DataLayout::kNCHW) { + col_shape_vec[0] = out_dims[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; + } + } else { + col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1]; + } } DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation - // size: (c/g * k_h * k_w, h * w) or (c/g * k_d * k_h * k_w, d * h * w) + // size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w) DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); Tensor col; @@ -105,15 +226,27 @@ class GemmConvTransposeKernel : public framework::OpKernel { col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first + // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last DDim output_shape = framework::slice_ddim(output->dims(), 1, output->dims().size()); - // input matrix size: (m, h * w) or (m, d * h * w) - DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; + // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first + // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last + DDim input_matrix_shape; + if (data_layout == framework::DataLayout::kNCHW) { + input_matrix_shape = {in_dims[1], col_matrix_shape[1]}; + } else { + input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]}; + } - // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w) - DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; + // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w) + DDim filter_matrix_shape; + if (data_layout == framework::DataLayout::kNCHW) { + filter_matrix_shape = {in_dims[1], col_matrix_shape[0]}; + } else { + filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]}; + } filter.Resize(filter_matrix_shape); output->mutable_data(context.GetPlace()); @@ -122,43 +255,84 @@ class GemmConvTransposeKernel : public framework::OpKernel { auto blas = math::GetBlas(dev_ctx); set_zero(dev_ctx, output, static_cast(0)); - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; + int in_step = + (data_layout == framework::DataLayout::kNCHW + ? static_cast(in_dims[1]) / groups + : static_cast(in_dims[in_dims.size() - 1]) / groups); + + int out_step = + (data_layout == framework::DataLayout::kNCHW + ? static_cast(out_dims[1]) / groups + : static_cast(out_dims[out_dims.size() - 1]) / groups); math::Col2ImFunctor col2im; math::Col2VolFunctor col2vol; + math::ConcatFunctor concat_functor; // convolution transpose: gemm + col2im or col2vol (similar to conv-backward // on input) + size_t D = input->dims().size(); for (int i = 0; i < batch_size; i++) { - // batch with size (m, h * w) or (m, d * h * w) + // batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first + // batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first + // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + std::vector output_batch_vec; for (int g = 0; g < groups; g++) { - Tensor in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step); + int64_t start = g * in_step; + int64_t end = (g + 1) * in_step; + int axes = (data_layout == framework::DataLayout::kNCHW ? 0 : 1); Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); - Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice, out_slice; // col_matrix = filter_slice * input_slice - // of shape (c/g * k_h * k_w, h * w) - // or (c/g * k_d * k_h * k_w, d * h * w) - blas.MatMul(filter_slice, true, in_slice, false, static_cast(1.0), - &col_matrix, static_cast(0.0)); + // of shape (o_c/g * k_h * k_w, h * w) + // or (o_c/g * k_d * k_h * k_w, d * h * w) + if (data_layout == framework::DataLayout::kNCHW) { + in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step); + out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step); + blas.MatMul(filter_slice, true, in_slice, false, static_cast(1.0), + &col_matrix, static_cast(0.0)); + } else { + Slice(context, &input_batch, &in_slice, start, + end, axes); + start = g * out_step; + end = (g + 1) * out_step; + axes = D - 2; + if (D == 4U) { + Slice(context, &output_batch, &out_slice, + start, end, axes); + } else if (D == 5U) { + Slice(context, &output_batch, &out_slice, + start, end, axes); + } + blas.MatMul(filter_slice, true, in_slice, true, static_cast(1.0), + &col_matrix, static_cast(0.0)); + } if (data_dim == 2U) { // col2im: col_matrix -> dy - // from (c/g * k_h * k_w, h * w) to (c/g, o_h, o_w) + // from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w, + // o_c/g) col2im(dev_ctx, col, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &out_slice); + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, + &out_slice, data_layout); } else if (data_dim == 3U) { // col2vol: col_matrix -> dy - // from (c/g * k_d * k_h * k_w, d * h * w) to (c/g, o_d, o_h, o_w) - col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice); + // from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w) + // or (o_d, o_h, o_w, o_c/g) + col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice, + data_layout); } + output_batch_vec.push_back(out_slice); + } + if (data_layout == framework::DataLayout::kNHWC) { + concat_functor(dev_ctx, output_batch_vec, static_cast(D - 2), + &output_batch); } } } @@ -168,6 +342,10 @@ template class GemmConvTransposeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + const std::string data_layout_str = + context.Attr("data_format"); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const Tensor* input = context.Input("Input"); const Tensor* output_grad = context.Input(framework::GradVarName("Output")); @@ -185,41 +363,84 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); int groups = context.Attr("groups"); + std::string padding_algorithm = + context.Attr("padding_algorithm"); + auto in_dims = input->dims(); + auto filter_dims = filter.dims(); + auto out_grad_dims = output_grad->dims(); const int batch_size = static_cast(input->dims()[0]); - // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} + framework::DDim in_data_dims; + if (data_layout == framework::DataLayout::kNCHW) { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } else { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first + // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last std::vector input_shape_vec = framework::vectorize(input->dims()); - // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w} + // filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w} std::vector filter_shape_vec = framework::vectorize(filter.dims()); // use col_shape in the im2col and col2im (or vol2col and col2vol) // calculation - // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} + // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = output_grad->dims()[1]; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; + if (data_layout == framework::DataLayout::kNCHW) { + col_shape_vec[0] = out_grad_dims[1]; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; + } + } else { + col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1]; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1]; + } } DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation - // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) + // size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w) DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); - // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) + // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first + // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last DDim output_shape = framework::slice_ddim(output_grad->dims(), 1, output_grad->dims().size()); - // input matrix size: (m, h * w) or (m, d * h * w) - DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; + // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first + // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last + DDim input_matrix_shape; + if (data_layout == framework::DataLayout::kNCHW) { + input_matrix_shape = {in_dims[1], col_matrix_shape[1]}; + } else { + input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]}; + } - // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w) - DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0] / groups}; + // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w) + DDim filter_matrix_shape; + if (data_layout == framework::DataLayout::kNCHW) { + filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups}; + } else { + filter_matrix_shape = {in_dims[in_dims.size() - 1], + col_matrix_shape[0] / groups}; + } filter.Resize(filter_matrix_shape); - int in_step = static_cast(input->dims()[1]) / groups; + + int in_step = + (data_layout == framework::DataLayout::kNCHW + ? static_cast(in_dims[1]) / groups + : static_cast(in_dims[in_dims.size() - 1]) / groups); int col_step = static_cast(col_matrix_shape[0]) / groups; // convolution transpose grad on input: @@ -242,75 +463,136 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { math::Im2ColFunctor im2col; math::Vol2ColFunctor vol2col; + math::ConcatFunctor concat_functor; if (input_grad) { input_grad->mutable_data(context.GetPlace()); + set_zero(dev_ctx, input_grad, static_cast(0)); } - if (filter_grad) { // filter size (m, c/g, k_h, k_w) + if (filter_grad) { // filter_grad_ size (i_c, o_c/g, k_h, k_w) filter_grad->mutable_data(context.GetPlace()); set_zero(dev_ctx, filter_grad, static_cast(0)); filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); } + size_t D = input->dims().size(); for (int i = 0; i < batch_size; i++) { - // batch with size (c, o_h * o_w) + // batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for + // channel_first + // batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for + // channel_last Tensor output_grad_batch = output_grad->Slice(i, i + 1).Resize(output_shape); if (data_dim == 2U) { // im2col: dy -> col matrix - // from (c, o_h, o_w) to (c * k_h * k_w, h * w) + // from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for + // channel_first + // from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for + // channel_last im2col(dev_ctx, output_grad_batch, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, + &col, data_layout); } else if (data_dim == 3U) { // vol2col: dy -> col_matrix - // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) + // from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h * + // i_w) for channel_first + // from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h * + // k_w) for channel_last vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings, - &col); + &col, data_layout); } if (input_grad) { - // batch with size (m, h, w) + // batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c) Tensor input_grad_batch = input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + // gemm: dx = filter * dy - // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w) + // (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h + // * i_w) // or - // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, - // d, h, w) + // (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h * + // i_w) -> (i_c, + // i_d, i_h, i_w) + // gemm: dx = dy^T * filter^T for channel_last + + std::vector input_grad_batch_vec; for (int g = 0; g < groups; g++) { - Tensor input_grad_slice = - input_grad_batch.Slice(g * in_step, (g + 1) * in_step); + // input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w) + // for channel_first + // input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g) + // for channel_last + // filter_slice: (i_c/g, o_c/g * k_h * k_w) Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); + // col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * + // k_h * k_w, d * h * w) Tensor col_matrix_slice = col_matrix.Slice(g * col_step, (g + 1) * col_step); - - blas.MatMul(filter_slice, false, col_matrix_slice, false, - static_cast(1.0), &input_grad_slice, - static_cast(0.0)); + if (data_layout == framework::DataLayout::kNCHW) { + Tensor input_grad_slice = + input_grad_batch.Slice(g * in_step, (g + 1) * in_step); + blas.MatMul(filter_slice, false, col_matrix_slice, false, + static_cast(1.0), &input_grad_slice, + static_cast(0.0)); + } else { + Tensor input_grad_slice; + Slice(context, &input_grad_batch, + &input_grad_slice, g * in_step, + (g + 1) * in_step, 1); + blas.MatMul(col_matrix_slice, true, filter_slice, true, + static_cast(1.0), &input_grad_slice, + static_cast(0.0)); + DDim input_grad_slice_shape; + if (data_dim == 2U) { + input_grad_slice_shape = {in_dims[1], in_dims[2], in_step}; + } else { + input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3], + in_step}; + } + input_grad_slice = + input_grad_slice.Resize(input_grad_slice_shape); + input_grad_batch_vec.push_back(input_grad_slice); + } + } + if (data_layout == framework::DataLayout::kNHWC) { + concat_functor(dev_ctx, input_grad_batch_vec, + static_cast(D - 2), &input_grad_batch); } } if (filter_grad) { - // input batch + // input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c) Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); // gemm: d_filter = x * dy^T - // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w) + // (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h + // * k_w) // or - // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * + // (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w) + // -> (i_c, o_c * k_d * // k_h * k_w) + // gemm: d_filter = x^T * dy^T for channel_last + for (int g = 0; g < groups; g++) { - Tensor in_batch_slice = - in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor filter_grad_slice = filter_grad_.Slice(g * in_step, (g + 1) * in_step); Tensor col_matrix_slice = col_matrix.Slice(g * col_step, (g + 1) * col_step); - blas.MatMul(in_batch_slice, false, col_matrix_slice, true, - static_cast(1.0), &filter_grad_slice, - static_cast(1.0)); + if (data_layout == framework::DataLayout::kNCHW) { + Tensor in_batch_slice = + in_batch.Slice(g * in_step, (g + 1) * in_step); + blas.MatMul(in_batch_slice, false, col_matrix_slice, true, + static_cast(1.0), &filter_grad_slice, + static_cast(1.0)); + } else { + Tensor in_batch_slice; + Slice(context, &in_batch, &in_batch_slice, + g * in_step, (g + 1) * in_step, 1); + blas.MatMul(in_batch_slice, true, col_matrix_slice, true, + static_cast(1.0), &filter_grad_slice, + static_cast(1.0)); + } } } } @@ -322,6 +604,10 @@ template class DepthwiseConvTransposeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + const std::string data_layout_str = + context.Attr("data_format"); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const Tensor* input = context.Input("Input"); Tensor filter = *context.Input("Filter"); Tensor* output = context.Output("Output"); @@ -333,10 +619,27 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + std::string padding_algorithm = + context.Attr("padding_algorithm"); for (auto v : dilations) { PADDLE_ENFORCE_EQ(v, 1); } + auto in_dims = input->dims(); + auto filter_dims = filter.dims(); + + framework::DDim in_data_dims; + if (data_layout == framework::DataLayout::kNCHW) { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } else { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + output->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); math::SetConstant set_zero; @@ -344,8 +647,10 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel { math::DepthwiseConvInputGradFunctor depthwiseConvInputGrad; - depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings, - dilations, output); + depthwiseConvInputGrad( + dev_ctx, *output, filter, *input, strides, + std::vector{paddings[0], paddings[2], paddings[1], paddings[3]}, + dilations, output, data_layout); } }; @@ -353,6 +658,10 @@ template class DepthwiseConvTransposeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + const std::string data_layout_str = + context.Attr("data_format"); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const Tensor* input = context.Input("Input"); const Tensor* output_grad = context.Input(framework::GradVarName("Output")); @@ -368,11 +677,30 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + std::string padding_algorithm = + context.Attr("padding_algorithm"); + + auto in_dims = input->dims(); + auto filter_dims = filter.dims(); + + framework::DDim in_data_dims; + if (data_layout == framework::DataLayout::kNCHW) { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } else { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); if (input_grad) { math::DepthwiseConvFunctor depthwiseConv; - depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings, dilations, - input_grad); + depthwiseConv( + dev_ctx, *output_grad, filter, strides, paddings, + std::vector{paddings[0], paddings[2], paddings[1], paddings[3]}, + input_grad, data_layout); } if (filter_grad) { @@ -382,8 +710,10 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel { math::DepthwiseConvFilterGradFunctor depthwiseConvFilterGrad; - depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings, - dilations, filter_grad); + depthwiseConvFilterGrad( + dev_ctx, *output_grad, *input, strides, + std::vector{paddings[0], paddings[2], paddings[1], paddings[3]}, + dilations, filter_grad, data_layout); } } }; diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index a372f6fa718e5db6b8de5d77391e0171d82f18dd..28083a1b2050649cc98ad3c5e10b211b4e7d7d57 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -39,7 +39,8 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { const int filter_multiplier, const int filter_height, \ const int filter_width, const int stride_height, const int stride_width, \ const int padding_height, const int padding_width, \ - const int dilate_height, const int dilate_width, T *const output_data + const int dilate_height, const int dilate_width, T *const output_data, \ + const DataLayout data_layout = DataLayout::kNCHW // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. @@ -58,8 +59,13 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { const int h_in_end = h_in_start + filter_height * dilate_height; const int w_in_end = w_in_start + filter_width * dilate_width; - const int in_offset = - ((batch * input_channels + c_in) * input_height) * input_width; + int in_offset; + if (data_layout == DataLayout::kNCHW) { + in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + } else { + in_offset = batch * input_height * input_width * input_channels; + } const int h_end = h_in_end < input_height ? h_in_end : input_height; const int w_end = w_in_end < input_width ? w_in_end : input_width; @@ -71,7 +77,13 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { - const int offset = in_offset + h_in * input_width + w_in; + int offset; + if (data_layout == DataLayout::kNCHW) { + offset = in_offset + h_in * input_width + w_in; + } else { + offset = in_offset + + (h_in * input_width + w_in) * input_channels + c_in; + } if (fuse_relu_before_conv) { value += weight[weight_offset] * max(0.0f, input_data[offset]); } else { @@ -81,9 +93,16 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { weight_offset++; } } - int index = - ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + - w_out; + int index; + if (data_layout == DataLayout::kNCHW) { + index = ((batch * gridDim.x + c_out) * output_height + h_out) * + output_width + + w_out; + } else { + index = ((batch * output_height + h_out) * output_width + w_out) * + gridDim.x + + c_out; + } output_data[index] = value; } } @@ -111,8 +130,13 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( const int h_in_end = h_in_start + c_filter * dilate_height; const int w_in_end = w_in_start + c_filter * dilate_width; - const int in_offset = - ((batch * input_channels + c_in) * input_height) * input_width; + int in_offset; + if (data_layout == DataLayout::kNCHW) { + in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + } else { + in_offset = batch * input_height * input_width * input_channels; + } const int h_end = h_in_end < input_height ? h_in_end : input_height; const int w_end = w_in_end < input_width ? w_in_end : input_width; @@ -125,7 +149,13 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( w_in += dilate_width, w_f++) { if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) { - const int offset = in_offset + h_in * input_width + w_in; + int offset; + if (data_layout == DataLayout::kNCHW) { + offset = in_offset + h_in * input_width + w_in; + } else { + offset = in_offset + + (h_in * input_width + w_in) * input_channels + c_in; + } if (fuse_relu_before_conv) { value += r_weight[h_f * c_filter + w_f] * max(0.0f, input_data[offset]); @@ -135,9 +165,16 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( } } } - int index = - ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + - w_out; + int index; + if (data_layout == DataLayout::kNCHW) { + index = ((batch * gridDim.x + c_out) * output_height + h_out) * + output_width + + w_out; + } else { + index = ((batch * output_height + h_out) * output_width + w_out) * + gridDim.x + + c_out; + } output_data[index] = value; } } @@ -153,14 +190,14 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { output_width, input_channels, input_height, input_width, filter_multiplier, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, dilate_height, - dilate_width, output_data); + dilate_width, output_data, data_layout); else KernelDepthwiseConvCFilter( input_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, filter_multiplier, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, dilate_height, - dilate_width, output_data); + dilate_width, output_data, data_layout); } else { if (c_filter == -1) KernelDepthwiseConv( @@ -168,14 +205,14 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { output_width, input_channels, input_height, input_width, c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, padding_height, padding_width, dilate_height, dilate_width, - output_data); + output_data, data_layout); else KernelDepthwiseConvCFilter( input_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, padding_height, padding_width, dilate_height, dilate_width, - output_data); + output_data, data_layout); } } @@ -190,7 +227,8 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { const int filter_width, const int stride_height, const int stride_width, \ const int padding_height, const int padding_width, \ const int dilate_height, const int dilate_width, \ - T *const input_grad_data + T *const input_grad_data, \ + const DataLayout data_layout = DataLayout::kNCHW template __device__ __inline__ void KernelDepthwiseConvInputGrad( @@ -213,9 +251,17 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( int w_out_end = w_in + padding_width; T value = 0; - int index = - ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + - w_in; + int index; + if (data_layout == DataLayout::kNCHW) { + index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; + } else { + index = + ((batch * input_height + h_in) * input_width + w_in) * gridDim.x + + c_in; + } + if (fuse_relu_before_conv) { if (input_data[index] <= 0) { input_grad_data[index] = 0; @@ -236,11 +282,20 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( if (h_out % stride_height == 0 && w_out % stride_width == 0 && s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_w_out < output_width) { - const int output_grad_offset = - ((batch * output_channels + c_out) * output_height + - s_h_out) * - output_width + - s_w_out; + int output_grad_offset; + if (data_layout == DataLayout::kNCHW) { + output_grad_offset = + ((batch * output_channels + c_out) * output_height + + s_h_out) * + output_width + + s_w_out; + } else { + output_grad_offset = + ((batch * output_height + s_h_out) * output_width + + s_w_out) * + output_channels + + c_out; + } value += output_grad_data[output_grad_offset] * filter_data[filter_offset]; } @@ -279,9 +334,16 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; T value = 0; - int index = - ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + - w_in; + int index; + if (data_layout == DataLayout::kNCHW) { + index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; + } else { + index = + ((batch * input_height + h_in) * input_width + w_in) * gridDim.x + + c_in; + } if (fuse_relu_before_conv) { if (input_data[index] <= 0) { input_grad_data[index] = 0; @@ -300,11 +362,20 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( if (h_out % stride_height == 0 && w_out % stride_width == 0 && s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_w_out < output_width) { - const int output_grad_offset = - ((batch * output_channels + c_out) * output_height + - s_h_out) * - output_width + - s_w_out; + int output_grad_offset; + if (data_layout == DataLayout::kNCHW) { + output_grad_offset = + ((batch * output_channels + c_out) * output_height + + s_h_out) * + output_width + + s_w_out; + } else { + output_grad_offset = + ((batch * output_height + s_h_out) * output_width + + s_w_out) * + output_channels + + c_out; + } value += output_grad_data[output_grad_offset] * r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter]; @@ -327,14 +398,14 @@ __global__ void KernelDepthwiseConvInputGradSp( output_height, output_width, input_channels, input_height, input_width, filter_multiplier, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, dilate_height, - dilate_width, input_grad_data); + dilate_width, input_grad_data, data_layout); else if (c_filter == -1) KernelDepthwiseConvInputGrad( input_data, output_grad_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, padding_height, padding_width, dilate_height, dilate_width, - input_grad_data); + input_grad_data, data_layout); else KernelDepthwiseConvInputGradCFilter( @@ -342,7 +413,7 @@ __global__ void KernelDepthwiseConvInputGradSp( output_height, output_width, input_channels, input_height, input_width, c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, padding_height, padding_width, dilate_height, dilate_width, - input_grad_data); + input_grad_data, data_layout); } // Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. @@ -354,7 +425,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad( const int filter_multiplier, const int filter_height, const int filter_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* filter_grad_data) { + const int dilate_width, T* filter_grad_data, + const DataLayout data_layout = DataLayout::kNCHW) { T s = 0; int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; @@ -374,18 +446,35 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad( if (image_wk < 0 || image_wk >= input_width) continue; #define gaid(N, C, H, W) \ ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) - int input_id = ((bid * (gridDim.z / filter_multiplier) + - kernel_id / filter_multiplier) * - input_height + - image_hk) * - input_width + - image_wk; - if (fuse_relu_before_conv) { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - max(0.0f, input_data[input_id]); +#define gaid_nhwc(N, H, W, C) \ + ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C)) + int input_id; + if (data_layout == DataLayout::kNCHW) { + input_id = ((bid * (gridDim.z / filter_multiplier) + + kernel_id / filter_multiplier) * + input_height + + image_hk) * + input_width + + image_wk; + if (fuse_relu_before_conv) { + s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * + max(0.0f, input_data[input_id]); + } else { + s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * + input_data[input_id]; + } } else { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - input_data[input_id]; + input_id = + ((bid * input_height + image_hk) * input_width + image_wk) * + (gridDim.z / filter_multiplier) + + kernel_id / filter_multiplier; + if (fuse_relu_before_conv) { + s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] * + max(0.0f, input_data[input_id]); + } else { + s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] * + input_data[input_id]; + } } #undef gaid @@ -403,21 +492,22 @@ __global__ void KernelDepthwiseConvFilterGradSp( const int filter_multiplier, const int filter_height, const int filter_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* filter_grad_data) { + const int dilate_width, T* filter_grad_data, + const DataLayout data_layout = DataLayout::kNCHW) { if (c_filter_multiplier == 0) KernelDepthwiseConvFilterGrad( output_grad_data, input_data, num, output_channels, output_height, output_width, input_channels, input_height, input_width, filter_multiplier, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, dilate_height, - dilate_width, filter_grad_data); + dilate_width, filter_grad_data, data_layout); else KernelDepthwiseConvFilterGrad( output_grad_data, input_data, num, output_channels, output_height, output_width, input_channels, input_height, input_width, c_filter_multiplier, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, dilate_height, - dilate_width, filter_grad_data); + dilate_width, filter_grad_data, data_layout); } /* @@ -434,15 +524,24 @@ class DepthwiseConvFunctor& strides, const std::vector& paddings, - const std::vector& dilations, - framework::Tensor* output) { + const std::vector& dilations, framework::Tensor* output, + const DataLayout data_layout = DataLayout::kNCHW) { const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[1]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output->dims()[1]; - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; + const int input_channels = + (data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]); + const int input_height = + (data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]); + const int input_width = + (data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]); + const int output_channels = + (data_layout == DataLayout::kNCHW ? output->dims()[1] + : output->dims()[3]); + const int output_height = + (data_layout == DataLayout::kNCHW ? output->dims()[2] + : output->dims()[1]); + const int output_width = + (data_layout == DataLayout::kNCHW ? output->dims()[3] + : output->dims()[2]); const int ksize_height = filter.dims()[2]; const int ksize_width = filter.dims()[3]; const int stride_height = strides[0]; @@ -478,7 +577,7 @@ class DepthwiseConvFunctor& strides, const std::vector& paddings, const std::vector& dilations, - framework::Tensor* input_grad) { + framework::Tensor* input_grad, + const DataLayout data_layout = DataLayout::kNCHW) { const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[1]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output_grad.dims()[1]; - const int output_height = output_grad.dims()[2]; - const int output_width = output_grad.dims()[3]; + const int input_channels = + (data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]); + const int input_height = + (data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]); + const int input_width = + (data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]); + const int output_channels = + (data_layout == DataLayout::kNCHW ? output_grad.dims()[1] + : output_grad.dims()[3]); + const int output_height = + (data_layout == DataLayout::kNCHW ? output_grad.dims()[2] + : output_grad.dims()[1]); + const int output_width = + (data_layout == DataLayout::kNCHW ? output_grad.dims()[3] + : output_grad.dims()[2]); const int ksize_height = filter.dims()[2]; const int ksize_width = filter.dims()[3]; const int stride_height = strides[0]; @@ -556,7 +665,8 @@ class DepthwiseConvInputGradFunctor& strides, const std::vector& paddings, const std::vector& dilations, - framework::Tensor* filter_grad) { + framework::Tensor* filter_grad, + const DataLayout data_layout = DataLayout::kNCHW) { const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[1]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output_grad.dims()[1]; - const int output_height = output_grad.dims()[2]; - const int output_width = output_grad.dims()[3]; + const int input_channels = + (data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]); + const int input_height = + (data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]); + const int input_width = + (data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]); + const int output_channels = + (data_layout == DataLayout::kNCHW ? output_grad.dims()[1] + : output_grad.dims()[3]); + const int output_height = + (data_layout == DataLayout::kNCHW ? output_grad.dims()[2] + : output_grad.dims()[1]); + const int output_width = + (data_layout == DataLayout::kNCHW ? output_grad.dims()[3] + : output_grad.dims()[2]); const int ksize_height = filter_grad->dims()[2]; const int ksize_width = filter_grad->dims()[3]; const int stride_height = strides[0]; @@ -629,7 +749,7 @@ class DepthwiseConvFilterGradFunctor& strides, const std::vector& paddings, - const std::vector& dilations, framework::Tensor* output); + const std::vector& dilations, framework::Tensor* output, + const DataLayout data_layout = DataLayout::kNCHW); }; template & strides, const std::vector& paddings, const std::vector& dilations, - framework::Tensor* input_grad); + framework::Tensor* input_grad, + const DataLayout data_layout = DataLayout::kNCHW); }; template & strides, const std::vector& paddings, const std::vector& dilations, - framework::Tensor* filter_grad); + framework::Tensor* filter_grad, + const DataLayout data_layout = DataLayout::kNCHW); }; } // namespace math diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index fe646ea2e779e85614d4ac3e0295775c58100b1b..4736c78fe98b1338b4ed0290e51e4bc39db7554c 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -32,7 +32,8 @@ class Im2ColFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* col) { + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout) { PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); PADDLE_ENFORCE_EQ(col->dims().size(), 5, "The dimension of col should be 5."); @@ -41,16 +42,16 @@ class Im2ColFunctor(im, col); + im2col_sh1sw1dh1dw1ph0pw0(im, col, data_layout); return; } else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 && padding[3] == 1) { - im2col_sh1sw1dh1dw1ph1pw1(im, col); + im2col_sh1sw1dh1dw1ph1pw1(im, col, data_layout); return; } // TODO(TJ): complete padding >=2 } - im2col_common(im, dilation, stride, padding, col); + im2col_common(im, dilation, stride, padding, col, data_layout); } }; @@ -67,13 +68,17 @@ class Col2ImFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* im) { + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout) { PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); PADDLE_ENFORCE_EQ(col.dims().size(), 5, "The dimension of col should be 5."); - int im_channels = im->dims()[0]; - int im_height = im->dims()[1]; - int im_width = im->dims()[2]; + int im_channels = + (data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]); + int im_height = + (data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]); + int im_width = + (data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]); int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; @@ -109,7 +114,15 @@ class Col2ImFunctor= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { - im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] += + int im_offset; + if (data_layout == DataLayout::kNCHW) { + im_offset = + (c_im * im_height + im_row_idx) * im_width + im_col_idx; + } else { + im_offset = + (im_row_idx * im_width + im_col_idx) * im_channels + c_im; + } + im_data[im_offset] += col_data[(c * col_height + h) * col_width + w]; } } @@ -139,7 +152,8 @@ class Im2ColFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* col) { + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout) { PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); PADDLE_ENFORCE_EQ(col->dims().size(), 5, "The dimension of col should be 5."); @@ -202,7 +216,8 @@ class Col2ImFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* im) { + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout) { PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); PADDLE_ENFORCE_EQ(col.dims().size(), 5, "The dimension of col should be 5."); diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/fluid/operators/math/im2col.cu index 809014ea3d6ce51fd0dae478e7f0bedca2420412..ffb598dceda3448f3f595e3f2b68836fdcf37200 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/fluid/operators/math/im2col.cu @@ -26,27 +26,41 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height, int im_width, int dilation_h, int dilation_w, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int col_height, int col_width, T* data_col) { + int col_height, int col_width, T* data_col, + const DataLayout data_layout) { + int input_channels = num_outs / col_height / col_width; + int channels_col = input_channels * filter_height * filter_width; const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < num_outs) { - int w_out = index % col_width; - int h_out = (index / col_width) % col_height; - int channel_in = index / col_width / col_height; + int w_out = (data_layout == DataLayout::kNCHW + ? index % col_width + : (index / input_channels) % col_width); + int h_out = (data_layout == DataLayout::kNCHW + ? (index / col_width) % col_height + : (index / input_channels / col_width) % col_height); + int channel_in = + (data_layout == DataLayout::kNCHW ? index / col_width / col_height + : index % input_channels); int channel_out = channel_in * filter_height * filter_width; int h_in = h_out * stride_height - padding_height; int w_in = w_out * stride_width - padding_width; data_col += (channel_out * col_height + h_out) * col_width + w_out; - data_im += (channel_in * im_height + h_in) * im_width + w_in; for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { int rIdx = h_in + i * dilation_h; int cIdx = w_in + j * dilation_w; + int im_idx; + if (data_layout == DataLayout::kNCHW) { + im_idx = (channel_in * im_height + rIdx) * im_width + cIdx; + } else { + im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in; + } *data_col = (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0) ? 0 - : data_im[i * dilation_h * im_width + j * dilation_w]; + : data_im[im_idx]; data_col += col_height * col_width; } } @@ -65,13 +79,18 @@ class Im2ColFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* col) { - PADDLE_ENFORCE_EQ(im.dims().size(), 3); - PADDLE_ENFORCE_EQ(col->dims().size(), 5); - - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col->dims().size(), 5, + "The dimension of col should be 5."); + + int im_channels = + (data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); + int im_height = + (data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); + int im_width = + (data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int col_height = col->dims()[3]; @@ -86,7 +105,8 @@ class Im2ColFunctor<<>>( im.data(), num_outputs, im_height, im_width, dilation[0], dilation[1], filter_height, filter_width, stride[0], stride[1], - padding[0], padding[1], col_height, col_width, col->data()); + padding[0], padding[1], col_height, col_width, col->data(), + data_layout); } }; @@ -95,18 +115,27 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, int dilation_h, int dilation_w, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, int col_height, - int col_width, T* data_im) { + int col_width, T* data_im, + const DataLayout data_layout) { const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; const int d_filter_height = dilation_h * (filter_height - 1) + 1; const int d_filter_width = dilation_w * (filter_width - 1) + 1; + int input_channels = n / im_height / im_width; + if (index < n) { T val = 0; - int w = index % im_width + padding_width; - int h = (index / im_width) % im_height + padding_height; - int c = index / (im_width * im_height); + int w = (data_layout == DataLayout::kNCHW + ? index % im_width + padding_width + : (index / input_channels) % im_width + padding_width); + int h = (data_layout == DataLayout::kNCHW + ? (index / im_width) % im_height + padding_height + : (index / input_channels / im_width) % im_height + + padding_height); + int c = (data_layout == DataLayout::kNCHW ? index / im_width / im_height + : index % input_channels); // compute the start and end of the output int w_col_start = @@ -151,13 +180,18 @@ class Col2ImFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* im) { - PADDLE_ENFORCE_EQ(im->dims().size(), 3); - PADDLE_ENFORCE_EQ(col.dims().size(), 5); - - int im_channels = im->dims()[0]; - int im_height = im->dims()[1]; - int im_width = im->dims()[2]; + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col.dims().size(), 5, + "The dimension of col should be 5."); + + int im_channels = + (data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]); + int im_height = + (data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]); + int im_width = + (data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]); int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; @@ -191,7 +225,8 @@ class Col2ImFunctor<<>>( num_kernels, col.data(), im_height, im_width, dilation[0], dilation[1], filter_height, filter_width, stride[0], stride[1], - padding[0], padding[2], col_height, col_width, im->data()); + padding[0], padding[1], col_height, col_width, im->data(), + data_layout); } }; @@ -248,9 +283,12 @@ class Im2ColFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* col) { - PADDLE_ENFORCE_EQ(im.dims().size(), 3); - PADDLE_ENFORCE_EQ(col->dims().size(), 5); + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col->dims().size(), 5, + "The dimension of col should be 5."); + int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; @@ -330,9 +368,12 @@ class Col2ImFunctor& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* im) { - PADDLE_ENFORCE_EQ(im->dims().size(), 3); - PADDLE_ENFORCE_EQ(col.dims().size(), 5); + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col.dims().size(), 5, + "The dimension of col should be 5."); + int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; diff --git a/paddle/fluid/operators/math/im2col.h b/paddle/fluid/operators/math/im2col.h index 26d94e0f2e6163eb7452cf1fbea5966b4344ace1..3865443170481de53ea4679d43075e14d386bb71 100644 --- a/paddle/fluid/operators/math/im2col.h +++ b/paddle/fluid/operators/math/im2col.h @@ -23,6 +23,8 @@ namespace paddle { namespace operators { namespace math { +using DataLayout = framework::DataLayout; + /* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */ enum class ColFormat { kCFO = 0, kOCF = 1 }; @@ -86,7 +88,8 @@ class Im2ColFunctor { void operator()(const DeviceContext& context, const framework::Tensor& im, const std::vector& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* col); + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW); }; template @@ -95,7 +98,8 @@ class Col2ImFunctor { void operator()(const DeviceContext& context, const framework::Tensor& col, const std::vector& dilation, const std::vector& stride, - const std::vector& padding, framework::Tensor* im); + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout = DataLayout::kNCHW); }; } // namespace math diff --git a/paddle/fluid/operators/math/im2col_cfo_cpu.h b/paddle/fluid/operators/math/im2col_cfo_cpu.h index 0d32bc5bd0d7f25479370959cabeb9b9c9e7e2d6..bd42bd1a186198fd72f33e3fde62d399a56e0f08 100644 --- a/paddle/fluid/operators/math/im2col_cfo_cpu.h +++ b/paddle/fluid/operators/math/im2col_cfo_cpu.h @@ -30,10 +30,14 @@ inline void im2col_common(const framework::Tensor& im, const std::vector& dilation, const std::vector& stride, const std::vector& padding, - framework::Tensor* col) { - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW) { + int im_channels = + (data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); + int im_height = + (data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); + int im_width = + (data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int output_height = col->dims()[3]; @@ -50,8 +54,14 @@ inline void im2col_common(const framework::Tensor& im, int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; for (int w = 0; w < output_width; ++w) { int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; + int im_idx; + if (data_layout == DataLayout::kNCHW) { + im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; + } else { + im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im; + } int col_idx = (c * output_height + h) * output_width + w; - int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 || im_col_idx >= im_width) ? static_cast(0) @@ -65,11 +75,15 @@ inline void im2col_common(const framework::Tensor& im, * im2col algorithm with strides == 1, dilations == 1, paddings == 0 */ template -inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im, - framework::Tensor* col) { - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; +inline void im2col_sh1sw1dh1dw1ph0pw0( + const framework::Tensor& im, framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW) { + int im_channels = + (data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); + int im_height = + (data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); + int im_width = + (data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int output_height = col->dims()[3]; @@ -89,7 +103,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im, const T* src_data = src_data_ic; for (int kh = 0; kh < filter_height; ++kh) { for (int kw = 0; kw < filter_width; ++kw) { - std::memcpy(dst_data, src_data + kw, copy_size); + if (data_layout == DataLayout::kNCHW) { + std::memcpy(dst_data, src_data + kw, copy_size); + } else { + for (int kow = 0; kow < output_width; ++kow) { + dst_data[kow] = + im_data[((oh + kh) * im_width + kw + kow) * im_channels + ic]; + } + } dst_data = dst_data + col_matrix_width; } src_data = src_data + im_width; @@ -107,10 +128,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im, */ template inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, - framework::Tensor* col) { - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + framework::Tensor* col, + const DataLayout data_layout) { + int im_channels = + (data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); + int im_height = + (data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); + int im_width = + (data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int output_height = col->dims()[3]; @@ -180,7 +205,17 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, dst_data = dst_data + col_matrix_width; continue; } - std::memcpy(dst_data + plw, src_data, copy_size); + if (data_layout == DataLayout::kNCHW) { + std::memcpy(dst_data + plw, src_data, copy_size); + } else { + for (int kow = 0; kow < output_width - plw - prw; ++kow) { + dst_data[plw + kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kow) * + im_channels + + ic]; + } + } dst_data = dst_data + col_matrix_width; src_data = src_data + im_width; } @@ -226,19 +261,49 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, // TODO(TJ): reuse plw-kw outside this for // try to unify for (int kw = 0; kw < plw; ++kw) { - std::memcpy(dst_data + (plw - kw), src_data, - sizeof(T) * (output_width - (plw - kw))); + if (data_layout == DataLayout::kNCHW) { + std::memcpy(dst_data + (plw - kw), src_data, + sizeof(T) * (output_width - (plw - kw))); + } else { + for (int kow = 0; kow < output_width - (plw - kw); ++kow) { + dst_data[plw - kw + kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kow) * + im_channels + + ic]; + } + } dst_data = dst_data + col_matrix_width; } for (int kw = plw; kw < filter_width - prw; ++kw) { - std::memcpy(dst_data, src_data + (kw - plw), - sizeof(T) * output_width); + if (data_layout == DataLayout::kNCHW) { + std::memcpy(dst_data, src_data + (kw - plw), + sizeof(T) * output_width); + } else { + for (int kow = 0; kow < output_width; ++kow) { + dst_data[kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kw - plw + kow) * + im_channels + + ic]; + } + } dst_data = dst_data + col_matrix_width; } int i = 1; for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) { - std::memcpy(dst_data, src_data + (kw - plw), - sizeof(T) * (output_width - i)); + if (data_layout == DataLayout::kNCHW) { + std::memcpy(dst_data, src_data + (kw - plw), + sizeof(T) * (output_width - i)); + } else { + for (int kow = 0; kow < output_width - i; ++kow) { + dst_data[kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kw - plw + kow) * + im_channels + + ic]; + } + } dst_data = dst_data + col_matrix_width; } src_data = src_data + im_width; diff --git a/paddle/fluid/operators/math/vol2col.cc b/paddle/fluid/operators/math/vol2col.cc index 1083cac3020162e9e31fdfe4091db722d7847515..da051034da4252e9ba78db08a4fef70ef1d6bfc3 100644 --- a/paddle/fluid/operators/math/vol2col.cc +++ b/paddle/fluid/operators/math/vol2col.cc @@ -32,16 +32,21 @@ class Vol2ColFunctor { const framework::Tensor& vol, const std::vector& dilations, const std::vector& strides, - const std::vector& paddings, - framework::Tensor* col) const { + const std::vector& paddings, framework::Tensor* col, + const DataLayout data_layout) const { PADDLE_ENFORCE_EQ(vol.dims().size(), 4, "The dimension of vol should be 4."); PADDLE_ENFORCE_EQ(col->dims().size(), 7, "The dimension of col should be 7."); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + + int input_channels = + (data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]); + int input_depth = + (data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]); + int input_height = + (data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]); + int input_width = + (data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]); int filter_depth = col->dims()[1]; int filter_height = col->dims()[2]; int filter_width = col->dims()[3]; @@ -59,6 +64,7 @@ class Vol2ColFunctor { int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back - ((dilations[0] * (filter_depth - 1) + 1))) / strides[0] + @@ -97,10 +103,16 @@ class Vol2ColFunctor { int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; - int vol_idx = - ((c_in * input_depth + d_pad) * input_height + h_pad) * - input_width + - w_pad; + int vol_idx; + if (data_layout == DataLayout::kNCHW) { + vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + } else { + vol_idx = ((d_pad * input_height + h_pad) * input_width + w_pad) * + input_channels + + c_in; + } col_data[col_idx] = (h_pad < 0 || h_pad >= input_height || w_pad < 0 || w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) @@ -126,16 +138,21 @@ class Col2VolFunctor { const framework::Tensor& col, const std::vector& dilations, const std::vector& strides, - const std::vector& paddings, - framework::Tensor* vol) const { + const std::vector& paddings, framework::Tensor* vol, + const DataLayout data_layout) const { PADDLE_ENFORCE_EQ(vol->dims().size(), 4, "The dimension of vol should be 4."); PADDLE_ENFORCE_EQ(col.dims().size(), 7, "The dimension of col should be 7."); - int input_channels = vol->dims()[0]; - int input_depth = vol->dims()[1]; - int input_height = vol->dims()[2]; - int input_width = vol->dims()[3]; + + int input_channels = + (data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]); + int input_depth = + (data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]); + int input_height = + (data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]); + int input_width = + (data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]); int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -191,11 +208,17 @@ class Col2VolFunctor { if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { - int vol_idx = - ((cIm * input_depth + d_pad) * input_height + h_pad) * - input_width + - w_pad; - + int vol_idx; + if (data_layout == DataLayout::kNCHW) { + vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + } else { + vol_idx = + ((d_pad * input_height + h_pad) * input_width + w_pad) * + input_channels + + cIm; + } int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; diff --git a/paddle/fluid/operators/math/vol2col.cu b/paddle/fluid/operators/math/vol2col.cu index a167a9021bc5bf865c364f9fd3d332db0895289d..b42dd55bda51e56c7c7994af1b745fba3fac1691 100644 --- a/paddle/fluid/operators/math/vol2col.cu +++ b/paddle/fluid/operators/math/vol2col.cu @@ -28,7 +28,12 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, int filter_width, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width, int output_detph, int output_height, - int output_width, T* data_col) { + int output_width, T* data_col, + const DataLayout data_layout) { + int input_channels = + num_kernels / output_detph / output_height / output_width; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { int w_out = index % output_width; @@ -43,18 +48,22 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, data_col += ((channel_out * output_detph + d_out) * output_height + h_out) * output_width + w_out; - data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in; for (int k = 0; k < filter_depth; ++k) { for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { int d = d_in + k * dilation_d; int h = h_in + i * dilation_h; int w = w_in + j * dilation_w; - int col_idx = (k * dilation_d * height + i * dilation_h) * width + - j * dilation_w; + int vol_idx; + if (data_layout == DataLayout::kNCHW) { + vol_idx = ((channel_in * depth + d) * height + h) * width + w; + } else { + vol_idx = + ((d * height + h) * width + w) * input_channels + channel_in; + } *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width) - ? data_vol[col_idx] + ? data_vol[vol_idx] : 0; data_col += output_detph * output_height * output_width; } @@ -64,7 +73,10 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, } /* - * im = [input_channels,intpu_depth, input_height, input_width] + * im = [input_channels,intpu_depth, input_height, input_width] for + * channels_first + * im = [input_depth, input_height, input_width, input_channels] for + * channels_last * col = * [input_channels, filter_depth, filter_height, filter_width, * output_depth, output_height, output_width] @@ -76,15 +88,21 @@ class Vol2ColFunctor { const framework::Tensor& vol, const std::vector& dilations, const std::vector& strides, - const std::vector& paddings, - framework::Tensor* col) const { - PADDLE_ENFORCE_EQ(vol.dims().size(), 4); - PADDLE_ENFORCE_EQ(col->dims().size(), 7); + const std::vector& paddings, framework::Tensor* col, + const DataLayout data_layout) const { + PADDLE_ENFORCE_EQ(vol.dims().size(), 4, + "The dimension of vol should be 4."); + PADDLE_ENFORCE_EQ(col->dims().size(), 7, + "The dimension of col should be 7."); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + int input_channels = + (data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]); + int input_depth = + (data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]); + int input_height = + (data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]); + int input_width = + (data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]); int filter_depth = col->dims()[1]; int filter_height = col->dims()[2]; int filter_width = col->dims()[3]; @@ -130,7 +148,8 @@ class Vol2ColFunctor { num_outputs, vol.data(), input_depth, input_height, input_width, dilations[0], dilations[1], dilations[2], filter_depth, filter_height, filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, - pad_w_left, output_depth, output_height, output_width, col->data()); + pad_w_left, output_depth, output_height, output_width, col->data(), + data_layout); } }; @@ -141,18 +160,27 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, int filter_width, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width, int output_detph, int output_height, - int output_width, T* data_vol) { + int output_width, T* data_vol, + const DataLayout data_layout) { const int d_filter_depth = dilation_d * (filter_depth - 1) + 1; const int d_filter_height = dilation_h * (filter_height - 1) + 1; const int d_filter_width = dilation_w * (filter_width - 1) + 1; + int input_channels = num_kernels / depth / height / width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { T src_val = 0; - int w = index % width + padding_width; - int h = (index / width) % height + padding_height; - int d = (index / width / height) % depth + padding_depth; - int c = index / width / height / depth; + int w = (data_layout == DataLayout::kNCHW + ? index % width + padding_width + : (index / input_channels) % width + padding_width); + int h = (data_layout == DataLayout::kNCHW + ? (index / width) % height + padding_height + : (index / input_channels / width) % height + padding_height); + int d = (data_layout == DataLayout::kNCHW + ? (index / width / height) % depth + padding_depth + : index / input_channels / width / height + padding_depth); + int c = (data_layout == DataLayout::kNCHW ? index / width / height / depth + : index % input_channels); // compute the start and end of the output int w_col_start = @@ -196,7 +224,10 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, } /* - * im = [input_channels, input_depth, input_height, input_width] + * im = [input_channels,intpu_depth, input_height, input_width] for + * channels_first + * im = [input_depth, input_height, input_width, input_channels] for + * channels_last * col = * [input_channels, filter_depth, filter_height, filter_width, * output_depth, output_height, output_width] @@ -208,15 +239,21 @@ class Col2VolFunctor { const framework::Tensor& col, const std::vector& dilations, const std::vector& strides, - const std::vector& paddings, - framework::Tensor* vol) const { - PADDLE_ENFORCE_EQ(vol->dims().size(), 4); - PADDLE_ENFORCE_EQ(col.dims().size(), 7); + const std::vector& paddings, framework::Tensor* vol, + const DataLayout data_layout) const { + PADDLE_ENFORCE_EQ(vol->dims().size(), 4, + "The dimension of vol should be 4."); + PADDLE_ENFORCE_EQ(col.dims().size(), 7, + "The dimension of col should be 7."); - int input_channels = vol->dims()[0]; - int input_depth = vol->dims()[1]; - int input_height = vol->dims()[2]; - int input_width = vol->dims()[3]; + int input_channels = + (data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]); + int input_depth = + (data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]); + int input_height = + (data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]); + int input_width = + (data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]); int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -263,7 +300,8 @@ class Col2VolFunctor { num_kernels, col.data(), input_depth, input_height, input_width, dilations[0], dilations[1], dilations[2], filter_depth, filter_height, filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, - pad_w_left, output_depth, output_height, output_width, vol->data()); + pad_w_left, output_depth, output_height, output_width, vol->data(), + data_layout); } }; diff --git a/paddle/fluid/operators/math/vol2col.h b/paddle/fluid/operators/math/vol2col.h index 5f59de8f02a52209a3901ca03680eb2d0dbc2658..3122828b2eeba5fb1428235dd3a5f926705bd78e 100644 --- a/paddle/fluid/operators/math/vol2col.h +++ b/paddle/fluid/operators/math/vol2col.h @@ -22,6 +22,9 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { + +using DataLayout = framework::DataLayout; + /* * \brief Converts the feature data of four dimensions(CDHW) into a colData of * seven dimensions in the Vol2ColFunctor calculation, @@ -70,8 +73,8 @@ class Vol2ColFunctor { void operator()(const DeviceContext& context, const framework::Tensor& vol, const std::vector& dilations, const std::vector& strides, - const std::vector& paddings, - framework::Tensor* col) const; + const std::vector& paddings, framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW) const; }; template @@ -80,8 +83,8 @@ class Col2VolFunctor { void operator()(const DeviceContext& context, const framework::Tensor& col, const std::vector& dilations, const std::vector& strides, - const std::vector& paddings, - framework::Tensor* vol) const; + const std::vector& paddings, framework::Tensor* vol, + const DataLayout data_layout = DataLayout::kNCHW) const; }; } // namespace math diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c7bc28091df3c97b20be081899bfa8020ffe8c6f..e4c599d519907868b66d6fa97a48da49a2b96de6 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4424,13 +4424,14 @@ def conv2d_transpose(input, bias_attr=None, use_cudnn=True, act=None, - name=None): + name=None, + data_format='NCHW'): """ **Convlution2D transpose layer** The convolution2D transpose layer calculates the output based on the input, filter, and dilations, strides, paddings. Input(Input) and output(Output) - are in NCHW format. Where N is batch size, C is the number of channels, + are in NCHW or NHWC format. Where N is batch size, C is the number of channels, H is the height of the feature, and W is the width of the feature. Parameters(dilations, strides, paddings) are two elements. These two elements represent height and width, respectively. The details of convolution transpose @@ -4448,12 +4449,12 @@ def conv2d_transpose(input, Where: - * :math:`X`: Input value, a tensor with NCHW format. - * :math:`W`: Filter value, a tensor with MCHW format. + * :math:`X`: Input value, a 4-D Tensor with NCHW or NHWC format. + * :math:`W`: Filter value, a 4-D Tensor with MCHW format. * :math:`\\ast`: Convolution operation. - * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. + * :math:`b`: Bias value, a 2-D Tensor with shape [M, 1]. * :math:`\\sigma`: Activation function. - * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + * :math:`Out`: Output value, a 4-D Tensor with data format 'NCHW' or 'NHWC', the shape of :math:`Out` and :math:`X` may be different. Example: @@ -4471,10 +4472,12 @@ def conv2d_transpose(input, .. math:: - H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\ - W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\ + H^\prime_{out} &= (H_{in} - 1) * strides[0] - pad_height_top - pad_height_bottom + dilations[0] * (H_f - 1) + 1 \\\\ + W^\prime_{out} &= (W_{in} - 1) * strides[1] - pad_width_left - pad_width_right + dilations[1] * (W_f - 1) + 1 \\\\ H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ] \\\\ - W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ] + W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ] + + padding mode is 'SAME' and 'VALID' can reference this link`_ Note: if output_size is None, :math:`H_{out} = H^\prime_{out}, W_{out} = W^\prime_{out}`; @@ -4484,51 +4487,63 @@ def conv2d_transpose(input, conv2d_transpose can compute the kernel size automatically. Args: - input(Variable): The input image with [N, C, H, W] format. + input(Variable): 4-D Tensor with [N, C, H, W] or [N, H, W, C] format, + its data type is float32 or float64. num_filters(int): The number of the filter. It is as same as the output image channel. - output_size(int|tuple|None): The output image size. If output size is a + output_size(int|tuple, optional): The output image size. If output size is a tuple, it must contain two integers, (image_height, image_width). None if use filter_size, padding, and stride to calculate output_size. if output_size and filter_size are specified at the same time, They - should follow the formula above. - filter_size(int|tuple|None): The filter size. If filter_size is a tuple, + should follow the formula above. Default: None. + filter_size(int|tuple, optional): The filter size. If filter_size is a tuple, it must contain two integers, (filter_size_height, filter_size_width). Otherwise, filter_size_height = filter_size_width = filter_size. None if - use output size to calculate filter_size. - padding(int|tuple): The padding size. If padding is a tuple, it must - contain two integers, (padding_height, padding_width). Otherwise, - padding_height = padding_width = padding. Default: padding = 0. - stride(int|tuple): The stride size. If stride is a tuple, it must + use output size to calculate filter_size. Default: None. + padding(int|list|str|tuple, optional):The padding size. If `padding` is a + string, either 'VALID' or 'SAME' supported, which is the padding algorithm. + If `padding` is a tuple or list, it could be in three forms: + `[pad_height, pad_width]` or + `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and + when `data_format` is `'NCHW'`, + `padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `'NHWC'`, `padding` can be in the form + `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. + stride(int|tuple, optional): The stride size. If stride is a tuple, it must contain two integers, (stride_height, stride_width). Otherwise, stride_height = stride_width = stride. Default: stride = 1. - dilation(int|tuple): The dilation size. If dilation is a tuple, it must + dilation(int|tuple, optional): The dilation size. If dilation is a tuple, it must contain two integers, (dilation_height, dilation_width). Otherwise, dilation_height = dilation_width = dilation. Default: dilation = 1. - groups(int): The groups number of the Conv2d transpose layer. Inspired by + groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by grouped convolution in Alex Krizhevsky's Deep CNN paper, in which when group=2, the first half of the filters is only connected to the first half of the input channels, while the second half of the filters is only connected to the second half of the input channels. Default: groups = 1. - param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + param_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. - bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose. + bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of conv2d_transpose. If it is set to False, no bias will be added to the output units. If it is set to None or one attribute of ParamAttr, conv2d_transpose will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn + use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: True. - act (str): Activation type, if it is set to None, activation is not appended. + act (str, optional): Activation type, if it is set to None, activation is not appended. Default: None. - name(str|None): A name for this layer(optional). If set None, the layer + name(str, optional): A name for this layer(optional). If set None, the layer will be named automatically. Default: True. + data_format(str, optional): The data format of the input and output data. An optional string + from: `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. Default: 'NCHW'. Returns: - Variable: The tensor variable storing the convolution transpose result. + Variable: A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or + (num_batches, out_h, out_w, channels). Raises: ValueError: If the shapes of input, filter_size, stride, padding and @@ -4542,8 +4557,12 @@ def conv2d_transpose(input, conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3) """ assert param_attr is not False, "param_attr should not be False in conv2d_transpose." - input_channel = input.shape[1] + if data_format not in ['NCHW', 'NHWC']: + raise ValueError( + "Attr(data_format) of Op(fluid.layers.conv2d_transpose) got wrong value: received " + + data_format + " but only NCHW or NHWC supported.") + input_channel = input.shape[1] if data_format == 'NCHW' else input.shape[-1] op_type = 'conv2d_transpose' if (input_channel == groups and num_filters == input_channel and not use_cudnn): @@ -4553,26 +4572,68 @@ def conv2d_transpose(input, if not isinstance(input, Variable): raise TypeError("Input of conv2d_transpose must be Variable") - padding = utils.convert_to_list(padding, 2, 'padding') stride = utils.convert_to_list(stride, 2, 'stride') dilation = utils.convert_to_list(dilation, 2, 'dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") + def _update_padding(padding, data_format): + def is_list_or_tuple(ele): + if isinstance(ele, list) or isinstance(ele, tuple): + return True + return False + + if is_list_or_tuple(padding) and len(padding) == 4: + if is_list_or_tuple(padding[0]) and (data_format == "NCHW"): + if not (padding[0] == [0, 0] and padding[1] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[2:4] + padding = [ele for a_list in padding for ele in a_list] + elif is_list_or_tuple(padding[0]) and (data_format == "NHWC"): + if not (padding[0] == [0, 0] and padding[3] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[1:3] + padding = [ele for a_list in padding for ele in a_list] + padding = utils.convert_to_list(padding, 4, 'padding') + else: + padding = utils.convert_to_list(padding, 2, 'padding') + padding = [padding[0], padding[0], padding[1], padding[1]] + return padding + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." % + str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0, 0, 0] + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0, 0, 0] + + padding = _update_padding(padding, data_format) + if filter_size is None: if output_size is None: raise ValueError("output_size must be set when filter_size is None") if isinstance(output_size, int): output_size = [output_size, output_size] - h_in = input.shape[2] - w_in = input.shape[3] + h_in = input.shape[2] if data_format == 'NCHW' else input.shape[1] + w_in = input.shape[3] if data_format == 'NCHW' else input.shape[2] - filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 * - padding[0] - 1) // dilation[0] + 1 - filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * - padding[1] - 1) // dilation[1] + 1 + filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + padding[0] + + padding[1] - 1) // dilation[0] + 1 + filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + padding[2] + + padding[3] - 1) // dilation[1] + 1 filter_size = [filter_size_h, filter_size_w] else: filter_size = utils.convert_to_list(filter_size, 2, @@ -4584,7 +4645,6 @@ def conv2d_transpose(input, output_size = utils.convert_to_list(output_size, 2, 'output_size') else: raise ValueError("output_size should be list or int") - padding = utils.convert_to_list(padding, 2, 'padding') groups = 1 if groups is None else groups filter_shape = [input_channel, num_filters // groups] + filter_size @@ -4601,9 +4661,11 @@ def conv2d_transpose(input, 'output_size': output_size, 'strides': stride, 'paddings': padding, + 'padding_algorithm': padding_algorithm, 'dilations': dilation, 'groups': groups, - 'use_cudnn': use_cudnn + 'use_cudnn': use_cudnn, + 'data_format': data_format }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) @@ -4623,13 +4685,14 @@ def conv3d_transpose(input, bias_attr=None, use_cudnn=True, act=None, - name=None): + name=None, + data_format='NCDHW'): """ **Convlution3D transpose layer** The convolution3D transpose layer calculates the output based on the input, filter, and dilations, strides, paddings. Input(Input) and output(Output) - are in NCDHW format. Where N is batch size, C is the number of channels, + are in NCDHW or NDHWC format. Where N is batch size, C is the number of channels, D is the depth of the feature, H is the height of the feature, and W is the width of the feature. Parameters(dilations, strides, paddings) are two elements. These two elements represent height and width, respectively. @@ -4647,10 +4710,10 @@ def conv3d_transpose(input, In the above equation: - * :math:`X`: Input value, a tensor with NCDHW format. - * :math:`W`: Filter value, a tensor with MCDHW format. + * :math:`X`: Input value, a Tensor with NCDHW or NDHWC format. + * :math:`W`: Filter value, a Tensor with MCDHW format. * :math:`\\ast`: Convolution operation. - * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. + * :math:`b`: Bias value, a 2-D Tensor with shape [M, 1]. * :math:`\\sigma`: Activation function. * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. @@ -4670,55 +4733,68 @@ def conv3d_transpose(input, .. math:: - D_{out} &= (D_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (D_f - 1) + 1 \\\\ - H_{out} &= (H_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (H_f - 1) + 1 \\\\ - W_{out} &= (W_{in} - 1) * strides[2] - 2 * paddings[2] + dilations[2] * (W_f - 1) + 1 + D_{out} &= (D_{in} - 1) * strides[0] - pad_depth_front - pad_depth_back + dilations[0] * (D_f - 1) + 1 \\\\ + H_{out} &= (H_{in} - 1) * strides[1] - pad_height_top - pad_height_bottom + dilations[1] * (H_f - 1) + 1 \\\\ + W_{out} &= (W_{in} - 1) * strides[2] - pad_width_left - pad_width_right + dilations[2] * (W_f - 1) + 1 + + Padding mode is 'SAME' and 'VALID' can reference this + link`_ Args: - input(Variable): The input image with [N, C, D, H, W] format. + input(Variable): A 5-D Tensor with [N, C, H, W] or [N, H, W, C] format. Its data type is float32 or float64. num_filters(int): The number of the filter. It is as same as the output image channel. - output_size(int|tuple|None): The output image size. If output size is a + output_size(int|tuple, optional): The output image size. If output size is a tuple, it must contain three integers, (image_D, image_H, image_W). This parameter only works when filter_size is None. - filter_size(int|tuple|None): The filter size. If filter_size is a tuple, + filter_size(int|tuple, optional): The filter size. If filter_size is a tuple, it must contain three integers, (filter_size_depth, filter_size_height, \ filter_size_width). Otherwise, filter_size_depth = filter_size_height = \ filter_size_width = filter_size. None if use output size to calculate filter_size. - padding(int|tuple): The padding size. If padding is a tuple, it must - contain three integers, (padding_depth, padding_height, padding_width). Otherwise, - padding_depth = padding_height = padding_width = padding. Default: padding = 0. - stride(int|tuple): The stride size. If stride is a tuple, it must + padding(int|list|str|tuple, optional): The padding size. if `padding` is a string, + either 'VALID' or 'SAME' supported, which is the padding algorithm. If `padding` + is a tuple or list, it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `'NCDHW'`, `padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `'NDHWC'`, `padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. + stride(int|tuple, optional): The stride size. If stride is a tuple, it must contain three integers, (stride_depth, stride_height, stride_width). Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1. - dilation(int|tuple): The dilation size. If dilation is a tuple, it must + dilation(int|tuple, optional): The dilation size. If dilation is a tuple, it must contain three integers, (dilation_depth, dilation_height, dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation. Default: dilation = 1. - groups(int): The groups number of the Conv3d transpose layer. Inspired by + groups(int, optional): The groups number of the Conv3d transpose layer. Inspired by grouped convolution in Alex Krizhevsky's Deep CNN paper, in which when group=2, the first half of the filters is only connected to the first half of the input channels, while the second half of the filters is only connected to the second half of the input channels. Default: groups=1 - param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + param_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights of conv3d_transpose. If it is set to None or one attribute of ParamAttr, conv3d_transpose will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. - bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv3d_transpose. + bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of conv3d_transpose. If it is set to False, no bias will be added to the output units. If it is set to None or one attribute of ParamAttr, conv3d_transpose will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn + use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn library is installed. Default: True - act (str): Activation type, if it is set to None, activation is not appended. + act (str, optional): Activation type, if it is set to None, activation is not appended. Default: None. - name(str|None): A name for this layer(optional). If set None, the layer + name(str, optional): A name for this layer(optional). If set None, the layer will be named automatically. + data_format(str, optional):The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. + When it is `"NCHW"`, the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`. + Default: 'NCDHW'. Returns: - Variable: The tensor variable storing the convolution transpose result. + A 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or + (num_batches, out_d, out_h, out_w, channels). Raises: ValueError: If the shapes of input, filter_size, stride, padding and @@ -4732,35 +4808,89 @@ def conv3d_transpose(input, conv3d_transpose = fluid.layers.conv3d_transpose(input=data, num_filters=2, filter_size=3) """ assert param_attr is not False, "param_attr should not be False in conv3d_transpose." + if data_format not in ['NCDHW', 'NDHWC']: + raise ValueError( + "Param(data_format) of Op(fluid.layers.conv3d_transpose) got wrong value: received " + + data_format + " but only NCDHW or NDHWC supported.") l_type = "conv3d_transpose" helper = LayerHelper(l_type, **locals()) if not isinstance(input, Variable): raise TypeError("Input of conv3d_transpose must be Variable") - input_channel = input.shape[1] + input_channel = input.shape[1] if data_format == 'NCDHW' else input.shape[ + -1] - padding = utils.convert_to_list(padding, 3, 'padding') stride = utils.convert_to_list(stride, 3, 'stride') dilation = utils.convert_to_list(dilation, 3, 'dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") + def _update_padding(padding, data_format): + def is_list_or_tuple(ele): + if isinstance(ele, list) or isinstance(ele, tuple): + return True + return False + + if is_list_or_tuple(padding) and len(padding) == 5: + if is_list_or_tuple(padding[0]) and (data_format == "NCDHW"): + if not (padding[0] == [0, 0] and padding[1] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[2:5] + padding = [ele for a_list in padding for ele in a_list] + elif is_list_or_tuple(padding[0]) and (data_format == "NDHWC"): + if not (padding[0] == [0, 0] and padding[4] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[1:4] + padding = [ele for a_list in padding for ele in a_list] + padding = utils.convert_to_list(padding, 6, 'padding') + + elif is_list_or_tuple(padding) and len(padding) == 6: + padding = utils.convert_to_list(padding, 6, 'padding') + else: + padding = utils.convert_to_list(padding, 3, 'padding') + padding = [ + padding[0], padding[0], padding[1], padding[1], padding[2], + padding[2] + ] + + return padding + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." % + str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0, 0, 0, 0, 0] + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0, 0, 0, 0, 0] + + padding = _update_padding(padding, data_format) + if filter_size is None: if output_size is None: raise ValueError("output_size must be set when filter_size is None") if isinstance(output_size, int): output_size = [output_size, output_size] - d_in = input.shape[2] - h_in = input.shape[3] - w_in = input.shape[4] + d_in = input.shape[2] if data_format == 'NCDHW' else input.shape[1] + h_in = input.shape[3] if data_format == 'NCDHW' else input.shape[2] + w_in = input.shape[4] if data_format == 'NCDHW' else input.shape[3] - filter_size_d = (output_size[0] - (d_in - 1) * stride[0] + 2 * - padding[0] - 1) // dilation[0] + 1 - filter_size_h = (output_size[1] - (h_in - 1) * stride[1] + 2 * - padding[1] - 1) // dilation[1] + 1 - filter_size_w = (output_size[2] - (w_in - 1) * stride[2] + 2 * - padding[2] - 1) // dilation[2] + 1 + filter_size_d = (output_size[0] - (d_in - 1) * stride[0] + padding[0] + + padding[1] - 1) // dilation[0] + 1 + filter_size_h = (output_size[1] - (h_in - 1) * stride[1] + padding[2] + + padding[3] - 1) // dilation[1] + 1 + filter_size_w = (output_size[2] - (w_in - 1) * stride[2] + padding[4] + + padding[5] - 1) // dilation[2] + 1 filter_size = [filter_size_d, filter_size_h, filter_size_w] else: filter_size = utils.convert_to_list(filter_size, 3, @@ -4771,6 +4901,11 @@ def conv3d_transpose(input, img_filter = helper.create_parameter( dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) + if data_format == 'NCDHW': + data_format = 'NCHW' + if data_format == 'NDHWC': + data_format = 'NHWC' + pre_bias = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( type=l_type, @@ -4780,9 +4915,11 @@ def conv3d_transpose(input, attrs={ 'strides': stride, 'paddings': padding, + 'padding_algorithm': padding_algorithm, 'dilations': dilation, 'groups': groups, - 'use_cudnn': use_cudnn + 'use_cudnn': use_cudnn, + 'data_format': data_format }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index 3b820f6ad716e5717e45d0c6341fb89010406d59..08eb559d957a77e823d3a9e37541a53f5ab492f0 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -18,10 +18,19 @@ import unittest import numpy as np import paddle.fluid.core as core +import paddle.fluid as fluid from op_test import OpTest def conv2dtranspose_forward_naive(input_, filter_, attrs): + padding_algorithm = attrs['padding_algorithm'] + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError("Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % + str(padding_algorithm)) + + if attrs['data_format'] == 'NHWC': + input_ = np.transpose(input_, [0, 3, 1, 2]) in_n, in_c, in_h, in_w = input_.shape f_c, f_out_c, f_h, f_w = filter_.shape groups = attrs['groups'] @@ -31,14 +40,47 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ 'dilations'] + + # update pad and dilation + def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride): + padding = [] + for input_size, filter_size, stride_size in zip( + input_shape, kernel_size, kernel_stride): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max(( + (out_size - 1) * stride_size + filter_size - input_size, 0)) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + ksize = filter_.shape[2:4] + if padding_algorithm == "VALID": + pad = [0, 0, 0, 0] + elif padding_algorithm == "SAME": + dilation = [1, 1] + input_data_shape = [] + if attrs['data_format'] == "NCHW": + input_data_shape = input_.shape[2:4] + elif attrs['data_format'] == "NHWC": + input_data_shape = input_.shape[1:3] + pad = _get_padding_with_SAME(input_data_shape, ksize, stride) + + pad_h_0, pad_h_1 = pad[0], pad[0] + pad_w_0, pad_w_1 = pad[1], pad[1] + if len(pad) == 4: + pad_h_0, pad_h_1 = pad[0], pad[1] + pad_w_0, pad_w_1 = pad[2], pad[3] + d_bolck_h = dilations[0] * (f_h - 1) + 1 d_bolck_w = dilations[1] * (f_w - 1) + 1 out_h = (in_h - 1) * stride[0] + d_bolck_h out_w = (in_w - 1) * stride[1] + d_bolck_w if 'output_size' in attrs: output_size = attrs['output_size'] - out_h = output_size[0] + 2 * pad[0] - out_w = output_size[1] + 2 * pad[1] + out_h = output_size[0] + pad_h_0 + pad_h_1 + out_w = output_size[1] + pad_w_0 + pad_w_1 out = np.zeros((in_n, out_c, out_h, out_w)) @@ -61,7 +103,9 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): out[n, g * f_out_c + k, i1:i2:dilations[0], j1:j2: dilations[1]] += tmp_out - out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]] + out = out[:, :, pad_h_0:out_h - pad_h_1, pad_w_0:out_w - pad_w_1] + if attrs['data_format'] == 'NHWC': + out = np.transpose(out, [0, 2, 3, 1]) return out @@ -72,7 +116,9 @@ class TestConv2dTransposeOp(OpTest): self.use_cudnn = False self.use_mkldnn = False self.output_size = None - self.data_format = "AnyLayout" + self.data_format = "NCHW" + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" self.init_op_type() self.init_test_case() @@ -83,6 +129,7 @@ class TestConv2dTransposeOp(OpTest): self.attrs = { 'strides': self.stride, 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, 'groups': self.groups, 'dilations': self.dilations, 'use_cudnn': self.use_cudnn, @@ -160,7 +207,7 @@ class TestConv2dTransposeOp(OpTest): self.op_type = "conv2d_transpose" -class TestWithPad(TestConv2dTransposeOp): +class TestWithSymmetricPad(TestConv2dTransposeOp): def init_test_case(self): self.pad = [1, 1] self.stride = [1, 1] @@ -171,6 +218,39 @@ class TestWithPad(TestConv2dTransposeOp): self.filter_size = [f_c, 6, 3, 3] +class TestWithAsymmetricPad(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithSAMEPad(TestConv2dTransposeOp): + def init_test_case(self): + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + self.padding_algorithm = 'SAME' + + +class TestWithVALIDPad(TestConv2dTransposeOp): + def init_test_case(self): + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + self.padding_algorithm = 'VALID' + + class TestWithGroups(TestConv2dTransposeOp): def init_test_case(self): self.pad = [1, 1] @@ -216,6 +296,91 @@ class TestWithEvenUpsample(TestConv2dTransposeOp): self.filter_size = [f_c, 6, 5, 5] +class Test_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithSymmetricPad_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithAsymmetricPad_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithGroups_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 4] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithStride_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NCHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithDilation_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [2, 2] + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithEvenUpsample_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + # ------------ test_cudnn ------------ @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") @@ -227,7 +392,7 @@ class TestCUDNN(TestConv2dTransposeOp): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestCUDNNWithPad(TestWithPad): +class TestCUDNNWithSymmetricPad(TestWithSymmetricPad): def init_test_case(self): self.pad = [1, 1] self.stride = [1, 1] @@ -242,6 +407,57 @@ class TestCUDNNWithPad(TestWithPad): self.op_type = "conv2d_transpose" +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithAsymmetricPad(TestWithAsymmetricPad): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithSAMEPad(TestWithSAMEPad): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithVALIDPad(TestWithVALIDPad): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestCUDNNWithStride(TestWithStride): @@ -276,19 +492,6 @@ class TestCUDNNWithGroups(TestWithGroups): self.op_type = "conv2d_transpose" -class TestDepthwiseConvTranspose(TestConv2dTransposeOp): - def init_test_case(self): - self.pad = [1, 1] - self.stride = [2, 2] - self.dilations = [1, 1] - self.input_size = [2, 8, 16, 16] # NCHW - self.groups = 8 - assert np.mod(self.input_size[1], self.groups) == 0 - f_c = self.input_size[1] // self.groups - self.filter_size = [self.input_size[1], f_c, 4, 4] - self.op_type = "depthwise_conv2d_transpose" - - # ------------ test_cudnn ------------ @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") @@ -312,5 +515,334 @@ class TestCUDNNWithEvenUpsample(TestWithEvenUpsample): # def init_op_type(self): # self.op_type = "conv2d_transpose" + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNN_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithSymmetricPad_NHWC(TestWithSymmetricPad): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [1, 1] + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithAsymmetricPad_NHWC(TestWithSymmetricPad): + def init_test_case(self): + self.pad = [1, 0, 2, 3] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithStride_NHWC(TestWithStride): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithGroups_NHWC(TestWithGroups): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 4] # NCHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithEvenUpsample_NHWC(TestWithEvenUpsample): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + +class TestDepthwiseConvTranspose(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 8, 16, 16] # NCHW + self.groups = 8 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [self.input_size[1], f_c, 4, 4] + self.op_type = "depthwise_conv2d_transpose" + + +class TestDepthwiseConvTransposeAsymmetricPad(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 8, 16, 16] # NCHW + self.groups = 8 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [self.input_size[1], f_c, 3, 3] + self.op_type = "depthwise_conv2d_transpose" + self.data_format = 'NCHW' + + +class TestDepthwiseConvTransposeSAMEPad(TestConv2dTransposeOp): + def init_test_case(self): + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 8, 16, 16] # NHWC + self.groups = 8 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [self.input_size[1], f_c, 3, 3] + self.op_type = "depthwise_conv2d_transpose" + self.padding_algorithm = 'SAME' + + +class TestDepthwiseConvTransposeVALIDPad(TestConv2dTransposeOp): + def init_test_case(self): + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 8, 16, 16] # NHWC + self.groups = 8 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [self.input_size[1], f_c, 3, 3] + self.op_type = "depthwise_conv2d_transpose" + self.padding_algorithm = 'VALID' + + +class TestDepthwiseConvTranspose_NHWC_4x4kernel(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 16, 16, 8] # NHWC + self.groups = 8 + assert np.mod(self.input_size[3], self.groups) == 0 + f_c = self.input_size[3] // self.groups + self.filter_size = [self.input_size[3], f_c, 4, 4] + self.op_type = "depthwise_conv2d_transpose" + self.data_format = 'NHWC' + + +class TestDepthwiseConvTranspose_NHWC_3x3kernel(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 16, 16, 8] # NHWC + self.groups = 8 + assert np.mod(self.input_size[3], self.groups) == 0 + f_c = self.input_size[3] // self.groups + self.filter_size = [self.input_size[3], f_c, 3, 3] + self.op_type = "depthwise_conv2d_transpose" + self.data_format = 'NHWC' + + +class TestDepthwiseConvTransposeAsymmetricPad_NHWC(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 16, 16, 8] # NHWC + self.groups = 8 + assert np.mod(self.input_size[3], self.groups) == 0 + f_c = self.input_size[3] // self.groups + self.filter_size = [self.input_size[3], f_c, 3, 3] + self.op_type = "depthwise_conv2d_transpose" + self.data_format = 'NHWC' + + +class TestConv2dTransposeAPI(OpTest): + def test_case1(self): + data1 = fluid.layers.data( + name='data1', shape=[3, 5, 5], dtype='float32') + data2 = fluid.layers.data( + name='data2', shape=[5, 5, 3], dtype='float32') + out1 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + data_format='NCHW') + out2 = fluid.layers.conv2d_transpose( + input=data2, + groups=1, + num_filters=6, + filter_size=3, + data_format='NHWC') + out3 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + padding=[[0, 0], [1, 1], [1, 1], [0, 0]], + data_format='NHWC') + out4 = fluid.layers.conv2d_transpose( + input=data1, + groups=3, + num_filters=6, + filter_size=3, + padding=[[0, 0], [0, 0], [2, 1], [0, 0]], + data_format='NCHW') + out5 = fluid.layers.conv2d_transpose( + input=data2, + groups=1, + num_filters=6, + filter_size=3, + padding='SAME', + data_format='NCHW') + out6 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + padding='VALID', + data_format='NHWC') + out7 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + output_size=[7, 7], + padding=[0, 0], + data_format='NHWC') + + data1_np = np.random.random((2, 3, 5, 5)).astype("float32") + data2_np = np.random.random((2, 5, 5, 3)).astype("float32") + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run( + fluid.default_main_program(), + feed={"data1": data1_np, + "data2": data2_np}, + fetch_list=[out1, out2, out3, out4, out5, out6, out7], + return_numpy=True) + self.assertIsNotNone(results[0]) + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[2]) + self.assertIsNotNone(results[3]) + self.assertIsNotNone(results[4]) + self.assertIsNotNone(results[5]) + self.assertIsNotNone(results[6]) + + +class TestConv2dTransposeOpException(OpTest): + def test_exception(self): + data = fluid.layers.data(name='data', shape=[3, 5, 5], dtype="float32") + + def attr_data_format(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + data_format="NCDHW") + + self.assertRaises(ValueError, attr_data_format) + + def attr_padding_str(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding='Vald') + + self.assertRaises(ValueError, attr_padding_str) + + def attr_padding_list(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [1, 1], [0, 0], [0, 0]]) + + self.assertRaises(ValueError, attr_padding_list) + + def attr_padding_with_data_format(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [0, 0], [0, 0], [1, 1]], + data_format='NHWC') + + self.assertRaises(ValueError, attr_padding_with_data_format) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py index 8d9075961cbec32bc34fcf0c92cfbb7e6c00d886..f90ca27c09e875ac08bf68d244df5a15773ef3f9 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py @@ -18,10 +18,19 @@ import unittest import numpy as np import paddle.fluid.core as core +import paddle.fluid as fluid from op_test import OpTest def conv3dtranspose_forward_naive(input_, filter_, attrs): + padding_algorithm = attrs['padding_algorithm'] + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError("Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % + str(padding_algorithm)) + + if attrs['data_format'] == 'NHWC': + input_ = np.transpose(input_, [0, 4, 1, 2, 3]) in_n, in_c, in_d, in_h, in_w = input_.shape f_c, f_out_c, f_d, f_h, f_w = filter_.shape groups = attrs['groups'] @@ -32,6 +41,39 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs): stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ 'dilations'] + def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride): + padding = [] + for input_size, filter_size, stride_size in zip( + input_shape, kernel_size, kernel_stride): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max(( + (out_size - 1) * stride_size + filter_size - input_size, 0)) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + ksize = filter_.shape[2:5] + if padding_algorithm == "VALID": + pad = [0, 0, 0, 0, 0, 0] + elif padding_algorithm == "SAME": + dilation = [1, 1, 1] + input_data_shape = [] + if attrs['data_format'] == "NCHW": + input_data_shape = input_.shape[2:5] + elif attrs['data_format'] == "NHWC": + input_data_shape = input_.shape[1:4] + pad = _get_padding_with_SAME(input_data_shape, ksize, stride) + + pad_d_0, pad_d_1 = pad[0], pad[0] + pad_h_0, pad_h_1 = pad[1], pad[1] + pad_w_0, pad_w_1 = pad[2], pad[2] + if len(pad) == 6: + pad_d_0, pad_d_1 = pad[0], pad[1] + pad_h_0, pad_h_1 = pad[2], pad[3] + pad_w_0, pad_w_1 = pad[4], pad[5] + d_bolck_d = dilations[0] * (f_d - 1) + 1 d_bolck_h = dilations[1] * (f_h - 1) + 1 d_bolck_w = dilations[2] * (f_w - 1) + 1 @@ -62,8 +104,10 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs): out[n, g * f_out_c + k, d1:d2:dilations[0], i1:i2: dilations[1], j1:j2:dilations[2]] += tmp_out - out = out[:, :, pad[0]:out_d - pad[0], pad[1]:out_h - pad[1], pad[2]:out_w - - pad[2]] + out = out[:, :, pad_d_0:out_d - pad_d_1, pad_h_0:out_h - pad_h_1, pad_w_0: + out_w - pad_w_1] + if attrs['data_format'] == 'NHWC': + out = np.transpose(out, [0, 2, 3, 4, 1]) return out @@ -71,6 +115,9 @@ class TestConv3dTransposeOp(OpTest): def setUp(self): # init as conv transpose self.use_cudnn = False + self.data_format = 'NCHW' + self.pad = [0, 0, 0] + self.padding_algorithm = "EXPLICIT" self.init_op_type() self.init_test_case() @@ -81,10 +128,11 @@ class TestConv3dTransposeOp(OpTest): self.attrs = { 'strides': self.stride, 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, 'dilations': self.dilations, 'groups': self.groups, 'use_cudnn': self.use_cudnn, - 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter + 'data_format': self.data_format } output = conv3dtranspose_forward_naive(input_, filter_, @@ -154,7 +202,7 @@ class TestConv3dTransposeOp(OpTest): self.op_type = "conv3d_transpose" -class TestWithPad(TestConv3dTransposeOp): +class TestWithSymmetricPad(TestConv3dTransposeOp): def init_test_case(self): self.pad = [1, 1, 1] self.stride = [1, 1, 1] @@ -165,6 +213,39 @@ class TestWithPad(TestConv3dTransposeOp): self.filter_size = [f_c, 6, 3, 3, 3] +class TestWithAsymmetricPad(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 0, 1, 2] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + + +class TestWithSAMEPad(TestConv3dTransposeOp): + def init_test_case(self): + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.padding_algorithm = 'SAME' + + +class TestWithVALIDPad(TestConv3dTransposeOp): + def init_test_case(self): + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.padding_algorithm = 'VALID' + + class TestWithGroups(TestConv3dTransposeOp): def init_test_case(self): self.pad = [1, 1, 1] @@ -198,6 +279,78 @@ class TestWithDilation(TestConv3dTransposeOp): self.filter_size = [f_c, 6, 3, 3, 3] +class Test_NHWC(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NDHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithSymmetricPad_NHWC(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NDHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithAsymmetricPad_NHWC(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 0, 1, 2] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NDHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithGroups_NHWC(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 5, 4] # NDHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithStride_NHWC(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [2, 2, 2] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NCDHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithDilation_NHWC(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [2, 2, 2] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NCDHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + # ------------ test_cudnn ------------ @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") @@ -209,7 +362,7 @@ class TestCUDNN(TestConv3dTransposeOp): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestCUDNNWithPad(TestWithPad): +class TestCUDNNWithSymmetricPad(TestWithSymmetricPad): def init_test_case(self): self.pad = [1, 1, 1] self.stride = [1, 1, 1] @@ -224,6 +377,57 @@ class TestCUDNNWithPad(TestWithPad): self.op_type = "conv3d_transpose" +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithAsymmetricPad(TestWithAsymmetricPad): + def init_test_case(self): + self.pad = [1, 1, 1, 0, 0, 2] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 3, 4, 4, 4] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithSAMEPad(TestWithSAMEPad): + def init_test_case(self): + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.padding_algorithm = 'SAME' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithVALIDPad(TestWithVALIDPad): + def init_test_case(self): + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.padding_algorithm = 'VALID' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestCUDNNWithStride(TestWithStride): @@ -272,5 +476,222 @@ class TestCUDNNWithGroups(TestWithGroups): # def init_op_type(self): # self.op_type = "conv3d_transpose" + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNN_NHWC(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NDHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithSymmetricPad_NHWC(TestWithSymmetricPad): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NDHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithAsymmetricPad_NHWC(TestWithAsymmetricPad): + def init_test_case(self): + self.pad = [1, 0, 1, 0, 0, 2] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NDHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithStride_NHWC(TestWithStride): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [2, 2, 2] + self.dilations = [1, 1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 5, 3] # NCDHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNWithGroups_NHWC(TestWithGroups): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 5, 4] # NCHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3, 3] + self.data_format = 'NHWC' + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + +class TestConv3dTransposeAPI(OpTest): + def test_case1(self): + data1 = fluid.layers.data( + name='data1', shape=[3, 5, 5, 5], dtype='float32') + data2 = fluid.layers.data( + name='data2', shape=[5, 5, 5, 3], dtype='float32') + + out1 = fluid.layers.conv3d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + data_format='NCDHW') + out2 = fluid.layers.conv3d_transpose( + input=data2, + groups=1, + num_filters=6, + filter_size=3, + data_format='NDHWC') + out3 = fluid.layers.conv3d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + padding=[[0, 0], [0, 0], [1, 1], [0, 0], [1, 1]], + data_format='NCDHW') + out4 = fluid.layers.conv3d_transpose( + input=data2, + groups=3, + num_filters=6, + filter_size=3, + padding=[[0, 0], [0, 0], [1, 1], [1, 2], [0, 0]], + data_format='NDHWC') + out5 = fluid.layers.conv3d_transpose( + input=data2, + groups=1, + num_filters=6, + filter_size=3, + padding='SAME', + data_format='NCDHW') + out6 = fluid.layers.conv3d_transpose( + input=data2, + groups=1, + num_filters=6, + filter_size=3, + padding='VALID', + data_format='NDHWC') + out7 = fluid.layers.conv3d_transpose( + input=data2, + groups=1, + num_filters=6, + output_size=[7, 7, 7], + padding=[0, 0, 0], + data_format='NDHWC') + + data1_np = np.random.random((2, 3, 5, 5, 5)).astype("float32") + data2_np = np.random.random((2, 5, 5, 5, 3)).astype("float32") + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run( + fluid.default_main_program(), + feed={"data1": data1_np, + "data2": data2_np}, + fetch_list=[out1, out2, out3, out4, out5, out6, out7], + return_numpy=True) + self.assertIsNotNone(results[0]) + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[2]) + self.assertIsNotNone(results[3]) + self.assertIsNotNone(results[4]) + self.assertIsNotNone(results[5]) + self.assertIsNotNone(results[6]) + + +class TestConv3dTransposeOpException(OpTest): + def test_exception(self): + data = fluid.layers.data( + name='data', shape=[3, 5, 5, 5], dtype="float32") + + def attr_data_format(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + data_format="NCDW") + + self.assertRaises(ValueError, attr_data_format) + + def attr_padding_str(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding='Vald') + + self.assertRaises(ValueError, attr_padding_str) + + def attr_padding_list(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [1, 1], [0, 0], [0, 0], [1, 1]]) + + self.assertRaises(ValueError, attr_padding_list) + + def attr_padding_with_data_format(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [0, 0], [0, 0], [1, 0], [1, 1]], + data_format='NDHWC') + + self.assertRaises(ValueError, attr_padding_with_data_format) + + if __name__ == '__main__': unittest.main()