diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 70d159b4f3549662e080794efad8af943ce1f0bc..59c40a0e5d18b753038f2b9301d1c9494e3901be 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -110,7 +110,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/conv_cudnn_op_cache.h b/paddle/fluid/operators/conv_cudnn_op_cache.h index 92d394eb3c5aeb84605179cb2b5106f56a13f88e..f172431e483f38665251617e6fcfddb4bcc0d9d4 100644 --- a/paddle/fluid/operators/conv_cudnn_op_cache.h +++ b/paddle/fluid/operators/conv_cudnn_op_cache.h @@ -19,6 +19,10 @@ limitations under the License. */ #include #include "paddle/fluid/platform/cudnn_helper.h" +DECLARE_uint64(conv_workspace_size_limit); +DECLARE_bool(cudnn_exhaustive_search); +DECLARE_int64(cudnn_exhaustive_search_times); + namespace paddle { namespace operators { @@ -45,6 +49,7 @@ static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5; template class AlgorithmsCache { public: + AlgorithmsCache() : search_times_(0) { hash_.clear(); } // Caches the best algorithm for a given // combination of tensor dimensions & compute data type. TAlgorithm GetAlgorithm( @@ -54,9 +59,14 @@ class AlgorithmsCache { int algorithmFlags, // can set for different data type std::function gen_func); + TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags, + std::function gen_func); + private: std::unordered_map hash_; std::mutex mutex_; + + int search_times_; }; template @@ -107,5 +117,29 @@ TAlgorithm AlgorithmsCache::GetAlgorithm( return hash_[seed]; } +template +TAlgorithm AlgorithmsCache::GetAlgorithm( + int64_t area, int search_times, int algorithmFlags, + std::function gen_func) { + if (hash_.find(area) != hash_.end()) { + return hash_[area]; + } + if (search_times_ < search_times) { + auto algo = gen_func(); + hash_[area] = algo; + ++search_times_; + return algo; + } + TAlgorithm algo; + int64_t min = static_cast(INT_MAX); + for (const auto& m : hash_) { + if (m.first < min) { + min = m.first; + algo = m.second; + } + } + return algo; +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/conv_fusion_op.cc b/paddle/fluid/operators/conv_fusion_op.cc index 9bdedb10e0b1bc2d45c084bbc070875117675b75..23b8087e781da30ed7b66ba651f8071ecb7aaf50 100644 --- a/paddle/fluid/operators/conv_fusion_op.cc +++ b/paddle/fluid/operators/conv_fusion_op.cc @@ -28,6 +28,8 @@ namespace operators { // x is Input, // z is ResidualData, // bias is Bias +// When `split_channels` is set, y will be splitted into multiple outputs, +// each output has split_channels[i] number of channels. class Conv2DFusionOpMaker : public Conv2DOpMaker { protected: void Apply() override { @@ -36,8 +38,65 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker { "The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' " "'relux' , 'tanh', 'band_pass'") .SetDefault("relu"); + AddAttr>( + "split_channels", + "When `split_channels` are set, there will be multiple outputs, the " + "output size is equal to the number of `split_channels`.") + .SetDefault({}); + AddOutput("Outputs", + "This Outputs is used when setting `split_channels`." + "Usually used to fuse conv with same input and same filter size, " + "padding, stride, dilation size.") + .AsDuplicable() + .AsDispensable(); + AddInput("AlgoCache", + "The cache of convolution algorithm, a RAW type variable.") + .AsDispensable(); + AddAttr( + "search_times", + "The number of exhaustive search times for convolution algorithm.") + .SetDefault(-1); } }; + +class Conv2DFusionOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of ConvOp should not be null."); + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector dilations = + ctx->Attrs().Get>("dilations"); + + std::vector oshape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + oshape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], strides[i])); + } + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of ConvOp should not be null."); + ctx->SetOutputDim("Output", framework::make_ddim(oshape)); + std::vector channels = + ctx->Attrs().Get>("split_channels"); + if (channels.size()) { + PADDLE_ENFORCE(ctx->HasOutputs("Outputs"), + "Output(Outputs) of ConvOp should not be null."); + std::vector oshapes; + oshapes.reserve(channels.size()); + for (size_t i = 0; i < channels.size(); ++i) { + oshapes.push_back({oshape[0], channels[i], oshape[2], oshape[3]}); + } + ctx->SetOutputsDim("Outputs", oshapes); + } + } +}; + // TODO(qingqing): add gradient operator for conv2d_fusion } // namespace operators @@ -45,4 +104,5 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker, - ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker); + ops::Conv2DFusionOpInferShape, ops::ConvOpInferVarType, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/conv_fusion_op.cu.cc b/paddle/fluid/operators/conv_fusion_op.cu.cc index e73762f5fb2386633212c5aa9fc768153cf63f85..d8b997cca613f660046106512fc03bf55f9b992d 100644 --- a/paddle/fluid/operators/conv_fusion_op.cu.cc +++ b/paddle/fluid/operators/conv_fusion_op.cu.cc @@ -16,8 +16,9 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_helper.h" -DECLARE_uint64(conv_workspace_size_limit); -DECLARE_bool(cudnn_exhaustive_search); +DEFINE_int64(cudnn_exhaustive_search_times, -1, + "Exhaustive search times for cuDNN convolution, " + "defalut is 1, only search once."); namespace paddle { namespace operators { @@ -117,41 +118,60 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { workspace_size_limit, &algo)); VLOG(3) << "cuDNN forward algo " << algo; } else { + auto search_func = [&]() { + int returned_algo_count; + std::array + fwd_perf_stat; + auto cudnn_find_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + handle, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, cudnn_output_desc, output_data, + kNUM_CUDNN_FWD_ALGS, &returned_algo_count, + fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit); + VLOG(3) << "Perf result: (algo: stat, time, memory)"; + for (int i = 0; i < returned_algo_count; ++i) { + const auto& stat = fwd_perf_stat[i]; + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " " + << stat.memory; + } + return fwd_perf_stat[0].algo; + }; AlgorithmsCache* algo_cache = nullptr; - if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { + int search_times = ctx.Attr("search_times"); + search_times = std::max( + static_cast(FLAGS_cudnn_exhaustive_search_times), search_times); + if (search_times > 0) { + // The searched algo will be cached by `search_times` times for + // different input dimension. For other dimensions, select the algo + // of closest area. + auto var_name = ctx.Inputs("AlgoCache")[0]; algo_cache = ctx.scope() - .FindVar(kCUDNNFwdAlgoCache) + .FindVar(var_name) ->GetMutable>(); + algo = algo_cache->GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0, + search_func); } else { - algo_cache = - const_cast(ctx.scope()) - .Var(kCUDNNFwdAlgoCache) - ->GetMutable>(); + // Cache searched algo in Var(kCUDNNFwdAlgoCache). + // all conv ops use the same kCUDNNFwdAlgoCache variable. + if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { + algo_cache = + ctx.scope() + .FindVar(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } else { + // TODO(qingqing) remove const_cast + algo_cache = + const_cast(ctx.scope().parent()) + ->Var(kCUDNNFwdAlgoCache) + ->GetMutable>(); + } + algo = algo_cache->GetAlgorithm(x_dims, f_dims, strides, paddings, + dilations, 0, search_func); } - algo = algo_cache->GetAlgorithm( - x_dims, f_dims, strides, paddings, dilations, 0, [&]() { - int returned_algo_count; - std::array - fwd_perf_stat; - auto cudnn_find_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( - handle, cudnn_input_desc, input_data, cudnn_filter_desc, - filter_data, cudnn_conv_desc, cudnn_output_desc, - output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, - fwd_perf_stat.data(), cudnn_workspace, - workspace_size_limit)); - }; - workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit); - VLOG(3) << "Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = fwd_perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - return fwd_perf_stat[0].algo; - }); VLOG(3) << "choose algo " << algo; } @@ -195,6 +215,27 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { }; workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } + std::vector channels = ctx.Attr>("split_channels"); + if (channels.size()) { + auto outs = ctx.MultiOutput("Outputs"); + if (x_dims[0] == 1) { + // share data with Output + framework::Tensor t; + t.ShareDataWith(*output); + auto y_dims = output->dims(); + t.Resize({y_dims[1], y_dims[2], y_dims[3]}); + int s = 0; + for (size_t i = 0; i < channels.size(); ++i) { + int e = s + channels[i]; + outs[i]->ShareDataWith(t.Slice(s, e)); + outs[i]->Resize({x_dims[0], channels[i], y_dims[2], y_dims[3]}); + s = e; + } + } else { + // TODO(qingiqng): do copy when batch size large than 1 + PADDLE_THROW("Batch size greater than 1 is Unsupported"); + } + } } }; #endif diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index a0397acab1267365b8aeba30a63152b61b5b25bb..2bddba7db2f1c1a4bf7a207d361d900ec625807f 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -1,6 +1,8 @@ include(operators) -register_operators(EXCLUDES fusion_transpose_flatten_concat_op) +register_operators(EXCLUDES fusion_transpose_flatten_concat_op fusion_conv_inception_op) if (WITH_GPU) op_library(fusion_transpose_flatten_concat_op) + op_library(fusion_conv_inception_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n") + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_inception_fusion);\n") endif() diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4690bd766d0b8a4b7a249fb5ccad5f278d1830f5 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc @@ -0,0 +1,110 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +class ConvInceptionFusionOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + // 1 x + auto in_dims = ctx->GetInputDim("Input"); + // 4 filters + auto w_dims = ctx->GetInputsDim("Filter"); + + PADDLE_ENFORCE(in_dims.size(), 4, "Conv intput should be 4-D tensor."); + PADDLE_ENFORCE_EQ(w_dims.size(), 4, "There should be 4 filters"); + PADDLE_ENFORCE_EQ(w_dims[0][1], in_dims[1]); + PADDLE_ENFORCE_EQ(w_dims[1][1], in_dims[1]); + + int n = in_dims[0]; + // compute output channel + // 1st channel + int c = w_dims[0][0]; + // add 2nd channel + c += (w_dims[1][0] - w_dims[2][1] * 2); + // add 3rd channel + c += (w_dims[2][0] - w_dims[3][1]); + // add 4-th channel + c += w_dims[3][0]; + + int h = in_dims[2]; + int w = in_dims[3]; + + ctx->SetOutputDim("Output", {n, c, h, w}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input("Input")->type(), ctx.device_context()); + } +}; + +class ConvInceptionFusionOpMaker : public framework::OpProtoAndCheckerMaker { + protected: + void Make() override { + AddInput("Input", "(Tensor) NCHW layout."); + AddInput("Filter", "(vector) 4 aggregated filters").AsDuplicable(); + AddInput("Bias", "(vector) it's lenght is equal to Filter") + .AsDuplicable(); + AddOutput("Output", + "(Tensor) The output tensor of convolution operator. " + "The format of output tensor is also NCHW."); + AddOutput("TempOutput", "").AsDuplicable(); + AddAttr( + "pooling_type", + "(string), pooling type, can be \"max\" for max-pooling " + "and \"avg\" for average-pooling.") + .InEnum({"max", "avg"}); + AddAttr( + "exclusive", + "(bool, default True) When true, will exclude the zero-padding in the " + "averaging calculating, otherwise, include the zero-padding. Note, it " + "is only used when pooling_type is avg. The defalut is True.") + .SetDefault(true); + AddAttr( + "activation", + "The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' " + "'relux' , 'tanh', 'band_pass'") + .SetDefault("relu"); + AddAttr("workspace_size_MB", + "Only used in cudnn kernel. Need set use_cudnn to true." + "workspace size for cudnn, in MB, " + "workspace is a section of GPU memory which will be " + "allocated/freed each time the operator runs, larger " + "workspace size can increase performance but also requires " + "better hardware. This size should be chosen carefully.") + .SetDefault(4096); + AddComment(R"DOC( +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(conv2d_inception_fusion, ops::ConvInceptionFusionOp, + ops::ConvInceptionFusionOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cu b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3349b0b31ebf6e266820b077011f4f4d11974e09 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu @@ -0,0 +1,272 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +DECLARE_uint64(conv_workspace_size_limit); + +namespace paddle { +namespace operators { + +#if CUDNN_VERSION >= 7001 +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; +using DataLayout = platform::DataLayout; + +using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor; +using PoolingMode = platform::PoolingMode; +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; + +template +using CudnnDataType = platform::CudnnDataType; + +template +class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + auto* input = ctx.Input("Input"); + auto filters = ctx.MultiInput("Filter"); + auto bias = ctx.MultiInput("Bias"); + + auto* output = ctx.Output("Output"); + auto temp_outs = ctx.MultiOutput("TempOutput"); + + const std::string pool_type = ctx.Attr("pooling_type"); + const std::string activation = ctx.Attr("activation"); + const bool exclusive = ctx.Attr("exclusive"); + + int64_t user_workspace_size = + static_cast(ctx.Attr("workspace_size_MB")); + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + T* temp_data = temp_outs[0]->mutable_data(input->dims(), ctx.GetPlace()); + + DataLayout layout = DataLayout::kNCHW; + std::vector in_dim = framework::vectorize2int(input->dims()); + + // ------------------- cudnn descriptors --------------------- + PoolingMode pooling_mode; + if (pool_type == "max") { + pooling_mode = PoolingMode::kMaximum; + } else { + pooling_mode = exclusive ? PoolingMode::kAverageExclusive + : (PoolingMode::kAverageInclusive); + } + std::vector k0x0 = {0, 0}; + std::vector k1x1 = {1, 1}; + std::vector k1x1_2 = {1, 1}; + std::vector k3x3 = {3, 3}; + ScopedPoolingDescriptor pool_desc; + ScopedActivationDescriptor act_desc; + ScopedTensorDescriptor out_pool_desc; + ScopedTensorDescriptor input_desc; + cudnnPoolingDescriptor_t cudnn_pool_desc = + pool_desc.descriptor(pooling_mode, k3x3, k1x1, k1x1); + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t pool_out_desc = out_pool_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + + cudnnDataType_t cudnn_dtype = CudnnDataType::type; + cudnnTensorDescriptor_t* out_desc = new cudnnTensorDescriptor_t[4]; + cudnnFilterDescriptor_t* filter_desc = new cudnnFilterDescriptor_t[4]; + cudnnTensorDescriptor_t* bias_desc = new cudnnTensorDescriptor_t[4]; + cudnnTensorDescriptor_t* in_desc = new cudnnTensorDescriptor_t[4]; + cudnnConvolutionDescriptor_t* conv_desc = + new cudnnConvolutionDescriptor_t[4]; + for (int i = 0; i < 4; ++i) { + CUDNN_ENFORCE( + platform::dynload::cudnnCreateFilterDescriptor(&filter_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&bias_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&in_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&out_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateConvolutionDescriptor(&conv_desc[i])); + } + + std::vector> filter_dims; + std::vector> bias_dims; + std::vector> in_dims; + std::vector> out_dims; + std::vector> in_strides; + std::vector> out_strides; + std::vector> bias_strides; + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + int n = in_dim[0]; + int h = in_dim[2]; + int w = in_dim[3]; + int oc = output->dims()[1]; + + cudnnDataType_t compute_type = (cudnn_dtype == CUDNN_DATA_DOUBLE) + ? CUDNN_DATA_DOUBLE + : CUDNN_DATA_FLOAT; + + for (int i = 0; i < 4; ++i) { + filter_dims.push_back(framework::vectorize2int(filters[i]->dims())); + CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor( + filter_desc[i], cudnn_dtype, format, 4, filter_dims[i].data())); + bias_dims.push_back({1, filter_dims[i][0], 1, 1}); + bias_strides.push_back({filter_dims[i][0], 1, 1, 1}); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + bias_desc[i], cudnn_dtype, 4, bias_dims[i].data(), + bias_strides[i].data())); + in_dims.push_back({n, filter_dims[i][1], h, w}); + out_dims.push_back({n, filter_dims[i][0], h, w}); + in_strides.push_back({filter_dims[i][1] * h * w, h * w, w, 1}); + out_strides.push_back({oc * h * w, h * w, w, 1}); + + if (i < 2) { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionNdDescriptor( + conv_desc[i], 2, k0x0.data(), k1x1.data(), k1x1.data(), + CUDNN_CROSS_CORRELATION, compute_type)); + } else { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionNdDescriptor( + conv_desc[i], 2, k1x1.data(), k1x1.data(), k1x1.data(), + CUDNN_CROSS_CORRELATION, compute_type)); + } + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + conv_desc[i], CUDNN_DEFAULT_MATH)); + } + in_dims[2][1] *= 2; + in_strides[2][0] = oc * h * w; + out_strides[2][0] = filter_dims[2][0] * h * w; // this out is continuous. + in_strides[3][0] = filter_dims[2][0] * h * w; + CUDNN_ENFORCE( + platform::dynload::cudnnSetConvolutionGroupCount(conv_desc[2], 2)); + + cudnnConvolutionFwdAlgo_t algo[4]; + auto handle = dev_ctx.cudnn_handle(); + size_t workspace_size_in_bytes = 0; // final workspace to allocate. + + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { + int64_t max_user_size = + std::max(static_cast(FLAGS_conv_workspace_size_limit), + user_workspace_size); + workspace_size_limit = max_user_size * 1024 * 1024; + } + + for (int i = 0; i < 4; ++i) { + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + in_desc[i], cudnn_dtype, 4, in_dims[i].data(), in_strides[i].data())); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + out_desc[i], cudnn_dtype, 4, out_dims[i].data(), + out_strides[i].data())); + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, in_desc[i], filter_desc[i], conv_desc[i], out_desc[i], + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, + &algo[i])); + size_t tmp_size = 0; + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, in_desc[i], filter_desc[i], conv_desc[i], out_desc[i], + algo[i], &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + } + cudnnActivationDescriptor_t cudnn_act_desc = + act_desc.descriptor(activation); + + int oc0 = filter_dims[0][0]; + int oc1 = filter_dims[1][0] - filter_dims[2][1] * 2; + int oc3 = filter_dims[3][0]; + int oc2 = oc - oc0 - oc1 - oc3; + + // branch1: pool + 1x1 conv + ScalingParamType alpha = 1.0f, beta = 0.0f; + CUDNN_ENFORCE(platform::dynload::cudnnPoolingForward( + handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, + pool_out_desc, temp_data)); + + std::vector in_datas; + in_datas.push_back(static_cast(temp_data)); + in_datas.push_back(static_cast(input_data)); + in_datas.push_back( + static_cast(output_data + (oc0 + oc1) * h * w)); + T* temp2_data = temp_outs[1]->mutable_data( + framework::make_ddim(out_dims[2]), ctx.GetPlace()); + in_datas.push_back(static_cast(temp2_data + oc2 * h * w)); + + std::vector out_datas; + out_datas.push_back(static_cast(output_data)); + out_datas.push_back(static_cast(output_data + oc0 * h * w)); + out_datas.push_back(static_cast(temp2_data)); + out_datas.push_back( + static_cast(output_data + (oc0 + oc1 + oc2) * h * w)); + + for (int i = 0; i < 4; ++i) { + auto func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward( + handle, &alpha, in_desc[i], in_datas[i], filter_desc[i], + static_cast(filters[i]->data()), conv_desc[i], + algo[i], cudnn_workspace, workspace_size_in_bytes, &beta, + out_desc[i], out_datas[i], bias_desc[i], + static_cast(bias[i]->data()), cudnn_act_desc, + out_desc[i], out_datas[i])); + }; + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + workspace_handle.RunFunc(func, workspace_size_in_bytes); + } + + cudnnTensorDescriptor_t x_desc; + cudnnTensorDescriptor_t y_desc; + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&x_desc)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&y_desc)); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + x_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[2].data())); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + y_desc, cudnn_dtype, 4, out_dims[3].data(), out_strides[3].data())); + CUDNN_ENFORCE(platform::dynload::cudnnTransformTensor( + handle, CudnnDataType::kOne(), x_desc, + static_cast(out_datas[2]), CudnnDataType::kZero(), + y_desc, static_cast(output_data + (oc0 + oc1) * h * w))); + + for (int i = 0; i < 4; ++i) { + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(in_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(out_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyFilterDescriptor(filter_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(bias_desc[i])); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyConvolutionDescriptor(conv_desc[i])); + } + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(x_desc)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(y_desc)); + } +}; +#endif + +} // namespace operators +} // namespace paddle + +#if CUDNN_VERSION >= 7001 +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(conv2d_inception_fusion, + ops::CUDNNConvInceptionFusionOpKernel, + ops::CUDNNConvInceptionFusionOpKernel); +#endif diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index e0078e53141ac7834fd00e4f2dbd8a6c8a1d6b1b..7433c2cbb63577b398a58c8bfb0855f547b0c9d3 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -154,9 +154,14 @@ def __bootstrap__(): if core.is_compiled_with_cuda(): read_env_flags += [ - 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', - 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', - 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus' + 'fraction_of_gpu_memory_to_use', + 'cudnn_deterministic', + 'enable_cublas_tensor_op_math', + 'conv_workspace_size_limit', + 'cudnn_exhaustive_search', + 'memory_optimize_debug', + 'selected_gpus', + 'cudnn_exhaustive_search_times', ] core.init_gflags([sys.argv[0]] + diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 45e6a856f209d0b5badb22ce40063960087809d9..921d59158f90686f9c2044f51651a7c4c3090c0e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -647,20 +647,16 @@ class Operator(object): self.desc.set_input(in_proto.name, []) if outputs is not None: - given = set() - need = set() - for n in outputs: - given.add(n) for m in proto.outputs: - need.add(m.name) - if not given == need: - raise ValueError(("Incorrect setting for output(s) of " - "operator \"%s\". Need: [%s] Given: [%s]") % - (type, - ", ".join(six.binary_type(e) for e in need), - ", ".join(six.binary_type(e) for e in given))) - + if (m.name not in outputs) and m.dispensable: + continue + if not ((m.name in outputs) or m.dispensable): + raise ValueError( + ("Incorrect setting for output(s) of " + "operator \"%s\", should set: [%s].") % (type, m.name)) for out_proto in proto.outputs: + if out_proto.name not in outputs: + continue out_args = outputs[out_proto.name] if not isinstance(out_args, list): out_args = [out_args] diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py index 6cd71e39e41dae5d07e5761fc9caeca113f3b47e..a27212f38f4e96090f6bc30d507581ce5c0a26ff 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py @@ -32,6 +32,8 @@ class TestConv2dFusionOp(OpTest): self.activation = 'relu' self.add_bias = True self.add_residual_data = True + self.channels = None + self.outputs = None self.init_group() self.init_dilation() @@ -49,8 +51,8 @@ class TestConv2dFusionOp(OpTest): input = np.random.random(self.input_size).astype(self.dtype) filter = np.random.random(self.filter_size).astype(self.dtype) - output = conv2d_forward_naive(input, filter, self.groups, - conv2d_param).astype(self.dtype) + self.output = conv2d_forward_naive(input, filter, self.groups, + conv2d_param).astype(self.dtype) self.inputs = { 'Input': OpTest.np_dtype_to_fluid_dtype(input), @@ -58,19 +60,20 @@ class TestConv2dFusionOp(OpTest): } if self.add_residual_data: - residual_data = np.random.random(output.shape).astype(self.dtype) + residual_data = np.random.random(self.output.shape).astype( + self.dtype) self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype( residual_data) - output += residual_data + self.output += residual_data if self.add_bias: bias = np.random.random(self.filter_size[0]).astype(self.dtype) self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias) - output = output + bias.reshape((1, bias.size, 1, 1)) + self.output = self.output + bias.reshape((1, bias.size, 1, 1)) assert self.activation in ['relu', 'identity'] if self.activation == 'relu': - output = np.maximum(output, 0) + self.output = np.maximum(self.output, 0) self.attrs = { 'strides': self.stride, @@ -79,9 +82,12 @@ class TestConv2dFusionOp(OpTest): 'dilations': self.dilations, 'data_format': self.data_format, 'exhaustive_search': self.exhaustive_search, - 'activation': self.activation + 'activation': self.activation, + 'split_channels': self.channels } - self.outputs = {'Output': output} + self.outputs = {'Output': self.output} + + self.set_outputs() def testcuda(self): return core.is_compiled_with_cuda() @@ -117,6 +123,9 @@ class TestConv2dFusionOp(OpTest): def set_search_method(self): self.exhaustive_search = False + def set_outputs(self): + pass + class TestWithoutResidual(TestConv2dFusionOp): def init_bias_residual(self): @@ -160,5 +169,21 @@ class TestCUDNNExhaustiveSearch(TestConv2dFusionOp): self.exhaustive_search = True +class TestMultipleOutputs(TestConv2dFusionOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [1, 32, 17, 17] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [126, f_c, 3, 3] + self.channels = [84, 42] + + def set_outputs(self): + out1 = self.output[:, 0:84, :, :] + out2 = self.output[:, 84:126, :, :] + self.outputs['Outputs'] = [('out1', out1), ('out2', out2)] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index dc3b2cb8bc15836a4bf067caa05c3a37a917ecad..c4eb26893cd1faac72ac06c70a68c52f26b39182 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -137,9 +137,9 @@ def append_input_output(block, op_proto, np_list, is_input, dtype): var_dict = {} for var_proto in proto_list: var_name = str(var_proto.name) + if (var_name not in np_list) and var_proto.dispensable: + continue if is_input: - if (var_name not in np_list) and var_proto.dispensable: - continue assert (var_name in np_list) or (var_proto.dispensable), \ "Missing {} as input".format(var_name) if var_proto.duplicable: