From bfb07aafe82fee83808ef41a1df72328414721fe Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Thu, 2 Apr 2020 14:25:14 +0800 Subject: [PATCH] Revert "Exhaustive search (#22821)", test=develop (#23401) This reverts commit 48144e40995e5ecc77a9649e4f2e4f180249ce2d. --- paddle/fluid/framework/operator.cc | 18 +++- paddle/fluid/framework/operator.h | 36 ++++++- .../fluid/framework/operator_kernel_configs.h | 82 ++++------------ paddle/fluid/framework/operator_test.cc | 2 +- paddle/fluid/imperative/execution_context.h | 3 +- paddle/fluid/imperative/prepared_operator.cc | 32 ++++--- paddle/fluid/imperative/prepared_operator.h | 4 +- paddle/fluid/imperative/tests/test_layer.cc | 2 +- .../fluid/operators/beam_search_decode_op.cc | 2 +- paddle/fluid/operators/conv_cudnn_helper.h | 96 ++++++------------- paddle/fluid/operators/conv_cudnn_op.cu | 59 ++++-------- paddle/fluid/operators/conv_op.cc | 45 +++++++++ .../test_elementwise_mul_op_dim.cc | 3 +- .../fluid/operators/fused/conv_fusion_op.cu | 7 +- paddle/fluid/operators/warpctc_op.cc | 6 +- 15 files changed, 200 insertions(+), 197 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 87eda389028..f4440e44124 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -905,6 +905,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, this->InferShape(&infer_shape_ctx); } +std::vector* OperatorWithKernel::GetKernelConfig( + const OpKernelType& key) const { + auto config_iter = kernel_configs_map_.find(key); + std::vector* kernel_configs = nullptr; + if (config_iter != kernel_configs_map_.end()) { + kernel_configs = &(config_iter->second); + } + return kernel_configs; +} + void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { // To reduce the elapsed time of HasAttr, we use bool variable to record the @@ -941,6 +951,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ChooseKernel(*runtime_ctx, scope, place); } + std::vector* kernel_configs = GetKernelConfig(*kernel_type_); + // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; Scope* transfer_scope = nullptr; @@ -976,8 +988,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, { platform::RecordEvent record_event("compute", platform::EventRole::kInnerOp); - (*kernel_func_)( - ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); + (*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx, + kernel_configs)); } if (!transfered_inplace_vars.empty()) { @@ -1046,7 +1058,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, OpKernelMap& kernels = kernels_iter->second; auto expected_kernel_key = this->GetExpectedKernelType( - ExecutionContext(*this, scope, *dev_ctx, ctx)); + ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); if (HasAttr("op_device")) { if (Attr("op_device") == "cpu") { expected_kernel_key.place_ = platform::CPUPlace(); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 8e2dc860a4c..b58ad71b8da 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -31,6 +31,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_kernel_type.h" +#include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" @@ -215,12 +216,30 @@ class OperatorBase { const platform::Place& place) const = 0; }; +#ifdef PADDLE_WITH_CUDA +using KernelConfig = boost::variant< + std::shared_ptr>, + std::shared_ptr>, + std::shared_ptr>>; +#else +using KernelConfig = boost::variant; +#endif + +using OpKernelConfigsMap = + std::unordered_map, + OpKernelType::Hash>; + class ExecutionContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, const platform::DeviceContext& device_context, - const RuntimeContext& ctx) - : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {} + const RuntimeContext& ctx, + std::vector* configs) + : op_(op), + scope_(scope), + device_context_(device_context), + ctx_(ctx), + kernel_configs_(configs) {} virtual ~ExecutionContext() {} virtual std::string InputName(const std::string& name) const { @@ -386,6 +405,15 @@ class ExecutionContext { return temp_tensor; } + template + T& GetKernelConfig(size_t idx) const { + PADDLE_ENFORCE( + kernel_configs_ && kernel_configs_->size() > static_cast(idx), + "%s selected kernel doesn't have kernel config %lu <= %lu", + op_.Type().c_str(), kernel_configs_->size(), idx); + return *boost::get>((*kernel_configs_)[idx]); + } + const RuntimeContext Context() const { return ctx_; } std::string DebugString() const { return op_.DebugString(); } @@ -395,6 +423,7 @@ class ExecutionContext { const Scope& scope_; const platform::DeviceContext& device_context_; const RuntimeContext& ctx_; + mutable std::vector* kernel_configs_; }; template <> @@ -470,6 +499,8 @@ class OperatorWithKernel : public OperatorBase { virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; + std::vector* GetKernelConfig(const OpKernelType& key) const; + // change this to public so that in dygraph mode we can call it to check if we // need transform data virtual OpKernelType GetKernelTypeForVar( @@ -506,6 +537,7 @@ class OperatorWithKernel : public OperatorBase { const platform::Place& place) const; protected: + mutable OpKernelConfigsMap kernel_configs_map_; mutable std::unique_ptr kernel_type_; mutable std::unique_ptr kernel_func_; mutable std::unique_ptr runtime_ctx_; diff --git a/paddle/fluid/framework/operator_kernel_configs.h b/paddle/fluid/framework/operator_kernel_configs.h index 68edb7c89dd..5c5a7423832 100644 --- a/paddle/fluid/framework/operator_kernel_configs.h +++ b/paddle/fluid/framework/operator_kernel_configs.h @@ -21,21 +21,19 @@ limitations under the License. */ namespace paddle { namespace framework { -// thread-safe. +// Not thread-safe. Should be owned per-kernel. template class AlgorithmsCache { public: AlgorithmsCache() : search_times_(0) { hash_.clear(); } // Caches the best algorithm for a given // combination of tensor dimensions & compute data type. - // cudnn_dtype set for different data type - TAlgorithm GetAlgorithm(const std::vector& dims1, - const std::vector& dims2, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, int algorithmFlags, - int64_t cudnn_dtype, - std::function gen_func); + TAlgorithm GetAlgorithm( + const std::vector& dims1, const std::vector& dims2, + const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, + int algorithmFlags, // can set for different data type + std::function gen_func); TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags, std::function gen_func); @@ -43,14 +41,13 @@ class AlgorithmsCache { private: std::unordered_map hash_; int search_times_; - std::mutex cache_mutex; }; template TAlgorithm framework::AlgorithmsCache::GetAlgorithm( const std::vector& dims1, const std::vector& dims2, const std::vector& strides, const std::vector& paddings, - const std::vector& dilations, int algorithmFlags, int64_t cudnn_dtype, + const std::vector& dilations, int algorithmFlags, std::function gen_func) { int64_t seed = 0; // Hash all of the inputs, use to try and look up a previously @@ -84,73 +81,36 @@ TAlgorithm framework::AlgorithmsCache::GetAlgorithm( seed ^= hashFn(static_cast(algorithmFlags)) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 5; - seed ^= hashFn(static_cast(cudnn_dtype)) + 0x9e3779b9 + (seed << 6) + - (seed >> 2) + 6; - VLOG(10) << "seed:" << seed << ", hash_.size:" << hash_.size(); if (seed == 0) return gen_func(); - TAlgorithm ret; - auto it = hash_.end(); - bool have_found = false; - { - std::lock_guard lock(cache_mutex); - it = hash_.find(seed); - - if (it != hash_.end()) { - ret = it->second; - have_found = true; - } - } - - if (!have_found) { - ret = gen_func(); - std::lock_guard lock(cache_mutex); - hash_[seed] = ret; + if (hash_.find(seed) == hash_.end()) { + TAlgorithm value = gen_func(); + hash_[seed] = value; } - - return ret; + return hash_[seed]; } template TAlgorithm AlgorithmsCache::GetAlgorithm( int64_t area, int search_times, int algorithmFlags, std::function gen_func) { - auto it = hash_.end(); - { - std::lock_guard lock(cache_mutex); - it = hash_.find(area); - - if (it != hash_.end()) { - return it->second; - } - } - - bool gene_flag = false; - - { - std::lock_guard lock(cache_mutex); - gene_flag = (search_times_ < search_times); + if (hash_.find(area) != hash_.end()) { + return hash_[area]; } - - TAlgorithm algo{}; - if (gene_flag) { - algo = gen_func(); - std::lock_guard lock(cache_mutex); + if (search_times_ < search_times) { + auto algo = gen_func(); hash_[area] = algo; ++search_times_; return algo; } - + TAlgorithm algo{}; int64_t min = static_cast(INT_MAX); - { - std::lock_guard lock(cache_mutex); - for (const auto& m : hash_) { - if (m.first < min) { - min = m.first; - algo = m.second; - } + for (const auto& m : hash_) { + if (m.first < min) { + min = m.first; + algo = m.second; } } return algo; diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 6fbaca7174e..77c98a08cf0 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) { paddle::framework::RuntimeContext ctx({}, {}); paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx, - ctx); + ctx, nullptr); ASSERT_EQ(exe_context.InputSize("input"), 1u); ASSERT_EQ(exe_context.OutputSize("output"), 1u); diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index 398b1292e2f..0537370b074 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -33,10 +33,11 @@ class DygraphExecutionContext : public framework::ExecutionContext { const framework::Scope& scope, const platform::DeviceContext& device_context, const framework::RuntimeContext& ctx, + std::vector* configs, const NameVarMap& var_base_map_in, const NameVarMap& var_base_map_out, const framework::AttributeMap& attrs) - : ExecutionContext(op, scope, device_context, ctx), + : ExecutionContext(op, scope, device_context, ctx, configs), var_base_map_in_(var_base_map_in), var_base_map_out_(var_base_map_out), attrs_(attrs) {} diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 0e29ed86d13..c4aa2f7392a 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -80,8 +80,13 @@ void PreparedOp::PrepareData( PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorWithKernel::OpKernelFunc& func, - platform::DeviceContext* dev_ctx) - : op_(op), ctx_(ctx), func_(func), dev_ctx_(dev_ctx) {} + platform::DeviceContext* dev_ctx, + std::vector* kernel_configs) + : op_(op), + ctx_(ctx), + func_(func), + dev_ctx_(dev_ctx), + kernel_configs_(kernel_configs) {} template PreparedOp PrepareOpImpl(const NameVarMap& ins, @@ -106,7 +111,7 @@ PreparedOp PrepareOpImpl(const NameVarMap& ins, framework::RuntimeContext ctx({}, {}); auto expected_kernel_key = op.GetExpectedKernelType(DygraphExecutionContext( - op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); + op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); @@ -115,6 +120,8 @@ PreparedOp PrepareOpImpl(const NameVarMap& ins, PADDLE_THROW("op %s does not have kernel for %s", op.Type(), KernelTypeToString(expected_kernel_key)); } + std::vector* kernel_configs = + op.GetKernelConfig(expected_kernel_key); if (!(expected_kernel_key.place_ == place)) { dev_ctx = pool.Get(expected_kernel_key.place_); @@ -122,7 +129,7 @@ PreparedOp PrepareOpImpl(const NameVarMap& ins, } PrepareDataImpl(place, ins, op, expected_kernel_key); - return PreparedOp(op, ctx, kernel_iter->second, dev_ctx); + return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, @@ -145,8 +152,10 @@ template static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorWithKernel::OpKernelFunc& func, - platform::DeviceContext* dev_ctx, const NameVarMap& ins, - const NameVarMap& outs, const framework::AttributeMap& attrs) { + platform::DeviceContext* dev_ctx, + std::vector* kernel_configs, + const NameVarMap& ins, const NameVarMap& outs, + const framework::AttributeMap& attrs) { // TODO(zjl): remove scope in dygraph framework::Scope scope; @@ -154,21 +163,22 @@ static void PreparedOpRunImpl( static_cast(op).InferShape( &infer_shape_ctx); - func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, - attrs)); + func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, + kernel_configs, ins, outs, attrs)); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { - PreparedOpRunImpl(op_, ctx_, func_, dev_ctx_, ins, outs, attrs); + PreparedOpRunImpl(op_, ctx_, func_, dev_ctx_, kernel_configs_, ins, + outs, attrs); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { - PreparedOpRunImpl(op_, ctx_, func_, dev_ctx_, ins, outs, - attrs); + PreparedOpRunImpl(op_, ctx_, func_, dev_ctx_, + kernel_configs_, ins, outs, attrs); } } // namespace imperative diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 8ffc3eaf82f..7120960b902 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -33,7 +33,8 @@ class PreparedOp { PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorWithKernel::OpKernelFunc& func, - platform::DeviceContext* dev_ctx); + platform::DeviceContext* dev_ctx, + std::vector* kernel_configs); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, @@ -71,6 +72,7 @@ class PreparedOp { const framework::RuntimeContext& ctx_; framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; + std::vector* kernel_configs_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index f249a09f4b5..9f7cb3344fb 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) { framework::Scope scope; DygraphExecutionContext dy_exe_context( - *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map); + *(op.get()), scope, *dev_ctx, ctx, nullptr, ins, outs, concat_att_map); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index 8c9e397b13f..3edaf58cd01 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { auto& dev_ctx = *pool.Get(dev_place); framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope); - framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx); + framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr); const LoDTensorArray* ids = ctx.Input("Ids"); const LoDTensorArray* scores = ctx.Input("Scores"); diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index df94de8a18f..0d4a3c56bb2 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_desc.h" -// #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -90,43 +89,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector& v) { return out; } -// ConvSearchCache using framework::AlgorithmsCache to search -// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or -// cudnnConvolutionBwdFilterAlgo_t -class ConvSearchCache { - public: - static ConvSearchCache& Instance() { - static ConvSearchCache instance; - return instance; - } - - framework::AlgorithmsCache* GetForward() { - return &forward_cache_; - } - framework::AlgorithmsCache* GetBackwardData() { - return &backward_data_cache_; - } - framework::AlgorithmsCache* - GetBackwardFilter() { - return &backward_filter_cache_; - } - framework::AlgorithmsCache* GetConvFusion() { - return &fusion_forward_cache_; - } - - private: - ConvSearchCache() {} - ~ConvSearchCache() {} - ConvSearchCache(const ConvSearchCache&) {} - ConvSearchCache& operator=(const ConvSearchCache&) {} - - framework::AlgorithmsCache forward_cache_; - framework::AlgorithmsCache - backward_data_cache_; - framework::AlgorithmsCache - backward_filter_cache_; - framework::AlgorithmsCache fusion_forward_cache_; -}; +using framework::AlgorithmsCache; struct ConvArgs { cudnnHandle_t handle; @@ -134,7 +97,6 @@ struct ConvArgs { platform::FilterDescriptor wdesc; platform::ConvolutionDescriptor cdesc; const framework::Tensor *x, *w, *o; - cudnnDataType_t cudnn_dtype; // strides std::vector s; @@ -145,9 +107,8 @@ struct ConvArgs { ConvArgs(const framework::Tensor* x, const framework::Tensor* w, const framework::Tensor* o, const std::vector s, - const std::vector p, const std::vector d, - cudnnDataType_t dtype) - : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} + const std::vector p, const std::vector d) + : x(x), w(w), o(o), s(s), p(p), d(d) {} }; template @@ -160,7 +121,7 @@ struct SearchAlgorithm { template static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, + bool deterministic, int algo_cache_id, const framework::ExecutionContext& ctx) { auto dtype = platform::CudnnDataType::type; bool has_got_workspace_size = true; @@ -222,24 +183,22 @@ struct SearchAlgorithm { #endif VLOG(3) << "choose algo " << algo; } else { + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(algo_cache_id); auto& dev_ctx = ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto& temp = ctx.cuda_device_context(); - AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetForward()); - auto x_dims = framework::vectorize(args.x->dims()); auto w_dims = framework::vectorize(args.w->dims()); - VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; + VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:" + << algo_cache_id << ", x_dims:" << x_dims + << ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p" + << args.p << ", args.d" << args.d; algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { + x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { int returned_algo_count; std::array perf_stat; @@ -285,7 +244,7 @@ struct SearchAlgorithm { template static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, + bool deterministic, int algo_cache_id, const framework::ExecutionContext& ctx) { auto dtype = platform::CudnnDataType::type; bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); @@ -362,23 +321,22 @@ struct SearchAlgorithm { } else if (deterministic) { return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } else { + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(algo_cache_id); auto& dev_ctx = ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetBackwardData()); - auto x_dims = framework::vectorize(args.x->dims()); auto w_dims = framework::vectorize(args.w->dims()); - VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; + VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:" + << algo_cache_id << ", x_dims:" << x_dims + << ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p" + << args.p << ", args.d" << args.d; algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { + x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { int returned_algo_count; std::array perf_stat; @@ -427,7 +385,7 @@ struct SearchAlgorithm { template static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, + bool deterministic, int algo_cache_id, const framework::ExecutionContext& ctx) { auto dtype = platform::CudnnDataType::type; bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); @@ -491,22 +449,22 @@ struct SearchAlgorithm { } else if (deterministic) { return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } else { + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(algo_cache_id); auto& dev_ctx = ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetBackwardFilter()); auto x_dims = framework::vectorize(args.x->dims()); auto w_dims = framework::vectorize(args.w->dims()); - VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; + VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:" + << algo_cache_id << ", x_dims:" << x_dims + << ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p" + << args.p << ", args.d" << args.d; algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { + x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { int returned_algo_count; std::array perf_stat; auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index 7f705755915..c885bf0d50e 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -216,13 +216,9 @@ class CUDNNConvOpKernel : public framework::OpKernel { const T* filter_data = transformed_filter_channel.data(); // ------------------- cudnn descriptors --------------------- - ConvArgs args{&transformed_input, - &transformed_filter_channel, - &transformed_output, - strides, - padding_common, - dilations, - dtype}; + ConvArgs args{&transformed_input, &transformed_filter_channel, + &transformed_output, strides, + padding_common, dilations}; auto handle = dev_ctx.cudnn_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); @@ -273,7 +269,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { cudnnConvolutionFwdAlgo_t algo{}; using search = SearchAlgorithm; - algo = search::Find(args, exhaustive_search, false, ctx); + algo = search::Find(args, exhaustive_search, false, 0, ctx); workspace_size = search::GetWorkspaceSize(args, algo); #if CUDNN_VERSION_MIN(7, 0, 1) @@ -522,15 +518,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { &transformed_output_grad_channel, strides, padding_common, - dilations, - dtype}; + dilations}; ConvArgs args2{&transformed_input, &transformed_filter_grad_channel, &transformed_output_grad_channel, strides, padding_common, - dilations, - dtype}; + dilations}; auto handle = dev_ctx.cudnn_handle(); DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC @@ -586,7 +580,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { using search1 = SearchAlgorithm; data_algo = - search1::Find(args1, exhaustive_search, deterministic, ctx); + search1::Find(args1, exhaustive_search, deterministic, 0, ctx); workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); } @@ -603,7 +597,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { using search2 = SearchAlgorithm; filter_algo = - search2::Find(args2, exhaustive_search, deterministic, ctx); + search2::Find(args2, exhaustive_search, deterministic, 1, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo)); } @@ -904,26 +898,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { auto handle = dev_ctx.cudnn_handle(); - ConvArgs args1{&transformed_ddX, - W, - &transformed_ddO_channel, - strides, - padding_common, - dilations, - dtype}; - ConvArgs args2{ - &transformed_X, ddW, &transformed_ddO_channel, strides, padding_common, - dilations, dtype}; - ConvArgs args3{&transformed_ddX, - dW, - &transformed_dO_channel, - strides, - padding_common, - dilations, - dtype}; - ConvArgs args4{ - &transformed_dX, ddW, &transformed_dO_channel, strides, padding_common, - dilations, dtype}; + ConvArgs args1{&transformed_ddX, W, + &transformed_ddO_channel, strides, + padding_common, dilations}; + ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides, + padding_common, dilations}; + ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides, + padding_common, dilations}; + ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides, + padding_common, dilations}; cudnnConvolutionFwdAlgo_t fwd_algo1 = static_cast(0); @@ -951,7 +934,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search1 = SearchAlgorithm; - fwd_algo1 = search1::Find(args1, exhaustive_search, false, ctx); + fwd_algo1 = search1::Find(args1, exhaustive_search, false, 0, ctx); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); } @@ -966,7 +949,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); using search2 = SearchAlgorithm; - fwd_algo2 = search2::Find(args2, exhaustive_search, false, ctx); + fwd_algo2 = search2::Find(args2, exhaustive_search, false, 0, ctx); workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2)); } @@ -984,7 +967,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { using search3 = SearchAlgorithm; filter_algo = - search3::Find(args3, exhaustive_search, deterministic, ctx); + search3::Find(args3, exhaustive_search, deterministic, 1, ctx); workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); } @@ -1000,7 +983,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { using search4 = SearchAlgorithm; data_algo = - search4::Find(args4, exhaustive_search, deterministic, ctx); + search4::Find(args4, exhaustive_search, deterministic, 2, ctx); workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); } diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 81d1a39309a..8b0d5710384 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -178,6 +178,17 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, library, customized_type_value); +#ifdef PADDLE_WITH_CUDA + std::vector& configs = kernel_configs_map_[type]; + // TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn + // to false. It should be fixed and then here should only create if library + // is kCUDNN. + if (configs.empty()) { + std::shared_ptr> p( + new framework::AlgorithmsCache()); + configs.push_back(p); + } +#endif return type; } @@ -552,6 +563,21 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( auto type = framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), layout_, library_, customized_type_value); +#ifdef PADDLE_WITH_CUDA + if (library_ == framework::LibraryType::kCUDNN) { + std::vector& configs = kernel_configs_map_[type]; + if (configs.empty()) { + std::shared_ptr> + p(new framework::AlgorithmsCache()); + configs.push_back(p); + + std::shared_ptr< + framework::AlgorithmsCache> + p2(new framework::AlgorithmsCache()); + configs.push_back(p2); + } + } +#endif return type; } @@ -728,6 +754,25 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( auto type = framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), layout_, library_, customized_type_value); +#ifdef PADDLE_WITH_CUDA + if (library_ == framework::LibraryType::kCUDNN) { + std::vector& configs = kernel_configs_map_[type]; + if (configs.empty()) { + std::shared_ptr> p0( + new framework::AlgorithmsCache()); + configs.push_back(p0); + + std::shared_ptr< + framework::AlgorithmsCache> + p1(new framework::AlgorithmsCache()); + configs.push_back(p1); + + std::shared_ptr> + p2(new framework::AlgorithmsCache()); + configs.push_back(p2); + } + } +#endif return type; } diff --git a/paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc b/paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc index 6a04aa7dedd..7443c142d0f 100644 --- a/paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc +++ b/paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc @@ -58,7 +58,8 @@ void MainTest(const TestData& test_data) { RuntimeContext runtime_ctx = RuntimeContext(op->Inputs(), op->Outputs(), scope); - ExecutionContext ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_ctx); + ExecutionContext ctx = + ExecutionContext(*op, scope, *dev_ctx, runtime_ctx, nullptr); bool result = ElementwiseMulOp::AreDimsAndFormatCorrect( ctx, 16, MKLDNNMemoryFormat::nChw16c); if (test_data.supposed_to_fail) diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 92769bb93ea..858cde00865 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -14,10 +14,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/math/padding.h" +#include "paddle/fluid/platform/cudnn_helper.h" DECLARE_int64(cudnn_exhaustive_search_times); @@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { return fwd_perf_stat[0].algo; }; AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetConvFusion()); + ctx.GetKernelConfig>(0); int search_times = ctx.Attr("search_times"); search_times = std::max( static_cast(FLAGS_cudnn_exhaustive_search_times), search_times); @@ -245,9 +245,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0, search_func); } else { - auto dtype = platform::CudnnDataType::type; algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings, - dilations, 0, dtype, search_func); + dilations, 0, search_func); } VLOG(3) << "choose algo " << algo; } diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index 2ca2588470e..7ec0aa0e296 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace(), - layout_, library_); + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context(), layout_, library_); } }; @@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Loss")), - ctx.GetPlace()); + ctx.device_context()); } }; -- GitLab