未验证 提交 ccfde2da 编写于 作者: Z Zhen Wang 提交者: GitHub

Update the lock logic used in CinnCompiler::Compile. (#43876)

* Update the lock logic used in CinnCompiler::Compile.
上级 8bd69193
......@@ -18,6 +18,7 @@
#include <iterator>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
......@@ -43,7 +44,6 @@
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
DECLARE_bool(enable_pe_launch_cinn);
DECLARE_bool(enable_cinn_auto_tune);
......@@ -60,66 +60,61 @@ using inference::analysis::Dot;
using ir::Graph;
using ir::Node;
CinnCompiler* CinnCompiler::GetInstance() {
static CinnCompiler* instance = new CinnCompiler();
CinnCompiler *CinnCompiler::GetInstance() {
static CinnCompiler *instance = new CinnCompiler();
return instance;
}
const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target,
void* stream) {
const CinnCompiledObject &CinnCompiler::Compile(
const Graph &graph,
const std::map<std::string, const LoDTensor *> &input_tensors,
const Target &target,
void *stream) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKeyByAddress cur_key_by_address(
graph, input_tensors, target.arch_str());
CinnCacheKeyByStructure cur_key_by_struct;
bool exist = false;
{
phi::AutoRDLock r_guard{&rwlock_};
exist = cache_by_address_.count(cur_key_by_address) != 0;
// if cannot find graph by address, checkout whether the graph structure
// have been stored in cache.
if (!exist) {
if (!cache_by_address_.count(cur_key_by_address)) {
// generate the structure cache key
cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
// if the graph structure can be found, storing the graph address in
// cache for next query.
if (cache_by_struct_.count(cur_key_by_struct) != 0) {
exist = true;
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct);
}
}
}
if (!exist) {
if (!cache_by_struct_.count(cur_key_by_struct)) {
std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream);
phi::AutoWRLock w_guard{&rwlock_};
std::unique_lock<std::mutex> guard(lock_);
// double check cache_by_struct_
if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_by_address_[cur_key_by_address] = compiled_num;
cache_by_struct_[cur_key_by_struct] = compiled_num;
index2cache_.emplace(compiled_num, std::move(compiled_res));
}
// double check cache_by_address_
if (!cache_by_address_.count(cur_key_by_address)) {
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct);
}
} else {
std::unique_lock<std::mutex> guard(lock_);
// double check cache_by_address_
if (!cache_by_address_.count(cur_key_by_address)) {
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct);
}
}
}
phi::AutoRDLock guard{&rwlock_};
const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]];
return cached_boj;
return *index2cache_.at(cache_by_address_.at(cur_key_by_address));
}
const CinnCompiledObject& CinnCompiler::Compile(
const CinnCompiledObject &CinnCompiler::Compile(
int64_t compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target,
void* stream) {
const auto& graph = FindGraph(compilation_key);
const std::map<std::string, const LoDTensor *> &input_tensors,
const Target &target,
void *stream) {
const auto &graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target, stream);
}
const CinnCompiledObject& CinnCompiler::GetCompiledObject(
const CinnCompiledObject &CinnCompiler::GetCompiledObject(
int64_t cached_index) const {
auto res = index2cache_.find(cached_index);
PADDLE_ENFORCE_NE(res,
......@@ -130,7 +125,7 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject(
}
int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
int64_t graph_key = std::hash<Graph*>()((&(*graph)));
int64_t graph_key = std::hash<Graph *>()((&(*graph)));
PADDLE_ENFORCE_EQ(
graphs_.count(graph_key),
0,
......@@ -143,7 +138,7 @@ int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
return graph_key;
}
const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
const Graph &CinnCompiler::FindGraph(int64_t graph_key) const {
auto it = graphs_.find(graph_key);
PADDLE_ENFORCE_NE(
it,
......@@ -155,16 +150,16 @@ const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
}
std::string CinnCompiler::VizGraph(int64_t graph_key) const {
const Graph& graph = FindGraph(graph_key);
const Graph &graph = FindGraph(graph_key);
return VizGraph(graph);
}
std::string CinnCompiler::VizGraph(const Graph& graph) const {
std::string CinnCompiler::VizGraph(const Graph &graph) const {
Dot dot;
std::unordered_map<const Node*, std::string> node2dot;
std::unordered_map<const Node *, std::string> node2dot;
int id = 0;
// Create nodes
for (const Node* n : graph.Nodes()) {
for (const Node *n : graph.Nodes()) {
std::string node_id = "Node" + std::to_string(id++);
if (n->IsOp()) {
dot.AddNode(node_id,
......@@ -180,7 +175,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
auto shape = n->Var()->GetShape();
std::vector<std::string> shape_str(shape.size());
std::transform(
shape.begin(), shape.end(), shape_str.begin(), [](const auto& val) {
shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) {
return std::to_string(val);
});
label += "\n" + string::join_strings(shape_str, ',');
......@@ -198,10 +193,10 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
node2dot[n] = node_id;
}
// Create edges
for (const Node* n : graph.Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& dest_id = node2dot.at(out);
for (const Node *n : graph.Nodes()) {
const auto &src_id = node2dot.at(n);
for (auto *out : n->outputs) {
const auto &dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {});
}
}
......@@ -209,7 +204,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
}
std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
const auto& graph = FindGraph(compilation_key);
const auto &graph = FindGraph(compilation_key);
ProgramDesc program;
GraphToProgram(graph, &program);
......@@ -220,7 +215,7 @@ std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
}
std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
const auto& graph = FindGraph(compilation_key);
const auto &graph = FindGraph(compilation_key);
ProgramDesc program;
GraphToProgram(graph, &program);
......@@ -230,7 +225,7 @@ std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
void CinnCompiler::Clear() {
{
phi::AutoWRLock guard{&rwlock_};
std::unique_lock<std::mutex> guard(lock_);
graphs_.clear();
cache_by_address_.clear();
cache_by_struct_.clear();
......@@ -240,22 +235,22 @@ void CinnCompiler::Clear() {
}
void CinnCompiler::CheckCompiledValid(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const CinnCompiledObject& compiled_obj) const {
const auto& input_var_names = graph.Get<std::vector<std::string>>(kInputVars);
const auto& output_var_names =
const ir::Graph &graph,
const std::map<std::string, const LoDTensor *> &input_tensors,
const CinnCompiledObject &compiled_obj) const {
const auto &input_var_names = graph.Get<std::vector<std::string>>(kInputVars);
const auto &output_var_names =
graph.Get<std::vector<std::string>>(kOutputVars);
auto* launch_context = compiled_obj.launch_context.get();
auto *launch_context = compiled_obj.launch_context.get();
// 1. check all of the output variables will be assigned by compiled program
for (auto&& var_name : output_var_names) {
for (auto &&var_name : output_var_names) {
PADDLE_ENFORCE_EQ(launch_context->IsVariableUsed(var_name),
true,
platform::errors::PreconditionNotMet(
"Variable(%s) not applied in CINN", var_name));
}
// 2. check all of the used input variables were correctly deduced by CINN.
for (const auto& var_name : input_var_names) {
for (const auto &var_name : input_var_names) {
// some input variables were not used by CINN because they were eliminated
// by its optimized passes or some operators of it need less inputs
if (!launch_context->IsVariableUsed(var_name)) {
......@@ -268,11 +263,11 @@ void CinnCompiler::CheckCompiledValid(
}
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target,
const ir::Graph &graph,
const std::map<std::string, const LoDTensor *> &input_tensors,
const Target &target,
std::int64_t compiled_num,
void* stream) const {
void *stream) const {
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
auto frontend_program = symbol();
auto fetch_ids = symbol.GetFetchIds();
......
......@@ -18,6 +18,7 @@
#include <cstdint>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
......@@ -26,7 +27,6 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace cinn {
namespace common {
......@@ -129,7 +129,7 @@ class CinnCompiler {
std::unordered_map<std::int64_t, std::unique_ptr<CinnCompiledObject>>
index2cache_;
std::atomic_int64_t real_compiled_num_{0};
mutable phi::RWLock rwlock_;
mutable std::mutex lock_;
DISABLE_COPY_AND_ASSIGN(CinnCompiler);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册