未验证 提交 ccfde2da 编写于 作者: Z Zhen Wang 提交者: GitHub

Update the lock logic used in CinnCompiler::Compile. (#43876)

* Update the lock logic used in CinnCompiler::Compile.
上级 8bd69193
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <iterator> #include <iterator>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -43,7 +44,6 @@ ...@@ -43,7 +44,6 @@
#include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
DECLARE_bool(enable_pe_launch_cinn); DECLARE_bool(enable_pe_launch_cinn);
DECLARE_bool(enable_cinn_auto_tune); DECLARE_bool(enable_cinn_auto_tune);
...@@ -60,66 +60,61 @@ using inference::analysis::Dot; ...@@ -60,66 +60,61 @@ using inference::analysis::Dot;
using ir::Graph; using ir::Graph;
using ir::Node; using ir::Node;
CinnCompiler* CinnCompiler::GetInstance() { CinnCompiler *CinnCompiler::GetInstance() {
static CinnCompiler* instance = new CinnCompiler(); static CinnCompiler *instance = new CinnCompiler();
return instance; return instance;
} }
const CinnCompiledObject& CinnCompiler::Compile( const CinnCompiledObject &CinnCompiler::Compile(
const Graph& graph, const Graph &graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor *> &input_tensors,
const Target& target, const Target &target,
void* stream) { void *stream) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph); VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKeyByAddress cur_key_by_address( CinnCacheKeyByAddress cur_key_by_address(
graph, input_tensors, target.arch_str()); graph, input_tensors, target.arch_str());
CinnCacheKeyByStructure cur_key_by_struct; CinnCacheKeyByStructure cur_key_by_struct;
bool exist = false; if (!cache_by_address_.count(cur_key_by_address)) {
{ // generate the structure cache key
phi::AutoRDLock r_guard{&rwlock_}; cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
exist = cache_by_address_.count(cur_key_by_address) != 0; if (!cache_by_struct_.count(cur_key_by_struct)) {
// if cannot find graph by address, checkout whether the graph structure std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
// have been stored in cache. auto compiled_res =
if (!exist) { CompileGraph(graph, input_tensors, target, compiled_num, stream);
// generate the structure cache key std::unique_lock<std::mutex> guard(lock_);
cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str()); // double check cache_by_struct_
if (!cache_by_struct_.count(cur_key_by_struct)) {
// if the graph structure can be found, storing the graph address in cache_by_struct_[cur_key_by_struct] = compiled_num;
// cache for next query. index2cache_.emplace(compiled_num, std::move(compiled_res));
if (cache_by_struct_.count(cur_key_by_struct) != 0) { }
exist = true; // double check cache_by_address_
if (!cache_by_address_.count(cur_key_by_address)) {
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct);
}
} else {
std::unique_lock<std::mutex> guard(lock_);
// double check cache_by_address_
if (!cache_by_address_.count(cur_key_by_address)) {
cache_by_address_[cur_key_by_address] = cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct); cache_by_struct_.at(cur_key_by_struct);
} }
} }
} }
if (!exist) { return *index2cache_.at(cache_by_address_.at(cur_key_by_address));
std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream);
phi::AutoWRLock w_guard{&rwlock_};
if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_by_address_[cur_key_by_address] = compiled_num;
cache_by_struct_[cur_key_by_struct] = compiled_num;
index2cache_.emplace(compiled_num, std::move(compiled_res));
}
}
phi::AutoRDLock guard{&rwlock_};
const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]];
return cached_boj;
} }
const CinnCompiledObject& CinnCompiler::Compile( const CinnCompiledObject &CinnCompiler::Compile(
int64_t compilation_key, int64_t compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor *> &input_tensors,
const Target& target, const Target &target,
void* stream) { void *stream) {
const auto& graph = FindGraph(compilation_key); const auto &graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target, stream); return Compile(graph, input_tensors, target, stream);
} }
const CinnCompiledObject& CinnCompiler::GetCompiledObject( const CinnCompiledObject &CinnCompiler::GetCompiledObject(
int64_t cached_index) const { int64_t cached_index) const {
auto res = index2cache_.find(cached_index); auto res = index2cache_.find(cached_index);
PADDLE_ENFORCE_NE(res, PADDLE_ENFORCE_NE(res,
...@@ -130,7 +125,7 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject( ...@@ -130,7 +125,7 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject(
} }
int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
int64_t graph_key = std::hash<Graph*>()((&(*graph))); int64_t graph_key = std::hash<Graph *>()((&(*graph)));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
graphs_.count(graph_key), graphs_.count(graph_key),
0, 0,
...@@ -143,7 +138,7 @@ int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { ...@@ -143,7 +138,7 @@ int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
return graph_key; return graph_key;
} }
const Graph& CinnCompiler::FindGraph(int64_t graph_key) const { const Graph &CinnCompiler::FindGraph(int64_t graph_key) const {
auto it = graphs_.find(graph_key); auto it = graphs_.find(graph_key);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, it,
...@@ -155,16 +150,16 @@ const Graph& CinnCompiler::FindGraph(int64_t graph_key) const { ...@@ -155,16 +150,16 @@ const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
} }
std::string CinnCompiler::VizGraph(int64_t graph_key) const { std::string CinnCompiler::VizGraph(int64_t graph_key) const {
const Graph& graph = FindGraph(graph_key); const Graph &graph = FindGraph(graph_key);
return VizGraph(graph); return VizGraph(graph);
} }
std::string CinnCompiler::VizGraph(const Graph& graph) const { std::string CinnCompiler::VizGraph(const Graph &graph) const {
Dot dot; Dot dot;
std::unordered_map<const Node*, std::string> node2dot; std::unordered_map<const Node *, std::string> node2dot;
int id = 0; int id = 0;
// Create nodes // Create nodes
for (const Node* n : graph.Nodes()) { for (const Node *n : graph.Nodes()) {
std::string node_id = "Node" + std::to_string(id++); std::string node_id = "Node" + std::to_string(id++);
if (n->IsOp()) { if (n->IsOp()) {
dot.AddNode(node_id, dot.AddNode(node_id,
...@@ -180,7 +175,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { ...@@ -180,7 +175,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
auto shape = n->Var()->GetShape(); auto shape = n->Var()->GetShape();
std::vector<std::string> shape_str(shape.size()); std::vector<std::string> shape_str(shape.size());
std::transform( std::transform(
shape.begin(), shape.end(), shape_str.begin(), [](const auto& val) { shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) {
return std::to_string(val); return std::to_string(val);
}); });
label += "\n" + string::join_strings(shape_str, ','); label += "\n" + string::join_strings(shape_str, ',');
...@@ -198,10 +193,10 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { ...@@ -198,10 +193,10 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
node2dot[n] = node_id; node2dot[n] = node_id;
} }
// Create edges // Create edges
for (const Node* n : graph.Nodes()) { for (const Node *n : graph.Nodes()) {
const auto& src_id = node2dot.at(n); const auto &src_id = node2dot.at(n);
for (auto* out : n->outputs) { for (auto *out : n->outputs) {
const auto& dest_id = node2dot.at(out); const auto &dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {}); dot.AddEdge(src_id, dest_id, {});
} }
} }
...@@ -209,7 +204,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { ...@@ -209,7 +204,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
} }
std::string CinnCompiler::SerializeKey(int64_t compilation_key) const { std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
const auto& graph = FindGraph(compilation_key); const auto &graph = FindGraph(compilation_key);
ProgramDesc program; ProgramDesc program;
GraphToProgram(graph, &program); GraphToProgram(graph, &program);
...@@ -220,7 +215,7 @@ std::string CinnCompiler::SerializeKey(int64_t compilation_key) const { ...@@ -220,7 +215,7 @@ std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
} }
std::string CinnCompiler::ReadableKey(int64_t compilation_key) const { std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
const auto& graph = FindGraph(compilation_key); const auto &graph = FindGraph(compilation_key);
ProgramDesc program; ProgramDesc program;
GraphToProgram(graph, &program); GraphToProgram(graph, &program);
...@@ -230,7 +225,7 @@ std::string CinnCompiler::ReadableKey(int64_t compilation_key) const { ...@@ -230,7 +225,7 @@ std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
void CinnCompiler::Clear() { void CinnCompiler::Clear() {
{ {
phi::AutoWRLock guard{&rwlock_}; std::unique_lock<std::mutex> guard(lock_);
graphs_.clear(); graphs_.clear();
cache_by_address_.clear(); cache_by_address_.clear();
cache_by_struct_.clear(); cache_by_struct_.clear();
...@@ -240,22 +235,22 @@ void CinnCompiler::Clear() { ...@@ -240,22 +235,22 @@ void CinnCompiler::Clear() {
} }
void CinnCompiler::CheckCompiledValid( void CinnCompiler::CheckCompiledValid(
const ir::Graph& graph, const ir::Graph &graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor *> &input_tensors,
const CinnCompiledObject& compiled_obj) const { const CinnCompiledObject &compiled_obj) const {
const auto& input_var_names = graph.Get<std::vector<std::string>>(kInputVars); const auto &input_var_names = graph.Get<std::vector<std::string>>(kInputVars);
const auto& output_var_names = const auto &output_var_names =
graph.Get<std::vector<std::string>>(kOutputVars); graph.Get<std::vector<std::string>>(kOutputVars);
auto* launch_context = compiled_obj.launch_context.get(); auto *launch_context = compiled_obj.launch_context.get();
// 1. check all of the output variables will be assigned by compiled program // 1. check all of the output variables will be assigned by compiled program
for (auto&& var_name : output_var_names) { for (auto &&var_name : output_var_names) {
PADDLE_ENFORCE_EQ(launch_context->IsVariableUsed(var_name), PADDLE_ENFORCE_EQ(launch_context->IsVariableUsed(var_name),
true, true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Variable(%s) not applied in CINN", var_name)); "Variable(%s) not applied in CINN", var_name));
} }
// 2. check all of the used input variables were correctly deduced by CINN. // 2. check all of the used input variables were correctly deduced by CINN.
for (const auto& var_name : input_var_names) { for (const auto &var_name : input_var_names) {
// some input variables were not used by CINN because they were eliminated // some input variables were not used by CINN because they were eliminated
// by its optimized passes or some operators of it need less inputs // by its optimized passes or some operators of it need less inputs
if (!launch_context->IsVariableUsed(var_name)) { if (!launch_context->IsVariableUsed(var_name)) {
...@@ -268,11 +263,11 @@ void CinnCompiler::CheckCompiledValid( ...@@ -268,11 +263,11 @@ void CinnCompiler::CheckCompiledValid(
} }
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
const ir::Graph& graph, const ir::Graph &graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor *> &input_tensors,
const Target& target, const Target &target,
std::int64_t compiled_num, std::int64_t compiled_num,
void* stream) const { void *stream) const {
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
auto frontend_program = symbol(); auto frontend_program = symbol();
auto fetch_ids = symbol.GetFetchIds(); auto fetch_ids = symbol.GetFetchIds();
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <cstdint> #include <cstdint>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -26,7 +27,6 @@ ...@@ -26,7 +27,6 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace cinn { namespace cinn {
namespace common { namespace common {
...@@ -129,7 +129,7 @@ class CinnCompiler { ...@@ -129,7 +129,7 @@ class CinnCompiler {
std::unordered_map<std::int64_t, std::unique_ptr<CinnCompiledObject>> std::unordered_map<std::int64_t, std::unique_ptr<CinnCompiledObject>>
index2cache_; index2cache_;
std::atomic_int64_t real_compiled_num_{0}; std::atomic_int64_t real_compiled_num_{0};
mutable phi::RWLock rwlock_; mutable std::mutex lock_;
DISABLE_COPY_AND_ASSIGN(CinnCompiler); DISABLE_COPY_AND_ASSIGN(CinnCompiler);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册