未验证 提交 2567dfa4 编写于 作者: J jiangcheng 提交者: GitHub

Optimize CINN cache key (#37786)

* optimize cache key

* add cinn cache key by graph address

* perfect cache key test script

* rename GraphHashProto to GraphHashStrategy

* optimize graph_serialize_str_ to graph_hash_val_ and other change by review advices
上级 151c5d74
...@@ -29,55 +29,32 @@ namespace paddle { ...@@ -29,55 +29,32 @@ namespace paddle {
namespace framework { namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
using GraphHashStrategy = CinnCacheKey::GraphHashStrategy;
CinnCacheKey::CinnCacheKey(GraphHashStrategy graph_hash)
: graph_hash_(graph_hash) {}
CinnCacheKey::CinnCacheKey( CinnCacheKey::CinnCacheKey(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str) { const std::string& arch_str, GraphHashStrategy graph_hash)
: graph_hash_(graph_hash) {
this->SetKey(graph, input_tensors, arch_str); this->SetKey(graph, input_tensors, arch_str);
} }
CinnCacheKey::CinnCacheKey(const ir::Graph& graph, CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str) { const std::string& arch_str,
GraphHashStrategy graph_hash)
: graph_hash_(graph_hash) {
this->SetKey(graph, input_shapes, arch_str); this->SetKey(graph, input_shapes, arch_str);
} }
size_t CinnCacheKey::HashGraph(const ir::Graph& graph) {
// using Dot to unqiue graph
inference::analysis::Dot dot;
std::unordered_map<const ir::Node*, std::string> node2dot;
int id = 0;
// Create nodes
// graph.Nodes() return unordered_set, the same graph may
// return different result?
for (const ir::Node* n : graph.Nodes()) {
std::string node_id = std::to_string(id++);
dot.AddNode(node_id, {}, n->Name(), true);
node2dot[n] = node_id;
}
// Create edges
for (const ir::Node* n : graph.Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {});
}
}
const std::string& viz_graph = dot.Build();
VLOG(1) << "The hash graph:\n" << viz_graph;
size_t hash_val = std::hash<std::string>()(viz_graph);
VLOG(4) << "The graph's hash value is: " << hash_val;
return hash_val;
}
void CinnCacheKey::SetKey( void CinnCacheKey::SetKey(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str) { const std::string& arch_str) {
graph_serialize_str_ = std::to_string(HashGraph(graph)); graph_hash_val_ = graph_hash_(graph);
for (const auto& name_tensor : input_tensors) { for (const auto& name_tensor : input_tensors) {
input_shapes_[name_tensor.first] = name_tensor.second->dims(); input_shapes_[name_tensor.first] = name_tensor.second->dims();
} }
...@@ -87,7 +64,7 @@ void CinnCacheKey::SetKey( ...@@ -87,7 +64,7 @@ void CinnCacheKey::SetKey(
void CinnCacheKey::SetKey(const ir::Graph& graph, void CinnCacheKey::SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str) { const std::string& arch_str) {
graph_serialize_str_ = std::to_string(HashGraph(graph)); graph_hash_val_ = graph_hash_(graph);
input_shapes_ = input_shapes; input_shapes_ = input_shapes;
arch_str_ = arch_str; arch_str_ = arch_str;
} }
...@@ -97,7 +74,7 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const { ...@@ -97,7 +74,7 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const {
} }
bool CinnCacheKey::operator==(const CinnCacheKey& other) const { bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
return graph_serialize_str_ == other.graph_serialize_str_ && return graph_hash_val_ == other.graph_hash_val_ &&
input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_; input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_;
} }
...@@ -114,11 +91,48 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const { ...@@ -114,11 +91,48 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
ret = hash_combine(ret, string_hasher(name_shape.second.to_str())); ret = hash_combine(ret, string_hasher(name_shape.second.to_str()));
} }
ret = hash_combine(ret, string_hasher(key.graph_serialize_str_)); ret = hash_combine(ret, key.graph_hash_val_);
ret = hash_combine(ret, string_hasher(key.arch_str_)); ret = hash_combine(ret, string_hasher(key.arch_str_));
return ret; return ret;
} }
size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
// sort grad node by name and id.
auto compare = [](ir::Node* n1, ir::Node* n2) {
return (n1->Name() == n2->Name()) ? (n1->id() < n2->id())
: (n1->Name() < n2->Name());
};
// graph.Nodes() return unordered_set, here using set to avoid the same graph
// may return different result
std::set<ir::Node *, bool (*)(ir::Node *, ir::Node *)> node_set(compare),
output_set(compare);
node_set.insert(graph.Nodes().begin(), graph.Nodes().end());
std::string hash_str;
for (ir::Node* n : node_set) {
hash_str.append(n->Name());
output_set.clear();
output_set.insert(n->outputs.begin(), n->outputs.end());
for (auto* out : output_set) {
hash_str.append(out->Name());
}
}
VLOG(1) << "The hash graph:\n" << hash_str;
size_t hash_val = std::hash<std::string>()(hash_str);
VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
return hash_val;
}
size_t CinnCacheKeyByAddress::HashGraph(const ir::Graph& graph) {
size_t hash_val = reinterpret_cast<size_t>(&graph);
VLOG(4) << "The graph's hash value by graph address is: " << hash_val;
return hash_val;
}
} // namespace paddle2cinn } // namespace paddle2cinn
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <functional>
#include <map> #include <map>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -33,14 +34,18 @@ namespace paddle2cinn { ...@@ -33,14 +34,18 @@ namespace paddle2cinn {
// shapes. // shapes.
class CinnCacheKey { class CinnCacheKey {
public: public:
using GraphHashStrategy = std::function<size_t(const ir::Graph&)>;
explicit CinnCacheKey(GraphHashStrategy graph_hash);
CinnCacheKey(const ir::Graph& graph, CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str); const std::string& arch_str, GraphHashStrategy graph_hash);
CinnCacheKey(const ir::Graph& graph, CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str); const std::string& arch_str, GraphHashStrategy graph_hash);
~CinnCacheKey() {} ~CinnCacheKey() = default;
void SetKey(const ir::Graph& graph, void SetKey(const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
...@@ -58,13 +63,38 @@ class CinnCacheKey { ...@@ -58,13 +63,38 @@ class CinnCacheKey {
}; };
private: private:
size_t HashGraph(const ir::Graph& graph); GraphHashStrategy graph_hash_;
size_t graph_hash_val_;
std::string graph_serialize_str_;
std::map<std::string, DDim> input_shapes_; std::map<std::string, DDim> input_shapes_;
std::string arch_str_; std::string arch_str_;
}; };
#define CINN_CACHE_KEY_CREATE(NAME) \
class NAME : public CinnCacheKey { \
public: \
NAME() : CinnCacheKey(HashGraph) {} \
\
NAME(const ir::Graph& graph, \
const std::map<std::string, const LoDTensor*>& input_tensors, \
const std::string& arch_str) \
: CinnCacheKey(graph, input_tensors, arch_str, HashGraph) {} \
\
NAME(const ir::Graph& graph, \
const std::map<std::string, DDim>& input_shapes, \
const std::string& arch_str) \
: CinnCacheKey(graph, input_shapes, arch_str, HashGraph) {} \
\
private: \
static size_t HashGraph(const ir::Graph& graph); \
};
// Class to store the keys by graph address for compiling CINN.
CINN_CACHE_KEY_CREATE(CinnCacheKeyByAddress)
// Class to store the keys by graph structure for compiling CINN.
CINN_CACHE_KEY_CREATE(CinnCacheKeyByStructure)
#undef CINN_CACHE_KEY_CREATE
} // namespace paddle2cinn } // namespace paddle2cinn
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,8 +26,8 @@ namespace paddle { ...@@ -26,8 +26,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
TEST(CinnCacheKeyTest, TestAsUnorderedKey) { TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) {
std::unordered_set<CinnCacheKey, CinnCacheKey::Hash> test_set; std::unordered_set<CinnCacheKeyByStructure, CinnCacheKey::Hash> test_set;
ProgramDesc empty_program; ProgramDesc empty_program;
ir::Graph empty_graph(empty_program); ir::Graph empty_graph(empty_program);
...@@ -47,19 +47,20 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) { ...@@ -47,19 +47,20 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) {
DDim ddim = paddle::framework::make_ddim({1, 2, 3}); DDim ddim = paddle::framework::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}}; std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
CinnCacheKey cache_key0(empty_graph, feed_tensors, "x86"); CinnCacheKeyByStructure cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKey cache_key1(empty_graph, feed_shapes, "x86"); CinnCacheKeyByStructure cache_key1(empty_graph, feed_shapes, "x86");
EXPECT_EQ(cache_key0, cache_key1); EXPECT_EQ(cache_key0, cache_key1);
CinnCacheKey cache_key2(graph, feed_shapes, "x86"); CinnCacheKeyByStructure cache_key2(graph, feed_shapes, "x86");
CinnCacheKey cache_key3(graph, feed_shapes, "nvgpu"); CinnCacheKeyByStructure cache_key3(graph, feed_shapes, "nvgpu");
CinnCacheKey cache_key4(graph, feed_tensors, "nvgpu"); CinnCacheKeyByStructure cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3); EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4); EXPECT_EQ(cache_key3, cache_key4);
CinnCacheKey cache_key5(empty_graph, CinnCacheKeyByStructure cache_key5(
std::map<std::string, const LoDTensor *>(), "unk"); empty_graph, std::map<std::string, const LoDTensor *>(), "unk");
CinnCacheKey cache_key6(empty_graph, std::map<std::string, DDim>(), "unk"); CinnCacheKeyByStructure cache_key6(empty_graph, std::map<std::string, DDim>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6); EXPECT_EQ(cache_key5, cache_key6);
EXPECT_NE(cache_key1, cache_key3); EXPECT_NE(cache_key1, cache_key3);
...@@ -98,6 +99,107 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) { ...@@ -98,6 +99,107 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) {
EXPECT_EQ(test_set.find(cache_key6), test_set.end()); EXPECT_EQ(test_set.find(cache_key6), test_set.end());
} }
TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) {
std::unordered_set<CinnCacheKeyByAddress, CinnCacheKey::Hash> test_set;
ProgramDesc empty_program;
ir::Graph empty_graph(empty_program);
ProgramDesc program;
auto *global_block = program.MutableBlock(0);
auto *x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph(program);
LoDTensor tensor;
tensor.Resize({1, 2, 3});
const LoDTensor *tensor_pointer = &tensor;
std::map<std::string, const LoDTensor *> feed_tensors = {
{"X", tensor_pointer}};
DDim ddim = paddle::framework::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
CinnCacheKeyByAddress cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKeyByAddress cache_key1(empty_graph, feed_shapes, "x86");
EXPECT_EQ(cache_key0, cache_key1);
CinnCacheKeyByAddress cache_key2(graph, feed_shapes, "x86");
CinnCacheKeyByAddress cache_key3(graph, feed_shapes, "nvgpu");
CinnCacheKeyByAddress cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4);
CinnCacheKeyByAddress cache_key5(
empty_graph, std::map<std::string, const LoDTensor *>(), "unk");
CinnCacheKeyByAddress cache_key6(empty_graph, std::map<std::string, DDim>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6);
EXPECT_NE(cache_key1, cache_key3);
EXPECT_NE(cache_key4, cache_key2);
EXPECT_NE(cache_key3, cache_key5);
EXPECT_NE(cache_key6, cache_key4);
EXPECT_NE(cache_key5, cache_key1);
EXPECT_NE(cache_key2, cache_key6);
test_set.insert(cache_key0);
test_set.insert(cache_key1);
test_set.insert(cache_key3);
test_set.insert(cache_key4);
test_set.insert(cache_key5);
test_set.insert(cache_key6);
EXPECT_EQ(test_set.size(), 3U);
auto iter = test_set.find(cache_key0);
EXPECT_NE(iter, test_set.end());
test_set.erase(iter);
EXPECT_EQ(test_set.size(), 2U);
EXPECT_EQ(test_set.find(cache_key1), test_set.end());
iter = test_set.find(cache_key3);
EXPECT_NE(iter, test_set.end());
test_set.erase(iter);
EXPECT_EQ(test_set.size(), 1U);
EXPECT_EQ(test_set.find(cache_key4), test_set.end());
iter = test_set.find(cache_key5);
EXPECT_NE(iter, test_set.end());
test_set.erase(iter);
EXPECT_EQ(test_set.size(), 0U);
EXPECT_EQ(test_set.find(cache_key6), test_set.end());
}
TEST(CinnCacheKeyTest, TestSameGraph) {
ProgramDesc program1;
auto *global_block1 = program1.MutableBlock(0);
auto *x1 = global_block1->Var("X");
x1->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph1(program1);
ProgramDesc program2;
auto *global_block2 = program2.MutableBlock(0);
auto *x2 = global_block2->Var("X");
x2->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph2(program2);
LoDTensor tensor;
tensor.Resize({1, 2, 3});
const LoDTensor *tensor_pointer = &tensor;
std::map<std::string, const LoDTensor *> feed_tensors = {
{"X", tensor_pointer}};
CinnCacheKeyByAddress cache_key_by_address1(graph1, feed_tensors, "x86");
CinnCacheKeyByAddress cache_key_by_address2(graph2, feed_tensors, "x86");
EXPECT_NE(cache_key_by_address1, cache_key_by_address2);
CinnCacheKeyByStructure cache_key_by_struct1(graph1, feed_tensors, "x86");
CinnCacheKeyByStructure cache_key_by_struct2(graph2, feed_tensors, "x86");
EXPECT_EQ(cache_key_by_struct1, cache_key_by_struct2);
}
} // namespace paddle2cinn } // namespace paddle2cinn
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -69,23 +69,41 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -69,23 +69,41 @@ const CinnCompiledObject& CinnCompiler::Compile(
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target, void* stream) { const Target& target, void* stream) {
VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph); VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); CinnCacheKeyByAddress cur_key_by_address(graph, input_tensors,
target.arch_str());
CinnCacheKeyByStructure cur_key_by_struct;
bool exist = false; bool exist = false;
{ {
AutoRDLock r_guard{&rwlock_}; AutoRDLock r_guard{&rwlock_};
exist = cache_.count(cur_key) != 0; exist = cache_by_address_.count(cur_key_by_address) != 0;
// if cannot find graph by address, checkout whether the graph structure
// have been stored in cache.
if (!exist) {
// generate the structure cache key
cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
// if the graph structure can be found, storing the graph address in
// cache for next query.
if (cache_by_struct_.count(cur_key_by_struct) != 0) {
exist = true;
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct).get();
}
}
} }
if (!exist) { if (!exist) {
std::int64_t compiled_num = real_compiled_num_.fetch_add(1); std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res = auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream); CompileGraph(graph, input_tensors, target, compiled_num, stream);
AutoWRLock w_guard{&rwlock_}; AutoWRLock w_guard{&rwlock_};
if (!cache_.count(cur_key)) { if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_[cur_key] = std::move(compiled_res); cache_by_address_[cur_key_by_address] = compiled_res.get();
cache_by_struct_[cur_key_by_struct] = std::move(compiled_res);
} }
} }
AutoRDLock guard{&rwlock_}; AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_[cur_key]; const auto& cached_boj = *cache_by_address_[cur_key_by_address];
return cached_boj; return cached_boj;
} }
...@@ -182,7 +200,8 @@ void CinnCompiler::Clear() { ...@@ -182,7 +200,8 @@ void CinnCompiler::Clear() {
{ {
AutoWRLock guard{&rwlock_}; AutoWRLock guard{&rwlock_};
graphs_.clear(); graphs_.clear();
cache_.clear(); cache_by_address_.clear();
cache_by_struct_.clear();
} }
real_compiled_num_.store(0); real_compiled_num_.store(0);
} }
......
...@@ -95,9 +95,12 @@ class CinnCompiler { ...@@ -95,9 +95,12 @@ class CinnCompiler {
void* stream = nullptr) const; void* stream = nullptr) const;
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_; std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKey, std::unique_ptr<CinnCompiledObject>, std::unordered_map<CinnCacheKeyByAddress, CinnCompiledObject*,
CinnCacheKey::Hash> CinnCacheKey::Hash>
cache_; cache_by_address_;
std::unordered_map<CinnCacheKeyByStructure,
std::unique_ptr<CinnCompiledObject>, CinnCacheKey::Hash>
cache_by_struct_;
std::atomic_int64_t real_compiled_num_{0}; std::atomic_int64_t real_compiled_num_{0};
mutable RWLock rwlock_; mutable RWLock rwlock_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册