未验证 提交 edc3496f 编写于 作者: J jiangcheng 提交者: GitHub

Optimize cinn_cache_key by replace GraphToProgram to Dot string (#37317)

* optimize cache-key by replace GraphToProgram to Dot string

* fix compile failure bug
上级 d29cc7b4
......@@ -14,13 +14,16 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <string>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle {
namespace framework {
......@@ -39,13 +42,42 @@ CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
this->SetKey(graph, input_shapes, arch_str);
}
size_t CinnCacheKey::HashGraph(const ir::Graph& graph) {
// using Dot to unqiue graph
inference::analysis::Dot dot;
std::unordered_map<const ir::Node*, std::string> node2dot;
int id = 0;
// Create nodes
// graph.Nodes() return unordered_set, the same graph may
// return different result?
for (const ir::Node* n : graph.Nodes()) {
std::string node_id = std::to_string(id++);
dot.AddNode(node_id, {}, n->Name(), true);
node2dot[n] = node_id;
}
// Create edges
for (const ir::Node* n : graph.Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {});
}
}
const std::string& viz_graph = dot.Build();
VLOG(1) << "The hash graph:\n" << viz_graph;
size_t hash_val = std::hash<std::string>()(viz_graph);
VLOG(4) << "The graph's hash value is: " << hash_val;
return hash_val;
}
void CinnCacheKey::SetKey(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str) {
ProgramDesc program;
GraphToProgram(graph, &program);
program.Proto()->SerializeToString(&graph_serialize_str_);
graph_serialize_str_ = std::to_string(HashGraph(graph));
for (const auto& name_tensor : input_tensors) {
input_shapes_[name_tensor.first] = name_tensor.second->dims();
}
......@@ -55,9 +87,7 @@ void CinnCacheKey::SetKey(
void CinnCacheKey::SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str) {
ProgramDesc program;
GraphToProgram(graph, &program);
program.Proto()->SerializeToString(&graph_serialize_str_);
graph_serialize_str_ = std::to_string(HashGraph(graph));
input_shapes_ = input_shapes;
arch_str_ = arch_str;
}
......
......@@ -58,6 +58,8 @@ class CinnCacheKey {
};
private:
size_t HashGraph(const ir::Graph& graph);
std::string graph_serialize_str_;
std::map<std::string, DDim> input_shapes_;
std::string arch_str_;
......
......@@ -139,7 +139,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")},
n->Name());
n->Name(), true);
} else if (n->IsVar()) {
auto label = n->Name();
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
......@@ -155,7 +155,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"),
Dot::Attr("fontcolor",
n->Var()->IsParameter() ? "#ffffff" : "#000000")},
label);
label, true);
}
node2dot[n] = node_id;
}
......
......@@ -59,6 +59,9 @@ class Dot {
attrs(attrs),
id_("node_" + std::to_string(dot_node_counter++)) {}
Node(const std::string& name, const std::vector<Attr>& attrs, size_t id)
: name(name), attrs(attrs), id_("node_" + std::to_string(id)) {}
std::string id() const { return id_; }
std::string repr() const {
......@@ -113,10 +116,14 @@ class Dot {
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& id, const std::vector<Attr>& attrs,
std::string label = "") {
std::string label = "", bool use_local_id = false) {
CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'";
if (label.empty()) label = id;
nodes_.emplace(id, Node{label, attrs});
if (use_local_id) {
nodes_.emplace(id, Node{label, attrs, local_node_counter_++});
} else {
nodes_.emplace(id, Node{label, attrs});
}
}
void AddEdge(const std::string& source, const std::string& target,
......@@ -154,6 +161,8 @@ class Dot {
std::unordered_map<std::string, Node> nodes_;
std::vector<Edge> edges_;
std::vector<Attr> attrs_;
size_t local_node_counter_{0};
};
} // namespace analysis
......
......@@ -98,13 +98,13 @@ CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
return cinn_scope_->GetTensor(var_name);
}
std::vector<std::string> CinnLaunchContext::GetInternalVariableNames() {
std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() {
std::unordered_set<std::string> all_parameters(cinn_variable_names_);
std::for_each(name2argument_.begin(), name2argument_.end(),
[&all_parameters](const auto& name2arg) {
all_parameters.erase(name2arg.first);
});
return {all_parameters.begin(), all_parameters.end()};
return all_parameters;
}
void CinnLaunchContext::MutableTensorData(const std::string& var_name,
......
......@@ -62,7 +62,7 @@ class CinnLaunchContext {
// Extract internal variable names from CinnScope
// by excluding used input and output variables
std::vector<std::string> GetInternalVariableNames();
std::unordered_set<std::string> GetInternalVariableNames();
// Finalize all execution arguments and return them
const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const;
......
......@@ -223,7 +223,7 @@ TEST(CinnLaunchContextTest, TestGetInternalVariableNames) {
std::make_unique<CinnLaunchContext>(GetDefaultCompiledObj());
auto internal_variable_names = launch_context->GetInternalVariableNames();
ASSERT_EQ(internal_variable_names.size(), 1);
EXPECT_EQ(internal_variable_names.front(), "cinn_var2");
EXPECT_EQ(*internal_variable_names.begin(), "cinn_var2");
}
TEST(CinnLaunchContextTest, TestMutableTensorData) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册