未验证 提交 21c6eccf 编写于 作者: A Aurelius84 提交者: GitHub

[CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>
上级 6e37a2c0
...@@ -8,14 +8,6 @@ pass_library( ...@@ -8,14 +8,6 @@ pass_library(
errors errors
enforce) enforce)
cc_library(
cinn_cache_key
SRCS cinn_cache_key.cc
DEPS graph graph_helper lod_tensor proto_desc)
cc_library(
cinn_subgraph_detector
SRCS cinn_subgraph_detector.cc
DEPS graph graph_helper subgraph_detector lod_tensor proto_desc)
cc_library( cc_library(
transform_desc transform_desc
SRCS transform_desc.cc SRCS transform_desc.cc
...@@ -24,6 +16,14 @@ cc_library( ...@@ -24,6 +16,14 @@ cc_library(
transform_type transform_type
SRCS transform_type.cc SRCS transform_type.cc
DEPS errors enforce cinn) DEPS errors enforce cinn)
cc_library(
cinn_cache_key
SRCS cinn_cache_key.cc
DEPS graph graph_helper lod_tensor proto_desc transform_type)
cc_library(
cinn_subgraph_detector
SRCS cinn_subgraph_detector.cc
DEPS graph graph_helper subgraph_detector lod_tensor proto_desc)
cc_library( cc_library(
cinn_graph_symbolization cinn_graph_symbolization
SRCS cinn_graph_symbolization.cc SRCS cinn_graph_symbolization.cc
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
namespace paddle { namespace paddle {
...@@ -45,10 +46,11 @@ CinnCacheKey::CinnCacheKey( ...@@ -45,10 +46,11 @@ CinnCacheKey::CinnCacheKey(
CinnCacheKey::CinnCacheKey(const ir::Graph& graph, CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str, const std::string& arch_str,
GraphHashStrategy graph_hash) GraphHashStrategy graph_hash)
: graph_hash_(graph_hash) { : graph_hash_(graph_hash) {
this->SetKey(graph, input_shapes, arch_str); this->SetKey(graph, input_shapes, input_dtypes, arch_str);
} }
void CinnCacheKey::SetKey( void CinnCacheKey::SetKey(
...@@ -58,15 +60,24 @@ void CinnCacheKey::SetKey( ...@@ -58,15 +60,24 @@ void CinnCacheKey::SetKey(
graph_hash_val_ = graph_hash_(graph); graph_hash_val_ = graph_hash_(graph);
for (const auto& name_tensor : input_tensors) { for (const auto& name_tensor : input_tensors) {
input_shapes_[name_tensor.first] = name_tensor.second->dims(); input_shapes_[name_tensor.first] = name_tensor.second->dims();
input_dtypes_[name_tensor.first] = name_tensor.second->dtype();
} }
arch_str_ = arch_str; arch_str_ = arch_str;
} }
void CinnCacheKey::SetKey(const ir::Graph& graph, void CinnCacheKey::SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str) { const std::string& arch_str) {
PADDLE_ENFORCE_EQ(
input_shapes.size(),
input_dtypes.size(),
platform::errors::PreconditionNotMet(
"Required input_shapes has same length with input_dtypes."));
graph_hash_val_ = graph_hash_(graph); graph_hash_val_ = graph_hash_(graph);
input_shapes_ = input_shapes; input_shapes_ = input_shapes;
input_dtypes_ = input_dtypes;
arch_str_ = arch_str; arch_str_ = arch_str;
} }
...@@ -76,19 +87,26 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const { ...@@ -76,19 +87,26 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const {
bool CinnCacheKey::operator==(const CinnCacheKey& other) const { bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
return graph_hash_val_ == other.graph_hash_val_ && return graph_hash_val_ == other.graph_hash_val_ &&
input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_; input_shapes_ == other.input_shapes_ &&
input_dtypes_ == other.input_dtypes_ && arch_str_ == other.arch_str_;
} }
size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const { size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
std::ostringstream has_str; std::ostringstream has_str;
for (const auto& name_shape : key.input_shapes_) { for (const auto& name_shape : key.input_shapes_) {
has_str << name_shape.first; has_str << name_shape.first << ",";
has_str << std::hash<phi::DDim>()(name_shape.second); has_str << "[" << name_shape.second << "],";
PADDLE_ENFORCE_NE(key.input_dtypes_.find(name_shape.first),
key.input_dtypes_.end(),
platform::errors::PreconditionNotMet(
"%s is not in key.input_dtypes_.", name_shape.first));
has_str << key.input_dtypes_.at(name_shape.first) << ";";
} }
has_str << key.arch_str_ << ",";
has_str << key.graph_hash_val_; has_str << key.graph_hash_val_;
has_str << key.arch_str_; VLOG(1) << "CinnCacheKey : " << has_str.str();
return std::hash<std::string>()(has_str.str()); return std::hash<std::string>()(has_str.str());
} }
...@@ -101,24 +119,45 @@ size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) { ...@@ -101,24 +119,45 @@ size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
// graph.Nodes() return unordered_set, here using set to avoid the same graph // graph.Nodes() return unordered_set, here using set to avoid the same graph
// may return different result // may return different result
std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare), std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare);
output_set(compare); for (ir::Node* node : graph.Nodes()) {
node_set.insert(graph.Nodes().begin(), graph.Nodes().end()); if (node->IsOp()) {
// only need cache graph with same op
std::string hash_str; node_set.insert(node);
for (ir::Node* n : node_set) { }
hash_str.append(n->Name()); }
output_set.clear(); static std::unordered_set<std::string> ignore_attr = {"op_callstack",
output_set.insert(n->outputs.begin(), n->outputs.end()); "op_device",
for (auto* out : output_set) { "op_namescope",
hash_str.append(out->Name()); "op_role",
"op_role_var",
"with_quant_attr"};
std::ostringstream hash_str;
for (ir::Node* op : node_set) {
hash_str << op->Name() << ":";
hash_str << "input_num=" << op->inputs.size() << ",";
hash_str << "output_num=" << op->outputs.size() << ",";
const auto& attrs_unordered_map = op->Op()->GetAttrMap();
std::map<std::string, Attribute> attrs_map(attrs_unordered_map.begin(),
attrs_unordered_map.end());
for (const auto& attr : attrs_map) {
if (ignore_attr.count(attr.first)) {
continue;
}
const auto& attr_str = PaddleAttributeToString(attr.second);
if (!attr_str.empty()) {
hash_str << attr.first << "=" << attr_str << ",";
}
} }
hash_str << ";";
} }
VLOG(1) << "The hash graph:\n" << hash_str; VLOG(1) << "The hash graph:\n" << hash_str.str();
size_t hash_val = std::hash<std::string>()(hash_str); size_t hash_val = std::hash<std::string>()(hash_str.str());
VLOG(4) << "The graph's hash value by graph structure is: " << hash_val; VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
return hash_val; return hash_val;
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
namespace paddle { namespace paddle {
...@@ -45,6 +46,7 @@ class CinnCacheKey { ...@@ -45,6 +46,7 @@ class CinnCacheKey {
GraphHashStrategy graph_hash); GraphHashStrategy graph_hash);
CinnCacheKey(const ir::Graph& graph, CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str, const std::string& arch_str,
GraphHashStrategy graph_hash); GraphHashStrategy graph_hash);
...@@ -56,6 +58,7 @@ class CinnCacheKey { ...@@ -56,6 +58,7 @@ class CinnCacheKey {
const std::string& arch_str); const std::string& arch_str);
void SetKey(const ir::Graph& graph, void SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes, const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str); const std::string& arch_str);
bool operator==(const CinnCacheKey& other) const; bool operator==(const CinnCacheKey& other) const;
...@@ -69,6 +72,7 @@ class CinnCacheKey { ...@@ -69,6 +72,7 @@ class CinnCacheKey {
GraphHashStrategy graph_hash_; GraphHashStrategy graph_hash_;
size_t graph_hash_val_; size_t graph_hash_val_;
std::map<std::string, DDim> input_shapes_; std::map<std::string, DDim> input_shapes_;
std::map<std::string, DataType> input_dtypes_;
std::string arch_str_; std::string arch_str_;
}; };
...@@ -84,8 +88,10 @@ class CinnCacheKey { ...@@ -84,8 +88,10 @@ class CinnCacheKey {
\ \
NAME(const ir::Graph& graph, \ NAME(const ir::Graph& graph, \
const std::map<std::string, DDim>& input_shapes, \ const std::map<std::string, DDim>& input_shapes, \
const std::map<std::string, DataType>& input_dtypes, \
const std::string& arch_str) \ const std::string& arch_str) \
: CinnCacheKey(graph, input_shapes, arch_str, HashGraph) {} \ : CinnCacheKey( \
graph, input_shapes, input_dtypes, arch_str, HashGraph) {} \
\ \
private: \ private: \
static size_t HashGraph(const ir::Graph& graph); \ static size_t HashGraph(const ir::Graph& graph); \
......
...@@ -39,7 +39,9 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) { ...@@ -39,7 +39,9 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) {
x->SetType(proto::VarType::LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph(program); ir::Graph graph(program);
DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor; phi::DenseTensor tensor;
tensor.set_type(fp32);
tensor.Resize({1, 2, 3}); tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor; const phi::DenseTensor *tensor_pointer = &tensor;
std::map<std::string, const phi::DenseTensor *> feed_tensors = { std::map<std::string, const phi::DenseTensor *> feed_tensors = {
...@@ -47,21 +49,25 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) { ...@@ -47,21 +49,25 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) {
DDim ddim = phi::make_ddim({1, 2, 3}); DDim ddim = phi::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}}; std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
std::map<std::string, DataType> feed_dtypes = {{"X", fp32}};
CinnCacheKeyByStructure cache_key0(empty_graph, feed_tensors, "x86"); CinnCacheKeyByStructure cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKeyByStructure cache_key1(empty_graph, feed_shapes, "x86"); CinnCacheKeyByStructure cache_key1(
empty_graph, feed_shapes, feed_dtypes, "x86");
EXPECT_EQ(cache_key0, cache_key1); EXPECT_EQ(cache_key0, cache_key1);
CinnCacheKeyByStructure cache_key2(graph, feed_shapes, "x86"); CinnCacheKeyByStructure cache_key2(graph, feed_shapes, feed_dtypes, "x86");
CinnCacheKeyByStructure cache_key3(graph, feed_shapes, "nvgpu"); CinnCacheKeyByStructure cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu");
CinnCacheKeyByStructure cache_key4(graph, feed_tensors, "nvgpu"); CinnCacheKeyByStructure cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3); EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4); EXPECT_EQ(cache_key3, cache_key4);
CinnCacheKeyByStructure cache_key5( CinnCacheKeyByStructure cache_key5(
empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk"); empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk");
CinnCacheKeyByStructure cache_key6( CinnCacheKeyByStructure cache_key6(empty_graph,
empty_graph, std::map<std::string, DDim>(), "unk"); std::map<std::string, DDim>(),
std::map<std::string, DataType>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6); EXPECT_EQ(cache_key5, cache_key6);
EXPECT_NE(cache_key1, cache_key3); EXPECT_NE(cache_key1, cache_key3);
...@@ -112,6 +118,7 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) { ...@@ -112,6 +118,7 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) {
x->SetType(proto::VarType::LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph(program); ir::Graph graph(program);
DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor; phi::DenseTensor tensor;
tensor.Resize({1, 2, 3}); tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor; const phi::DenseTensor *tensor_pointer = &tensor;
...@@ -120,21 +127,29 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) { ...@@ -120,21 +127,29 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) {
DDim ddim = phi::make_ddim({1, 2, 3}); DDim ddim = phi::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}}; std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
std::map<std::string, DataType> feed_dtypes = {{"X", fp32}};
std::map<std::string, DataType> new_dtypes = {{"X", DataType::FLOAT64}};
CinnCacheKeyByAddress cache_key0(empty_graph, feed_tensors, "x86"); CinnCacheKeyByAddress cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKeyByAddress cache_key1(empty_graph, feed_shapes, "x86"); CinnCacheKeyByAddress cache_key1(
empty_graph, feed_shapes, feed_dtypes, "x86");
EXPECT_EQ(cache_key0, cache_key1); EXPECT_EQ(cache_key0, cache_key1);
CinnCacheKeyByAddress cache_key2(graph, feed_shapes, "x86"); CinnCacheKeyByAddress cache_key7(empty_graph, feed_shapes, new_dtypes, "x86");
CinnCacheKeyByAddress cache_key3(graph, feed_shapes, "nvgpu"); EXPECT_NE(cache_key1, cache_key7);
CinnCacheKeyByAddress cache_key2(graph, feed_shapes, feed_dtypes, "x86");
CinnCacheKeyByAddress cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu");
CinnCacheKeyByAddress cache_key4(graph, feed_tensors, "nvgpu"); CinnCacheKeyByAddress cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3); EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4); EXPECT_EQ(cache_key3, cache_key4);
CinnCacheKeyByAddress cache_key5( CinnCacheKeyByAddress cache_key5(
empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk"); empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk");
CinnCacheKeyByAddress cache_key6( CinnCacheKeyByAddress cache_key6(empty_graph,
empty_graph, std::map<std::string, DDim>(), "unk"); std::map<std::string, DDim>(),
std::map<std::string, DataType>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6); EXPECT_EQ(cache_key5, cache_key6);
EXPECT_NE(cache_key1, cache_key3); EXPECT_NE(cache_key1, cache_key3);
...@@ -186,7 +201,9 @@ TEST(CinnCacheKeyTest, TestSameGraph) { ...@@ -186,7 +201,9 @@ TEST(CinnCacheKeyTest, TestSameGraph) {
x2->SetType(proto::VarType::LOD_TENSOR); x2->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph2(program2); ir::Graph graph2(program2);
DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor; phi::DenseTensor tensor;
tensor.set_type(fp32);
tensor.Resize({1, 2, 3}); tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor; const phi::DenseTensor *tensor_pointer = &tensor;
std::map<std::string, const phi::DenseTensor *> feed_tensors = { std::map<std::string, const phi::DenseTensor *> feed_tensors = {
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" #include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" #include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/paddle2cinn/transform_desc.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
...@@ -78,9 +79,11 @@ const CinnCompiledObject &CinnCompiler::Compile( ...@@ -78,9 +79,11 @@ const CinnCompiledObject &CinnCompiler::Compile(
CinnCacheKeyByStructure cur_key_by_struct; CinnCacheKeyByStructure cur_key_by_struct;
if (!cache_by_address_.count(cur_key_by_address)) { if (!cache_by_address_.count(cur_key_by_address)) {
VLOG(4) << "Not found CinnCompiledObject in cache_by_address_.";
// generate the structure cache key // generate the structure cache key
cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str()); cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
if (!cache_by_struct_.count(cur_key_by_struct)) { if (!cache_by_struct_.count(cur_key_by_struct)) {
VLOG(4) << "Not found CinnCompiledObject in cache_by_struct_.";
std::int64_t compiled_num = real_compiled_num_.fetch_add(1); std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res = auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream); CompileGraph(graph, input_tensors, target, compiled_num, stream);
...@@ -180,7 +183,8 @@ std::string CinnCompiler::VizGraph(const Graph &graph) const { ...@@ -180,7 +183,8 @@ std::string CinnCompiler::VizGraph(const Graph &graph) const {
shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) { shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) {
return std::to_string(val); return std::to_string(val);
}); });
label += "\n" + string::join_strings(shape_str, ','); label += "\n[" + string::join_strings(shape_str, ',') + "]";
label += "\n" + VarDataTypeToString(n->Var()->GetDataType());
} }
dot.AddNode( dot.AddNode(
node_id, node_id,
......
...@@ -97,6 +97,33 @@ namespace cpp = ::cinn::frontend::paddle::cpp; ...@@ -97,6 +97,33 @@ namespace cpp = ::cinn::frontend::paddle::cpp;
#undef SET_DATA_TYPE_CASE_ITEM #undef SET_DATA_TYPE_CASE_ITEM
} }
std::string VarDataTypeToString(
const ::paddle::framework::proto::VarType::Type &type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case ::paddle::framework::proto::VarType::type__: \
return std::string(#type__); \
break;
switch (type) {
SET_DATA_TYPE_CASE_ITEM(BOOL);
SET_DATA_TYPE_CASE_ITEM(SIZE_T);
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
SET_DATA_TYPE_CASE_ITEM(BF16);
SET_DATA_TYPE_CASE_ITEM(COMPLEX64);
SET_DATA_TYPE_CASE_ITEM(COMPLEX128);
default:
PADDLE_THROW(platform::errors::NotFound("Cannot found var data type"));
}
#undef SET_DATA_TYPE_CASE_ITEM
}
::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp( ::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp(
const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) { const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \ #define SET_DATA_TYPE_CASE_ITEM(type__) \
......
...@@ -74,6 +74,10 @@ void TransformProgramDescFromCinn( ...@@ -74,6 +74,10 @@ void TransformProgramDescFromCinn(
const ::cinn::frontend::paddle::cpp::ProgramDesc& cpp_desc, const ::cinn::frontend::paddle::cpp::ProgramDesc& cpp_desc,
framework::ProgramDesc* pb_desc); framework::ProgramDesc* pb_desc);
// debug function
std::string VarDataTypeToString(
const ::paddle::framework::proto::VarType::Type& type);
} // namespace paddle2cinn } // namespace paddle2cinn
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -232,6 +232,16 @@ TEST(TransformProgramDesc, pb2cpp) { ...@@ -232,6 +232,16 @@ TEST(TransformProgramDesc, pb2cpp) {
ASSERT_EQ(cpp_prog.BlocksSize(), correct_prog.BlocksSize()); ASSERT_EQ(cpp_prog.BlocksSize(), correct_prog.BlocksSize());
} }
TEST(HelperFunction, VarDataTypeToString) {
const auto &pd_fp32_var = CreatePbVarDesc();
const auto &debug_fp32_string =
VarDataTypeToString(pd_fp32_var.GetDataType());
ASSERT_EQ(debug_fp32_string, std::string("FP32"));
ASSERT_EQ(VarDataTypeToString(::paddle::framework::proto::VarType::INT32),
std::string("INT32"));
}
} // namespace paddle2cinn } // namespace paddle2cinn
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "cinn/runtime/cinn_runtime.h" #include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/utils/string/string_helper.h"
namespace paddle::framework::paddle2cinn { namespace paddle::framework::paddle2cinn {
...@@ -78,4 +79,52 @@ namespace paddle::framework::paddle2cinn { ...@@ -78,4 +79,52 @@ namespace paddle::framework::paddle2cinn {
#undef SET_TYPE_CASE_ITEM #undef SET_TYPE_CASE_ITEM
} }
std::string PaddleAttributeToString(const framework::Attribute& attr) {
std::ostringstream ss;
#define EXPAND_ATTRIBUTE_MACRO(TYPE_) \
if (attr.type() == typeid(TYPE_)) { \
ss << PADDLE_GET_CONST(TYPE_, attr); \
return ss.str(); \
} \
if (attr.type() == typeid(std::vector<TYPE_>)) { \
const auto& vals = PADDLE_GET_CONST(std::vector<TYPE_>, attr); \
if (!vals.empty()) { \
ss << "[" << string::join_strings(vals, ", ") << "]"; \
} \
return ss.str(); \
}
if (attr.type() == typeid(bool)) {
ss << std::boolalpha << PADDLE_GET_CONST(bool, attr);
return ss.str();
}
if (attr.type() == typeid(std::vector<bool>)) {
// join_strings<bool> will compile failed:
// cannot bind non-const lvalue reference of type ‘bool&’
const auto& vals = PADDLE_GET_CONST(std::vector<bool>, attr);
if (!vals.empty()) {
ss << "[";
bool first_value = true;
for (bool val : vals) {
if (!first_value) {
ss << ", ";
}
first_value = false;
ss << std::boolalpha << val;
}
ss << "]";
}
return ss.str();
}
EXPAND_ATTRIBUTE_MACRO(std::string)
EXPAND_ATTRIBUTE_MACRO(int)
EXPAND_ATTRIBUTE_MACRO(float)
EXPAND_ATTRIBUTE_MACRO(int64_t)
EXPAND_ATTRIBUTE_MACRO(double)
ss << "Unkown_Dtype:" << attr.type().name();
#undef EXPAND_ATTRIBUTE_MACRO
return ss.str();
}
} // namespace paddle::framework::paddle2cinn } // namespace paddle::framework::paddle2cinn
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
// type declaration forward // type declaration forward
...@@ -27,4 +30,6 @@ namespace paddle::framework::paddle2cinn { ...@@ -27,4 +30,6 @@ namespace paddle::framework::paddle2cinn {
::phi::DataType TransToPaddleDataType(const cinn_type_t& type); ::phi::DataType TransToPaddleDataType(const cinn_type_t& type);
std::string PaddleAttributeToString(const framework::Attribute& attr);
} // namespace paddle::framework::paddle2cinn } // namespace paddle::framework::paddle2cinn
...@@ -62,4 +62,40 @@ TEST(TransToPaddleDataType, runtime_type) { ...@@ -62,4 +62,40 @@ TEST(TransToPaddleDataType, runtime_type) {
paddle::platform::EnforceNotMet); paddle::platform::EnforceNotMet);
} }
TEST(HelperFunction, PaddleAttributeToStringPODValue) {
paddle::framework::Attribute attr1 = 1;
ASSERT_EQ(PaddleAttributeToString(attr1), std::string("1"));
paddle::framework::Attribute attr2 = 0.2f;
ASSERT_EQ(PaddleAttributeToString(attr2), std::string("0.2"));
paddle::framework::Attribute attr3 = true;
ASSERT_EQ(PaddleAttributeToString(attr3), std::string("true"));
paddle::framework::Attribute attr4 = std::string("string_attribute");
ASSERT_EQ(PaddleAttributeToString(attr4), std::string("string_attribute"));
}
TEST(HelperFunction, PaddleAttributeToStringVectorValue) {
paddle::framework::Attribute attr1 = std::vector<int>();
ASSERT_EQ(PaddleAttributeToString(attr1), std::string(""));
paddle::framework::Attribute attr2 = std::vector<int>{1, 2, 3, 4, 5};
ASSERT_EQ(PaddleAttributeToString(attr2), std::string("[1, 2, 3, 4, 5]"));
paddle::framework::Attribute attr3 =
std::vector<float>{0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
ASSERT_EQ(PaddleAttributeToString(attr3),
std::string("[0.1, 0.2, 0.3, 0.4, 0.5]"));
paddle::framework::Attribute attr4 =
std::vector<bool>{true, false, true, false, false};
ASSERT_EQ(PaddleAttributeToString(attr4),
std::string("[true, false, true, false, false]"));
paddle::framework::Attribute attr5 =
std::vector<std::string>{"a", "b", "c", "d", "e"};
ASSERT_EQ(PaddleAttributeToString(attr5), std::string("[a, b, c, d, e]"));
}
} // namespace paddle::framework::paddle2cinn } // namespace paddle::framework::paddle2cinn
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册