From 21c6eccf3d9281d66126fbcd7a78deedbf84c97a Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 24 Feb 2023 10:26:45 +0800 Subject: [PATCH] [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng --- .../framework/paddle2cinn/CMakeLists.txt | 16 ++-- .../framework/paddle2cinn/cinn_cache_key.cc | 77 ++++++++++++++----- .../framework/paddle2cinn/cinn_cache_key.h | 8 +- .../paddle2cinn/cinn_cache_key_test.cc | 37 ++++++--- .../framework/paddle2cinn/cinn_compiler.cc | 6 +- .../framework/paddle2cinn/transform_desc.cc | 27 +++++++ .../framework/paddle2cinn/transform_desc.h | 4 + .../paddle2cinn/transform_desc_test.cc | 10 +++ .../framework/paddle2cinn/transform_type.cc | 49 ++++++++++++ .../framework/paddle2cinn/transform_type.h | 5 ++ .../paddle2cinn/transform_type_test.cc | 36 +++++++++ 11 files changed, 236 insertions(+), 39 deletions(-) mode change 100755 => 100644 paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 6bb21c569b3..a082dff6e54 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -8,14 +8,6 @@ pass_library( errors enforce) -cc_library( - cinn_cache_key - SRCS cinn_cache_key.cc - DEPS graph graph_helper lod_tensor proto_desc) -cc_library( - cinn_subgraph_detector - SRCS cinn_subgraph_detector.cc - DEPS graph graph_helper subgraph_detector lod_tensor proto_desc) cc_library( transform_desc SRCS transform_desc.cc @@ -24,6 +16,14 @@ cc_library( transform_type SRCS transform_type.cc DEPS errors enforce cinn) +cc_library( + cinn_cache_key + SRCS cinn_cache_key.cc + DEPS graph graph_helper lod_tensor proto_desc transform_type) +cc_library( + cinn_subgraph_detector + SRCS cinn_subgraph_detector.cc + DEPS graph graph_helper subgraph_detector lod_tensor proto_desc) cc_library( cinn_graph_symbolization SRCS cinn_graph_symbolization.cc diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc old mode 100755 new mode 100644 index 3a7aa273f27..f8b518452e7 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/paddle2cinn/transform_type.h" #include "paddle/phi/core/ddim.h" namespace paddle { @@ -45,10 +46,11 @@ CinnCacheKey::CinnCacheKey( CinnCacheKey::CinnCacheKey(const ir::Graph& graph, const std::map& input_shapes, + const std::map& input_dtypes, const std::string& arch_str, GraphHashStrategy graph_hash) : graph_hash_(graph_hash) { - this->SetKey(graph, input_shapes, arch_str); + this->SetKey(graph, input_shapes, input_dtypes, arch_str); } void CinnCacheKey::SetKey( @@ -58,15 +60,24 @@ void CinnCacheKey::SetKey( graph_hash_val_ = graph_hash_(graph); for (const auto& name_tensor : input_tensors) { input_shapes_[name_tensor.first] = name_tensor.second->dims(); + input_dtypes_[name_tensor.first] = name_tensor.second->dtype(); } arch_str_ = arch_str; } void CinnCacheKey::SetKey(const ir::Graph& graph, const std::map& input_shapes, + const std::map& input_dtypes, const std::string& arch_str) { + PADDLE_ENFORCE_EQ( + input_shapes.size(), + input_dtypes.size(), + platform::errors::PreconditionNotMet( + "Required input_shapes has same length with input_dtypes.")); + graph_hash_val_ = graph_hash_(graph); input_shapes_ = input_shapes; + input_dtypes_ = input_dtypes; arch_str_ = arch_str; } @@ -76,19 +87,26 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const { bool CinnCacheKey::operator==(const CinnCacheKey& other) const { return graph_hash_val_ == other.graph_hash_val_ && - input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_; + input_shapes_ == other.input_shapes_ && + input_dtypes_ == other.input_dtypes_ && arch_str_ == other.arch_str_; } size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const { std::ostringstream has_str; for (const auto& name_shape : key.input_shapes_) { - has_str << name_shape.first; - has_str << std::hash()(name_shape.second); + has_str << name_shape.first << ","; + has_str << "[" << name_shape.second << "],"; + PADDLE_ENFORCE_NE(key.input_dtypes_.find(name_shape.first), + key.input_dtypes_.end(), + platform::errors::PreconditionNotMet( + "%s is not in key.input_dtypes_.", name_shape.first)); + has_str << key.input_dtypes_.at(name_shape.first) << ";"; } + has_str << key.arch_str_ << ","; has_str << key.graph_hash_val_; - has_str << key.arch_str_; + VLOG(1) << "CinnCacheKey : " << has_str.str(); return std::hash()(has_str.str()); } @@ -101,24 +119,45 @@ size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) { // graph.Nodes() return unordered_set, here using set to avoid the same graph // may return different result - std::set node_set(compare), - output_set(compare); - node_set.insert(graph.Nodes().begin(), graph.Nodes().end()); - - std::string hash_str; - for (ir::Node* n : node_set) { - hash_str.append(n->Name()); - - output_set.clear(); - output_set.insert(n->outputs.begin(), n->outputs.end()); - for (auto* out : output_set) { - hash_str.append(out->Name()); + std::set node_set(compare); + for (ir::Node* node : graph.Nodes()) { + if (node->IsOp()) { + // only need cache graph with same op + node_set.insert(node); + } + } + + static std::unordered_set ignore_attr = {"op_callstack", + "op_device", + "op_namescope", + "op_role", + "op_role_var", + "with_quant_attr"}; + + std::ostringstream hash_str; + for (ir::Node* op : node_set) { + hash_str << op->Name() << ":"; + hash_str << "input_num=" << op->inputs.size() << ","; + hash_str << "output_num=" << op->outputs.size() << ","; + + const auto& attrs_unordered_map = op->Op()->GetAttrMap(); + std::map attrs_map(attrs_unordered_map.begin(), + attrs_unordered_map.end()); + for (const auto& attr : attrs_map) { + if (ignore_attr.count(attr.first)) { + continue; + } + const auto& attr_str = PaddleAttributeToString(attr.second); + if (!attr_str.empty()) { + hash_str << attr.first << "=" << attr_str << ","; + } } + hash_str << ";"; } - VLOG(1) << "The hash graph:\n" << hash_str; + VLOG(1) << "The hash graph:\n" << hash_str.str(); - size_t hash_val = std::hash()(hash_str); + size_t hash_val = std::hash()(hash_str.str()); VLOG(4) << "The graph's hash value by graph structure is: " << hash_val; return hash_val; } diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h index 008a7f4579f..d1797ddf6bb 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/ddim.h" namespace paddle { @@ -45,6 +46,7 @@ class CinnCacheKey { GraphHashStrategy graph_hash); CinnCacheKey(const ir::Graph& graph, const std::map& input_shapes, + const std::map& input_dtypes, const std::string& arch_str, GraphHashStrategy graph_hash); @@ -56,6 +58,7 @@ class CinnCacheKey { const std::string& arch_str); void SetKey(const ir::Graph& graph, const std::map& input_shapes, + const std::map& input_dtypes, const std::string& arch_str); bool operator==(const CinnCacheKey& other) const; @@ -69,6 +72,7 @@ class CinnCacheKey { GraphHashStrategy graph_hash_; size_t graph_hash_val_; std::map input_shapes_; + std::map input_dtypes_; std::string arch_str_; }; @@ -84,8 +88,10 @@ class CinnCacheKey { \ NAME(const ir::Graph& graph, \ const std::map& input_shapes, \ + const std::map& input_dtypes, \ const std::string& arch_str) \ - : CinnCacheKey(graph, input_shapes, arch_str, HashGraph) {} \ + : CinnCacheKey( \ + graph, input_shapes, input_dtypes, arch_str, HashGraph) {} \ \ private: \ static size_t HashGraph(const ir::Graph& graph); \ diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc index 959ebf8aadb..cd2da68a7f6 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc @@ -39,7 +39,9 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) { x->SetType(proto::VarType::LOD_TENSOR); ir::Graph graph(program); + DataType fp32 = DataType::FLOAT32; phi::DenseTensor tensor; + tensor.set_type(fp32); tensor.Resize({1, 2, 3}); const phi::DenseTensor *tensor_pointer = &tensor; std::map feed_tensors = { @@ -47,21 +49,25 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) { DDim ddim = phi::make_ddim({1, 2, 3}); std::map feed_shapes = {{"X", ddim}}; + std::map feed_dtypes = {{"X", fp32}}; CinnCacheKeyByStructure cache_key0(empty_graph, feed_tensors, "x86"); - CinnCacheKeyByStructure cache_key1(empty_graph, feed_shapes, "x86"); + CinnCacheKeyByStructure cache_key1( + empty_graph, feed_shapes, feed_dtypes, "x86"); EXPECT_EQ(cache_key0, cache_key1); - CinnCacheKeyByStructure cache_key2(graph, feed_shapes, "x86"); - CinnCacheKeyByStructure cache_key3(graph, feed_shapes, "nvgpu"); + CinnCacheKeyByStructure cache_key2(graph, feed_shapes, feed_dtypes, "x86"); + CinnCacheKeyByStructure cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu"); CinnCacheKeyByStructure cache_key4(graph, feed_tensors, "nvgpu"); EXPECT_NE(cache_key2, cache_key3); EXPECT_EQ(cache_key3, cache_key4); CinnCacheKeyByStructure cache_key5( empty_graph, std::map(), "unk"); - CinnCacheKeyByStructure cache_key6( - empty_graph, std::map(), "unk"); + CinnCacheKeyByStructure cache_key6(empty_graph, + std::map(), + std::map(), + "unk"); EXPECT_EQ(cache_key5, cache_key6); EXPECT_NE(cache_key1, cache_key3); @@ -112,6 +118,7 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) { x->SetType(proto::VarType::LOD_TENSOR); ir::Graph graph(program); + DataType fp32 = DataType::FLOAT32; phi::DenseTensor tensor; tensor.Resize({1, 2, 3}); const phi::DenseTensor *tensor_pointer = &tensor; @@ -120,21 +127,29 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) { DDim ddim = phi::make_ddim({1, 2, 3}); std::map feed_shapes = {{"X", ddim}}; + std::map feed_dtypes = {{"X", fp32}}; + std::map new_dtypes = {{"X", DataType::FLOAT64}}; CinnCacheKeyByAddress cache_key0(empty_graph, feed_tensors, "x86"); - CinnCacheKeyByAddress cache_key1(empty_graph, feed_shapes, "x86"); + CinnCacheKeyByAddress cache_key1( + empty_graph, feed_shapes, feed_dtypes, "x86"); EXPECT_EQ(cache_key0, cache_key1); - CinnCacheKeyByAddress cache_key2(graph, feed_shapes, "x86"); - CinnCacheKeyByAddress cache_key3(graph, feed_shapes, "nvgpu"); + CinnCacheKeyByAddress cache_key7(empty_graph, feed_shapes, new_dtypes, "x86"); + EXPECT_NE(cache_key1, cache_key7); + + CinnCacheKeyByAddress cache_key2(graph, feed_shapes, feed_dtypes, "x86"); + CinnCacheKeyByAddress cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu"); CinnCacheKeyByAddress cache_key4(graph, feed_tensors, "nvgpu"); EXPECT_NE(cache_key2, cache_key3); EXPECT_EQ(cache_key3, cache_key4); CinnCacheKeyByAddress cache_key5( empty_graph, std::map(), "unk"); - CinnCacheKeyByAddress cache_key6( - empty_graph, std::map(), "unk"); + CinnCacheKeyByAddress cache_key6(empty_graph, + std::map(), + std::map(), + "unk"); EXPECT_EQ(cache_key5, cache_key6); EXPECT_NE(cache_key1, cache_key3); @@ -186,7 +201,9 @@ TEST(CinnCacheKeyTest, TestSameGraph) { x2->SetType(proto::VarType::LOD_TENSOR); ir::Graph graph2(program2); + DataType fp32 = DataType::FLOAT32; phi::DenseTensor tensor; + tensor.set_type(fp32); tensor.Resize({1, 2, 3}); const phi::DenseTensor *tensor_pointer = &tensor; std::map feed_tensors = { diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index c01624a5549..359bab84430 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -39,6 +39,7 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" #include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" +#include "paddle/fluid/framework/paddle2cinn/transform_desc.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/inference/analysis/dot.h" @@ -78,9 +79,11 @@ const CinnCompiledObject &CinnCompiler::Compile( CinnCacheKeyByStructure cur_key_by_struct; if (!cache_by_address_.count(cur_key_by_address)) { + VLOG(4) << "Not found CinnCompiledObject in cache_by_address_."; // generate the structure cache key cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str()); if (!cache_by_struct_.count(cur_key_by_struct)) { + VLOG(4) << "Not found CinnCompiledObject in cache_by_struct_."; std::int64_t compiled_num = real_compiled_num_.fetch_add(1); auto compiled_res = CompileGraph(graph, input_tensors, target, compiled_num, stream); @@ -180,7 +183,8 @@ std::string CinnCompiler::VizGraph(const Graph &graph) const { shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) { return std::to_string(val); }); - label += "\n" + string::join_strings(shape_str, ','); + label += "\n[" + string::join_strings(shape_str, ',') + "]"; + label += "\n" + VarDataTypeToString(n->Var()->GetDataType()); } dot.AddNode( node_id, diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.cc b/paddle/fluid/framework/paddle2cinn/transform_desc.cc index af2d1c06de5..732edff8283 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_desc.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.cc @@ -97,6 +97,33 @@ namespace cpp = ::cinn::frontend::paddle::cpp; #undef SET_DATA_TYPE_CASE_ITEM } +std::string VarDataTypeToString( + const ::paddle::framework::proto::VarType::Type &type) { +#define SET_DATA_TYPE_CASE_ITEM(type__) \ + case ::paddle::framework::proto::VarType::type__: \ + return std::string(#type__); \ + break; + + switch (type) { + SET_DATA_TYPE_CASE_ITEM(BOOL); + SET_DATA_TYPE_CASE_ITEM(SIZE_T); + SET_DATA_TYPE_CASE_ITEM(UINT8); + SET_DATA_TYPE_CASE_ITEM(INT8); + SET_DATA_TYPE_CASE_ITEM(INT16); + SET_DATA_TYPE_CASE_ITEM(INT32); + SET_DATA_TYPE_CASE_ITEM(INT64); + SET_DATA_TYPE_CASE_ITEM(FP16); + SET_DATA_TYPE_CASE_ITEM(FP32); + SET_DATA_TYPE_CASE_ITEM(FP64); + SET_DATA_TYPE_CASE_ITEM(BF16); + SET_DATA_TYPE_CASE_ITEM(COMPLEX64); + SET_DATA_TYPE_CASE_ITEM(COMPLEX128); + default: + PADDLE_THROW(platform::errors::NotFound("Cannot found var data type")); + } +#undef SET_DATA_TYPE_CASE_ITEM +} + ::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp( const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) { #define SET_DATA_TYPE_CASE_ITEM(type__) \ diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.h b/paddle/fluid/framework/paddle2cinn/transform_desc.h index 76a4f812730..2e6fd3755f5 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_desc.h +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.h @@ -74,6 +74,10 @@ void TransformProgramDescFromCinn( const ::cinn::frontend::paddle::cpp::ProgramDesc& cpp_desc, framework::ProgramDesc* pb_desc); +// debug function +std::string VarDataTypeToString( + const ::paddle::framework::proto::VarType::Type& type); + } // namespace paddle2cinn } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc_test.cc b/paddle/fluid/framework/paddle2cinn/transform_desc_test.cc index 1ee108f566f..153794b5d44 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_desc_test.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_desc_test.cc @@ -232,6 +232,16 @@ TEST(TransformProgramDesc, pb2cpp) { ASSERT_EQ(cpp_prog.BlocksSize(), correct_prog.BlocksSize()); } +TEST(HelperFunction, VarDataTypeToString) { + const auto &pd_fp32_var = CreatePbVarDesc(); + const auto &debug_fp32_string = + VarDataTypeToString(pd_fp32_var.GetDataType()); + ASSERT_EQ(debug_fp32_string, std::string("FP32")); + + ASSERT_EQ(VarDataTypeToString(::paddle::framework::proto::VarType::INT32), + std::string("INT32")); +} + } // namespace paddle2cinn } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/transform_type.cc b/paddle/fluid/framework/paddle2cinn/transform_type.cc index b0877b618bb..0d4d5b57b57 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_type.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_type.cc @@ -18,6 +18,7 @@ #include "cinn/runtime/cinn_runtime.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" +#include "paddle/utils/string/string_helper.h" namespace paddle::framework::paddle2cinn { @@ -78,4 +79,52 @@ namespace paddle::framework::paddle2cinn { #undef SET_TYPE_CASE_ITEM } +std::string PaddleAttributeToString(const framework::Attribute& attr) { + std::ostringstream ss; +#define EXPAND_ATTRIBUTE_MACRO(TYPE_) \ + if (attr.type() == typeid(TYPE_)) { \ + ss << PADDLE_GET_CONST(TYPE_, attr); \ + return ss.str(); \ + } \ + if (attr.type() == typeid(std::vector)) { \ + const auto& vals = PADDLE_GET_CONST(std::vector, attr); \ + if (!vals.empty()) { \ + ss << "[" << string::join_strings(vals, ", ") << "]"; \ + } \ + return ss.str(); \ + } + + if (attr.type() == typeid(bool)) { + ss << std::boolalpha << PADDLE_GET_CONST(bool, attr); + return ss.str(); + } + if (attr.type() == typeid(std::vector)) { + // join_strings will compile failed: + // cannot bind non-const lvalue reference of type ‘bool&’ + const auto& vals = PADDLE_GET_CONST(std::vector, attr); + if (!vals.empty()) { + ss << "["; + bool first_value = true; + for (bool val : vals) { + if (!first_value) { + ss << ", "; + } + first_value = false; + ss << std::boolalpha << val; + } + ss << "]"; + } + return ss.str(); + } + EXPAND_ATTRIBUTE_MACRO(std::string) + EXPAND_ATTRIBUTE_MACRO(int) + EXPAND_ATTRIBUTE_MACRO(float) + EXPAND_ATTRIBUTE_MACRO(int64_t) + EXPAND_ATTRIBUTE_MACRO(double) + + ss << "Unkown_Dtype:" << attr.type().name(); +#undef EXPAND_ATTRIBUTE_MACRO + return ss.str(); +} + } // namespace paddle::framework::paddle2cinn diff --git a/paddle/fluid/framework/paddle2cinn/transform_type.h b/paddle/fluid/framework/paddle2cinn/transform_type.h index f0b08ba1e00..53ee2ca9735 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_type.h +++ b/paddle/fluid/framework/paddle2cinn/transform_type.h @@ -13,6 +13,9 @@ // limitations under the License. #pragma once +#include + +#include "paddle/fluid/framework/type_defs.h" #include "paddle/phi/common/data_type.h" // type declaration forward @@ -27,4 +30,6 @@ namespace paddle::framework::paddle2cinn { ::phi::DataType TransToPaddleDataType(const cinn_type_t& type); +std::string PaddleAttributeToString(const framework::Attribute& attr); + } // namespace paddle::framework::paddle2cinn diff --git a/paddle/fluid/framework/paddle2cinn/transform_type_test.cc b/paddle/fluid/framework/paddle2cinn/transform_type_test.cc index 4456642b3e9..365c25b0544 100644 --- a/paddle/fluid/framework/paddle2cinn/transform_type_test.cc +++ b/paddle/fluid/framework/paddle2cinn/transform_type_test.cc @@ -62,4 +62,40 @@ TEST(TransToPaddleDataType, runtime_type) { paddle::platform::EnforceNotMet); } +TEST(HelperFunction, PaddleAttributeToStringPODValue) { + paddle::framework::Attribute attr1 = 1; + ASSERT_EQ(PaddleAttributeToString(attr1), std::string("1")); + + paddle::framework::Attribute attr2 = 0.2f; + ASSERT_EQ(PaddleAttributeToString(attr2), std::string("0.2")); + + paddle::framework::Attribute attr3 = true; + ASSERT_EQ(PaddleAttributeToString(attr3), std::string("true")); + + paddle::framework::Attribute attr4 = std::string("string_attribute"); + ASSERT_EQ(PaddleAttributeToString(attr4), std::string("string_attribute")); +} + +TEST(HelperFunction, PaddleAttributeToStringVectorValue) { + paddle::framework::Attribute attr1 = std::vector(); + ASSERT_EQ(PaddleAttributeToString(attr1), std::string("")); + + paddle::framework::Attribute attr2 = std::vector{1, 2, 3, 4, 5}; + ASSERT_EQ(PaddleAttributeToString(attr2), std::string("[1, 2, 3, 4, 5]")); + + paddle::framework::Attribute attr3 = + std::vector{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + ASSERT_EQ(PaddleAttributeToString(attr3), + std::string("[0.1, 0.2, 0.3, 0.4, 0.5]")); + + paddle::framework::Attribute attr4 = + std::vector{true, false, true, false, false}; + ASSERT_EQ(PaddleAttributeToString(attr4), + std::string("[true, false, true, false, false]")); + + paddle::framework::Attribute attr5 = + std::vector{"a", "b", "c", "d", "e"}; + ASSERT_EQ(PaddleAttributeToString(attr5), std::string("[a, b, c, d, e]")); +} + } // namespace paddle::framework::paddle2cinn -- GitLab