未验证 提交 edc3496f 编写于 作者: J jiangcheng 提交者: GitHub

Optimize cinn_cache_key by replace GraphToProgram to Dot string (#37317)

* optimize cache-key by replace GraphToProgram to Dot string

* fix compile failure bug
上级 d29cc7b4
...@@ -14,13 +14,16 @@ ...@@ -14,13 +14,16 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include <algorithm>
#include <functional>
#include <map> #include <map>
#include <set>
#include <string> #include <string>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -39,13 +42,42 @@ CinnCacheKey::CinnCacheKey(const ir::Graph& graph, ...@@ -39,13 +42,42 @@ CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
this->SetKey(graph, input_shapes, arch_str); this->SetKey(graph, input_shapes, arch_str);
} }
size_t CinnCacheKey::HashGraph(const ir::Graph& graph) {
// using Dot to unqiue graph
inference::analysis::Dot dot;
std::unordered_map<const ir::Node*, std::string> node2dot;
int id = 0;
// Create nodes
// graph.Nodes() return unordered_set, the same graph may
// return different result?
for (const ir::Node* n : graph.Nodes()) {
std::string node_id = std::to_string(id++);
dot.AddNode(node_id, {}, n->Name(), true);
node2dot[n] = node_id;
}
// Create edges
for (const ir::Node* n : graph.Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {});
}
}
const std::string& viz_graph = dot.Build();
VLOG(1) << "The hash graph:\n" << viz_graph;
size_t hash_val = std::hash<std::string>()(viz_graph);
VLOG(4) << "The graph's hash value is: " << hash_val;
return hash_val;
}
void CinnCacheKey::SetKey( void CinnCacheKey::SetKey(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str) { const std::string& arch_str) {
ProgramDesc program; graph_serialize_str_ = std::to_string(HashGraph(graph));
GraphToProgram(graph, &program);
program.Proto()->SerializeToString(&graph_serialize_str_);
for (const auto& name_tensor : input_tensors) { for (const auto& name_tensor : input_tensors) {
input_shapes_[name_tensor.first] = name_tensor.second->dims(); input_shapes_[name_tensor.first] = name_tensor.second->dims();
} }
...@@ -55,9 +87,7 @@ void CinnCacheKey::SetKey( ...@@ -55,9 +87,7 @@ void CinnCacheKey::SetKey(
void CinnCacheKey::SetKey(const ir::Graph& graph, void CinnCacheKey::SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str) { const std::string& arch_str) {
ProgramDesc program; graph_serialize_str_ = std::to_string(HashGraph(graph));
GraphToProgram(graph, &program);
program.Proto()->SerializeToString(&graph_serialize_str_);
input_shapes_ = input_shapes; input_shapes_ = input_shapes;
arch_str_ = arch_str; arch_str_ = arch_str;
} }
......
...@@ -58,6 +58,8 @@ class CinnCacheKey { ...@@ -58,6 +58,8 @@ class CinnCacheKey {
}; };
private: private:
size_t HashGraph(const ir::Graph& graph);
std::string graph_serialize_str_; std::string graph_serialize_str_;
std::map<std::string, DDim> input_shapes_; std::map<std::string, DDim> input_shapes_;
std::string arch_str_; std::string arch_str_;
......
...@@ -139,7 +139,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { ...@@ -139,7 +139,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
node_id, node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"), {Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")}, Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")},
n->Name()); n->Name(), true);
} else if (n->IsVar()) { } else if (n->IsVar()) {
auto label = n->Name(); auto label = n->Name();
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) { if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
...@@ -155,7 +155,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { ...@@ -155,7 +155,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"), Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"),
Dot::Attr("fontcolor", Dot::Attr("fontcolor",
n->Var()->IsParameter() ? "#ffffff" : "#000000")}, n->Var()->IsParameter() ? "#ffffff" : "#000000")},
label); label, true);
} }
node2dot[n] = node_id; node2dot[n] = node_id;
} }
......
...@@ -59,6 +59,9 @@ class Dot { ...@@ -59,6 +59,9 @@ class Dot {
attrs(attrs), attrs(attrs),
id_("node_" + std::to_string(dot_node_counter++)) {} id_("node_" + std::to_string(dot_node_counter++)) {}
Node(const std::string& name, const std::vector<Attr>& attrs, size_t id)
: name(name), attrs(attrs), id_("node_" + std::to_string(id)) {}
std::string id() const { return id_; } std::string id() const { return id_; }
std::string repr() const { std::string repr() const {
...@@ -113,10 +116,14 @@ class Dot { ...@@ -113,10 +116,14 @@ class Dot {
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {} explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& id, const std::vector<Attr>& attrs, void AddNode(const std::string& id, const std::vector<Attr>& attrs,
std::string label = "") { std::string label = "", bool use_local_id = false) {
CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'"; CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'";
if (label.empty()) label = id; if (label.empty()) label = id;
nodes_.emplace(id, Node{label, attrs}); if (use_local_id) {
nodes_.emplace(id, Node{label, attrs, local_node_counter_++});
} else {
nodes_.emplace(id, Node{label, attrs});
}
} }
void AddEdge(const std::string& source, const std::string& target, void AddEdge(const std::string& source, const std::string& target,
...@@ -154,6 +161,8 @@ class Dot { ...@@ -154,6 +161,8 @@ class Dot {
std::unordered_map<std::string, Node> nodes_; std::unordered_map<std::string, Node> nodes_;
std::vector<Edge> edges_; std::vector<Edge> edges_;
std::vector<Attr> attrs_; std::vector<Attr> attrs_;
size_t local_node_counter_{0};
}; };
} // namespace analysis } // namespace analysis
......
...@@ -98,13 +98,13 @@ CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) { ...@@ -98,13 +98,13 @@ CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
return cinn_scope_->GetTensor(var_name); return cinn_scope_->GetTensor(var_name);
} }
std::vector<std::string> CinnLaunchContext::GetInternalVariableNames() { std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() {
std::unordered_set<std::string> all_parameters(cinn_variable_names_); std::unordered_set<std::string> all_parameters(cinn_variable_names_);
std::for_each(name2argument_.begin(), name2argument_.end(), std::for_each(name2argument_.begin(), name2argument_.end(),
[&all_parameters](const auto& name2arg) { [&all_parameters](const auto& name2arg) {
all_parameters.erase(name2arg.first); all_parameters.erase(name2arg.first);
}); });
return {all_parameters.begin(), all_parameters.end()}; return all_parameters;
} }
void CinnLaunchContext::MutableTensorData(const std::string& var_name, void CinnLaunchContext::MutableTensorData(const std::string& var_name,
......
...@@ -62,7 +62,7 @@ class CinnLaunchContext { ...@@ -62,7 +62,7 @@ class CinnLaunchContext {
// Extract internal variable names from CinnScope // Extract internal variable names from CinnScope
// by excluding used input and output variables // by excluding used input and output variables
std::vector<std::string> GetInternalVariableNames(); std::unordered_set<std::string> GetInternalVariableNames();
// Finalize all execution arguments and return them // Finalize all execution arguments and return them
const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const; const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const;
......
...@@ -223,7 +223,7 @@ TEST(CinnLaunchContextTest, TestGetInternalVariableNames) { ...@@ -223,7 +223,7 @@ TEST(CinnLaunchContextTest, TestGetInternalVariableNames) {
std::make_unique<CinnLaunchContext>(GetDefaultCompiledObj()); std::make_unique<CinnLaunchContext>(GetDefaultCompiledObj());
auto internal_variable_names = launch_context->GetInternalVariableNames(); auto internal_variable_names = launch_context->GetInternalVariableNames();
ASSERT_EQ(internal_variable_names.size(), 1); ASSERT_EQ(internal_variable_names.size(), 1);
EXPECT_EQ(internal_variable_names.front(), "cinn_var2"); EXPECT_EQ(*internal_variable_names.begin(), "cinn_var2");
} }
TEST(CinnLaunchContextTest, TestMutableTensorData) { TEST(CinnLaunchContextTest, TestMutableTensorData) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册