未验证 提交 661d0800 编写于 作者: J jiangcheng 提交者: GitHub

optimize cinn find graph by graph address (#42697)

* optimize cinn find graph by graph address

* graph_key use int64_t instead of program string

* fix framework _to_readable_code python code

* rename get_readable_comile_key to get_serialize_comile_key
上级 f55c0b33
......@@ -66,7 +66,7 @@ static void ShareVarInfoToCinnLaunch(
<< paddle::string::join_strings(vars_to_delete, ',');
const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph(
cinn_launch_op->GetOp()->Attr<std::string>(operators::kCompilationKey));
cinn_launch_op->GetOp()->Attr<int64_t>(operators::kCompilationKey));
auto& dst_varinfo_map =
subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
const Name2VarInfoMap& src_varinfo_map =
......
......@@ -51,8 +51,7 @@ static ProgramDesc BuildProgramInsideCinnLaunchOp() {
return program;
}
static ProgramDesc BuildProgramWithCinnLaunchOp(
const std::string& compilation_key) {
static ProgramDesc BuildProgramWithCinnLaunchOp(int64_t compilation_key) {
// create a cinn_launch op
ProgramDesc program;
auto* block = program.MutableBlock(0);
......@@ -89,7 +88,7 @@ TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) {
auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp());
subgraph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
std::string compilation_key =
auto compilation_key =
paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph));
// build test data and apply pass
......
......@@ -487,7 +487,7 @@ void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const std::string& compilation_key,
int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) {
// Add the cinn launch op
......@@ -536,7 +536,7 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
void ReplaceSubGraphWithCinnOpNode(
const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_internals,
const std::string& compilation_key,
int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set, Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key,
......@@ -613,7 +613,7 @@ void SearchAllSubgraphs(Graph* graph) {
// Create a new subgraph according to the found cluster and
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
auto compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);
......
......@@ -90,12 +90,12 @@ inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {
}
// Get compilation_key values
std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
std::vector<int64_t> GetCompilationKeys(const Graph& graph) {
std::vector<int64_t> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey)));
int64_t, node->Op()->GetAttr(operators::kCompilationKey)));
}
}
return compilation_keys;
......
......@@ -18,6 +18,7 @@
#include <functional>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -77,22 +78,17 @@ bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_;
}
size_t CinnCacheKey::Hash::hash_combine(size_t seed, size_t value) {
return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}
size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
std::size_t ret = 0;
std::ostringstream has_str;
std::hash<std::string> string_hasher;
for (const auto& name_shape : key.input_shapes_) {
ret = hash_combine(ret, string_hasher(name_shape.first));
ret = hash_combine(ret, string_hasher(name_shape.second.to_str()));
has_str << name_shape.first;
has_str << name_shape.second.to_str();
}
ret = hash_combine(ret, key.graph_hash_val_);
ret = hash_combine(ret, string_hasher(key.arch_str_));
return ret;
has_str << key.graph_hash_val_;
has_str << key.arch_str_;
return std::hash<std::string>()(has_str.str());
}
size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
......
......@@ -58,7 +58,6 @@ class CinnCacheKey {
bool operator!=(const CinnCacheKey& other) const;
struct Hash {
static size_t hash_combine(size_t seed, size_t value);
size_t operator()(const CinnCacheKey& key) const;
};
......
......@@ -110,7 +110,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
}
const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
int64_t compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target, void* stream) {
const auto& graph = FindGraph(compilation_key);
......@@ -126,12 +126,8 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject(
return *res->second;
}
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key;
ProgramDesc program;
GraphToProgram(*graph, &program);
program.Proto()->SerializeToString(&graph_key);
int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
int64_t graph_key = std::hash<Graph*>()((&(*graph)));
PADDLE_ENFORCE_EQ(
graphs_.count(graph_key), 0,
platform::errors::PreconditionNotMet(
......@@ -143,16 +139,17 @@ std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
return graph_key;
}
const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
auto it = graphs_.find(graph_key);
PADDLE_ENFORCE_NE(
graphs_.count(graph_key), 0,
it, graphs_.end(),
platform::errors::PreconditionNotMet(
"Can not find the target graph, of which the key is:\n%s",
ReadableKey(graph_key).c_str()));
return *graphs_.at(graph_key);
"Can not find the target graph, of which the key is: %lld",
graph_key));
return *it->second;
}
std::string CinnCompiler::VizGraph(const std::string& graph_key) const {
std::string CinnCompiler::VizGraph(int64_t graph_key) const {
const Graph& graph = FindGraph(graph_key);
return VizGraph(graph);
}
......@@ -200,11 +197,24 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
return dot.Build();
}
std::string CinnCompiler::ReadableKey(
const std::string& compilation_key) const {
proto::ProgramDesc desc;
desc.ParseFromString(compilation_key);
return desc.DebugString();
std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
const auto& graph = FindGraph(compilation_key);
ProgramDesc program;
GraphToProgram(graph, &program);
std::string serial_graph;
program.Proto()->SerializeToString(&serial_graph);
return serial_graph;
}
std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
const auto& graph = FindGraph(compilation_key);
ProgramDesc program;
GraphToProgram(graph, &program);
return program.Proto()->DebugString();
}
void CinnCompiler::Clear() {
......
......@@ -78,21 +78,23 @@ class CinnCompiler {
const ::cinn::common::Target& target, void* stream = nullptr);
const CinnCompiledObject& Compile(
const std::string& compilation_key,
int64_t compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target, void* stream = nullptr);
const CinnCompiledObject& GetCompiledObject(int64_t cached_index) const;
std::string AddGraph(std::unique_ptr<ir::Graph> graph);
int64_t AddGraph(std::unique_ptr<ir::Graph> graph);
const ir::Graph& FindGraph(const std::string& graph_key) const;
const ir::Graph& FindGraph(int64_t graph_key) const;
std::string VizGraph(const std::string& graph_key) const;
std::string VizGraph(int64_t graph_key) const;
std::string VizGraph(const ir::Graph& graph) const;
std::string ReadableKey(const std::string& compilation_key) const;
std::string SerializeKey(int64_t compilation_key) const;
std::string ReadableKey(int64_t compilation_key) const;
void Clear();
......@@ -115,7 +117,7 @@ class CinnCompiler {
const std::map<std::string, const LoDTensor*>& input_tensors,
const CinnCompiledObject& compiled_obj) const;
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<int64_t, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKeyByAddress, std::int64_t, CinnCacheKey::Hash>
cache_by_address_;
std::unordered_map<CinnCacheKeyByStructure, std::int64_t, CinnCacheKey::Hash>
......
......@@ -59,12 +59,12 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T, Alloc>& vec) {
}
// Get compilation_key values
std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
std::vector<int64_t> GetCompilationKeys(const Graph& graph) {
std::vector<int64_t> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey)));
int64_t, node->Op()->GetAttr(operators::kCompilationKey)));
}
}
return compilation_keys;
......@@ -83,13 +83,12 @@ std::unordered_set<std::string> ExtractOpTypes(const Graph& graph) {
// Get inputs info
std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo(
const std::string& key, const Graph& graph) {
int64_t key, const Graph& graph) {
std::unordered_set<std::string> inputs;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
if (BOOST_GET_CONST(std::string,
node->Op()->GetAttr(operators::kCompilationKey)) !=
key) {
if (BOOST_GET_CONST(int64_t, node->Op()->GetAttr(
operators::kCompilationKey)) != key) {
continue;
}
for (auto in_var_name : node->Op()->InputArgumentNames()) {
......@@ -251,8 +250,7 @@ TEST(CinnCompilerTest, Compile) {
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key);
viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph));
EXPECT_THROW(cinn_compiler->FindGraph("no_existed"),
paddle::platform::EnforceNotMet);
EXPECT_THROW(cinn_compiler->FindGraph(0), paddle::platform::EnforceNotMet);
auto inputs_info = GetInputsInfo(compilation_key, *graph);
std::unordered_map<std::string, LoDTensor> create_inputs;
......
......@@ -136,7 +136,7 @@ class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<LoDTensor>)"
"which are the output of graph inside the CinnLaunchOp.")
.AsDuplicable();
AddAttr<std::string>(
AddAttr<int64_t>(
kCompilationKey,
"(string)"
"a hash key used to get the graph object or its computation result.");
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(enable_pe_launch_cinn);
namespace paddle {
......@@ -61,13 +62,14 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
const auto& scope = ctx.scope();
const auto& place = ctx.GetPlace();
void* stream = details::GetStream<DeviceContext>(ctx);
platform::RecordEvent record_event_1(
"Step 1. Find graph object and prepare input");
// Step 1. Find graph object and prepare input
PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey), true,
platform::errors::NotFound(
"No Attribute(%s) found for CinnLaunchOp operator.",
kCompilationKey));
const auto& compilation_key =
ctx.template Attr<std::string>(kCompilationKey);
const auto& compilation_key = ctx.template Attr<int64_t>(kCompilationKey);
VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") "
<< "value:\n"
<< CinnCompiler::GetInstance()->ReadableKey(compilation_key);
......@@ -100,6 +102,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
input_no_need_buffer_tensors);
}
platform::RecordEvent record_event_2(
"Step 2. Get compilation result of the graph");
// Step 2. Get compilation result of the graph
auto target = details::PlaceToCinnTarget(place);
using ClockType = std::chrono::steady_clock;
......@@ -120,17 +124,22 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
details::DebugCinnCompiledResult(cinn_compiled_object);
auto* launch_context = cinn_compiled_object.launch_context.get();
platform::RecordEvent record_event_3("Step 3. Set CINN runtime FLAGS.");
// Step 3. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
details::SetCinnRuntimeFlags();
// Step 4. Execute the compiled CINN instructions by a PE or
// by the CINN compiled program in sequential order
if (FLAGS_enable_pe_launch_cinn) {
platform::RecordEvent record_event_4(
"Step 4. Execute the runtime graph by PE.");
VLOG(4) << "Execute the runtime graph by PE";
framework::Scope& exec_scope = scope.NewScope();
auto* pe = launch_context->InitializePE(place, &exec_scope);
pe->RunWithoutFetch(launch_context->GetSkipEagerVars());
} else {
platform::RecordEvent record_event_4(
"Step 4. Execute the compiled executable program.");
VLOG(4) << "Execute the compiled executable program";
launch_context->UpdateCapturedEnv(scope, place);
LaunchCinnExecution(cinn_compiled_object, *launch_context, stream);
......
......@@ -166,6 +166,10 @@ limitations under the License. */
#include "paddle/fluid/pybind/fleet_py.h"
#endif
#ifdef PADDLE_WITH_CINN
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#endif
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/pybind/eager_utils.h"
......@@ -1930,16 +1934,18 @@ All parameter, weight, gradient are variables in Paddle.
which contains the id pair of pruned block and corresponding
origin block.
)DOC");
m.def("get_readable_comile_key", [](const OpDesc &op_desc) {
auto compilation_key =
BOOST_GET_CONST(std::string, op_desc.GetAttr("compilation_key"));
VLOG(4) << std::hash<std::string>{}(compilation_key) << " "
<< compilation_key.size();
proto::ProgramDesc desc;
desc.ParseFromString(compilation_key);
auto s = desc.DebugString();
m.def("get_serialize_comile_key", [](int64_t compilation_key) {
#ifdef PADDLE_WITH_CINN
auto compiler = framework::paddle2cinn::CinnCompiler::GetInstance();
auto s = compiler->SerializeKey(compilation_key);
VLOG(4) << s;
return s;
#else
PADDLE_THROW(
platform::errors::PermissionDenied(
"Cannot get compilation key in non-CINN version, "
"Please recompile or reinstall Paddle with CINN support."));
#endif
});
m.def("empty_var_name",
[]() { return std::string(framework::kEmptyVarName); });
......
......@@ -2864,9 +2864,10 @@ class Operator(object):
continue
# it is bytes of serialized protobuf
if self.type == 'cinn_launch' and name == 'compilation_key':
# value = core.get_readable_comile_key(self.desc)
v = self.desc.attr(name)
if is_compiled_with_cinn(
) and self.type == 'cinn_launch' and name == 'compilation_key':
key = self.desc.attr(name)
v = core.get_serialize_comile_key(key)
prog = Program()
prog = prog.parse_from_string(v)
s = prog._to_readable_code()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册