From f78b4079b40bb85b083ab2ef853c87baab7c3f95 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Mon, 27 Feb 2023 19:52:02 +0800 Subject: [PATCH] [CINN] fix cinn cache key should save var name bug (#50955) --- .../framework/paddle2cinn/cinn_cache_key.cc | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc index f8b518452e7..6fdbbaae9d7 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/paddle2cinn/transform_type.h" #include "paddle/phi/core/ddim.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { @@ -127,18 +128,33 @@ size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) { } } - static std::unordered_set ignore_attr = {"op_callstack", - "op_device", - "op_namescope", - "op_role", - "op_role_var", - "with_quant_attr"}; + static const std::unordered_set ignore_attr = { + "op_callstack", + "op_device", + "op_namescope", + "op_role", + "op_role_var", + "with_quant_attr"}; + + std::set input_set(compare), + output_set(compare); std::ostringstream hash_str; for (ir::Node* op : node_set) { hash_str << op->Name() << ":"; - hash_str << "input_num=" << op->inputs.size() << ","; - hash_str << "output_num=" << op->outputs.size() << ","; + input_set.clear(); + input_set.insert(op->inputs.begin(), op->inputs.end()); + hash_str << "inputs=[" + << string::join_strings( + input_set, ",", [](ir::Node* n) { return n->Name(); }) + << "],"; + + output_set.clear(); + output_set.insert(op->outputs.begin(), op->outputs.end()); + hash_str << "outputs=[" + << string::join_strings( + output_set, ",", [](ir::Node* n) { return n->Name(); }) + << "],"; const auto& attrs_unordered_map = op->Op()->GetAttrMap(); std::map attrs_map(attrs_unordered_map.begin(), -- GitLab