diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 3d8a6aa23e676411093f775ed516e8ece5647580..ba9c266d133b637fd99f128bbfe42253a2400aaf 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -111,7 +111,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" -"tensor_array_read_write_op" "tensorrt_engine_op") +"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index df2a3e7aa635c9ba41dad85ccb8316823f875639..4c0370d6ec2187fbda47d0724d44538ffa7cb152 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -34,7 +34,7 @@ if (WITH_GPU AND TENSORRT_FOUND) add_subdirectory(tensorrt) endif() -register_operators(EXCLUDES warpctc_op) +register_operators(EXCLUDES warpctc_op conv_fusion_op) # warpctc_cudnn need cudnn 7 above if (WITH_GPU) @@ -43,6 +43,8 @@ if (WITH_GPU) else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() + op_library(conv_fusion_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 3a4086274d8a4bf6725df9f3195cec2446ceae6c..42c2b3a24c116f92f4dd6ad0966dcb963ec702d6 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -43,26 +43,6 @@ using DataLayout = platform::DataLayout; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; -static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache"; -static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache"; -static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache"; - -static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = - static_cast(1024) * 1024 * 1024; - -#if CUDNN_VERSION_MIN(6, 0, 5) -static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; -static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; -static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = - CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; -#else -// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc. -static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7; -static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4; -static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5; -#endif - template class CUDNNConvOpKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/conv_cudnn_op_cache.h b/paddle/fluid/operators/conv_cudnn_op_cache.h index 4b534321f746d5620005743eb8d45b71259156dd..92d394eb3c5aeb84605179cb2b5106f56a13f88e 100644 --- a/paddle/fluid/operators/conv_cudnn_op_cache.h +++ b/paddle/fluid/operators/conv_cudnn_op_cache.h @@ -17,10 +17,31 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { +static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache"; +static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache"; +static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache"; + +static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = + static_cast(1024) * 1024 * 1024; + +#if CUDNN_VERSION_MIN(6, 0, 5) +static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; +static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; +static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = + CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; +#else +// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc. +static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7; +static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4; +static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5; +#endif + template class AlgorithmsCache { public: diff --git a/paddle/fluid/operators/conv_fusion_op.cc b/paddle/fluid/operators/conv_fusion_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9bdedb10e0b1bc2d45c084bbc070875117675b75 --- /dev/null +++ b/paddle/fluid/operators/conv_fusion_op.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/operators/conv_op.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +// This fused conv follows the equation: +// y = act ( alpha1 * conv(x) + alpha2 * z + bias ). +// here, y is Output, +// x is Input, +// z is ResidualData, +// bias is Bias +class Conv2DFusionOpMaker : public Conv2DOpMaker { + protected: + void Apply() override { + AddAttr( + "activation", + "The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' " + "'relux' , 'tanh', 'band_pass'") + .SetDefault("relu"); + } +}; +// TODO(qingqing): add gradient operator for conv2d_fusion + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker, + ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/conv_fusion_op.cu.cc b/paddle/fluid/operators/conv_fusion_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd1041ce0836014dc73fabd4a3896243a943bd38 --- /dev/null +++ b/paddle/fluid/operators/conv_fusion_op.cu.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +DECLARE_uint64(conv_workspace_size_limit); +DECLARE_bool(cudnn_exhaustive_search); + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; +using DataLayout = platform::DataLayout; +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; + +template +class CUDNNConvFusionOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* bias = ctx.Input("Bias"); + PADDLE_ENFORCE(bias, "The bias should not be null."); + auto* residual = ctx.Input("ResidualData"); + auto* output = ctx.Output("Output"); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + const std::string activation = ctx.Attr("activation"); + int groups = ctx.Attr("groups"); + int64_t user_workspace_size = + static_cast(ctx.Attr("workspace_size_MB")); + bool exhaustive_search = + FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + const T* bias_data = bias->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + const T* residual_data = residual ? residual->data() : output_data; + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedTensorDescriptor bias_desc; + ScopedConvolutionDescriptor conv_desc; + ScopedActivationDescriptor act_desc; + DataLayout layout = DataLayout::kNCHW; + if (input->dims().size() == 5) { + layout = DataLayout::kNCDHW; + } + + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( + cudnn_conv_desc, groups)); + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims())); + // Now only support NCHW + std::vector bias_dim = {1, static_cast(output->dims()[1]), 1, 1}; + cudnnTensorDescriptor_t cudnn_bias_desc = + bias_desc.descriptor(layout, bias_dim); + cudnnActivationDescriptor_t cudnn_act_desc = + act_desc.descriptor(activation); + + // ------------------- cudnn conv workspace --------------------- + size_t workspace_size_in_bytes; // final workspace to allocate. + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { + int64_t max_user_size = + std::max(static_cast(FLAGS_conv_workspace_size_limit), + user_workspace_size); + workspace_size_limit = max_user_size * 1024 * 1024; + } + + // ------------------- cudnn conv algorithm --------------------- + cudnnConvolutionFwdAlgo_t algo; + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + cudnn_conv_desc, CUDNN_DEFAULT_MATH)); + + auto x_dims = framework::vectorize(input->dims()); + auto f_dims = framework::vectorize(filter->dims()); + if (activation == "identity") { + // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is + // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. + algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + } else if (!exhaustive_search) { + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + VLOG(3) << "cuDNN forward algo " << algo; + } else { + AlgorithmsCache* algo_cache = nullptr; + if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { + algo_cache = + ctx.scope() + .FindVar(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } else { + algo_cache = + const_cast(ctx.scope()) + .Var(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } + algo = algo_cache->GetAlgorithm( + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { + int returned_algo_count; + std::array + fwd_perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + handle, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, cudnn_output_desc, + output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, + fwd_perf_stat.data(), cudnn_workspace, + workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit); + VLOG(3) << "Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = fwd_perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time + << " " << stat.memory; + } + return fwd_perf_stat[0].algo; + }); + VLOG(3) << "choose algo " << algo; + } + + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_output_desc, algo, &workspace_size_in_bytes)); + PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, + "workspace_size to be allocated exceeds the limit"); + + // ------------------- cudnn conv+bias+act forward -------------------- + ScalingParamType alpha1 = 1.0f; + ScalingParamType alpha2 = residual ? 1.0f : 0.0f; + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward( + handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, algo, cudnn_workspace, + workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data, + cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc, + output_data)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel, + ops::CUDNNConvFusionOpKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 1ac4bef615a546ec97aefd83653c649f187caba0..342525be49e28f1785e25d4daad38c3c81b4774f 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -225,17 +225,9 @@ $$ W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 $$ )DOC"); + Apply(); } -class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map GetInputOutputWithSameType() - const override { - return std::unordered_map{ - {"Input", /*->*/ "Output"}}; - } -}; - void Conv3DOpMaker::Make() { AddInput( "Input", @@ -334,6 +326,7 @@ Example: W_{out}= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1 $$ )DOC"); + Apply(); } void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index ef76106f17218a03d24ebc0eca43dbb0ae935093..e69814001e4da5d10e51ee57c1dbe291338b8b49 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -60,12 +61,27 @@ inline bool IsExpand(const std::vector& filter_dim, // operator implementations can reuse the code. class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override; + void Make() final; + + protected: + virtual void Apply() {} }; class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override; + void Make() final; + + protected: + virtual void Apply() {} +}; + +class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{ + {"Input", /*->*/ "Output"}}; + } }; class ConvOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index f174a7bc4867634df01367cc0091b0ec55849985..682b0c0ff39b71f08fe1a8b0c9c7b7d386b67738 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/operator.h" @@ -81,6 +82,16 @@ enum class PoolingMode { kAverageInclusive, }; +enum ActivationMode { + kNone, // activation identity + kSigmoid, + kRelu, + kRelu6, + kReluX, + kTanh, + kBandPass, +}; + #if CUDNN_VERSION < 6000 #pragma message "CUDNN version under 6.0 is supported at best effort." #pragma message "We strongly encourage you to move to 6.0 and above." @@ -120,6 +131,26 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) { } #endif // CUDNN_VERSION < 6000 +inline ActivationMode StringToActivationMode(const std::string& str) { + if (str == "identity") { + return ActivationMode::kNone; + } else if (str == "sigmoid") { + return ActivationMode::kSigmoid; + } else if (str == "relu") { + return ActivationMode::kRelu; + } else if (str == "relu6") { + return ActivationMode::kRelu6; + } else if (str == "relux") { + return ActivationMode::kReluX; + } else if (str == "tanh") { + return ActivationMode::kTanh; + } else if (str == "bandpass") { + return ActivationMode::kBandPass; + } else { + PADDLE_THROW("Unknown activation string: %s", str); + } +} + template class CudnnDataType; @@ -368,6 +399,58 @@ class ScopedSpatialTransformerDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor); }; +class ScopedActivationDescriptor { + public: + ScopedActivationDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreateActivationDescriptor(&desc_)); + } + ~ScopedActivationDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroyActivationDescriptor(desc_)); + } + + template + inline cudnnActivationDescriptor_t descriptor( + const std::string& act, double value_max = static_cast(0.)) { + double relu_ceiling = 0.0; + ActivationMode activation_mode = StringToActivationMode(act); + cudnnActivationMode_t mode; + switch (activation_mode) { +#if CUDNN_VERSION >= 7100 + case ActivationMode::kNone: + mode = CUDNN_ACTIVATION_IDENTITY; + break; +#endif + case ActivationMode::kRelu6: + relu_ceiling = 6.0; + mode = CUDNN_ACTIVATION_CLIPPED_RELU; + break; + case ActivationMode::kReluX: + relu_ceiling = value_max; + mode = CUDNN_ACTIVATION_CLIPPED_RELU; + break; + case ActivationMode::kRelu: + mode = CUDNN_ACTIVATION_RELU; + break; + case ActivationMode::kSigmoid: + mode = CUDNN_ACTIVATION_SIGMOID; + break; + case ActivationMode::kTanh: + mode = CUDNN_ACTIVATION_TANH; + break; + default: + PADDLE_THROW("unrecognized activation mode: %d .", + static_cast(activation_mode)); + } + CUDNN_ENFORCE(dynload::cudnnSetActivationDescriptor( + desc_, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling)); + return desc_; + } + + private: + cudnnActivationDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor); +}; + inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index db2e28bc911a05a077c786c57c4f0e4c34bd61f7..065b940b9ca6fb7522790d2145d1a93469169461 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -152,14 +152,15 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif #if CUDNN_VERSION >= 7001 -#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ - __macro(cudnnSetConvolutionGroupCount); \ - __macro(cudnnSetConvolutionMathType); \ - __macro(cudnnCreateCTCLossDescriptor); \ - __macro(cudnnDestroyCTCLossDescriptor); \ - __macro(cudnnGetCTCLossDescriptor); \ - __macro(cudnnSetCTCLossDescriptor); \ - __macro(cudnnGetCTCLossWorkspaceSize); \ +#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ + __macro(cudnnSetConvolutionGroupCount); \ + __macro(cudnnSetConvolutionMathType); \ + __macro(cudnnConvolutionBiasActivationForward); \ + __macro(cudnnCreateCTCLossDescriptor); \ + __macro(cudnnDestroyCTCLossDescriptor); \ + __macro(cudnnGetCTCLossDescriptor); \ + __macro(cudnnSetCTCLossDescriptor); \ + __macro(cudnnGetCTCLossWorkspaceSize); \ __macro(cudnnCTCLoss); CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3f2f348166864be9583855fcd1949fd4ac818c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py @@ -0,0 +1,158 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest + +from test_conv2d_op import conv2d_forward_naive + + +class TestConv2dFusionOp(OpTest): + def setUp(self): + self.op_type = "conv2d_fusion" + self.exhaustive_search = False + self.data_format = "AnyLayout" + self.dtype = np.float32 + self.activation = 'relu' + self.add_bias = True + self.add_residual_data = True + + self.init_group() + self.init_dilation() + self.init_test_case() + self.init_bias_residual() + self.init_activation() + self.set_search_method() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + + output = conv2d_forward_naive(input, filter, self.groups, + conv2d_param).astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + + if self.add_residual_data: + residual_data = np.random.random(output.shape).astype(self.dtype) + self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype( + residual_data) + output += residual_data + + if self.add_bias: + bias = np.random.random(self.filter_size[0]).astype(self.dtype) + self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias) + output = output + bias.reshape((1, bias.size, 1, 1)) + + assert self.activation in ['relu', 'identity'] + if self.activation == 'relu': + output = np.maximum(output, 0) + + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'data_format': self.data_format, + 'exhaustive_search': self.exhaustive_search, + 'activation': self.activation + } + self.outputs = {'Output': output} + + def testcuda(self): + return core.is_compiled_with_cuda() + + def test_check_output(self): + if self.testcuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + else: + pass + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + def init_bias_residual(self): + self.add_bias = True + self.add_residual_data = True + + def init_activation(self): + self.activation = 'relu' + + def set_search_method(self): + self.exhaustive_search = False + + +class TestWithoutResidual(TestConv2dFusionOp): + def init_bias_residual(self): + self.add_residual_data = False + + +class TestIdentityActivation(TestConv2dFusionOp): + def init_activation(self): + self.activation = 'identity' + + +class TestWithGroup(TestConv2dFusionOp): + def init_group(self): + self.groups = 3 + + +class TestWithDilation(TestConv2dFusionOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_dilation(self): + self.dilations = [2, 2] + + def init_group(self): + self.groups = 3 + + +class TestCUDNNExhaustiveSearch(TestConv2dFusionOp): + def set_search_method(self): + self.exhaustive_search = True + + +if __name__ == '__main__': + unittest.main()