From 3aa331d97e470d6b5c4669362cc24fec25d1dd66 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sun, 29 Sep 2019 19:39:14 +0800 Subject: [PATCH] fix conv2d and conv3d: (#20042) 1.support asymmetric padding; 2.support padding algorithm:"SAME" and "VALID"; 3.support channel_last: data_format NHWC and NDHWC; 4.change doc of python API and c++; test=develop, test=document_preview --- cmake/operators.cmake | 14 +- paddle/fluid/API.spec | 4 +- paddle/fluid/operators/conv_cudnn_op.cu | 1063 +++++++++++++++++ paddle/fluid/operators/conv_cudnn_op.cu.cc | 523 -------- paddle/fluid/operators/conv_op.cc | 141 ++- paddle/fluid/operators/conv_op.h | 607 ++++++++-- paddle/fluid/operators/math/im2col.cc | 26 +- paddle/fluid/operators/math/vol2col.cc | 55 +- paddle/fluid/operators/math/vol2col.cu | 49 +- python/paddle/fluid/layers/nn.py | 200 +++- .../fluid/tests/unittests/test_conv2d_op.py | 974 ++++++++++++++- .../fluid/tests/unittests/test_conv3d_op.py | 626 +++++++++- .../tests/unittests/test_conv_nn_grad.py | 358 +++++- 13 files changed, 3801 insertions(+), 839 deletions(-) create mode 100644 paddle/fluid/operators/conv_cudnn_op.cu delete mode 100644 paddle/fluid/operators/conv_cudnn_op.cu.cc diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 28e880fb51e..0f675b68d27 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -9,6 +9,7 @@ function(op_library TARGET) set(miopen_hip_cc_srcs) set(cu_cc_srcs) set(cudnn_cu_cc_srcs) + set(cudnn_cu_srcs) set(CUDNN_FILE) set(mkldnn_cc_srcs) set(MKLDNN_FILE) @@ -44,6 +45,9 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu) + list(APPEND cudnn_cu_srcs ${CUDNN_FILE}.cu) + endif() if(WITH_AMD_GPU) string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}") if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc) @@ -60,6 +64,8 @@ function(op_library TARGET) foreach(src ${op_library_SRCS}) if (${src} MATCHES ".*\\.hip.cu$") list(APPEND hip_cu_srcs ${src}) + elseif(${src} MATCHES ".*_cudnn_op.cu$") + list(APPEND cudnn_cu_srcs ${src}) elseif (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") @@ -97,7 +103,7 @@ function(op_library TARGET) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) endif() if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cudnn_cu_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) elseif (WITH_AMD_GPU) hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} @@ -160,6 +166,12 @@ function(op_library TARGET) endif() endif() + # pybind USE_OP_DEVICE_KERNEL for CUDNN + list(LENGTH cudnn_cu_srcs cudnn_cu_srcs_len) + if (WITH_GPU AND ${cudnn_cu_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") + endif() + # pybind USE_OP_DEVICE_KERNEL for MIOPEN if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n") diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 1b8ad725000..171df9cff69 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -140,8 +140,8 @@ paddle.fluid.layers.bpr_loss (ArgSpec(args=['input', 'label', 'name'], varargs=N paddle.fluid.layers.square_error_cost (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'bbb9e708bab250359864fefbdf48e9d9')) paddle.fluid.layers.chunk_eval (ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types', 'seq_length'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'b02844e0ad4bd713c5fe6802aa13219c')) paddle.fluid.layers.sequence_conv (ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'padding_start', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, True, None, None, None, None, None)), ('document', '2bf23e7884c380c3b27f2709aa322cb9')) -paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '06de9adb5994f6f8cb806c75b55550af')) -paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '71b09227709475fa178c1739dff64af6')) +paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None, 'NCHW')), ('document', 'b8da17862ba02b5297a37d2edd571d76')) +paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None, 'NCDHW')), ('document', '73a15322d460ef9aa90d4d237b0bc5d5')) paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test', 'pad_value'], varargs=None, keywords=None, defaults=(False, 0.0)), ('document', 'e90a93251c52dc4e6fb34fb3991b3f82')) paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'eaa9d0bbd3d4e017c8bc4ecdac483711')) paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', 'cee673c79e3ff4582656a24e04f841e5')) diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu new file mode 100644 index 00000000000..b38461f8cc9 --- /dev/null +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -0,0 +1,1063 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the spopecific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/conv_cudnn_helper.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" +#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/cudnn_workspace_helper.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/profiler.h" + +DECLARE_bool(cudnn_deterministic); +DECLARE_uint64(conv_workspace_size_limit); +DECLARE_bool(cudnn_exhaustive_search); + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using DataLayout = platform::DataLayout; +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; +using framework::AlgorithmsCache; + +static inline void GetNCDHW(const framework::DDim& dims, + const DataLayout& layout, int* N, int* C, int* D, + int* H, int* W) { + *N = dims[0]; + *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; + int i = layout == DataLayout::kNCHW ? 0 : 1; + if (dims.size() == 5) { + *D = dims[2 - i]; + *H = dims[3 - i]; + *W = dims[4 - i]; + } else { + *D = 1; + *H = dims[2 - i]; + *W = dims[3 - i]; + } +} + +static inline bool IsSymmetricPadding(const std::vector& paddings, + const int data_dim) { + bool is_sys_pad = true; + if (paddings.size() == data_dim * 2) { + for (size_t i = 0; i < data_dim; ++i) { + if (paddings[2 * i] != paddings[2 * i + 1]) { + is_sys_pad = false; + return is_sys_pad; + } + } + } + return is_sys_pad; +} + +template +using EigenTensor = framework::EigenTensor; + +template +static void PadFunction(const framework::ExecutionContext& context, + const std::vector& pads, + const framework::Tensor& src, T pad_value, + framework::Tensor* out) { + Eigen::array, D> paddings; + + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = pads[i * 2]; + paddings[i].second = pads[i * 2 + 1]; + } + + auto src_tensor = EigenTensor::From(src); + auto out_tensor = EigenTensor::From(*out); + + auto& place = + *context.template device_context().eigen_device(); + out_tensor.device(place) = src_tensor.pad(paddings, pad_value); +} + +template +static void Slice_2(const framework::ExecutionContext& context, + const Tensor* input, Tensor* out, + const std::vector& starts, + const std::vector& axes) { + auto& place = + *context.template device_context().eigen_device(); + auto in_dims = input->dims(); + auto new_out_dims = out->dims(); + auto offsets = Eigen::array(); + auto extents = Eigen::array(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = new_out_dims[i]; + } + + int start; + for (size_t i = 0; i < axes.size(); ++i) { + start = starts[i]; + if (start < 0) { + start = (start + in_dims[axes[i]]); + } + start = std::max(start, 0); + offsets[axes[i]] = start; + } + auto in_t = + framework::EigenTensor::From( + *input); + + auto out_t = + framework::EigenTensor::From( + *out, new_out_dims); + out_t.device(place) = in_t.slice(offsets, extents); +} + +template +class CUDNNConvOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "It must use CUDAPlace."); + const Tensor* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + const std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + + if (exhaustive_search && FLAGS_cudnn_deterministic) { + PADDLE_THROW( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time."); + } + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + // ------------ transformed tensor ----------- + Tensor transformed_input_channel(input->type()); + Tensor transformed_output(output->type()); + T* output_data = nullptr; + if (channel_last) { + ResizeToChannelFirst( + ctx, input, &transformed_input_channel); + TransToChannelFirst( + ctx, input, &transformed_input_channel); + + ResizeToChannelFirst(ctx, output, + &transformed_output); + + } else { + transformed_input_channel = *input; + transformed_output = *output; + } + output_data = transformed_output.data(); + + // update padding and dilation + auto in_dims = transformed_input_channel.dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = IsSymmetricPadding(paddings, data_dim); + + Tensor transformed_input; + std::vector padding_common(data_dim, 0); + if (!is_sys_pad) { + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + new_input_shape_vec[0] = transformed_input_channel.dims()[0]; + new_input_shape_vec[1] = transformed_input_channel.dims()[1]; + + std::vector input_pad(transformed_input_channel.dims().size() * 2, + 0); + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + new_input_shape_vec[i + 2] = + transformed_input_channel.dims()[i + 2] + padding_diff[i]; + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + framework::DDim new_input_shape( + framework::make_ddim(new_input_shape_vec)); + transformed_input.Resize(new_input_shape); + auto& dev_ctx = + ctx.template device_context(); + + transformed_input = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + const int rank = transformed_input_channel.dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + PadFunction( + ctx, input_pad, transformed_input_channel, pad_value, + &transformed_input); + } break; + case 5: { + PadFunction( + ctx, input_pad, transformed_input_channel, pad_value, + &transformed_input); + } break; + default: + PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); + } + + } else { + transformed_input = transformed_input_channel; + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + const T* input_data = transformed_input.data(); + const T* filter_data = filter->data(); + + // ------------------- cudnn descriptors --------------------- + ConvArgs args{&transformed_input, filter, &transformed_output, strides, + padding_common, dilations}; + + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto dtype = platform::CudnnDataType::type; + DataLayout layout = DataLayout::kNCHW; + if (transformed_input_channel.dims().size() == 5) { + layout = DataLayout::kNCDHW; + } + auto layout_format = GetCudnnTensorFormat(layout); + + args.handle = handle; + args.cdesc.set(dtype, padding_common, strides, dilations); + +#if CUDNN_VERSION_MIN(7, 0, 1) + // cudnn 7 can support groups, no need to do it manually + // FIXME(typhoonzero): find a better way to disable groups + // rather than setting it to 1. + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( + args.cdesc.desc(), groups)); + groups = 1; +#endif + args.idesc.set(transformed_input, groups); + + args.wdesc.set(*filter, layout_format, groups); + args.odesc.set(transformed_output, groups); + int i_n, i_c, i_d, i_h, i_w; + + GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, + &i_h, &i_w); + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(transformed_output.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, + &o_h, &o_w); + + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; + int group_offset_filter = filter->numel() / groups; + // ------------------- cudnn conv workspace --------------------- + size_t workspace_size = 0; // final workspace to allocate. + // ------------------- cudnn conv algorithm --------------------- + cudnnConvolutionFwdAlgo_t algo{}; + + using search = SearchAlgorithm; + algo = search::Find(args, exhaustive_search, false, 0, ctx); + workspace_size = search::GetWorkspaceSize(args, algo); + + // ------------------- cudnn conv forward --------------------- + ScalingParamType alpha = 1.0f, beta = 0.0f; + for (int i = 0; i < groups; i++) { + workspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, args.idesc.desc(), + input_data + i * group_offset_in, args.wdesc.desc(), + filter_data + i * group_offset_filter, args.cdesc.desc(), algo, + workspace_ptr, workspace_size, &beta, args.odesc.desc(), + output_data + i * group_offset_out)); + }, + workspace_size); + } + + if (channel_last) { + TransToChannelLast( + ctx, &transformed_output, output); + } + } +}; + +template +class CUDNNConvGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "It must use CUDAPlace."); + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + + const T* filter_data = filter->data(); + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + } + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + } + + std::vector dilations = ctx.Attr>("dilations"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + int groups = ctx.Attr("groups"); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + bool deterministic = FLAGS_cudnn_deterministic; + if (exhaustive_search && deterministic) { + PADDLE_THROW( + "Can't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time."); + } + const std::string data_format = ctx.Attr("data_format"); + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + // transform Tensor + Tensor transformed_input_channel(input->type()); + Tensor transformed_output_grad_channel(output_grad->type()); + Tensor transformed_input_grad_channel(input->type()); + + if (channel_last) { + ResizeToChannelFirst( + ctx, input, &transformed_input_channel); + TransToChannelFirst( + ctx, input, &transformed_input_channel); + + ResizeToChannelFirst( + ctx, output_grad, &transformed_output_grad_channel); + TransToChannelFirst( + ctx, output_grad, &transformed_output_grad_channel); + + if (input_grad) { + ResizeToChannelFirst( + ctx, input_grad, &transformed_input_grad_channel); + } + + } else { + transformed_input_channel = *input; + transformed_output_grad_channel = *output_grad; + if (input_grad) { + transformed_input_grad_channel.ShareDataWith(*input_grad); + } + } + + // update paddings + auto in_dims = transformed_input_channel.dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + // cuDNN only supports padding the same amount on every dimension. + // So we create a new padded input tensor. + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = IsSymmetricPadding(paddings, data_dim); + Tensor transformed_input(input->type()); + Tensor transformed_input_grad(input->type()); + std::vector padding_common(data_dim, 0); + std::vector input_pad(transformed_input_channel.dims().size() * 2, 0); + + if (!is_sys_pad) { + // get pad + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + new_input_shape_vec[0] = transformed_input_channel.dims()[0]; + new_input_shape_vec[1] = transformed_input_channel.dims()[1]; + + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + new_input_shape_vec[i + 2] = + transformed_input_channel.dims()[i + 2] + padding_diff[i]; + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + framework::DDim new_input_shape( + framework::make_ddim(new_input_shape_vec)); + transformed_input.Resize(new_input_shape); + + transformed_input_grad.Resize(new_input_shape); + auto& dev_ctx = + ctx.template device_context(); + + transformed_input = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + if (input_grad) { + transformed_input_grad = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + } + // pad for input + const int rank = transformed_input_channel.dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + PadFunction( + ctx, input_pad, transformed_input_channel, pad_value, + &transformed_input); + } break; + case 5: { + PadFunction( + ctx, input_pad, transformed_input_channel, pad_value, + &transformed_input); + } break; + default: + PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); + } + } else { + transformed_input.ShareDataWith(transformed_input_channel); + if (input_grad) { + transformed_input_grad.ShareDataWith(transformed_input_grad_channel); + } + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + const T* input_data = transformed_input.data(); + const T* output_grad_data = transformed_output_grad_channel.data(); + T* filter_grad_data = nullptr; + T* input_grad_data = nullptr; + T* transformed_input_grad_data = nullptr; + + ConvArgs args1{&transformed_input_grad, + filter, + &transformed_output_grad_channel, + strides, + padding_common, + dilations}; + ConvArgs args2{&transformed_input, + filter_grad, + &transformed_output_grad_channel, + strides, + padding_common, + dilations}; + + auto handle = dev_ctx.cudnn_handle(); + auto dtype = platform::CudnnDataType::type; + DataLayout layout = DataLayout::kNCHW; + if (input->dims().size() == 5) { + layout = DataLayout::kNCDHW; + } + auto layout_tensor = GetCudnnTensorFormat(layout); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + int i_n, i_c, i_d, i_h, i_w; + GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, + &i_h, &i_w); + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(transformed_output_grad_channel.dims(), DataLayout::kNCHW, &o_n, + &o_c, &o_d, &o_h, &o_w); + + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; + int group_offset_filter = filter->numel() / groups; + // ------------------- cudnn backward algorithm --------------------- + cudnnConvolutionBwdDataAlgo_t data_algo = + static_cast(0); + cudnnConvolutionBwdFilterAlgo_t filter_algo = + static_cast(0); + size_t workspace_size = 0; + int iwo_groups, c_groups; + +#if CUDNN_VERSION_MIN(7, 0, 1) + iwo_groups = 1; + c_groups = groups; + groups = 1; +#endif + + if (input_grad) { + // ------------------- cudnn descriptors --------------------- + input_grad_data = input_grad->data(); + transformed_input_grad_data = transformed_input_grad.data(); + args1.handle = handle; + args1.idesc.set(transformed_input_grad, iwo_groups); + args1.wdesc.set(*filter, layout_tensor, iwo_groups); + args1.odesc.set(transformed_output_grad_channel, iwo_groups); + args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups); + + using search1 = SearchAlgorithm; + data_algo = + search1::Find(args1, exhaustive_search, deterministic, 0, ctx); + workspace_size = + std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); + } + + if (filter_grad) { + // ------------------- cudnn descriptors --------------------- + filter_grad_data = filter_grad->data(); + args2.handle = handle; + args2.idesc.set(transformed_input, iwo_groups); + args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups); + args2.odesc.set(transformed_output_grad_channel, iwo_groups); + args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups); + + using search2 = SearchAlgorithm; + filter_algo = + search2::Find(args2, exhaustive_search, deterministic, 1, ctx); + workspace_size = std::max(workspace_size, + search2::GetWorkspaceSize(args2, filter_algo)); + } + + // ------------------- cudnn conv backward data --------------------- + ScalingParamType alpha = 1.0f, beta = 0.0f; + if (input_grad) { + // Because beta is zero, it is unnecessary to reset input_grad. + for (int i = 0; i < groups; i++) { + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, args1.wdesc.desc(), + filter_data + i * group_offset_filter, args1.odesc.desc(), + output_grad_data + i * group_offset_out, args1.cdesc.desc(), + data_algo, cudnn_workspace_ptr, workspace_size, &beta, + args1.idesc.desc(), + transformed_input_grad_data + i * group_offset_in)); + }, + workspace_size); + } + + std::vector starts(transformed_input_channel.dims().size(), 0); + std::vector axes(transformed_input_channel.dims().size(), 0); + + for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) { + starts[i] = input_pad[2 * i]; + axes[i] = i; + } + + transformed_input_grad_channel.mutable_data(ctx.GetPlace()); + if (transformed_input_channel.dims().size() == 4) { + Slice_2( + ctx, &transformed_input_grad, &transformed_input_grad_channel, + starts, axes); + } else { + Slice_2( + ctx, &transformed_input_grad, &transformed_input_grad_channel, + starts, axes); + } + + if (channel_last) { + TransToChannelLast( + ctx, &transformed_input_grad_channel, input_grad); + } + } + // ------------------- cudnn conv backward filter --------------------- + if (filter_grad) { + // Because beta is zero, it is unnecessary to reset filter_grad. + for (int i = 0; i < groups; i++) { + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, args2.idesc.desc(), + input_data + i * group_offset_in, args2.odesc.desc(), + output_grad_data + i * group_offset_out, args2.cdesc.desc(), + filter_algo, cudnn_workspace_ptr, workspace_size, &beta, + args2.wdesc.desc(), + filter_grad_data + i * group_offset_filter)); + }, + workspace_size); + } + } + } +}; + +/* + * Inputs: I, W, dO, ddI, ddW + * Outputs: ddO, dW, dI + * ddo = conv(ddI, W) + conv(I, ddW) + * dW = conv_bp_filter(ddI, dO) + * dI = conv_bp_data(ddW, dO) + */ +template +class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + "It must use CUDAPlace."); + auto X = ctx.Input("Input"); + auto W = ctx.Input("Filter"); + auto dO = ctx.Input("DOutput"); + auto ddX = ctx.Input("DDInput"); + auto ddW = ctx.Input("DDFilter"); + + auto ddO = ctx.Output("DDOutput"); + auto dW = ctx.Output("DFilter"); + auto dX = ctx.Output("DInput"); + if (ddO) { + ddO->mutable_data(ctx.GetPlace()); + } + if (dW) { + dW->mutable_data(ctx.GetPlace()); + } + if (dX) { + dX->mutable_data(ctx.GetPlace()); + } + + // const T* x = X->data(); + const T* dy = dO->data(); + const T* w = W->data(); + + const T* ddx = nullptr; + const T* ddw = nullptr; + T *dw, *dx, *ddy; + dw = dx = ddy = nullptr; + T* transformed_dx = nullptr; + const std::vector& strides = ctx.Attr>("strides"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + bool deterministic = FLAGS_cudnn_deterministic; + if (exhaustive_search && deterministic) { + PADDLE_THROW( + "Can't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time."); + } + std::vector paddings = ctx.Attr>("paddings"); + + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + // transform Tensors to channel first----------- + Tensor transformed_X_channel(X->type()); + Tensor transformed_dO_channel(dO->type()); + Tensor transformed_ddX_channel(ddX->type()); + + Tensor transformed_ddO_channel(dO->type()); + Tensor transformed_dX_channel(X->type()); + + if (channel_last) { + ResizeToChannelFirst( + ctx, X, &transformed_X_channel); + TransToChannelFirst( + ctx, X, &transformed_X_channel); + + ResizeToChannelFirst( + ctx, dO, &transformed_dO_channel); + TransToChannelFirst( + ctx, dO, &transformed_dO_channel); + + ResizeToChannelFirst( + ctx, ddX, &transformed_ddX_channel); + TransToChannelFirst( + ctx, ddX, &transformed_ddX_channel); + + if (ddO) { + ResizeToChannelFirst( + ctx, ddO, &transformed_ddO_channel); + } + if (dX) { + ResizeToChannelFirst( + ctx, dX, &transformed_dX_channel); + transformed_dX_channel.mutable_data(ctx.GetPlace()); + } + + } else { + transformed_X_channel = *X; + transformed_dO_channel = *dO; + transformed_ddX_channel = *ddX; + if (ddO) { + transformed_ddO_channel.ShareDataWith(*ddO); + } + if (dX) { + transformed_dX_channel.ShareDataWith(*dX); + } + } + + auto in_dims = transformed_X_channel.dims(); + auto filter_dims = W->dims(); + framework::DDim in_data_dims = + framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = IsSymmetricPadding(paddings, data_dim); + Tensor transformed_X(X->type()); + Tensor transformed_ddX(X->type()); + + Tensor transformed_dX(X->type()); + + std::vector padding_common(data_dim, 0); + std::vector input_pad(X->dims().size() * 2, 0); + + if (!is_sys_pad) { + // get pad + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + new_input_shape_vec[0] = transformed_X_channel.dims()[0]; + new_input_shape_vec[1] = transformed_X_channel.dims()[1]; + + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + new_input_shape_vec[i + 2] = + transformed_X_channel.dims()[i + 2] + padding_diff[i]; + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; + } + framework::DDim new_input_shape( + framework::make_ddim(new_input_shape_vec)); + transformed_X.Resize(new_input_shape); + transformed_ddX.Resize(new_input_shape); + transformed_dX.Resize(new_input_shape); + auto& dev_ctx = + ctx.template device_context(); + + transformed_X = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + transformed_ddX = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + if (dX) { + transformed_dX = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + } + + // pad for input + const int rank = X->dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + PadFunction( + ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + PadFunction( + ctx, input_pad, transformed_ddX_channel, pad_value, + &transformed_ddX); + } break; + case 5: { + PadFunction( + ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + PadFunction( + ctx, input_pad, transformed_ddX_channel, pad_value, + &transformed_ddX); + } break; + default: + PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); + } + + } else { + transformed_X.ShareDataWith(transformed_X_channel); + transformed_ddX.ShareDataWith(transformed_ddX_channel); + if (dX) { + transformed_dX.ShareDataWith(transformed_dX_channel); + } + + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + const T* x = transformed_X.data(); + + int iwo_group = groups; + int c_group = 1; +#if CUDNN_VERSION_MIN(7, 0, 1) + iwo_group = 1; + c_group = groups; +#endif + auto dtype = platform::CudnnDataType::type; + + auto handle = dev_ctx.cudnn_handle(); + + ConvArgs args1{&transformed_ddX, W, + &transformed_ddO_channel, strides, + padding_common, dilations}; + ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides, + padding_common, dilations}; + ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides, + padding_common, dilations}; + ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides, + padding_common, dilations}; + + cudnnConvolutionFwdAlgo_t fwd_algo1 = + static_cast(0); + cudnnConvolutionFwdAlgo_t fwd_algo2 = + static_cast(0); + cudnnConvolutionBwdDataAlgo_t data_algo = + static_cast(0); + cudnnConvolutionBwdFilterAlgo_t filter_algo = + static_cast(0); + + auto layout = GetCudnnTensorFormat(DataLayout::kNCHW); + + // ddo = conv(ddI, W) + conv(I, ddW) + size_t workspace_size = 0; + + T* transformed_ddy_channel = nullptr; + if (ddO) { + ddy = ddO->data(); + transformed_ddy_channel = transformed_ddO_channel.data(); + if (ddX) { + args1.handle = handle; + args1.idesc.set(transformed_ddX, iwo_group); + args1.wdesc.set(*W, layout, iwo_group); + args1.odesc.set(transformed_ddO_channel, iwo_group); + args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); + + using search1 = SearchAlgorithm; + fwd_algo1 = search1::Find(args1, exhaustive_search, false, 0, ctx); + workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); + } + + if (ddW) { + ddw = ddW->data(); + args2.handle = handle; + args2.idesc.set(transformed_X, iwo_group); + + args2.wdesc.set(*ddW, layout, iwo_group); + + args2.odesc.set(transformed_ddO_channel, iwo_group); + args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); + + using search2 = SearchAlgorithm; + fwd_algo2 = search2::Find(args2, exhaustive_search, false, 0, ctx); + workspace_size = std::max(workspace_size, + search2::GetWorkspaceSize(args2, fwd_algo2)); + } + } + + if (dW && ddX) { + dw = dW->data(); + args3.handle = handle; + args3.idesc.set(transformed_ddX, iwo_group); + args3.wdesc.set(*dW, layout, iwo_group); + + args3.odesc.set(transformed_dO_channel, iwo_group); + + args3.cdesc.set(dtype, padding_common, strides, dilations, c_group); + + using search3 = SearchAlgorithm; + filter_algo = + search3::Find(args3, exhaustive_search, deterministic, 1, ctx); + workspace_size = std::max(workspace_size, + search3::GetWorkspaceSize(args3, filter_algo)); + } + + if (ddW && dX) { + transformed_dx = transformed_dX.data(); + + args4.handle = handle; + args4.idesc.set(transformed_dX, iwo_group); + args4.wdesc.set(*ddW, layout, iwo_group); + args4.odesc.set(transformed_dO_channel, iwo_group); + args4.cdesc.set(dtype, padding_common, strides, dilations, c_group); + + using search4 = SearchAlgorithm; + data_algo = + search4::Find(args4, exhaustive_search, deterministic, 2, ctx); + workspace_size = + std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); + } + + int i_n, i_c, i_d, i_h, i_w; + GetNCDHW(transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, + &i_w); + + int o_n, o_c, o_d, o_h, o_w; + GetNCDHW(transformed_dO_channel.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, + &o_h, &o_w); + + int group_offset_in = i_c / groups * i_h * i_w * i_d; + int group_offset_out = o_c / groups * o_h * o_w * o_d; + int group_offset_filter = W->numel() / groups; + + ScalingParamType alpha = 1.0f, beta = 0.0f; + auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); + + if (ddO) { + if (ddX) { + ddx = transformed_ddX.data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, args1.idesc.desc(), + ddx + i * group_offset_in, args1.wdesc.desc(), + w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1, + workspace_ptr, workspace_size, &beta, args1.odesc.desc(), + transformed_ddy_channel + i * group_offset_out)); + }, + workspace_size); + } + } + if (ddW) { + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, args2.idesc.desc(), x + i * group_offset_in, + args2.wdesc.desc(), ddw + i * group_offset_filter, + args2.cdesc.desc(), fwd_algo2, workspace_ptr, + workspace_size, &alpha, args2.odesc.desc(), + transformed_ddy_channel + i * group_offset_out)); + }, + workspace_size); + } + } + if (channel_last) { + TransToChannelLast( + ctx, &transformed_ddO_channel, ddO); + } + } + T* transformed_dy_channel = nullptr; + if (dW && ddX) { + ddx = transformed_ddX.data(); + transformed_dy_channel = transformed_dO_channel.data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in, + args3.odesc.desc(), + transformed_dy_channel + i * group_offset_out, + args3.cdesc.desc(), filter_algo, workspace_ptr, + workspace_size, &beta, args3.wdesc.desc(), + dw + i * group_offset_filter)); + }, + workspace_size); + } + } + + if (dX && ddW) { + ddw = ddW->data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, args4.wdesc.desc(), + ddw + i * group_offset_filter, args4.odesc.desc(), + transformed_dy_channel + i * group_offset_out, + args4.cdesc.desc(), data_algo, workspace_ptr, workspace_size, + &beta, args4.idesc.desc(), + transformed_dx + i * group_offset_in)); + }, + workspace_size); + } + + // reverse padded input + std::vector starts(X->dims().size(), 0); + std::vector axes(X->dims().size(), 0); + + for (size_t i = 0; i < X->dims().size(); ++i) { + starts[i] = input_pad[2 * i]; + axes[i] = i; + } + if (X->dims().size() == 4) { + Slice_2( + ctx, &transformed_dX, &transformed_dX_channel, starts, axes); + } else { + Slice_2( + ctx, &transformed_dX, &transformed_dX_channel, starts, axes); + } + if (channel_last) { + TransToChannelLast( + ctx, &transformed_dX_channel, dX); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvOpKernel, + paddle::operators::CUDNNConvOpKernel, + paddle::operators::CUDNNConvOpKernel); +REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); +REGISTER_OP_KERNEL( + conv2d_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel); + +REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvOpKernel, + paddle::operators::CUDNNConvOpKernel, + paddle::operators::CUDNNConvOpKernel); +REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); +REGISTER_OP_KERNEL( + conv3d_grad_grad, CUDNN, plat::CUDAPlace, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel, + paddle::operators::CUDNNConvDoubleGradOpKernel); diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc deleted file mode 100644 index f82d9f6d2b9..00000000000 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ /dev/null @@ -1,523 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/conv_cudnn_helper.h" -#include "paddle/fluid/operators/conv_cudnn_op_cache.h" -#include "paddle/fluid/operators/conv_op.h" -#include "paddle/fluid/platform/cudnn_helper.h" -#include "paddle/fluid/platform/cudnn_workspace_helper.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/profiler.h" - -DECLARE_bool(cudnn_deterministic); -DECLARE_uint64(conv_workspace_size_limit); -DECLARE_bool(cudnn_exhaustive_search); - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; -using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; -using DataLayout = platform::DataLayout; -template -using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; -using framework::AlgorithmsCache; - -static inline void GetNCDHW(const framework::DDim& dims, - const DataLayout& layout, int* N, int* C, int* D, - int* H, int* W) { - *N = dims[0]; - *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; - int i = layout == DataLayout::kNCHW ? 0 : 1; - if (dims.size() == 5) { - *D = dims[2 - i]; - *H = dims[3 - i]; - *W = dims[4 - i]; - } else { - *D = 1; - *H = dims[2 - i]; - *W = dims[3 - i]; - } -} - -template -class CUDNNConvOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto* input = ctx.Input("Input"); - auto* filter = ctx.Input("Filter"); - auto* output = ctx.Output("Output"); - - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - std::vector dilations = ctx.Attr>("dilations"); - int groups = ctx.Attr("groups"); - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); - - if (exhaustive_search && FLAGS_cudnn_deterministic) { - PADDLE_THROW( - "Cann't set exhaustive_search True and " - "FLAGS_cudnn_deterministic True at same time."); - } - - const T* input_data = input->data(); - const T* filter_data = filter->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); - // ------------------- cudnn descriptors --------------------- - ConvArgs args{input, filter, output, strides, paddings, dilations}; - auto handle = dev_ctx.cudnn_handle(); - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto dtype = platform::CudnnDataType::type; - DataLayout layout = DataLayout::kNCHW; - if (input->dims().size() == 5) { - layout = DataLayout::kNCDHW; - } - auto layout_format = GetCudnnTensorFormat(layout); - - args.handle = handle; - args.cdesc.set(dtype, paddings, strides, dilations); -#if CUDNN_VERSION_MIN(7, 0, 1) - // cudnn 7 can support groups, no need to do it manually - // FIXME(typhoonzero): find a better way to disable groups - // rather than setting it to 1. - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( - args.cdesc.desc(), groups)); - groups = 1; -#endif - args.idesc.set(*input, groups); - args.wdesc.set(*filter, layout_format, groups); - args.odesc.set(*output, groups); - int i_n, i_c, i_d, i_h, i_w; - GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); - int o_n, o_c, o_d, o_h, o_w; - GetNCDHW(output->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); - - int group_offset_in = i_c / groups * i_h * i_w * i_d; - int group_offset_out = o_c / groups * o_h * o_w * o_d; - int group_offset_filter = filter->numel() / groups; - // ------------------- cudnn conv workspace --------------------- - size_t workspace_size = 0; // final workspace to allocate. - // ------------------- cudnn conv algorithm --------------------- - cudnnConvolutionFwdAlgo_t algo{}; - - using search = SearchAlgorithm; - algo = search::Find(args, exhaustive_search, false, 0, ctx); - workspace_size = search::GetWorkspaceSize(args, algo); - - // ------------------- cudnn conv forward --------------------- - ScalingParamType alpha = 1.0f, beta = 0.0f; - for (int i = 0; i < groups; i++) { - workspace_handle.RunFunc( - [&](void* workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, args.idesc.desc(), - input_data + i * group_offset_in, args.wdesc.desc(), - filter_data + i * group_offset_filter, args.cdesc.desc(), algo, - workspace_ptr, workspace_size, &beta, args.odesc.desc(), - output_data + i * group_offset_out)); - }, - workspace_size); - } - } -}; - -template -class CUDNNConvGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto input = ctx.Input("Input"); - auto filter = ctx.Input("Filter"); - auto output_grad = ctx.Input(framework::GradVarName("Output")); - auto input_grad = ctx.Output(framework::GradVarName("Input")); - auto filter_grad = ctx.Output(framework::GradVarName("Filter")); - - const T* input_data = input->data(); - const T* output_grad_data = output_grad->data(); - const T* filter_data = filter->data(); - - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - std::vector dilations = ctx.Attr>("dilations"); - int groups = ctx.Attr("groups"); - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); - bool deterministic = FLAGS_cudnn_deterministic; - if (exhaustive_search && deterministic) { - PADDLE_THROW( - "Can't set exhaustive_search True and " - "FLAGS_cudnn_deterministic True at same time."); - } - - T* filter_grad_data = nullptr; - T* input_grad_data = nullptr; - ConvArgs args1{input_grad, filter, output_grad, - strides, paddings, dilations}; - ConvArgs args2{input, filter_grad, output_grad, - strides, paddings, dilations}; - // conv_cudnn_helper.h - auto handle = dev_ctx.cudnn_handle(); - auto dtype = platform::CudnnDataType::type; - DataLayout layout = DataLayout::kNCHW; - if (input->dims().size() == 5) { - layout = DataLayout::kNCDHW; - } - auto layout_tensor = GetCudnnTensorFormat(layout); - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - - int i_n, i_c, i_d, i_h, i_w; - GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); - int o_n, o_c, o_d, o_h, o_w; - GetNCDHW(output_grad->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, - &o_w); - - int group_offset_in = i_c / groups * i_h * i_w * i_d; - int group_offset_out = o_c / groups * o_h * o_w * o_d; - int group_offset_filter = filter->numel() / groups; - // ------------------- cudnn backward algorithm --------------------- - cudnnConvolutionBwdDataAlgo_t data_algo = - static_cast(0); - cudnnConvolutionBwdFilterAlgo_t filter_algo = - static_cast(0); - size_t workspace_size = 0; - int iwo_groups, c_groups; - -#if CUDNN_VERSION_MIN(7, 0, 1) - iwo_groups = 1; - c_groups = groups; - groups = 1; -#endif - - if (input_grad) { - // ------------------- cudnn descriptors --------------------- - input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - args1.handle = handle; - args1.idesc.set(*input_grad, iwo_groups); - args1.wdesc.set(*filter, layout_tensor, iwo_groups); - args1.odesc.set(*output_grad, iwo_groups); - args1.cdesc.set(dtype, paddings, strides, dilations, c_groups); - - using search1 = SearchAlgorithm; - data_algo = - search1::Find(args1, exhaustive_search, deterministic, 0, ctx); - workspace_size = - std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); - } - - if (filter_grad) { - // ------------------- cudnn descriptors --------------------- - filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - args2.handle = handle; - args2.idesc.set(*input, iwo_groups); - args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups); - args2.odesc.set(*output_grad, iwo_groups); - args2.cdesc.set(dtype, paddings, strides, dilations, c_groups); - - using search2 = SearchAlgorithm; - filter_algo = - search2::Find(args2, exhaustive_search, deterministic, 1, ctx); - workspace_size = std::max(workspace_size, - search2::GetWorkspaceSize(args2, filter_algo)); - } - - // ------------------- cudnn conv backward data --------------------- - ScalingParamType alpha = 1.0f, beta = 0.0f; - if (input_grad) { - // Because beta is zero, it is unnecessary to reset input_grad. - for (int i = 0; i < groups; i++) { - workspace_handle.RunFunc( - [&](void* cudnn_workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, args1.wdesc.desc(), - filter_data + i * group_offset_filter, args1.odesc.desc(), - output_grad_data + i * group_offset_out, args1.cdesc.desc(), - data_algo, cudnn_workspace_ptr, workspace_size, &beta, - args1.idesc.desc(), input_grad_data + i * group_offset_in)); - }, - workspace_size); - } - } - // ------------------- cudnn conv backward filter --------------------- - if (filter_grad) { - // Because beta is zero, it is unnecessary to reset filter_grad. - for (int i = 0; i < groups; i++) { - workspace_handle.RunFunc( - [&](void* cudnn_workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, args2.idesc.desc(), - input_data + i * group_offset_in, args2.odesc.desc(), - output_grad_data + i * group_offset_out, args2.cdesc.desc(), - filter_algo, cudnn_workspace_ptr, workspace_size, &beta, - args2.wdesc.desc(), - filter_grad_data + i * group_offset_filter)); - }, - workspace_size); - } - } - } -}; - -/* - * Inputs: I, W, dO, ddI, ddW - * Outputs: ddO, dW, dI - * ddo = conv(ddI, W) + conv(I, ddW) - * dW = conv_bp_filter(ddI, dO) - * dI = conv_bp_data(ddW, dO) - */ -template -class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto X = ctx.Input("Input"); - auto W = ctx.Input("Filter"); - auto dO = ctx.Input("DOutput"); - auto ddX = ctx.Input("DDInput"); - auto ddW = ctx.Input("DDFilter"); - - auto ddO = ctx.Output("DDOutput"); - auto dW = ctx.Output("DFilter"); - auto dX = ctx.Output("DInput"); - - const T* x = X->data(); - const T* dy = dO->data(); - const T* w = W->data(); - - const T* ddx = nullptr; - const T* ddw = nullptr; - T *dw, *dx, *ddy; - dw = dx = ddy = nullptr; - - const std::vector& strides = ctx.Attr>("strides"); - const std::vector& paddings = ctx.Attr>("paddings"); - const std::vector& dilations = ctx.Attr>("dilations"); - int groups = ctx.Attr("groups"); - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); - bool deterministic = FLAGS_cudnn_deterministic; - if (exhaustive_search && deterministic) { - PADDLE_THROW( - "Can't set exhaustive_search True and " - "FLAGS_cudnn_deterministic True at same time."); - } - - int iwo_group = groups; - int c_group = 1; -#if CUDNN_VERSION_MIN(7, 0, 1) - iwo_group = 1; - c_group = groups; -#endif - auto dtype = platform::CudnnDataType::type; - - auto handle = dev_ctx.cudnn_handle(); - - ConvArgs args1{ddX, W, ddO, strides, paddings, dilations}; - ConvArgs args2{X, ddW, ddO, strides, paddings, dilations}; - ConvArgs args3{ddX, dW, dO, strides, paddings, dilations}; - ConvArgs args4{dX, ddW, dO, strides, paddings, dilations}; - - cudnnConvolutionFwdAlgo_t fwd_algo1 = - static_cast(0); - cudnnConvolutionFwdAlgo_t fwd_algo2 = - static_cast(0); - cudnnConvolutionBwdDataAlgo_t data_algo = - static_cast(0); - cudnnConvolutionBwdFilterAlgo_t filter_algo = - static_cast(0); - - auto layout = GetCudnnTensorFormat(DataLayout::kNCHW); - - // ddo = conv(ddI, W) + conv(I, ddW) - size_t workspace_size = 0; - if (ddO) { - ddy = ddO->mutable_data(ctx.GetPlace()); - if (ddX) { - args1.handle = handle; - args1.idesc.set(*ddX, iwo_group); - args1.wdesc.set(*W, layout, iwo_group); - args1.odesc.set(*ddO, iwo_group); - args1.cdesc.set(dtype, paddings, strides, dilations, c_group); - - using search1 = SearchAlgorithm; - fwd_algo1 = search1::Find(args1, exhaustive_search, false, 0, ctx); - workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); - } - - if (ddW) { - ddw = ddW->data(); - args2.handle = handle; - args2.idesc.set(*X, iwo_group); - args2.wdesc.set(*ddW, layout, iwo_group); - args2.odesc.set(*ddO, iwo_group); - args2.cdesc.set(dtype, paddings, strides, dilations, c_group); - - using search2 = SearchAlgorithm; - fwd_algo2 = search2::Find(args2, exhaustive_search, false, 0, ctx); - workspace_size = std::max(workspace_size, - search2::GetWorkspaceSize(args2, fwd_algo2)); - } - } - - if (dW && ddX) { - dw = dW->mutable_data(ctx.GetPlace()); - args3.handle = handle; - args3.idesc.set(*ddX, iwo_group); - args3.wdesc.set(*dW, layout, iwo_group); - args3.odesc.set(*dO, iwo_group); - args3.cdesc.set(dtype, paddings, strides, dilations, c_group); - - using search3 = SearchAlgorithm; - filter_algo = - search3::Find(args3, exhaustive_search, deterministic, 1, ctx); - workspace_size = std::max(workspace_size, - search3::GetWorkspaceSize(args3, filter_algo)); - } - - if (ddW && dX) { - dx = dX->mutable_data(ctx.GetPlace()); - args4.handle = handle; - args4.idesc.set(*dX, iwo_group); - args4.wdesc.set(*ddW, layout, iwo_group); - args4.odesc.set(*dO, iwo_group); - args4.cdesc.set(dtype, paddings, strides, dilations, c_group); - - using search4 = SearchAlgorithm; - data_algo = - search4::Find(args4, exhaustive_search, deterministic, 2, ctx); - workspace_size = - std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); - } - - int i_n, i_c, i_d, i_h, i_w; - GetNCDHW(X->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); - int o_n, o_c, o_d, o_h, o_w; - GetNCDHW(dO->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w); - - int group_offset_in = i_c / groups * i_h * i_w * i_d; - int group_offset_out = o_c / groups * o_h * o_w * o_d; - int group_offset_filter = W->numel() / groups; - - ScalingParamType alpha = 1.0f, beta = 0.0f; - auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); - - if (ddO) { - if (ddX) { - ddx = ddX->data(); - for (int i = 0; i < groups; i++) { - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, args1.idesc.desc(), - ddx + i * group_offset_in, args1.wdesc.desc(), - w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1, - workspace_ptr, workspace_size, &beta, args1.odesc.desc(), - ddy + i * group_offset_out)); - }, - workspace_size); - } - } - if (ddW) { - for (int i = 0; i < groups; i++) { - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, args2.idesc.desc(), x + i * group_offset_in, - args2.wdesc.desc(), ddw + i * group_offset_filter, - args2.cdesc.desc(), fwd_algo2, workspace_ptr, - workspace_size, &alpha, args2.odesc.desc(), - ddy + i * group_offset_out)); - }, - workspace_size); - } - } - } - - if (dW && ddX) { - ddx = ddX->data(); - for (int i = 0; i < groups; i++) { - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in, - args3.odesc.desc(), dy + i * group_offset_out, - args3.cdesc.desc(), filter_algo, workspace_ptr, - workspace_size, &beta, args3.wdesc.desc(), - dw + i * group_offset_filter)); - }, - workspace_size); - } - } - - if (dX && ddW) { - ddw = ddW->data(); - for (int i = 0; i < groups; i++) { - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, args4.wdesc.desc(), - ddw + i * group_offset_filter, args4.odesc.desc(), - dy + i * group_offset_out, args4.cdesc.desc(), data_algo, - workspace_ptr, workspace_size, &beta, args4.idesc.desc(), - dx + i * group_offset_in)); - }, - workspace_size); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace plat = paddle::platform; -REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNConvOpKernel, - paddle::operators::CUDNNConvOpKernel, - paddle::operators::CUDNNConvOpKernel); -REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); -REGISTER_OP_KERNEL( - conv2d_grad_grad, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNConvDoubleGradOpKernel, - paddle::operators::CUDNNConvDoubleGradOpKernel, - paddle::operators::CUDNNConvDoubleGradOpKernel); - -REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNConvOpKernel, - paddle::operators::CUDNNConvOpKernel, - paddle::operators::CUDNNConvOpKernel); -REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); -REGISTER_OP_KERNEL( - conv3d_grad_grad, CUDNN, plat::CUDAPlace, - paddle::operators::CUDNNConvDoubleGradOpKernel, - paddle::operators::CUDNNConvDoubleGradOpKernel, - paddle::operators::CUDNNConvDoubleGradOpKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 5528f758732..1230b848fdd 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -31,53 +31,76 @@ namespace paddle { namespace operators { void ConvOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of ConvOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, + "Input(Input) of ConvOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, + "Input(Filter) of ConvOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true, + "Output(Output) of ConvOp should not be null."); auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); + std::string padding_algorithm = + ctx->Attrs().Get("padding_algorithm"); int groups = ctx->Attrs().Get("groups"); std::vector dilations = ctx->Attrs().Get>("dilations"); + const std::string data_format = ctx->Attrs().Get("data_format"); + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); - PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, - "Conv intput should be 4-D or 5-D tensor, get %u", - in_dims.size()); + PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true, + "Conv intput should be 4-D or 5-D tensor, get %u", + in_dims.size()); PADDLE_ENFORCE_EQ( in_dims.size(), filter_dims.size(), "Conv input dimension and filter dimension should be the same."); - PADDLE_ENFORCE( - in_dims.size() - strides.size() == 2U, - "Conv input dimension and strides dimension should be consistent."); PADDLE_ENFORCE_EQ( - paddings.size(), strides.size(), - "Conv paddings dimension and Conv strides dimension should be the same."); + in_dims.size() - strides.size() == 2U, true, + "Conv input dimension and strides dimension should be consistent."); + + const auto input_channels = + channel_last ? in_dims[in_dims.size() - 1] : in_dims[1]; - PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " "channels * groups."); PADDLE_ENFORCE_EQ( filter_dims[0] % groups, 0, "The number of output channels should be divided by groups."); - std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < strides.size(); ++i) { + framework::DDim in_data_dims; + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + std::vector output_shape({in_dims[0]}); + if (!channel_last) { + output_shape.push_back(filter_dims[0]); + } + for (size_t i = 0; i < in_data_dims.size(); ++i) { if ((!ctx->IsRuntime()) && - (in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) { + (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) { output_shape.push_back(-1); } else { - output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], - dilations[i], paddings[i], - strides[i])); + output_shape.push_back(ConvOutputSize(in_data_dims[i], filter_dims[i + 2], + dilations[i], paddings[2 * i], + paddings[2 * i + 1], strides[i])); } } + if (channel_last) { + output_shape.push_back(filter_dims[0]); + } + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Output"); } @@ -89,7 +112,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( framework::LibraryType library{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready auto input_data_type = ctx.Input("Input")->type(); - std::string data_format = ctx.Attr("data_format"); + std::string data_format = + "AnyLayout"; // todo enable data layout when it's ready framework::DataLayout layout = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA @@ -142,12 +166,12 @@ void Conv2DOpMaker::Make() { "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddInput( - "Input", - "(Tensor) The input tensor of convolution operator. " - "The format of input tensor is NCHW, where N is batch size, C is the " - "number of channels, H is the height of the feature, " - "and W is the width of the feature."); + AddInput("Input", + "(Tensor) The input tensor of convolution operator. " + "The format of input tensor is NCHW or NHWC, where N is batch size, " + "C is the " + "number of channels, H is the height of the feature, " + "and W is the width of the feature."); AddInput("Filter", "(Tensor) The filter tensor of convolution operator. " "The format of the filter tensor is MCHW, where M is the number of " @@ -167,7 +191,7 @@ void Conv2DOpMaker::Make() { .AsDispensable(); AddOutput("Output", "(Tensor) The output tensor of convolution operator. " - "The format of output tensor is also NCHW."); + "It has same data fromat and data type as the Input."); AddAttr>("strides", "(vector default:{1, 1}), the " "strides(h_stride, w_stride) of " @@ -175,9 +199,16 @@ void Conv2DOpMaker::Make() { .SetDefault({1, 1}); AddAttr>("paddings", "(vector default:{0, 0}), the " - "paddings(h_pad, w_pad) of " + "paddings(pad_height_top, pad_height_bottom, " + "pad_width_left, pad_wifth_right) of " "convolution operator.") .SetDefault({0, 0}); + AddAttr( + "padding_algorithm", + "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\"," + "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. " + "Set to \"SAME\" or \"VALID\" for algorithm of padding. ") + .SetDefault("EXPLICIT"); AddAttr( "groups", "(int default:1), the groups number of the convolution operator. " @@ -254,7 +285,7 @@ void Conv2DOpMaker::Make() { "An optional string from: \"NHWC\", \"NCHW\". " "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); + .SetDefault("NCHW"); // TODO(dzhwinter): need to registered layout transform function AddAttr("workspace_size_MB", "Only used in cudnn kernel. Need set use_cudnn to true." @@ -269,13 +300,14 @@ void Conv2DOpMaker::Make() { "convolution, whether enable exhaustive search " "for cuDNN convolution or not, default is False.") .SetDefault(false); + AddComment(R"DOC( Convolution Operator. The convolution operation calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. -Input(Input) and Output(Output) are in NCHW format. Where N is batch +Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch size, C is the number of channels, H is the height of the feature, and W is the width of the feature. Filters(Input) is MCHW format. Where M is the number of output image channels, C is @@ -293,8 +325,8 @@ Example: Output shape: $(N, C_{out}, H_{out}, W_{out})$ Where $$ - H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\ - W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 + H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\ + W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 $$ )DOC"); Apply(); @@ -308,7 +340,8 @@ void Conv3DOpMaker::Make() { AddInput( "Input", "(Tensor) The input tensor of convolution operator. " - "The format of input tensor is NCDHW. Where N is batch size, C is the " + "The format of input tensor is NCDHW or NDHWC. Where N is batch size, C " + "is the " "number of channels, D is the depth of the feature, H is the height of " "the feature, " "and W is the width of the feature."); @@ -327,17 +360,25 @@ void Conv3DOpMaker::Make() { .AsDispensable(); AddOutput("Output", "(Tensor) The output tensor of convolution operator." - "The format of output tensor is also NCDHW."); + "It has same data fromat and data type as the Input."); AddAttr>("strides", "(vector, default:{1, 1, 1}), the " "strides(d_stride, h_stride, w_stride) of " "convolution operator.") .SetDefault({1, 1, 1}); - AddAttr>("paddings", - "(vector, default:{0, 0, 0}), the " - "paddings(d_pad, h_pad, w_pad) of convolution " - "operator.") + AddAttr>( + "paddings", + "(vector, default:{0, 0, 0}), the " + "paddings(pad_depth_front, pad_depth_back, pad_height_top, " + "pad_height_bottom, pad_width_left, pad_width_right) of convolution " + "operator.") .SetDefault({0, 0, 0}); + AddAttr( + "padding_algorithm", + "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\"," + "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. " + "Set to \"SAME\" or \"VALID\" for algorithm of padding. ") + .SetDefault("EXPLICIT"); AddAttr( "groups", "(int default:1), the groups number of the convolution operator. " @@ -375,11 +416,11 @@ void Conv3DOpMaker::Make() { .SetDefault(false); AddAttr( "data_format", - "(string, default NCHW) Only used in " - "An optional string from: \"NHWC\", \"NCHW\". " - "Defaults to \"NHWC\". Specify the data format of the output data, " + "(string, default NCDHW) Only used in " + "An optional string from: \"NDHWC\", \"NCDHW\". " + "Defaults to \"NDHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); + .SetDefault("NCDHW"); AddAttr("force_fp32_output", "(bool, default false) Only used in mkldnn INT8 kernel") .SetDefault(false); @@ -402,7 +443,7 @@ Convolution3D Operator. The convolution operation calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. -Input(Input) and output(Output) are in NCDHW format, where N is batch +Input(Input) and output(Output) are in NCDHW or NDHWC format, where N is batch size, C is the number of channels,D is the depth of the feature, H is the height of the feature, and W is the width of the feature. Filters(Input) is MCDHW format, where M is the number of output image channels, @@ -420,9 +461,9 @@ Example: Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$ Where $$ - D_{out}= \frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\ - H_{out}= \frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\ - W_{out}= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1 + D_{out}= \frac{(D_{in} + pad_depth_front + pad_depth_back - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\ + H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\ + W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1 $$ )DOC"); Apply(); @@ -445,7 +486,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - std::string data_format = ctx.Attr("data_format"); + std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA @@ -623,7 +664,7 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( int customized_type_value = framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); + std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index a6882897ad7..b6820a7bef4 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -39,8 +40,8 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, int padding, int stride) { const int dkernel = dilation * (filter_size - 1) + 1; int output_size = (input_size + 2 * padding - dkernel) / stride + 1; - PADDLE_ENFORCE( - output_size > 0, + PADDLE_ENFORCE_GT( + output_size, 0, "Due to the settings of padding(%d), filter_size(%d), dilation(%d) and " "stride(%d), the output size is less than 0, please check " "again. Input_size:%d", @@ -48,6 +49,62 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, return output_size; } + +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding_1, int padding_2, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1; + PADDLE_ENFORCE_GT(output_size, 0, + "Due to the settings of padding(%d, %d), filter_size(%d), " + "dilation(%d) and " + "stride(%d), the output size is less than 0, please check " + "again. Input_size:%d", + padding_1, padding_2, filter_size, dilation, stride, + input_size); + + return output_size; +} +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilation, + const std::string padding_algorithm, + const framework::DDim data_dims, + const std::vector& strides, + const std::vector& ksize) { + // set padding size == data_dims.size() * 2 + auto data_shape = framework::vectorize(data_dims); + if (paddings->size() == data_dims.size()) { + for (size_t i = 0; i < data_dims.size(); ++i) { + int copy_pad = *(paddings->begin() + 2 * i); + paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); + } + } else { + PADDLE_ENFORCE_EQ( + data_dims.size() * 2, paddings->size(), + "Paddings size should be the same or twice as the input data size."); + } + + // when padding_desc is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (size_t i = 0; i < data_dims.size(); ++i) { + int out_size = (data_dims[i] + strides[i] - 1) / strides[0]; + int pad_sum = + std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0); + int pad_0 = pad_sum / 2; + int pad_1 = pad_sum - pad_0; + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; + + // dilation + *(dilation->begin() + i) = 1; + } + + } else if (padding_algorithm == "VALID") { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = 0; + } + } +} + inline bool IsExpand(const std::vector& filter_dim, const std::vector& strides, const std::vector& paddings, @@ -59,9 +116,80 @@ inline bool IsExpand(const std::vector& filter_dim, padding_0 = padding_0 && (paddings[j] == 0); dilation_1 = dilation_1 && (dilations[j] == 1); } + if (paddings.size() != strides.size()) { + for (size_t j = 0; j < paddings.size(); ++j) { + padding_0 = padding_0 && (paddings[j] == 0); + } + } return !(filter_1 && strides_1 && padding_0 && dilation_1); } +template +inline void ResizeToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[4]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + in_dims_vec[4] = input->dims()[3]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[3]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void TransToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 4, 1, 2, 3}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 3, 1, 2}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } +} + +template +inline void TransToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 4, 1}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 1}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } +} // Define Op classes in .h file so that other conv // operator implementations can reuse the code. class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { @@ -131,39 +259,82 @@ class GemmConvKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); - int groups = context.Attr("groups"); - std::vector strides = context.Attr>("strides"); + const int groups = context.Attr("groups"); + const std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + const std::string data_format = context.Attr("data_format"); + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + Tensor transformed_input(input->type()); + Tensor transformed_output(output->type()); + + if (channel_last) { + ResizeToChannelFirst(context, input, + &transformed_input); + TransToChannelFirst(context, input, &transformed_input); + + ResizeToChannelFirst(context, output, + &transformed_output); + + } else { + transformed_input = *input; + transformed_output = *output; + } + + // update padding and dilation + auto trans_in_dims = transformed_input.dims(); + auto filter_dims = filter.dims(); + + framework::DDim in_data_dims = + framework::slice_ddim(trans_in_dims, 2, trans_in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); auto& dev_ctx = context.template device_context(); - const int batch_size = static_cast(input->dims()[0]); + const int batch_size = static_cast(transformed_input.dims()[0]); - // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} + // filter_shape_vec: + // {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); - // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} - std::vector output_shape_vec(framework::vectorize(output->dims())); + + // output_shape_vec: + // {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} + std::vector output_shape_vec( + framework::vectorize(transformed_output.dims())); // use col_shape in the im2col calculation - // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, - // o_h, o_w} + // col_shape_vec: + // {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, + // o_d,o_h, o_w} size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; + col_shape_vec[0] = trans_in_dims[1] / groups; for (size_t j = 0; j < data_dim; ++j) { col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation - // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * - // o_h * o_w) + // size: + // (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * o_h * + // o_w) + framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); + framework::flatten_to_2d(col_shape, data_dim); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape @@ -175,28 +346,31 @@ class GemmConvKernel : public framework::OpKernel { col_matrix.Resize(col_matrix_shape); } - framework::DDim input_shape = - framework::slice_ddim(input->dims(), 1, input->dims().size()); + framework::DDim in_matrix_shape = framework::slice_ddim( + transformed_input.dims(), 1, transformed_input.dims().size()); framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; + transformed_output.dims()[1], + transformed_output.numel() / + (transformed_output.dims()[0] * transformed_output.dims()[1])}; // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; + int in_step = static_cast(transformed_input.dims()[1]) / groups; + int out_step = static_cast(transformed_output.dims()[1]) / groups; math::Vol2ColFunctor vol2col; math::Im2ColFunctor im2col; auto blas = math::GetBlas(dev_ctx); for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = + transformed_input.Slice(i, i + 1).Resize(in_matrix_shape); + Tensor out_batch = + transformed_output.Slice(i, i + 1).Resize(output_matrix_shape); for (int g = 0; g < groups; g++) { Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); @@ -206,13 +380,12 @@ class GemmConvKernel : public framework::OpKernel { col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { - // im2col im2col(dev_ctx, in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &col); + } else if (data_dim == 3U) { - // vol2col vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col); } @@ -223,6 +396,10 @@ class GemmConvKernel : public framework::OpKernel { T(0.0)); } } + if (channel_last) { + TransToChannelLast(context, &transformed_output, + output); + } } }; @@ -245,11 +422,44 @@ class GemmConvGradKernel : public framework::OpKernel { if (!input_grad && !filter_grad) return; int groups = context.Attr("groups"); - std::vector strides = context.Attr>("strides"); + const std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + const std::string data_format = context.Attr("data_format"); + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + Tensor transformed_input(input->type()); + Tensor transformed_output_grad(output_grad->type()); - const int batch_size = static_cast(input->dims()[0]); + if (channel_last) { + ResizeToChannelFirst(context, input, + &transformed_input); + TransToChannelFirst(context, input, &transformed_input); + + ResizeToChannelFirst(context, output_grad, + &transformed_output_grad); + TransToChannelFirst(context, output_grad, + &transformed_output_grad); + } else { + transformed_input = *input; + transformed_output_grad = *output_grad; + } + + // update padding and dilation + auto in_dims = transformed_input.dims(); + auto filter_dims = filter.dims(); + framework::DDim in_data_dims = + framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + const int batch_size = static_cast(transformed_input.dims()[0]); auto& dev_ctx = context.template device_context(); @@ -257,14 +467,14 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector filter_shape_vec(framework::vectorize(filter.dims())); // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} std::vector output_shape_vec( - framework::vectorize(output_grad->dims())); + framework::vectorize(transformed_output_grad.dims())); // use col_shape in the im2col calculation // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, // o_h, o_w} size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; + col_shape_vec[0] = transformed_input.dims()[1] / groups; for (size_t j = 0; j < data_dim; ++j) { col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; @@ -278,24 +488,25 @@ class GemmConvGradKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); - framework::DDim input_shape = - framework::slice_ddim(input->dims(), 1, input->dims().size()); + framework::DDim input_shape = framework::slice_ddim( + transformed_input.dims(), 1, transformed_input.dims().size()); framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = { - output_grad->dims()[1], - output_grad->numel() / - (output_grad->dims()[0] * output_grad->dims()[1])}; + transformed_output_grad.dims()[1], + transformed_output_grad.numel() / (transformed_output_grad.dims()[0] * + transformed_output_grad.dims()[1])}; // convolution backward input operator: gemm + col2im(or col2vol) // convolution backward weight operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output_grad->dims()[1]) / groups; + int in_step = static_cast(transformed_input.dims()[1]) / groups; + int out_step = static_cast(transformed_output_grad.dims()[1]) / groups; bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape @@ -312,19 +523,27 @@ class GemmConvGradKernel : public framework::OpKernel { if (input_grad) { input_grad->mutable_data(context.GetPlace()); + Tensor transformed_input_grad(input_grad->type()); + if (channel_last) { + ResizeToChannelFirst(context, input_grad, + &transformed_input_grad); + } else { + transformed_input_grad = *input_grad; + } // if is_expand is false, the operation of set_zero is unnecessary, // because math::matmul will reset input_grad. if (is_expand) { - set_zero(dev_ctx, input_grad, static_cast(0)); + set_zero(dev_ctx, &transformed_input_grad, static_cast(0)); } math::Col2VolFunctor col2vol; math::Col2ImFunctor col2im; for (int i = 0; i < batch_size; i++) { Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + transformed_input_grad.Slice(i, i + 1).Resize(input_shape); for (int g = 0; g < groups; g++) { // gemm Tensor out_grad_slice = @@ -343,14 +562,18 @@ class GemmConvGradKernel : public framework::OpKernel { if (is_expand && data_dim == 2U) { col2im(dev_ctx, col, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &in_grad_slice); } else if (is_expand && data_dim == 3U) { col2vol(dev_ctx, col, dilations, strides, paddings, &in_grad_slice); } } } + if (channel_last) { + TransToChannelLast(context, &transformed_input_grad, + input_grad); + } } if (filter_grad) { @@ -362,8 +585,8 @@ class GemmConvGradKernel : public framework::OpKernel { math::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; i++) { Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = transformed_input.Slice(i, i + 1).Resize(input_shape); for (int g = 0; g < groups; g++) { // im2col Tensor out_grad_slice = @@ -376,9 +599,10 @@ class GemmConvGradKernel : public framework::OpKernel { col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { im2col(dev_ctx, in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &col); + } else if (data_dim == 3U) { vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col); } @@ -412,21 +636,60 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { Tensor W = detail::Ref(ctx.Input("Filter"), "Cannot find input Filter(%s) in scope)", ctx.Inputs("Filter")[0]); - if (!ddY && !dW && !dX) return; - int groups = ctx.Attr("groups"); - std::vector strides = ctx.Attr>("strides"); + + const int groups = ctx.Attr("groups"); + const std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + const std::string data_format = ctx.Attr("data_format"); + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + // transform Tensor + Tensor transformed_X(X->type()); + Tensor transformed_dY(dY->type()); + Tensor transformed_ddX(ddX->type()); + + if (channel_last) { + ResizeToChannelFirst(ctx, X, &transformed_X); + TransToChannelFirst(ctx, X, &transformed_X); + + ResizeToChannelFirst(ctx, dY, &transformed_dY); + TransToChannelFirst(ctx, dY, &transformed_dY); + + ResizeToChannelFirst(ctx, ddX, &transformed_ddX); + TransToChannelFirst(ctx, ddX, &transformed_ddX); - const int batch_size = static_cast(X->dims()[0]); + } else { + transformed_X = *X; + transformed_dY = *dY; + transformed_ddX = *ddX; + } + + // update padding and dilation + auto in_dims = transformed_X.dims(); + auto filter_dims = W.dims(); + + framework::DDim in_data_dims = + framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + const int batch_size = static_cast(transformed_X.dims()[0]); std::vector filter_shape_vec(framework::vectorize(W.dims())); - std::vector output_shape_vec(framework::vectorize(dY->dims())); + std::vector output_shape_vec( + framework::vectorize(transformed_dY.dims())); size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); // col_shape [in_channel/group, kh, kw, oh, ow] - col_shape_vec[0] = X->dims()[1] / groups; + col_shape_vec[0] = transformed_X.dims()[1] / groups; for (size_t j = 0; j < data_dim; ++j) { col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + data_dim + 1] = output_shape_vec[j + 2]; @@ -436,17 +699,19 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); // input_shape [Cin, H, W] - framework::DDim input_shape = - framework::slice_ddim(X->dims(), 1, X->dims().size()); + framework::DDim input_shape = framework::slice_ddim( + transformed_X.dims(), 1, transformed_X.dims().size()); // filter_matrix_shape [Cout, Cin * kh * kw] framework::DDim filter_matrix_shape = {W.dims()[0], W.numel() / W.dims()[0]}; W.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = { - dY->dims()[1], dY->numel() / (dY->dims()[0] * dY->dims()[1])}; - int in_step = static_cast(X->dims()[1]) / groups; - int out_step = static_cast(dY->dims()[1]) / groups; + transformed_dY.dims()[1], + transformed_dY.numel() / + (transformed_dY.dims()[0] * transformed_dY.dims()[1])}; + int in_step = static_cast(transformed_X.dims()[1]) / groups; + int out_step = static_cast(transformed_dY.dims()[1]) / groups; bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; @@ -466,19 +731,28 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { if (dX && ddW_in) { Tensor ddW; ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); - dX->mutable_data(ctx.GetPlace()); + + Tensor transformed_dX(dX->type()); + + if (channel_last) { + ResizeToChannelFirst(ctx, dX, &transformed_dX); + + } else { + transformed_dX = *dX; + } // if is_expand is false, the operation of set_zero is unnecessary // because math::matmul will reset dx if (is_expand) { - set_zero(dev_ctx, dX, static_cast(0)); + set_zero(dev_ctx, &transformed_dX, static_cast(0)); } math::Col2VolFunctor col2vol; math::Col2ImFunctor col2im; for (int i = 0; i < batch_size; i++) { - Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor dx_batch = dX->Slice(i, i + 1).Resize(input_shape); + Tensor dy_batch = + transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape); + Tensor dx_batch = transformed_dX.Slice(i, i + 1).Resize(input_shape); for (int g = 0; g < groups; g++) { // gemm Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step); @@ -493,14 +767,17 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { if (is_expand && data_dim == 2U) { col2im(dev_ctx, col, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &dx_slice); } else if (is_expand && data_dim == 3U) { col2vol(dev_ctx, col, dilations, strides, paddings, &dx_slice); } } } + if (channel_last) { + TransToChannelLast(ctx, &transformed_dX, dX); + } } // dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout, @@ -514,8 +791,9 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { math::Im2ColFunctor im2col; math::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { - Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape); + Tensor dy_batch = + transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape); + Tensor ddx_batch = transformed_ddX.Slice(i, i + 1).Resize(input_shape); for (int g = 0; g < groups; ++g) { // im2col Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step); @@ -526,8 +804,8 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { im2col(dev_ctx, ddx_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &col); } else if (data_dim == 3U) { vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); @@ -545,55 +823,62 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { // ddy convolution double grad: im2col(vol2col) + gemm if (ddY) { ddY->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, ddY, static_cast(0)); + + Tensor transformed_ddY(ddY->type()); + if (channel_last) { + ResizeToChannelFirst(ctx, ddY, &transformed_ddY); + } else { + transformed_ddY = *ddY; + } + + set_zero(dev_ctx, &transformed_ddY, static_cast(0)); math::Im2ColFunctor im2col; math::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { - Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor ddy_batch = + transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape); for (int g = 0; g < groups; ++g) { + // gemm Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step); + if (ddX) { - Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape); + Tensor ddx_batch = + transformed_ddX.Slice(i, i + 1).Resize(input_shape); Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step); if (!is_expand) { col.ShareDataWith(ddx_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { - // im2col im2col(dev_ctx, ddx_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &col); } else if (data_dim == 3U) { - // vol2col vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); } - - // gemm - Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); - blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, - T(0.0)); } + Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); + blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, + T(0.0)); + if (ddW_in) { - Tensor ddW; - ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); - Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape); + Tensor x_batch = transformed_X.Slice(i, i + 1).Resize(input_shape); Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step); + Tensor ddW; + ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); if (!is_expand) { col.ShareDataWith(x_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { - // im2col im2col(dev_ctx, x_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &col); } else if (data_dim == 3U) { - // vol2col vol2col(dev_ctx, x_slice, dilations, strides, paddings, &col); } @@ -604,6 +889,9 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { } } } + if (channel_last) { + TransToChannelLast(ctx, &transformed_ddY, ddY); + } } } }; @@ -617,23 +905,77 @@ class DepthwiseConvKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); - PADDLE_ENFORCE_EQ( - output->dims()[1] % input->dims()[1], 0, - "The output channels must be a multiple of the input channels"); - std::vector strides = context.Attr>("strides"); + const std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); bool fuse_relu = context.Attr("fuse_relu_before_depthwise_conv"); + + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + const std::string data_format = context.Attr("data_format"); + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + if (channel_last) { + PADDLE_ENFORCE_EQ( + output->dims()[output->dims().size() - 1] % + input->dims()[input->dims().size() - 1], + 0, "The output channels must be a multiple of the input channels"); + } else { + PADDLE_ENFORCE_EQ( + output->dims()[1] % input->dims()[1], 0, + "The output channels must be a multiple of the input channels"); + } + // transform tensor + Tensor transformed_input(input->type()); + Tensor transformed_output(output->type()); + + if (channel_last) { + ResizeToChannelFirst(context, input, + &transformed_input); + TransToChannelFirst(context, input, &transformed_input); + + ResizeToChannelFirst(context, output, + &transformed_output); + + } else { + transformed_input = *input; + transformed_output = *output; + } + + // update padding and dilation + auto in_dims = transformed_input.dims(); + auto filter_dims = filter.dims(); + + framework::DDim in_data_dims; + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + bool is_sys_pad = strides.size() * 2 == paddings.size() ? false : true; + if (!is_sys_pad) { + for (size_t i = 0; i < strides.size(); ++i) { + paddings.erase(paddings.begin() + i + 1); + } + } + auto& dev_ctx = context.template device_context(); if (fuse_relu) { math::DepthwiseConvFunctor depthwiseConv; - depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations, - output); + depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings, + dilations, &transformed_output); } else { math::DepthwiseConvFunctor depthwiseConv; - depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations, - output); + depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings, + dilations, &transformed_output); + } + if (channel_last) { + TransToChannelLast(context, &transformed_output, + output); } } }; @@ -657,24 +999,81 @@ class DepthwiseConvGradKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); bool fuse_relu = context.Attr("fuse_relu_before_depthwise_conv"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + const std::string data_format = context.Attr("data_format"); + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + // transform Tensor + Tensor transformed_input(input->type()); + Tensor transformed_output_grad(output_grad->type()); + if (channel_last) { + ResizeToChannelFirst(context, input, + &transformed_input); + TransToChannelFirst(context, input, &transformed_input); + + ResizeToChannelFirst(context, output_grad, + &transformed_output_grad); + TransToChannelFirst(context, output_grad, + &transformed_output_grad); + + } else { + transformed_input = *input; + transformed_output_grad = *output_grad; + } + + // update padding and dilation + auto in_dims = transformed_input.dims(); + auto filter_dims = filter.dims(); + + framework::DDim in_data_dims; + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + bool is_sys_pad = strides.size() * 2 == paddings.size() ? false : true; + if (!is_sys_pad) { + for (size_t i = 0; i < strides.size(); ++i) { + paddings.erase(paddings.begin() + i + 1); + } + } math::SetConstant set_zero; auto& dev_ctx = context.template device_context(); if (input_grad) { input_grad->mutable_data(context.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); + Tensor transformed_input_grad(input_grad->type()); + if (channel_last) { + ResizeToChannelFirst(context, input_grad, + &transformed_input_grad); + + } else { + transformed_input_grad = *input_grad; + } + + set_zero(dev_ctx, &transformed_input_grad, static_cast(0)); if (fuse_relu) { math::DepthwiseConvInputGradFunctor depthwiseConvInputGrad; - depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides, - paddings, dilations, input_grad); + depthwiseConvInputGrad(dev_ctx, transformed_input, filter, + transformed_output_grad, strides, paddings, + dilations, &transformed_input_grad); } else { math::DepthwiseConvInputGradFunctor depthwiseConvInputGrad; - depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides, - paddings, dilations, input_grad); + depthwiseConvInputGrad(dev_ctx, transformed_input, filter, + transformed_output_grad, strides, paddings, + dilations, &transformed_input_grad); + } + if (channel_last) { + TransToChannelLast(context, &transformed_input_grad, + input_grad); } } @@ -684,13 +1083,15 @@ class DepthwiseConvGradKernel : public framework::OpKernel { if (fuse_relu) { math::DepthwiseConvFilterGradFunctor depthwiseConvFilterGrad; - depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, - paddings, dilations, filter_grad); + depthwiseConvFilterGrad(dev_ctx, transformed_input, + transformed_output_grad, strides, paddings, + dilations, filter_grad); } else { math::DepthwiseConvFilterGradFunctor depthwiseConvFilterGrad; - depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, - paddings, dilations, filter_grad); + depthwiseConvFilterGrad(dev_ctx, transformed_input, + transformed_output_grad, strides, paddings, + dilations, filter_grad); } } } diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index 1472edbbf47..fe646ea2e77 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -33,15 +33,18 @@ class Im2ColFunctor& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* col) { - PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col->dims().size() == 5); + PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col->dims().size(), 5, + "The dimension of col should be 5."); if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 && dilation[1] == 1) { - if (padding[0] == 0 && padding[1] == 0) { + if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 && + padding[3] == 0) { im2col_sh1sw1dh1dw1ph0pw0(im, col); return; - } else if (padding[0] == 1 && padding[1] == 1) { + } else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 && + padding[3] == 1) { im2col_sh1sw1dh1dw1ph1pw1(im, col); return; } @@ -65,8 +68,9 @@ class Col2ImFunctor& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* im) { - PADDLE_ENFORCE(im->dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col.dims().size(), 5, + "The dimension of col should be 5."); int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; @@ -136,8 +140,9 @@ class Im2ColFunctor& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* col) { - PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col->dims().size() == 5); + PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col->dims().size(), 5, + "The dimension of col should be 5."); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; @@ -198,8 +203,9 @@ class Col2ImFunctor& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* im) { - PADDLE_ENFORCE(im->dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col.dims().size(), 5, + "The dimension of col should be 5."); int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; diff --git a/paddle/fluid/operators/math/vol2col.cc b/paddle/fluid/operators/math/vol2col.cc index e92adc09ba0..1083cac3020 100644 --- a/paddle/fluid/operators/math/vol2col.cc +++ b/paddle/fluid/operators/math/vol2col.cc @@ -34,9 +34,10 @@ class Vol2ColFunctor { const std::vector& strides, const std::vector& paddings, framework::Tensor* col) const { - PADDLE_ENFORCE(vol.dims().size() == 4); - PADDLE_ENFORCE(col->dims().size() == 7); - + PADDLE_ENFORCE_EQ(vol.dims().size(), 4, + "The dimension of vol should be 4."); + PADDLE_ENFORCE_EQ(col->dims().size(), 7, + "The dimension of col should be 7."); int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; @@ -50,28 +51,35 @@ class Vol2ColFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + // changed + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back - ((dilations[0] * (filter_depth - 1) + 1))) / strides[0] + 1, output_depth, "input_depth and output_depth are " "mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down - ((dilations[1] * (filter_height - 1) + 1))) / strides[1] + 1, output_height, "input_height and output_height are " "mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right - ((dilations[2] * (filter_width - 1) + 1))) / strides[2] + 1, output_width, "input_width and output_width are " "mismatching."); - const T* vol_data = vol.data(); T* col_data = col->data(); @@ -81,11 +89,11 @@ class Vol2ColFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int c_in = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; + int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; + int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; + int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2]; int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; @@ -120,9 +128,10 @@ class Col2VolFunctor { const std::vector& strides, const std::vector& paddings, framework::Tensor* vol) const { - PADDLE_ENFORCE(vol->dims().size() == 4); - PADDLE_ENFORCE(col.dims().size() == 7); - + PADDLE_ENFORCE_EQ(vol->dims().size(), 4, + "The dimension of vol should be 4."); + PADDLE_ENFORCE_EQ(col.dims().size(), 7, + "The dimension of col should be 7."); int input_channels = vol->dims()[0]; int input_depth = vol->dims()[1]; int input_height = vol->dims()[2]; @@ -136,21 +145,29 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + + PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back - ((dilations[0] * (filter_depth - 1) + 1))) / strides[0] + 1, output_depth, "input_depth and output_depth are " "mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down - ((dilations[1] * (filter_height - 1) + 1))) / strides[1] + 1, output_height, "input_height and output_height are " "mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right - ((dilations[2] * (filter_width - 1) + 1))) / strides[2] + 1, @@ -166,11 +183,11 @@ class Col2VolFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int cIm = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; + int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; + int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; + int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2]; if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { diff --git a/paddle/fluid/operators/math/vol2col.cu b/paddle/fluid/operators/math/vol2col.cu index 25d8a247bca..a167a9021bc 100644 --- a/paddle/fluid/operators/math/vol2col.cu +++ b/paddle/fluid/operators/math/vol2col.cu @@ -92,27 +92,34 @@ class Vol2ColFunctor { int output_height = col->dims()[5]; int output_width = col->dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back - ((dilations[0] * (filter_depth - 1) + 1))) / strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down - ((dilations[1] * (filter_height - 1) + 1))) / strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right - ((dilations[2] * (filter_width - 1) + 1))) / strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); + "mismatching."); int num_outputs = input_channels * output_depth * output_height * output_width; @@ -122,9 +129,8 @@ class Vol2ColFunctor { vol2col<<>>( num_outputs, vol.data(), input_depth, input_height, input_width, dilations[0], dilations[1], dilations[2], filter_depth, filter_height, - filter_width, strides[0], strides[1], strides[2], paddings[0], - paddings[1], paddings[2], output_depth, output_height, output_width, - col->data()); + filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, + pad_w_left, output_depth, output_height, output_width, col->data()); } }; @@ -218,27 +224,35 @@ class Col2VolFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + + PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back - ((dilations[0] * (filter_depth - 1) + 1))) / strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down - ((dilations[1] * (filter_height - 1) + 1))) / strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right - ((dilations[2] * (filter_width - 1) + 1))) / strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); + "mismatching."); int num_kernels = input_channels * input_depth * input_height * input_width; @@ -248,9 +262,8 @@ class Col2VolFunctor { col2vol<<>>( num_kernels, col.data(), input_depth, input_height, input_width, dilations[0], dilations[1], dilations[2], filter_depth, filter_height, - filter_width, strides[0], strides[1], strides[2], paddings[0], - paddings[1], paddings[2], output_depth, output_height, output_width, - vol->data()); + filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, + pad_w_left, output_depth, output_height, output_width, vol->data()); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 194e007fcb4..3eb9fb505dc 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2259,11 +2259,12 @@ def conv2d(input, bias_attr=None, use_cudnn=True, act=None, - name=None): + name=None, + data_format="NCHW"): """ The convolution2D layer calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. Input and - Output are in NCHW format, where N is batch size, C is the number of + Output are in NCHW or NHWC format, where N is batch size, C is the number of channels, H is the height of the feature, and W is the width of the feature. Filter is in MCHW format, where M is the number of output image channels, C is the number of input image channels, H is the height of the filter, @@ -2284,7 +2285,7 @@ def conv2d(input, Where: - * :math:`X`: Input value, a tensor with NCHW format. + * :math:`X`: Input value, a tensor with NCHW or NHWC format. * :math:`W`: Filter value, a tensor with MCHW format. * :math:`\\ast`: Convolution operation. * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. @@ -2314,7 +2315,7 @@ def conv2d(input, padding mode is 'SAME' and 'VALID' can reference this link`_ Args: - input (Variable): The input image with [N, C, H, W] format. + input (Variable): The input image with [N, C, H, W] or [N, H, W, C] format. num_filters(int): The number of filter. It is as same as the output image channel. filter_size (int|tuple): The filter size. If filter_size @@ -2324,9 +2325,14 @@ def conv2d(input, stride (int|tuple): The stride size. If stride is a tuple, it must contain two integers, (stride_height, stride_width). Otherwise, stride_height = stride_width = stride. Default: stride = 1. - padding (int|tuple): The padding size. If padding is a tuple, it must - contain two integers, (padding_height, padding_width). Otherwise, - padding_height = padding_width = padding. Default: padding = 0. + padding (string|int|list|tuple): The padding size. If `padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If padding size is a tuple or list, + it could be in three forms: `[pad_height, pad_width]` or + `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`, + `padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. dilation (int|tuple): The dilation size. If dilation is a tuple, it must contain two integers, (dilation_height, dilation_width). Otherwise, dilation_height = dilation_width = dilation. Default: dilation = 1. @@ -2350,7 +2356,10 @@ def conv2d(input, act (str): Activation type, if it is set to None, activation is not appended. Default: None name (str|None): A name for this layer(optional). If set None, the layer - will be named automatically. Default: None + will be named automatically. Default: None. + data_format (str): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. Returns: Variable: The tensor variable storing the convolution and \ @@ -2368,8 +2377,23 @@ def conv2d(input, conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu") """ - num_channels = input.shape[1] + if not isinstance(use_cudnn, bool): + raise ValueError("Attr(use_cudnn) should be True or False. Received " + "Attr(use_cudnn): %s. " % str(use_cudnn)) + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + + channel_last = (data_format == "NHWC") + num_channels = input.shape[3] if channel_last else input.shape[1] + if num_channels < 0: + raise ValueError( + "The channel dimmention of the input(%s) should be defined. " + "Received: %s." % (str(input.shape), str(num_channels))) assert param_attr is not False, "param_attr should not be False here." + l_type = 'conv2d' if (num_channels == groups and num_filters % num_channels == 0 and not use_cudnn): @@ -2382,18 +2406,61 @@ def conv2d(input, num_filter_channels = num_channels else: if num_channels % groups != 0: - raise ValueError("num_channels must be divisible by groups.") + raise ValueError( + "The number of input channels must be divisible by Attr(groups). " + "Received: number of channels(%s), groups(%s)." % + (str(num_channels), str(groups))) num_filter_channels = num_channels // groups filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') stride = utils.convert_to_list(stride, 2, 'stride') - padding = utils.convert_to_list(padding, 2, 'padding') dilation = utils.convert_to_list(dilation, 2, 'dilation') - if not isinstance(use_cudnn, bool): - raise ValueError("use_cudnn should be True or False") + # padding + def _update_padding(padding, data_format): + def is_list_or_tuple(ele): + if isinstance(ele, list) or isinstance(ele, tuple): + return True + return False + + if is_list_or_tuple(padding) and len(padding) == 4: + if is_list_or_tuple(padding[0]) and (data_format == "NCHW"): + if not (padding[0] == [0, 0] and padding[1] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[2:4] + padding = [ele for a_list in padding for ele in a_list] + elif is_list_or_tuple(padding[0]) and (data_format == "NHWC"): + if not (padding[0] == [0, 0] and padding[3] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[1:3] + padding = [ele for a_list in padding for ele in a_list] + padding = utils.convert_to_list(padding, 4, 'padding') + else: + padding = utils.convert_to_list(padding, 2, 'padding') + padding = [padding[0], padding[0], padding[1], padding[1]] + + return padding + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." % + str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0, 0, 0] + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0, 0, 0] + + padding = _update_padding(padding, data_format) - input_shape = input.shape filter_shape = [num_filters, int(num_filter_channels)] + filter_size def _get_default_param_initializer(): @@ -2423,7 +2490,9 @@ def conv2d(input, 'groups': groups, 'use_cudnn': use_cudnn, 'use_mkldnn': False, - 'fuse_relu_before_depthwise_conv': False + 'fuse_relu_before_depthwise_conv': False, + "padding_algorithm": padding_algorithm, + "data_format": data_format, }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) @@ -2442,13 +2511,14 @@ def conv3d(input, bias_attr=None, use_cudnn=True, act=None, - name=None): + name=None, + data_format="NCDHW"): """ **Convlution3D Layer** The convolution3D layer calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. Input(Input) and - Output(Output) are in NCDHW format. Where N is batch size C is the number of + Output(Output) are in NCDHW or NDHWC format. Where N is batch size C is the number of channels, D is the depth of the feature, H is the height of the feature, and W is the width of the feature. Convlution3D is similar with Convlution2D but adds one dimension(depth). If bias attribution and activation type are @@ -2463,7 +2533,7 @@ def conv3d(input, In the above equation: - * :math:`X`: Input value, a tensor with NCDHW format. + * :math:`X`: Input value, a tensor with NCDHW or NDHWC format. * :math:`W`: Filter value, a tensor with MCDHW format. * :math:`\\ast`: Convolution operation. * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. @@ -2490,7 +2560,7 @@ def conv3d(input, W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1 Args: - input (Variable): The input image with [N, C, D, H, W] format. + input (Variable): The input image with [N, C, D, H, W] or [N, D, H, W, C]format. num_filters(int): The number of filter. It is as same as the output image channel. filter_size (int|tuple): The filter size. If filter_size is a tuple, @@ -2500,9 +2570,15 @@ def conv3d(input, stride (int|tuple): The stride size. If stride is a tuple, it must contain three integers, (stride_depth, stride_height, stride_width). Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1. - padding (int|tuple): The padding size. If padding is a tuple, it must - contain three integers, (padding_depth, padding_height, padding_width). Otherwise, - padding_depth = padding_height = padding_width = padding. Default: padding = 0. + padding (string|int|list|tuple): The padding size. f `padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NDHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. dilation (int|tuple): The dilation size. If dilation is a tuple, it must contain three integers, (dilation_depth, dilation_height, dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation. Default: dilation = 1. @@ -2527,6 +2603,9 @@ def conv3d(input, Default: None. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. + data_format (str): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. + The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_depth, input_height, input_width]`. Returns: Variable: The tensor variable storing the convolution and \ @@ -2549,22 +2628,85 @@ def conv3d(input, helper = LayerHelper(l_type, **locals()) dtype = helper.input_dtype() - num_channels = input.shape[1] + if not isinstance(use_cudnn, bool): + raise ValueError("Attr(use_cudnn) should be True or False. Received " + "Attr(use_cudnn): %s. " % str(use_cudnn)) + + if data_format not in ["NCDHW", "NDHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + + channel_last = (data_format == "NDHWC") + num_channels = input.shape[4] if channel_last else input.shape[1] + if num_channels < 0: + raise ValueError( + "The channel dimmention of the input(%s) should be defined. " + "Received: %s." % (str(input.shape), str(num_channels))) if groups is None: num_filter_channels = num_channels else: if num_channels % groups != 0: - raise ValueError("num_channels must be divisible by groups.") + raise ValueError( + "The number of input channels must be divisible by Attr(groups). " + "Received: number of channels(%s), groups(%s)." % + (str(num_channels), str(groups))) num_filter_channels = num_channels // groups filter_size = utils.convert_to_list(filter_size, 3, 'filter_size') stride = utils.convert_to_list(stride, 3, 'stride') - padding = utils.convert_to_list(padding, 3, 'padding') dilation = utils.convert_to_list(dilation, 3, 'dilation') - if not isinstance(use_cudnn, bool): - raise ValueError("use_cudnn should be True or False") + def _update_padding(padding, data_format): + def is_list_or_tuple(ele): + if isinstance(ele, list) or isinstance(ele, tuple): + return True + return False + + if is_list_or_tuple(padding) and len(padding) == 5: + if is_list_or_tuple(padding[0]) and (data_format == "NCDHW"): + if not (padding[0] == [0, 0] and padding[1] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[2:5] + padding = [ele for a_list in padding for ele in a_list] + elif is_list_or_tuple(padding[0]) and (data_format == "NDHWC"): + if not (padding[0] == [0, 0] and padding[4] == [0, 0]): + raise ValueError( + "Non-zero padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[1:4] + padding = [ele for a_list in padding for ele in a_list] + padding = utils.convert_to_list(padding, 6, 'padding') + + elif is_list_or_tuple(padding) and len(padding) == 6: + padding = utils.convert_to_list(padding, 6, 'padding') + else: + padding = utils.convert_to_list(padding, 3, 'padding') + padding = [ + padding[0], padding[0], padding[1], padding[1], padding[2], + padding[2] + ] + + return padding + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown padding: '%s'. It can only be 'SAME' or 'VALID'." % + str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0, 0, 0, 0, 0] + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0, 0, 0, 0, 0] + + padding = _update_padding(padding, data_format) input_shape = input.shape filter_shape = [num_filters, num_filter_channels] + filter_size @@ -2596,7 +2738,9 @@ def conv3d(input, 'dilations': dilation, 'groups': groups, 'use_cudnn': use_cudnn, - 'use_mkldnn': False + 'use_mkldnn': False, + "padding_algorithm": padding_algorithm, + "data_format": data_format, }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 725953b67df..c9dd714f2a5 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -19,29 +19,87 @@ import numpy as np import paddle.fluid.core as core from op_test import OpTest +import paddle.fluid as fluid -def conv2d_forward_naive(input, filter, group, conv_param): +def conv2d_forward_naive(input, + filter, + group, + conv_param, + padding_algorithm='EXPLICIT', + data_format='NCHW'): + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError("Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % + str(padding_algorithm)) + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Unknown Attr(data_format): '%s' ." + "It can only be 'NCHW' or 'NHWC'." % str(data_format)) + + channel_last = (data_format == "NHWC") + if channel_last: + input = np.transpose(input, [0, 3, 1, 2]) + in_n, in_c, in_h, in_w = input.shape - out_c, f_c, f_h, f_w = filter.shape + f_n, f_c, f_h, f_w = filter.shape + out_n = in_n + out_c = f_n assert f_c * group == in_c assert np.mod(out_c, group) == 0 sub_out_c = out_c // group + sub_f_n = f_n // group stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ 'dilation'] - out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0] - out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1] - out = np.zeros((in_n, out_c, out_h, out_w)) + + # update pad and dilation + def _get_padding_with_SAME(input_shape, pool_size, pool_stride): + padding = [] + for input_size, filter_size, stride_size in zip(input_shape, pool_size, + pool_stride): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max(( + (out_size - 1) * stride_size + filter_size - input_size, 0)) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + ksize = filter.shape[2:4] + if padding_algorithm == "VALID": + pad = [0, 0, 0, 0] + elif padding_algorithm == "SAME": + dilation = [1, 1] + input_data_shape = [] + if data_format == "NCHW": + input_data_shape = input.shape[2:4] + elif data_format == "NHWC": + input_data_shape = input.shape[1:3] + pad = _get_padding_with_SAME(input_data_shape, ksize, stride) + + pad_h_0, pad_h_1 = pad[0], pad[0] + pad_w_0, pad_w_1 = pad[1], pad[1] + if len(pad) == 4: + pad_h_0, pad_h_1 = pad[0], pad[1] + pad_w_0, pad_w_1 = pad[2], pad[3] + + out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[0] * + (f_h - 1) + 1)) // stride[0] + out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[1] * + (f_w - 1) + 1)) // stride[1] + out = np.zeros((out_n, out_c, out_h, out_w)) d_bolck_h = (dilation[0] * (f_h - 1) + 1) d_bolck_w = (dilation[1] * (f_w - 1) + 1) - input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )), + input_pad = np.pad(input, ((0, 0), (0, 0), (pad_h_0, pad_h_1), + (pad_w_0, pad_w_1)), mode='constant', constant_values=0) - filter_dilation = np.zeros((out_c, f_c, d_bolck_h, d_bolck_w)) + filter_dilation = np.zeros((f_n, f_c, d_bolck_h, d_bolck_w)) filter_dilation[:, :, 0:d_bolck_h:dilation[0], 0:d_bolck_w:dilation[ 1]] = filter @@ -53,16 +111,156 @@ def conv2d_forward_naive(input, filter, group, conv_param): i * stride[0]:i * stride[0] + d_bolck_h, j * stride[1]:j * stride[1] + d_bolck_w] - f_sub = filter_dilation[g * sub_out_c:(g + 1) * - sub_out_c, :, :, :] + f_sub = filter_dilation[g * sub_f_n:(g + 1) * sub_f_n, :, :, :] + # sub_f_n == sub_out_c for k in range(sub_out_c): + # Multiplication of Corresponding Elements, then sum all out[:, g * sub_out_c + k, i, j] = \ np.sum(input_pad_masked * f_sub[k, :, :, :], axis=(1, 2, 3)) + if channel_last: + out = np.transpose(out, [0, 2, 3, 1]) + return out, in_n, out_h, out_w, out_c +def create_test_cudnn_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestCUDNNCase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + cls_name = "{0}_{1}".format(parent.__name__, "CUDNN") + TestCUDNNCase.__name__ = cls_name + globals()[cls_name] = TestCUDNNCase + + +def create_test_cudnn_fp16_class(parent, grad_check=True): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestConv2DCUDNNFp16(parent): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) + + def test_check_grad_no_filter(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place) and grad_check: + self.check_grad_with_place( + place, ['Input'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place) and grad_check: + self.check_grad_with_place( + place, ['Filter'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Input'])) + + cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16") + TestConv2DCUDNNFp16.__name__ = cls_name + globals()[cls_name] = TestConv2DCUDNNFp16 + + +def create_test_channel_last_class(parent): + class TestChannelLastCase(parent): + def init_data_format(self): + self.data_format = "NHWC" + + def init_test_case_2(self): + N, C, H, W = self.input_size + self.input_size = [N, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast") + TestChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestChannelLastCase + + +def create_test_cudnn_channel_last_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestCudnnChannelLastCase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + def init_data_format(self): + self.data_format = "NHWC" + + def init_test_case_2(self): + N, C, H, W = self.input_size + self.input_size = [N, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLast") + TestCudnnChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestCudnnChannelLastCase + + +def create_test_padding_SAME_class(parent): + class TestPaddingSMAECase(parent): + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") + TestPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSMAECase + + +def create_test_padding_VALID_class(parent): + class TestPaddingVALIDCase(parent): + def init_paddings(self): + self.pad = [1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp") + TestPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestPaddingVALIDCase + + +def create_test_cudnn_padding_SAME_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestCUDNNPaddingSMAECase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + def init_paddings(self): + self.pad = [1, 1] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingSAMEOp") + TestCUDNNPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestCUDNNPaddingSMAECase + + +def create_test_cudnn_padding_VALID_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestCUDNNPaddingVALIDCase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + def init_paddings(self): + self.pad = [1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingVALIDOp") + TestCUDNNPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestCUDNNPaddingVALIDCase + + class TestConv2dOp(OpTest): def setUp(self): self.op_type = "conv2d" @@ -95,6 +293,7 @@ class TestConv2dOp(OpTest): else: input2 = input filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups, conv2d_param) output = output.astype(self.dtype) @@ -160,6 +359,9 @@ class TestConv2dOp(OpTest): f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] + def init_test_case_2(self): + pass + def init_dilation(self): self.dilations = [1, 1] @@ -281,19 +483,6 @@ class TestWithInput1x1Filter1x1(TestConv2dOp): #----------------Conv2dCUDNN---------------- - -def create_test_cudnn_class(parent): - @unittest.skipIf(not core.is_compiled_with_cuda(), - "core is not compiled with CUDA") - class TestCUDNNCase(parent): - def init_kernel_type(self): - self.use_cudnn = True - - cls_name = "{0}_{1}".format(parent.__name__, "CUDNN") - TestCUDNNCase.__name__ = cls_name - globals()[cls_name] = TestCUDNNCase - - create_test_cudnn_class(TestConv2dOp) create_test_cudnn_class(TestWithPad) create_test_cudnn_class(TestWithStride) @@ -301,45 +490,7 @@ create_test_cudnn_class(TestWithGroup) create_test_cudnn_class(TestWith1x1) create_test_cudnn_class(TestWithInput1x1Filter1x1) -#----------------Conv2dCUDNN---------------- - - -def create_test_cudnn_fp16_class(parent, grad_check=True): - @unittest.skipIf(not core.is_compiled_with_cuda(), - "core is not compiled with CUDA") - class TestConv2DCUDNNFp16(parent): - def init_kernel_type(self): - self.use_cudnn = True - self.dtype = np.float16 - - def test_check_output(self): - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-2) - - def test_check_grad_no_filter(self): - place = core.CUDAPlace(0) - if core.is_float16_supported(place) and grad_check: - self.check_grad_with_place( - place, ['Input'], - 'Output', - max_relative_error=0.02, - no_grad_set=set(['Filter'])) - - def test_check_grad_no_input(self): - place = core.CUDAPlace(0) - if core.is_float16_supported(place) and grad_check: - self.check_grad_with_place( - place, ['Filter'], - 'Output', - max_relative_error=0.02, - no_grad_set=set(['Input'])) - - cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16") - TestConv2DCUDNNFp16.__name__ = cls_name - globals()[cls_name] = TestConv2DCUDNNFp16 - +#----------------Conv2dCUDNN fp16---------------- create_test_cudnn_fp16_class(TestConv2dOp, grad_check=False) create_test_cudnn_fp16_class(TestWithPad, grad_check=False) @@ -348,7 +499,7 @@ create_test_cudnn_fp16_class(TestWithGroup, grad_check=False) create_test_cudnn_fp16_class(TestWith1x1, grad_check=False) create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False) -# -------TestDepthwiseConv +#----------------TestDepthwiseConv ----- class TestDepthwiseConv(TestConv2dOp): @@ -502,5 +653,704 @@ class TestCUDNNExhaustiveSearch(TestConv2dOp): # def init_op_type(self): # self.op_type = "conv_cudnn" +# ---- test asymmetric padding ---- + + +class TestConv2dOp_v2(OpTest): + def setUp(self): + self.op_type = "conv2d" + self.use_cudnn = False + self.exhaustive_search = False + self.use_cuda = False + self.use_mkldnn = False + self.fuse_relu_before_depthwise_conv = False + self.dtype = np.float32 + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_data_format() + self.init_test_case() + + self.init_paddings() + self.init_test_case_2() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + if not self.has_cuda(): + self.fuse_relu_before_depthwise_conv = False + if self.fuse_relu_before_depthwise_conv: + input = input - 0.5 + input -= (input < 0) * 0.1 + input += (input >= 0) * 0.1 + input2 = np.maximum(input, 0.0) + else: + input2 = input + filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) + output, _, _, _, _ = conv2d_forward_naive( + input2, filter, self.groups, conv2d_param, self.padding_algorithm, + self.data_format) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format, + 'fuse_relu_before_depthwise_conv': + self.fuse_relu_before_depthwise_conv, + 'exhaustive_search': self.exhaustive_search + } + self.outputs = {'Output': output} + + def has_cuda(self): + return core.is_compiled_with_cuda() and (self.use_cudnn or + self.use_cuda) + + def test_check_output(self): + place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace() + self.check_output_with_place(place, atol=1e-5) + + def test_check_grad(self): + if self.dtype == np.float16: + return + place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace() + self.check_grad_with_place( + place, {'Input', 'Filter'}, 'Output', max_relative_error=0.02) + + def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return + place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace() + self.check_grad_with_place( + place, ['Input'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + if self.dtype == np.float16: + return + place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace() + self.check_grad_with_place( + place, ['Filter'], + 'Output', + max_relative_error=0.02, + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + def init_paddings(self): + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_data_format(self): + self.data_format = "NCHW" + + def init_test_case_2(self): + pass + + +class TestConv2dOp_AsyPadding(TestConv2dOp_v2): + def init_paddings(self): + self.pad = [0, 0, 1, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithPad_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [2, 1, 3, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithStride_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_paddings(self): + self.pad = [2, 1, 3, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithGroup_AsyPadding(TestConv2dOp_v2): + def init_group(self): + self.groups = 3 + + +class TestWith1x1_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 1, 1] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [2, 2, 4, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise3x3_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [3, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + def init_group(self): + self.groups = 4 + + def init_paddings(self): + self.pad = [1, 3, 2, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise5x5_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 4, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 5, 5] + + def init_group(self): + self.groups = 4 + + def init_paddings(self): + self.pad = [0, 1, 1, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDepthWise7x7_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [2, 2] + self.input_size = [2, 8, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [16, f_c, 7, 7] + + def init_group(self): + self.groups = 8 + + def init_paddings(self): + self.pad = [1, 3, 4, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDilation_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [0, 1, 3, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestWithInput1x1Filter1x1_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [2, 3, 1, 1] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 1, 1] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [0, 3, 4, 0] + self.padding_algorithm = "EXPLICIT" + + +create_test_cudnn_class(TestConv2dOp_AsyPadding) +create_test_cudnn_class(TestWithPad_AsyPadding) +create_test_cudnn_class(TestWithStride_AsyPadding) +create_test_cudnn_class(TestWithGroup_AsyPadding) +create_test_cudnn_class(TestWith1x1_AsyPadding) +create_test_cudnn_class(TestWithInput1x1Filter1x1_AsyPadding) + + +class TestDepthwiseConv_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.use_cuda = True + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 1, 0, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv2_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.use_cuda = True + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [0, 1, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv3_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.use_cuda = True + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 1, 0, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConvWithDilation_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.use_cuda = True + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + self.dilations = [2, 2] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 1, 2, 1] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConvWithDilation2_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.use_cuda = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + self.dilations = [2, 2] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [0, 1, 1, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConvandFuse_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.use_cuda = True + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [2, 1, 2, 3] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv2andFuse_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.use_cuda = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 1, 1, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConv3andFuse_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.use_cuda = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 2, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConvWithDilationandFuse_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.use_cuda = True + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + self.dilations = [2, 2] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [2, 1, 1, 0] + self.padding_algorithm = "EXPLICIT" + + +class TestDepthwiseConvWithDilation2andFuse_AsyPadding(TestConv2dOp_v2): + def init_test_case(self): + self.fuse_relu_before_depthwise_conv = True + self.use_cuda = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + self.dilations = [2, 2] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + def init_paddings(self): + self.pad = [1, 3, 1, 3] + self.padding_algorithm = "EXPLICIT" + + +#---------- test SAME VALID ----------- +create_test_padding_SAME_class(TestConv2dOp_AsyPadding) +create_test_padding_SAME_class(TestWithPad_AsyPadding) +create_test_padding_SAME_class(TestWithStride_AsyPadding) +create_test_padding_SAME_class(TestWithGroup_AsyPadding) +create_test_padding_SAME_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_padding_VALID_class(TestConv2dOp_AsyPadding) +create_test_padding_VALID_class(TestWithPad_AsyPadding) +create_test_padding_VALID_class(TestWithStride_AsyPadding) +create_test_padding_VALID_class(TestWithGroup_AsyPadding) +create_test_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_cudnn_padding_SAME_class(TestConv2dOp_AsyPadding) +create_test_cudnn_padding_SAME_class(TestWithPad_AsyPadding) +create_test_cudnn_padding_SAME_class(TestWithStride_AsyPadding) +create_test_cudnn_padding_SAME_class(TestWithGroup_AsyPadding) +create_test_cudnn_padding_SAME_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_cudnn_padding_VALID_class(TestConv2dOp_AsyPadding) +create_test_cudnn_padding_VALID_class(TestWithPad_AsyPadding) +create_test_cudnn_padding_VALID_class(TestWithStride_AsyPadding) +create_test_cudnn_padding_VALID_class(TestWithGroup_AsyPadding) +create_test_cudnn_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding) + +# depthwise conv2d + +create_test_padding_SAME_class(TestDepthwiseConv_AsyPadding) +create_test_padding_SAME_class(TestDepthwiseConvWithDilation_AsyPadding) +create_test_padding_SAME_class(TestDepthwiseConvandFuse_AsyPadding) +create_test_padding_SAME_class(TestDepthwiseConvWithDilationandFuse_AsyPadding) + +create_test_padding_VALID_class(TestDepthwiseConv_AsyPadding) +create_test_padding_VALID_class(TestDepthwiseConvWithDilation_AsyPadding) +create_test_padding_VALID_class(TestDepthwiseConvandFuse_AsyPadding) +create_test_padding_VALID_class(TestDepthwiseConvWithDilationandFuse_AsyPadding) + +# ------------ test channel last --------- +create_test_channel_last_class(TestConv2dOp_AsyPadding) +create_test_channel_last_class(TestWithPad_AsyPadding) +create_test_channel_last_class(TestWithGroup_AsyPadding) +create_test_channel_last_class(TestWith1x1_AsyPadding) +create_test_channel_last_class(TestWithInput1x1Filter1x1_AsyPadding) + +create_test_channel_last_class(TestDepthwiseConv_AsyPadding) +create_test_channel_last_class(TestDepthwiseConvWithDilation2_AsyPadding) +create_test_channel_last_class(TestDepthwiseConvandFuse_AsyPadding) +create_test_channel_last_class(TestDepthwiseConvWithDilationandFuse_AsyPadding) + +create_test_cudnn_channel_last_class(TestConv2dOp_AsyPadding) +create_test_cudnn_channel_last_class(TestWithPad_AsyPadding) +create_test_cudnn_channel_last_class(TestWithStride_AsyPadding) +create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding) +create_test_cudnn_channel_last_class(TestWithDilation_AsyPadding) + + +# --------- test python API --------------- +class TestConv2dAPI(OpTest): + def test_api(self): + + input_NHWC = fluid.layers.data( + name="input_NHWC", + shape=[2, 5, 5, 3], + append_batch_size=False, + dtype="float32") + + input_NCHW = fluid.layers.data( + name="input_NCHW", + shape=[2, 3, 5, 5], + append_batch_size=False, + dtype="float32") + + fluid.layers.conv2d( + input=input_NHWC, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=0, + dilation=[1, 1], + groups=1, + data_format="NCHW") + + fluid.layers.conv2d( + input=input_NCHW, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=[1, 2, 1, 0], + dilation=[1, 1], + groups=1, + data_format="NCHW") + + fluid.layers.conv2d( + input=input_NCHW, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=[[0, 0], [0, 0], [1, 1], [1, 1]], + dilation=[1, 1], + groups=1, + data_format="NCHW") + + fluid.layers.conv2d( + input=input_NHWC, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=[[0, 0], [1, 1], [1, 1], [0, 0]], + dilation=[1, 1], + groups=1, + data_format="NHWC") + + fluid.layers.conv2d( + input=input_NCHW, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding="SAME", + dilation=[1, 1], + groups=1, + data_format="NCHW") + + fluid.layers.conv2d( + input=input_NCHW, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding="VALID", + dilation=[1, 1], + groups=1, + data_format="NCHW") + + +class TestConv2dAPI_Error(OpTest): + def test_api(self): + input = fluid.layers.data( + name="input", + shape=[2, 5, 5, 5], + append_batch_size=False, + dtype="float32") + + # ValueError: cudnn + def run_1(): + fluid.layers.conv2d( + input=input, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=0, + dilation=[1, 1], + groups=1, + use_cudnn=[0], + data_format="NCHW") + + self.assertRaises(ValueError, run_1) + + # ValueError: data_format + def run_2(): + fluid.layers.conv2d( + input=input, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=0, + dilation=[1, 1], + groups=1, + use_cudnn=False, + data_format="NCHWC") + + self.assertRaises(ValueError, run_2) + + # ValueError: padding + def run_3(): + fluid.layers.conv2d( + input=input, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding="SAMEE", + dilation=[1, 1], + groups=1, + use_cudnn=False, + data_format="NCHW") + + self.assertRaises(ValueError, run_3) + + def run_4(): + fluid.layers.conv2d( + input=input, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=[[0, 1], [0, 1], [0, 1], [0, 1]], + dilation=[1, 1], + groups=1, + use_cudnn=False, + data_format="NCHW") + + self.assertRaises(ValueError, run_4) + + def run_5(): + fluid.layers.conv2d( + input=input, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=[[0, 1], [0, 1], [0, 1], [0, 1]], + dilation=[1, 1], + groups=1, + use_cudnn=False, + data_format="NHWC") + + self.assertRaises(ValueError, run_5) + + # ValueError: channel dimmention + x = fluid.layers.data( + name="x", + shape=[2, 5, 5, -1], + append_batch_size=False, + dtype="float32") + + def run_6(): + fluid.layers.conv2d( + input=x, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=0, + dilation=[1, 1], + groups=1, + use_cudnn=False, + data_format="NHWC") + + self.assertRaises(ValueError, run_6) + + # ValueError: groups + def run_7(): + fluid.layers.conv2d( + input=input, + num_filters=3, + filter_size=[3, 3], + stride=[1, 1], + padding=0, + dilation=[1, 1], + groups=3, + use_cudnn=False, + data_format="NHWC") + + self.assertRaises(ValueError, run_7) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index aedd85ad9a7..015d3caaa9a 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -19,21 +19,83 @@ import numpy as np import paddle.fluid.core as core from op_test import OpTest +import paddle.fluid as fluid -def conv3d_forward_naive(input, filter, group, conv_param): +def conv3d_forward_naive(input, + filter, + group, + conv_param, + padding_algorithm='EXPLICIT', + data_format="NCDHW"): + + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError("Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % + str(padding_algorithm)) + + if data_format not in ["NCDHW", "NDHWC"]: + raise ValueError("Unknown Attr(data_format): '%s' ." + "It can only be 'NCDHW' or 'NDHWC'." % + str(data_format)) + + channel_last = (data_format == "NDHWC") + if channel_last: + input = np.transpose(input, [0, 4, 1, 2, 3]) + in_n, in_c, in_d, in_h, in_w = input.shape - out_c, f_c, f_d, f_h, f_w = filter.shape + + f_n, f_c, f_d, f_h, f_w = filter.shape + out_n = in_n + out_c = f_n assert f_c * group == in_c assert np.mod(out_c, group) == 0 sub_out_c = out_c // group + sub_f_n = f_n // group stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ 'dilations'] - out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) // stride[0] - out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) // stride[1] - out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) // stride[2] + # update pad and dilation + def _get_padding_with_SAME(input_shape, pool_size, pool_stride): + padding = [] + for input_size, filter_size, stride_size in zip(input_shape, pool_size, + pool_stride): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max(( + (out_size - 1) * stride_size + filter_size - input_size, 0)) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + ksize = filter.shape[2:5] + if padding_algorithm == "VALID": + pad = [0, 0, 0, 0, 0, 0] + elif padding_algorithm == "SAME": + dilation = [1, 1, 1] + input_data_shape = [] + if data_format == "NCDHW": + input_data_shape = input.shape[2:5] + elif data_format == "NDHWC": + input_data_shape = input.shape[1:4] + pad = _get_padding_with_SAME(input_data_shape, ksize, stride) + + pad_d_0, pad_d_1 = pad[0], pad[0] + pad_h_0, pad_h_1 = pad[1], pad[1] + pad_w_0, pad_w_1 = pad[2], pad[2] + if len(pad) == 6: + pad_d_0, pad_d_1 = pad[0], pad[1] + pad_h_0, pad_h_1 = pad[2], pad[3] + pad_w_0, pad_w_1 = pad[4], pad[5] + + out_d = 1 + (in_d + pad_d_0 + pad_d_1 - (dilation[0] * + (f_d - 1) + 1)) // stride[0] + out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[1] * + (f_h - 1) + 1)) // stride[1] + out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[2] * + (f_w - 1) + 1)) // stride[2] out = np.zeros((in_n, out_c, out_d, out_h, out_w)) @@ -41,12 +103,12 @@ def conv3d_forward_naive(input, filter, group, conv_param): d_bolck_h = (dilation[1] * (f_h - 1) + 1) d_bolck_w = (dilation[2] * (f_w - 1) + 1) - input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ), - (pad[2], )), + input_pad = np.pad(input, ((0, 0), (0, 0), (pad_d_0, pad_d_1), + (pad_h_0, pad_h_1), (pad_w_0, pad_w_1)), mode='constant', constant_values=0) - filter_dilation = np.zeros((out_c, f_c, d_bolck_d, d_bolck_h, d_bolck_w)) + filter_dilation = np.zeros((f_n, f_c, d_bolck_d, d_bolck_h, d_bolck_w)) filter_dilation[:, :, 0:d_bolck_d:dilation[0], 0:d_bolck_h:dilation[1], 0: d_bolck_w:dilation[2]] = filter @@ -60,16 +122,114 @@ def conv3d_forward_naive(input, filter, group, conv_param): i * stride[1]:i * stride[1] + d_bolck_h, j * stride[2]:j * stride[2] + d_bolck_w] - f_sub = filter_dilation[g * sub_out_c:(g + 1) * - sub_out_c, :, :, :, :] + f_sub = filter_dilation[g * sub_f_n:(g + 1) * + sub_f_n, :, :, :, :] for k in range(sub_out_c): out[:, g * sub_out_c + k, d, i, j] = \ np.sum(input_pad_masked * f_sub[k, :, :, :, :], axis=(1, 2, 3, 4)) - + if channel_last: + out = np.transpose(out, [0, 2, 3, 4, 1]) return out +def create_test_cudnn_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestCUDNNCase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + cls_name = "{0}_{1}".format(parent.__name__, "CUDNN") + TestCUDNNCase.__name__ = cls_name + globals()[cls_name] = TestCUDNNCase + + +def create_test_padding_SAME_class(parent): + class TestPaddingSMAECase(parent): + def init_paddings(self): + self.pad = [0, 0, 0] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") + TestPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSMAECase + + +def create_test_padding_VALID_class(parent): + class TestPaddingVALIDCase(parent): + def init_paddings(self): + self.pad = [1, 1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp") + TestPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestPaddingVALIDCase + + +def create_test_cudnn_padding_SAME_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestCUDNNPaddingSMAECase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + def init_paddings(self): + self.pad = [1, 1, 1] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingSAMEOp") + TestCUDNNPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestCUDNNPaddingSMAECase + + +def create_test_cudnn_padding_VALID_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestCUDNNPaddingVALIDCase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + def init_paddings(self): + self.pad = [1, 1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "CudnnPaddingVALIDOp") + TestCUDNNPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestCUDNNPaddingVALIDCase + + +def create_test_channel_last_class(parent): + class TestChannelLastCase(parent): + def init_data_format(self): + self.data_format = "NDHWC" + + def init_test_case_2(self): + N, C, D, H, W = self.input_size + self.input_size = [N, D, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast") + TestChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestChannelLastCase + + +def create_test_cudnn_channel_last_class(parent): + class TestCudnnChannelLastCase(parent): + def init_kernel_type(self): + self.use_cudnn = True + + def init_data_format(self): + self.data_format = "NDHWC" + + def init_test_case_2(self): + N, C, D, H, W = self.input_size + self.input_size = [N, D, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLast") + TestCudnnChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestCudnnChannelLastCase + + class TestConv3dOp(OpTest): def setUp(self): self.op_type = "conv3d" @@ -90,8 +250,11 @@ class TestConv3dOp(OpTest): input = np.random.random(self.input_size).astype(self.dtype) filter = np.random.random(self.filter_size).astype(self.dtype) - output = conv3d_forward_naive(input, filter, self.groups, - conv3d_param).astype(self.dtype) + output = conv3d_forward_naive( + input, + filter, + self.groups, + conv3d_param, ).astype(self.dtype) self.inputs = { 'Input': OpTest.np_dtype_to_fluid_dtype(input), @@ -150,6 +313,9 @@ class TestConv3dOp(OpTest): f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3, 3] + def init_test_case_2(self): + pass + def init_dilation(self): self.dilations = [1, 1, 1] @@ -184,7 +350,7 @@ class TestWith1x1(TestConv3dOp): def init_test_case(self): self.pad = [0, 0, 0] self.stride = [1, 1, 1] - self.input_size = [2, 3, 4, 4, 4] # NCHW + self.input_size = [2, 3, 4, 4, 4] assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1, 1] @@ -200,7 +366,7 @@ class TestWithInput1x1Filter1x1(TestConv3dOp): def init_test_case(self): self.pad = [0, 0, 0] self.stride = [1, 1, 1] - self.input_size = [2, 3, 1, 1, 1] # NCHW + self.input_size = [2, 3, 1, 1, 1] assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1, 1] @@ -216,7 +382,7 @@ class TestWithDilation(TestConv3dOp): def init_test_case(self): self.pad = [0, 0, 0] self.stride = [1, 1, 1] - self.input_size = [2, 3, 6, 6, 6] # NCDHW + self.input_size = [2, 3, 6, 6, 6] assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 2, 2, 2] @@ -228,7 +394,9 @@ class TestWithDilation(TestConv3dOp): self.groups = 3 -#----------------Conv3dCUDNN---------------- +#---------------- Conv3dCUDNN ---------------- + + class TestCUDNN(TestConv3dOp): def init_kernel_type(self): self.use_cudnn = True @@ -320,11 +488,435 @@ class TestCUDNNExhaustiveSearch(TestCUDNN): self.exhaustive_search = True +# ---- test asymmetric padding ---- + + +class TestConv3dOp_2(OpTest): + def setUp(self): + self.op_type = "conv3d" + self.use_cudnn = False + self.use_mkldnn = False + self.data_format = "NCDHW" + self.dtype = np.float32 + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_data_format() + self.init_test_case() + self.init_paddings() + + self.init_test_case_2() + + conv3d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilations': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + output = conv3d_forward_naive(input, filter, self.groups, conv3d_param, + self.padding_algorithm, + self.data_format).astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format + } + self.outputs = {'Output': output} + + def has_cudnn(self): + return core.is_compiled_with_cuda() and self.use_cudnn + + def test_check_output(self): + place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() + self.check_output_with_place(place, atol=1e-5) + + def test_check_grad(self): + if self.dtype == np.float16: + return + place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() + self.check_grad_with_place( + place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03) + + def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return + place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() + self.check_grad_with_place( + place, ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + if self.dtype == np.float16: + return + place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace() + self.check_grad_with_place( + place, ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_test_case_2(self): + pass + + def init_dilation(self): + self.dilations = [1, 1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + def init_paddings(self): + self.pad = [0, 0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_data_format(self): + self.data_format = "NCDHW" + + +class TestConv3dOp_AsyPadding(TestConv3dOp_2): + def init_paddings(self): + self.pad = [1, 0, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestCase1_AsyPadding(TestConv3dOp_2): + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_paddings(self): + self.pad = [0, 0, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithGroup1_AsyPadding(TestConv3dOp_2): + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [1, 1, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithGroup2_AsyPadding(TestConv3dOp_2): + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [1, 1, 0, 1, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWith1x1_AsyPadding(TestConv3dOp_2): + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 1, 1, 1] + + def init_dilation(self): + self.dilations = [1, 1, 1] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [0, 0, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + +class TestWithDilation_AsyPadding(TestConv3dOp_2): + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 6, 6, 6] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 2, 2, 2] + + def init_dilation(self): + self.dilations = [2, 2, 2] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [0, 0, 1, 0, 1, 0] + self.padding_algorithm = "EXPLICIT" + + +create_test_cudnn_class(TestConv3dOp_AsyPadding) +create_test_cudnn_class(TestWithGroup1_AsyPadding) +create_test_cudnn_class(TestWithGroup2_AsyPadding) +create_test_cudnn_class(TestWith1x1_AsyPadding) +create_test_cudnn_class(TestWithDilation_AsyPadding) + +create_test_padding_SAME_class(TestConv3dOp_AsyPadding) +create_test_padding_SAME_class(TestWithGroup1_AsyPadding) +create_test_padding_SAME_class(TestWith1x1_AsyPadding) + +create_test_padding_VALID_class(TestConv3dOp_AsyPadding) +create_test_padding_VALID_class(TestWithGroup1_AsyPadding) +create_test_padding_VALID_class(TestWith1x1_AsyPadding) + +create_test_cudnn_padding_SAME_class(TestConv3dOp_AsyPadding) +create_test_cudnn_padding_SAME_class(TestWithGroup1_AsyPadding) +create_test_cudnn_padding_SAME_class(TestWith1x1_AsyPadding) + +create_test_cudnn_padding_VALID_class(TestConv3dOp_AsyPadding) +create_test_cudnn_padding_VALID_class(TestWithGroup1_AsyPadding) +create_test_cudnn_padding_VALID_class(TestWith1x1_AsyPadding) + +create_test_channel_last_class(TestConv3dOp_AsyPadding) +create_test_channel_last_class(TestWithGroup1_AsyPadding) +create_test_channel_last_class(TestWith1x1_AsyPadding) + +create_test_channel_last_class(TestConv3dOp_AsyPadding) +create_test_channel_last_class(TestWithGroup1_AsyPadding) +create_test_channel_last_class(TestWith1x1_AsyPadding) + +create_test_cudnn_channel_last_class(TestConv3dOp_AsyPadding) +create_test_cudnn_channel_last_class(TestWithGroup1_AsyPadding) +create_test_cudnn_channel_last_class(TestWith1x1_AsyPadding) + +create_test_cudnn_channel_last_class(TestConv3dOp_AsyPadding) +create_test_cudnn_channel_last_class(TestWithGroup1_AsyPadding) +create_test_cudnn_channel_last_class(TestWith1x1_AsyPadding) + # FIXME(typhoonzero): find a way to determine if # using cudnn > 6 in python # class TestWithDilationCUDNN(TestWithDilation): # def init_op_type(self): # self.op_type = "conv3d" + +# --------- test python API --------------- +class TestConv3dAPI(OpTest): + def test_api(self): + + input_NDHWC = fluid.layers.data( + name="input_NDHWC", + shape=[2, 5, 5, 5, 3], + append_batch_size=False, + dtype="float32") + + input_NCDHW = fluid.layers.data( + name="input_NCDHW", + shape=[2, 3, 5, 5, 3], + append_batch_size=False, + dtype="float32") + + fluid.layers.conv3d( + input=input_NDHWC, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=0, + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW") + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[1, 2, 1, 0, 1, 0], + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW") + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[[0, 0], [0, 0], [1, 1], [1, 1], [1, 1]], + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW") + + fluid.layers.conv3d( + input=input_NDHWC, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]], + dilation=[1, 1, 1], + groups=1, + data_format="NDHWC") + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding="SAME", + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW") + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding="VALID", + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW") + + +class TestConv3dAPI_Error(OpTest): + def test_api(self): + input = fluid.layers.data( + name="input", + shape=[2, 5, 5, 5, 4], + append_batch_size=False, + dtype="float32") + + # ValueError: cudnn + def run_1(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding=0, + dilation=1, + groups=1, + use_cudnn=[0], + data_format="NCDHW") + + self.assertRaises(ValueError, run_1) + + # ValueError: data_format + def run_2(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=0, + dilation=[1, 1, 1], + groups=1, + use_cudnn=False, + data_format="NCHWC") + + self.assertRaises(ValueError, run_2) + + # ValueError: padding + def run_3(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding="SAMEE", + dilation=1, + groups=1, + use_cudnn=False, + data_format="NCDHW") + + self.assertRaises(ValueError, run_3) + + def run_4(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding=[[0, 1], [0, 0], [0, 1], [0, 1], [0, 1]], + dilation=1, + groups=1, + use_cudnn=False, + data_format="NCDHW") + + self.assertRaises(ValueError, run_4) + + def run_5(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=0, + stride=0, + padding=[[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]], + dilation=1, + groups=1, + use_cudnn=False, + data_format="NDHWC") + + self.assertRaises(ValueError, run_5) + + # ValueError: channel dimmention + x = fluid.layers.data( + name="x", + shape=[2, 5, 5, 5, -1], + append_batch_size=False, + dtype="float32") + + def run_6(): + fluid.layers.conv3d( + input=x, + num_filters=3, + filter_size=3, + stride=1, + padding=0, + dilation=1, + groups=1, + use_cudnn=False, + data_format="NDHWC") + + self.assertRaises(ValueError, run_6) + + # ValueError: groups + def run_7(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding=0, + dilation=1, + groups=3, + use_cudnn=False, + data_format="NDHWC") + + self.assertRaises(ValueError, run_7) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_conv_nn_grad.py b/python/paddle/fluid/tests/unittests/test_conv_nn_grad.py index 81f902d529e..c953841be02 100644 --- a/python/paddle/fluid/tests/unittests/test_conv_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_conv_nn_grad.py @@ -28,11 +28,38 @@ from decorator_helper import prog_scope class TestConvDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place): - shape = [2, 4, 7, 8] + shape = [2, 4, 3, 3] eps = 0.005 dtype = np.float64 x = layers.data('x', shape, False, dtype) - y = layers.conv2d(x, 4, 1, bias_attr=False) + y = layers.conv2d(x, 2, 1, groups=1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + places = [] + + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConvDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 4, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d(x, 2, 1, bias_attr=False) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) w = fluid.default_main_program().global_block().all_parameters() @@ -53,11 +80,11 @@ class TestConvDoubleGradCheck(unittest.TestCase): class TestConvDoubleGradCheckTest1(unittest.TestCase): @prog_scope() def func(self, place): - shape = [2, 3, 4, 5] + shape = [2, 3, 3, 3] eps = 0.005 dtype = np.float64 x = layers.data('x', shape, False, dtype) - y = layers.conv2d(x, 4, 1, padding=1, bias_attr=False) + y = layers.conv2d(x, 2, 1, padding=1, bias_attr=False) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) w = fluid.default_main_program().global_block().all_parameters() @@ -82,7 +109,7 @@ class TestConv3DDoubleGradCheck(unittest.TestCase): eps = 0.005 dtype = np.float64 x = layers.data('x', shape, False, dtype) - y = layers.conv3d(x, 4, 1, bias_attr=False) + y = layers.conv3d(x, 2, 1, bias_attr=False) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) w = fluid.default_main_program().global_block().all_parameters() @@ -107,7 +134,326 @@ class TestConv3DDoubleGradCheckTest1(unittest.TestCase): eps = 0.005 dtype = np.float64 x = layers.data('x', shape, False, dtype) - y = layers.conv3d(x, 4, 1, padding=1, bias_attr=False) + y = layers.conv3d(x, 2, 1, padding=1, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv2DoubleGradCheck_AsyPadding(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 0, 0, 1], + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv2DoubleGradCheck_PaddingSAME(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d( + input=x, + num_filters=2, + filter_size=1, + padding="SAME", + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv2DoubleGradCheck_PaddingVALID(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d( + input=x, + num_filters=2, + filter_size=1, + padding="VALID", + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv2DoubleGradCheck_ChannelLast(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 1], + bias_attr=False, + use_cudnn=True, + groups=1, + data_format="NHWC") + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv2DoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv2d( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 0, 1, 0], + bias_attr=False, + use_cudnn=True, + groups=1, + data_format="NHWC") + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv3DDoubleGradCheck_AsyPadding(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 2, 2, 2] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv3d( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 0, 0, 1, 1, 2], + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv3DoubleGradCheck_PaddingSAME(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 2, 2, 2] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv3d( + input=x, + num_filters=2, + filter_size=1, + padding="SAME", + groups=1, + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv3DoubleGradCheck_PaddingVALID(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 3, 3, 2] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv3d( + input=x, + num_filters=2, + filter_size=1, + padding="VALID", + bias_attr=False, + use_cudnn=True) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv3DDoubleGradCheck_ChannelLast(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 2, 2, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv3d( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 1, 1], + bias_attr=False, + use_cudnn=True, + groups=1, + data_format="NDHWC") + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + w = fluid.default_main_program().global_block().all_parameters() + w_arr = [] + for p in w: + w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype)) + gradient_checker.double_grad_check( + [x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestConv3dDoubleGradCheck_ChannelLast_AsyPadding(unittest.TestCase): + @prog_scope() + def func(self, place): + shape = [2, 2, 2, 2, 3] + eps = 0.005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype) + y = layers.conv3d( + input=x, + num_filters=2, + filter_size=1, + padding=[1, 0, 1, 0, 1, 0], + bias_attr=False, + use_cudnn=True, + groups=1, + data_format="NDHWC") x_arr = np.random.uniform(-1, 1, shape).astype(dtype) w = fluid.default_main_program().global_block().all_parameters() -- GitLab