diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 7255d122e8eef013e071f6a39c06d25d60bf11d9..48b5b73b4c0b901d71074f86c2e91c2f4a4b3852 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -43,7 +44,6 @@ #include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/string_helper.h" -#include "paddle/phi/core/utils/rw_lock.h" DECLARE_bool(enable_pe_launch_cinn); DECLARE_bool(enable_cinn_auto_tune); @@ -60,66 +60,61 @@ using inference::analysis::Dot; using ir::Graph; using ir::Node; -CinnCompiler* CinnCompiler::GetInstance() { - static CinnCompiler* instance = new CinnCompiler(); +CinnCompiler *CinnCompiler::GetInstance() { + static CinnCompiler *instance = new CinnCompiler(); return instance; } -const CinnCompiledObject& CinnCompiler::Compile( - const Graph& graph, - const std::map& input_tensors, - const Target& target, - void* stream) { +const CinnCompiledObject &CinnCompiler::Compile( + const Graph &graph, + const std::map &input_tensors, + const Target &target, + void *stream) { VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph); CinnCacheKeyByAddress cur_key_by_address( graph, input_tensors, target.arch_str()); CinnCacheKeyByStructure cur_key_by_struct; - bool exist = false; - { - phi::AutoRDLock r_guard{&rwlock_}; - exist = cache_by_address_.count(cur_key_by_address) != 0; - // if cannot find graph by address, checkout whether the graph structure - // have been stored in cache. - if (!exist) { - // generate the structure cache key - cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str()); - - // if the graph structure can be found, storing the graph address in - // cache for next query. - if (cache_by_struct_.count(cur_key_by_struct) != 0) { - exist = true; + if (!cache_by_address_.count(cur_key_by_address)) { + // generate the structure cache key + cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str()); + if (!cache_by_struct_.count(cur_key_by_struct)) { + std::int64_t compiled_num = real_compiled_num_.fetch_add(1); + auto compiled_res = + CompileGraph(graph, input_tensors, target, compiled_num, stream); + std::unique_lock guard(lock_); + // double check cache_by_struct_ + if (!cache_by_struct_.count(cur_key_by_struct)) { + cache_by_struct_[cur_key_by_struct] = compiled_num; + index2cache_.emplace(compiled_num, std::move(compiled_res)); + } + // double check cache_by_address_ + if (!cache_by_address_.count(cur_key_by_address)) { + cache_by_address_[cur_key_by_address] = + cache_by_struct_.at(cur_key_by_struct); + } + } else { + std::unique_lock guard(lock_); + // double check cache_by_address_ + if (!cache_by_address_.count(cur_key_by_address)) { cache_by_address_[cur_key_by_address] = cache_by_struct_.at(cur_key_by_struct); } } } - if (!exist) { - std::int64_t compiled_num = real_compiled_num_.fetch_add(1); - auto compiled_res = - CompileGraph(graph, input_tensors, target, compiled_num, stream); - phi::AutoWRLock w_guard{&rwlock_}; - if (!cache_by_struct_.count(cur_key_by_struct)) { - cache_by_address_[cur_key_by_address] = compiled_num; - cache_by_struct_[cur_key_by_struct] = compiled_num; - index2cache_.emplace(compiled_num, std::move(compiled_res)); - } - } - phi::AutoRDLock guard{&rwlock_}; - const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]]; - return cached_boj; + return *index2cache_.at(cache_by_address_.at(cur_key_by_address)); } -const CinnCompiledObject& CinnCompiler::Compile( +const CinnCompiledObject &CinnCompiler::Compile( int64_t compilation_key, - const std::map& input_tensors, - const Target& target, - void* stream) { - const auto& graph = FindGraph(compilation_key); + const std::map &input_tensors, + const Target &target, + void *stream) { + const auto &graph = FindGraph(compilation_key); return Compile(graph, input_tensors, target, stream); } -const CinnCompiledObject& CinnCompiler::GetCompiledObject( +const CinnCompiledObject &CinnCompiler::GetCompiledObject( int64_t cached_index) const { auto res = index2cache_.find(cached_index); PADDLE_ENFORCE_NE(res, @@ -130,7 +125,7 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject( } int64_t CinnCompiler::AddGraph(std::unique_ptr graph) { - int64_t graph_key = std::hash()((&(*graph))); + int64_t graph_key = std::hash()((&(*graph))); PADDLE_ENFORCE_EQ( graphs_.count(graph_key), 0, @@ -143,7 +138,7 @@ int64_t CinnCompiler::AddGraph(std::unique_ptr graph) { return graph_key; } -const Graph& CinnCompiler::FindGraph(int64_t graph_key) const { +const Graph &CinnCompiler::FindGraph(int64_t graph_key) const { auto it = graphs_.find(graph_key); PADDLE_ENFORCE_NE( it, @@ -155,16 +150,16 @@ const Graph& CinnCompiler::FindGraph(int64_t graph_key) const { } std::string CinnCompiler::VizGraph(int64_t graph_key) const { - const Graph& graph = FindGraph(graph_key); + const Graph &graph = FindGraph(graph_key); return VizGraph(graph); } -std::string CinnCompiler::VizGraph(const Graph& graph) const { +std::string CinnCompiler::VizGraph(const Graph &graph) const { Dot dot; - std::unordered_map node2dot; + std::unordered_map node2dot; int id = 0; // Create nodes - for (const Node* n : graph.Nodes()) { + for (const Node *n : graph.Nodes()) { std::string node_id = "Node" + std::to_string(id++); if (n->IsOp()) { dot.AddNode(node_id, @@ -180,7 +175,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { auto shape = n->Var()->GetShape(); std::vector shape_str(shape.size()); std::transform( - shape.begin(), shape.end(), shape_str.begin(), [](const auto& val) { + shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) { return std::to_string(val); }); label += "\n" + string::join_strings(shape_str, ','); @@ -198,10 +193,10 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { node2dot[n] = node_id; } // Create edges - for (const Node* n : graph.Nodes()) { - const auto& src_id = node2dot.at(n); - for (auto* out : n->outputs) { - const auto& dest_id = node2dot.at(out); + for (const Node *n : graph.Nodes()) { + const auto &src_id = node2dot.at(n); + for (auto *out : n->outputs) { + const auto &dest_id = node2dot.at(out); dot.AddEdge(src_id, dest_id, {}); } } @@ -209,7 +204,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { } std::string CinnCompiler::SerializeKey(int64_t compilation_key) const { - const auto& graph = FindGraph(compilation_key); + const auto &graph = FindGraph(compilation_key); ProgramDesc program; GraphToProgram(graph, &program); @@ -220,7 +215,7 @@ std::string CinnCompiler::SerializeKey(int64_t compilation_key) const { } std::string CinnCompiler::ReadableKey(int64_t compilation_key) const { - const auto& graph = FindGraph(compilation_key); + const auto &graph = FindGraph(compilation_key); ProgramDesc program; GraphToProgram(graph, &program); @@ -230,7 +225,7 @@ std::string CinnCompiler::ReadableKey(int64_t compilation_key) const { void CinnCompiler::Clear() { { - phi::AutoWRLock guard{&rwlock_}; + std::unique_lock guard(lock_); graphs_.clear(); cache_by_address_.clear(); cache_by_struct_.clear(); @@ -240,22 +235,22 @@ void CinnCompiler::Clear() { } void CinnCompiler::CheckCompiledValid( - const ir::Graph& graph, - const std::map& input_tensors, - const CinnCompiledObject& compiled_obj) const { - const auto& input_var_names = graph.Get>(kInputVars); - const auto& output_var_names = + const ir::Graph &graph, + const std::map &input_tensors, + const CinnCompiledObject &compiled_obj) const { + const auto &input_var_names = graph.Get>(kInputVars); + const auto &output_var_names = graph.Get>(kOutputVars); - auto* launch_context = compiled_obj.launch_context.get(); + auto *launch_context = compiled_obj.launch_context.get(); // 1. check all of the output variables will be assigned by compiled program - for (auto&& var_name : output_var_names) { + for (auto &&var_name : output_var_names) { PADDLE_ENFORCE_EQ(launch_context->IsVariableUsed(var_name), true, platform::errors::PreconditionNotMet( "Variable(%s) not applied in CINN", var_name)); } // 2. check all of the used input variables were correctly deduced by CINN. - for (const auto& var_name : input_var_names) { + for (const auto &var_name : input_var_names) { // some input variables were not used by CINN because they were eliminated // by its optimized passes or some operators of it need less inputs if (!launch_context->IsVariableUsed(var_name)) { @@ -268,11 +263,11 @@ void CinnCompiler::CheckCompiledValid( } std::unique_ptr CinnCompiler::CompileGraph( - const ir::Graph& graph, - const std::map& input_tensors, - const Target& target, + const ir::Graph &graph, + const std::map &input_tensors, + const Target &target, std::int64_t compiled_num, - void* stream) const { + void *stream) const { CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; auto frontend_program = symbol(); auto fetch_ids = symbol.GetFetchIds(); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h index 6f46731566a33531508b5f8b3d52a86fec754141..d193afce9deb08f28345869a06dc40f6a039d127 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -26,7 +27,6 @@ #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/macros.h" -#include "paddle/phi/core/utils/rw_lock.h" namespace cinn { namespace common { @@ -129,7 +129,7 @@ class CinnCompiler { std::unordered_map> index2cache_; std::atomic_int64_t real_compiled_num_{0}; - mutable phi::RWLock rwlock_; + mutable std::mutex lock_; DISABLE_COPY_AND_ASSIGN(CinnCompiler); };