未验证 提交 661d0800 编写于 作者: J jiangcheng 提交者: GitHub

optimize cinn find graph by graph address (#42697)

* optimize cinn find graph by graph address

* graph_key use int64_t instead of program string

* fix framework _to_readable_code python code

* rename get_readable_comile_key to get_serialize_comile_key
上级 f55c0b33
...@@ -66,7 +66,7 @@ static void ShareVarInfoToCinnLaunch( ...@@ -66,7 +66,7 @@ static void ShareVarInfoToCinnLaunch(
<< paddle::string::join_strings(vars_to_delete, ','); << paddle::string::join_strings(vars_to_delete, ',');
const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph( const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph(
cinn_launch_op->GetOp()->Attr<std::string>(operators::kCompilationKey)); cinn_launch_op->GetOp()->Attr<int64_t>(operators::kCompilationKey));
auto& dst_varinfo_map = auto& dst_varinfo_map =
subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph); subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
const Name2VarInfoMap& src_varinfo_map = const Name2VarInfoMap& src_varinfo_map =
......
...@@ -51,8 +51,7 @@ static ProgramDesc BuildProgramInsideCinnLaunchOp() { ...@@ -51,8 +51,7 @@ static ProgramDesc BuildProgramInsideCinnLaunchOp() {
return program; return program;
} }
static ProgramDesc BuildProgramWithCinnLaunchOp( static ProgramDesc BuildProgramWithCinnLaunchOp(int64_t compilation_key) {
const std::string& compilation_key) {
// create a cinn_launch op // create a cinn_launch op
ProgramDesc program; ProgramDesc program;
auto* block = program.MutableBlock(0); auto* block = program.MutableBlock(0);
...@@ -89,7 +88,7 @@ TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) { ...@@ -89,7 +88,7 @@ TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) {
auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp()); auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp());
subgraph->GetOrInit<Name2VarInfoMap>( subgraph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph); paddle2cinn::kMemOptVarInfoFromMainGraph);
std::string compilation_key = auto compilation_key =
paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph)); paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph));
// build test data and apply pass // build test data and apply pass
......
...@@ -487,7 +487,7 @@ void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs, ...@@ -487,7 +487,7 @@ void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
void AddCinnOpToGraph(const GraphNodeSet& cluster, void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_outputs,
const std::string& compilation_key, int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set, const std::unordered_set<std::string>& deny_var_set,
Graph* graph) { Graph* graph) {
// Add the cinn launch op // Add the cinn launch op
...@@ -536,7 +536,7 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, ...@@ -536,7 +536,7 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
void ReplaceSubGraphWithCinnOpNode( void ReplaceSubGraphWithCinnOpNode(
const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_internals, const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_internals,
const std::string& compilation_key, int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set, Graph* graph) { const std::unordered_set<std::string>& deny_var_set, Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph // Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key, AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key,
...@@ -613,7 +613,7 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -613,7 +613,7 @@ void SearchAllSubgraphs(Graph* graph) {
// Create a new subgraph according to the found cluster and // Create a new subgraph according to the found cluster and
// save it in CinnCompiler // save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( auto compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
VLOG(4) << "Compilation Key:\n" VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key); << cinn_compiler->ReadableKey(compilation_key);
......
...@@ -90,12 +90,12 @@ inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) { ...@@ -90,12 +90,12 @@ inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {
} }
// Get compilation_key values // Get compilation_key values
std::vector<std::string> GetCompilationKeys(const Graph& graph) { std::vector<int64_t> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys; std::vector<int64_t> compilation_keys;
for (auto& node : graph.Nodes()) { for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) { if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(BOOST_GET_CONST( compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey))); int64_t, node->Op()->GetAttr(operators::kCompilationKey)));
} }
} }
return compilation_keys; return compilation_keys;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <functional> #include <functional>
#include <map> #include <map>
#include <set> #include <set>
#include <sstream>
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -77,22 +78,17 @@ bool CinnCacheKey::operator==(const CinnCacheKey& other) const { ...@@ -77,22 +78,17 @@ bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_; input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_;
} }
size_t CinnCacheKey::Hash::hash_combine(size_t seed, size_t value) {
return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}
size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const { size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
std::size_t ret = 0; std::ostringstream has_str;
std::hash<std::string> string_hasher;
for (const auto& name_shape : key.input_shapes_) { for (const auto& name_shape : key.input_shapes_) {
ret = hash_combine(ret, string_hasher(name_shape.first)); has_str << name_shape.first;
ret = hash_combine(ret, string_hasher(name_shape.second.to_str())); has_str << name_shape.second.to_str();
} }
ret = hash_combine(ret, key.graph_hash_val_); has_str << key.graph_hash_val_;
ret = hash_combine(ret, string_hasher(key.arch_str_)); has_str << key.arch_str_;
return ret; return std::hash<std::string>()(has_str.str());
} }
size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) { size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
......
...@@ -58,7 +58,6 @@ class CinnCacheKey { ...@@ -58,7 +58,6 @@ class CinnCacheKey {
bool operator!=(const CinnCacheKey& other) const; bool operator!=(const CinnCacheKey& other) const;
struct Hash { struct Hash {
static size_t hash_combine(size_t seed, size_t value);
size_t operator()(const CinnCacheKey& key) const; size_t operator()(const CinnCacheKey& key) const;
}; };
......
...@@ -110,7 +110,7 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -110,7 +110,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
} }
const CinnCompiledObject& CinnCompiler::Compile( const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key, int64_t compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target, void* stream) { const Target& target, void* stream) {
const auto& graph = FindGraph(compilation_key); const auto& graph = FindGraph(compilation_key);
...@@ -126,12 +126,8 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject( ...@@ -126,12 +126,8 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject(
return *res->second; return *res->second;
} }
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key; int64_t graph_key = std::hash<Graph*>()((&(*graph)));
ProgramDesc program;
GraphToProgram(*graph, &program);
program.Proto()->SerializeToString(&graph_key);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
graphs_.count(graph_key), 0, graphs_.count(graph_key), 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -143,16 +139,17 @@ std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { ...@@ -143,16 +139,17 @@ std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
return graph_key; return graph_key;
} }
const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const { const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
auto it = graphs_.find(graph_key);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
graphs_.count(graph_key), 0, it, graphs_.end(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Can not find the target graph, of which the key is:\n%s", "Can not find the target graph, of which the key is: %lld",
ReadableKey(graph_key).c_str())); graph_key));
return *graphs_.at(graph_key); return *it->second;
} }
std::string CinnCompiler::VizGraph(const std::string& graph_key) const { std::string CinnCompiler::VizGraph(int64_t graph_key) const {
const Graph& graph = FindGraph(graph_key); const Graph& graph = FindGraph(graph_key);
return VizGraph(graph); return VizGraph(graph);
} }
...@@ -200,11 +197,24 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const { ...@@ -200,11 +197,24 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
return dot.Build(); return dot.Build();
} }
std::string CinnCompiler::ReadableKey( std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
const std::string& compilation_key) const { const auto& graph = FindGraph(compilation_key);
proto::ProgramDesc desc;
desc.ParseFromString(compilation_key); ProgramDesc program;
return desc.DebugString(); GraphToProgram(graph, &program);
std::string serial_graph;
program.Proto()->SerializeToString(&serial_graph);
return serial_graph;
}
std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
const auto& graph = FindGraph(compilation_key);
ProgramDesc program;
GraphToProgram(graph, &program);
return program.Proto()->DebugString();
} }
void CinnCompiler::Clear() { void CinnCompiler::Clear() {
......
...@@ -78,21 +78,23 @@ class CinnCompiler { ...@@ -78,21 +78,23 @@ class CinnCompiler {
const ::cinn::common::Target& target, void* stream = nullptr); const ::cinn::common::Target& target, void* stream = nullptr);
const CinnCompiledObject& Compile( const CinnCompiledObject& Compile(
const std::string& compilation_key, int64_t compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target, void* stream = nullptr); const ::cinn::common::Target& target, void* stream = nullptr);
const CinnCompiledObject& GetCompiledObject(int64_t cached_index) const; const CinnCompiledObject& GetCompiledObject(int64_t cached_index) const;
std::string AddGraph(std::unique_ptr<ir::Graph> graph); int64_t AddGraph(std::unique_ptr<ir::Graph> graph);
const ir::Graph& FindGraph(const std::string& graph_key) const; const ir::Graph& FindGraph(int64_t graph_key) const;
std::string VizGraph(const std::string& graph_key) const; std::string VizGraph(int64_t graph_key) const;
std::string VizGraph(const ir::Graph& graph) const; std::string VizGraph(const ir::Graph& graph) const;
std::string ReadableKey(const std::string& compilation_key) const; std::string SerializeKey(int64_t compilation_key) const;
std::string ReadableKey(int64_t compilation_key) const;
void Clear(); void Clear();
...@@ -115,7 +117,7 @@ class CinnCompiler { ...@@ -115,7 +117,7 @@ class CinnCompiler {
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const CinnCompiledObject& compiled_obj) const; const CinnCompiledObject& compiled_obj) const;
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_; std::unordered_map<int64_t, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKeyByAddress, std::int64_t, CinnCacheKey::Hash> std::unordered_map<CinnCacheKeyByAddress, std::int64_t, CinnCacheKey::Hash>
cache_by_address_; cache_by_address_;
std::unordered_map<CinnCacheKeyByStructure, std::int64_t, CinnCacheKey::Hash> std::unordered_map<CinnCacheKeyByStructure, std::int64_t, CinnCacheKey::Hash>
......
...@@ -59,12 +59,12 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T, Alloc>& vec) { ...@@ -59,12 +59,12 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T, Alloc>& vec) {
} }
// Get compilation_key values // Get compilation_key values
std::vector<std::string> GetCompilationKeys(const Graph& graph) { std::vector<int64_t> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys; std::vector<int64_t> compilation_keys;
for (auto& node : graph.Nodes()) { for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) { if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(BOOST_GET_CONST( compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey))); int64_t, node->Op()->GetAttr(operators::kCompilationKey)));
} }
} }
return compilation_keys; return compilation_keys;
...@@ -83,13 +83,12 @@ std::unordered_set<std::string> ExtractOpTypes(const Graph& graph) { ...@@ -83,13 +83,12 @@ std::unordered_set<std::string> ExtractOpTypes(const Graph& graph) {
// Get inputs info // Get inputs info
std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo( std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo(
const std::string& key, const Graph& graph) { int64_t key, const Graph& graph) {
std::unordered_set<std::string> inputs; std::unordered_set<std::string> inputs;
for (auto& node : graph.Nodes()) { for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) { if (node->IsOp() && node->Name() == kCinnLaunchOp) {
if (BOOST_GET_CONST(std::string, if (BOOST_GET_CONST(int64_t, node->Op()->GetAttr(
node->Op()->GetAttr(operators::kCompilationKey)) != operators::kCompilationKey)) != key) {
key) {
continue; continue;
} }
for (auto in_var_name : node->Op()->InputArgumentNames()) { for (auto in_var_name : node->Op()->InputArgumentNames()) {
...@@ -251,8 +250,7 @@ TEST(CinnCompilerTest, Compile) { ...@@ -251,8 +250,7 @@ TEST(CinnCompilerTest, Compile) {
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key); const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key);
viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph)); viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph));
EXPECT_THROW(cinn_compiler->FindGraph("no_existed"), EXPECT_THROW(cinn_compiler->FindGraph(0), paddle::platform::EnforceNotMet);
paddle::platform::EnforceNotMet);
auto inputs_info = GetInputsInfo(compilation_key, *graph); auto inputs_info = GetInputsInfo(compilation_key, *graph);
std::unordered_map<std::string, LoDTensor> create_inputs; std::unordered_map<std::string, LoDTensor> create_inputs;
......
...@@ -136,7 +136,7 @@ class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -136,7 +136,7 @@ class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<LoDTensor>)" "(vector<LoDTensor>)"
"which are the output of graph inside the CinnLaunchOp.") "which are the output of graph inside the CinnLaunchOp.")
.AsDuplicable(); .AsDuplicable();
AddAttr<std::string>( AddAttr<int64_t>(
kCompilationKey, kCompilationKey,
"(string)" "(string)"
"a hash key used to get the graph object or its computation result."); "a hash key used to get the graph object or its computation result.");
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h" #include "paddle/fluid/operators/cinn/cinn_op_helper.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(enable_pe_launch_cinn); DECLARE_bool(enable_pe_launch_cinn);
namespace paddle { namespace paddle {
...@@ -61,13 +62,14 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -61,13 +62,14 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
const auto& scope = ctx.scope(); const auto& scope = ctx.scope();
const auto& place = ctx.GetPlace(); const auto& place = ctx.GetPlace();
void* stream = details::GetStream<DeviceContext>(ctx); void* stream = details::GetStream<DeviceContext>(ctx);
platform::RecordEvent record_event_1(
"Step 1. Find graph object and prepare input");
// Step 1. Find graph object and prepare input // Step 1. Find graph object and prepare input
PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey), true, PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey), true,
platform::errors::NotFound( platform::errors::NotFound(
"No Attribute(%s) found for CinnLaunchOp operator.", "No Attribute(%s) found for CinnLaunchOp operator.",
kCompilationKey)); kCompilationKey));
const auto& compilation_key = const auto& compilation_key = ctx.template Attr<int64_t>(kCompilationKey);
ctx.template Attr<std::string>(kCompilationKey);
VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") " VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") "
<< "value:\n" << "value:\n"
<< CinnCompiler::GetInstance()->ReadableKey(compilation_key); << CinnCompiler::GetInstance()->ReadableKey(compilation_key);
...@@ -100,6 +102,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -100,6 +102,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
input_no_need_buffer_tensors); input_no_need_buffer_tensors);
} }
platform::RecordEvent record_event_2(
"Step 2. Get compilation result of the graph");
// Step 2. Get compilation result of the graph // Step 2. Get compilation result of the graph
auto target = details::PlaceToCinnTarget(place); auto target = details::PlaceToCinnTarget(place);
using ClockType = std::chrono::steady_clock; using ClockType = std::chrono::steady_clock;
...@@ -120,17 +124,22 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -120,17 +124,22 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
details::DebugCinnCompiledResult(cinn_compiled_object); details::DebugCinnCompiledResult(cinn_compiled_object);
auto* launch_context = cinn_compiled_object.launch_context.get(); auto* launch_context = cinn_compiled_object.launch_context.get();
platform::RecordEvent record_event_3("Step 3. Set CINN runtime FLAGS.");
// Step 3. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic. // Step 3. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
details::SetCinnRuntimeFlags(); details::SetCinnRuntimeFlags();
// Step 4. Execute the compiled CINN instructions by a PE or // Step 4. Execute the compiled CINN instructions by a PE or
// by the CINN compiled program in sequential order // by the CINN compiled program in sequential order
if (FLAGS_enable_pe_launch_cinn) { if (FLAGS_enable_pe_launch_cinn) {
platform::RecordEvent record_event_4(
"Step 4. Execute the runtime graph by PE.");
VLOG(4) << "Execute the runtime graph by PE"; VLOG(4) << "Execute the runtime graph by PE";
framework::Scope& exec_scope = scope.NewScope(); framework::Scope& exec_scope = scope.NewScope();
auto* pe = launch_context->InitializePE(place, &exec_scope); auto* pe = launch_context->InitializePE(place, &exec_scope);
pe->RunWithoutFetch(launch_context->GetSkipEagerVars()); pe->RunWithoutFetch(launch_context->GetSkipEagerVars());
} else { } else {
platform::RecordEvent record_event_4(
"Step 4. Execute the compiled executable program.");
VLOG(4) << "Execute the compiled executable program"; VLOG(4) << "Execute the compiled executable program";
launch_context->UpdateCapturedEnv(scope, place); launch_context->UpdateCapturedEnv(scope, place);
LaunchCinnExecution(cinn_compiled_object, *launch_context, stream); LaunchCinnExecution(cinn_compiled_object, *launch_context, stream);
......
...@@ -166,6 +166,10 @@ limitations under the License. */ ...@@ -166,6 +166,10 @@ limitations under the License. */
#include "paddle/fluid/pybind/fleet_py.h" #include "paddle/fluid/pybind/fleet_py.h"
#endif #endif
#ifdef PADDLE_WITH_CINN
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#endif
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h" #include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/eager_utils.h"
...@@ -1930,16 +1934,18 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1930,16 +1934,18 @@ All parameter, weight, gradient are variables in Paddle.
which contains the id pair of pruned block and corresponding which contains the id pair of pruned block and corresponding
origin block. origin block.
)DOC"); )DOC");
m.def("get_readable_comile_key", [](const OpDesc &op_desc) { m.def("get_serialize_comile_key", [](int64_t compilation_key) {
auto compilation_key = #ifdef PADDLE_WITH_CINN
BOOST_GET_CONST(std::string, op_desc.GetAttr("compilation_key")); auto compiler = framework::paddle2cinn::CinnCompiler::GetInstance();
VLOG(4) << std::hash<std::string>{}(compilation_key) << " " auto s = compiler->SerializeKey(compilation_key);
<< compilation_key.size();
proto::ProgramDesc desc;
desc.ParseFromString(compilation_key);
auto s = desc.DebugString();
VLOG(4) << s; VLOG(4) << s;
return s; return s;
#else
PADDLE_THROW(
platform::errors::PermissionDenied(
"Cannot get compilation key in non-CINN version, "
"Please recompile or reinstall Paddle with CINN support."));
#endif
}); });
m.def("empty_var_name", m.def("empty_var_name",
[]() { return std::string(framework::kEmptyVarName); }); []() { return std::string(framework::kEmptyVarName); });
......
...@@ -2864,9 +2864,10 @@ class Operator(object): ...@@ -2864,9 +2864,10 @@ class Operator(object):
continue continue
# it is bytes of serialized protobuf # it is bytes of serialized protobuf
if self.type == 'cinn_launch' and name == 'compilation_key': if is_compiled_with_cinn(
# value = core.get_readable_comile_key(self.desc) ) and self.type == 'cinn_launch' and name == 'compilation_key':
v = self.desc.attr(name) key = self.desc.attr(name)
v = core.get_serialize_comile_key(key)
prog = Program() prog = Program()
prog = prog.parse_from_string(v) prog = prog.parse_from_string(v)
s = prog._to_readable_code() s = prog._to_readable_code()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册