未验证 提交 6fe5547d 编写于 作者: Y Yan Chunwei 提交者: GitHub

switch NodeAttr to boost::varient (#12539)

上级 535a6e92
...@@ -8,7 +8,7 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph ...@@ -8,7 +8,7 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
helper.cc helper.cc
model_store_pass.cc model_store_pass.cc
DEPS framework_proto proto_desc) DEPS framework_proto proto_desc)
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis gflags glog gtest)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
......
...@@ -20,17 +20,6 @@ namespace paddle { ...@@ -20,17 +20,6 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
template <>
std::string &NodeAttr::As<std::string>() {
if (data_.empty()) {
type_index_ = std::type_index(typeid(std::string));
}
PADDLE_ENFORCE_EQ(type_index_, std::type_index(typeid(std::string)));
return data_;
}
std::string &NodeAttr::String() { return As<std::string>(); }
std::vector<Dot::Attr> Value::dot_attrs() const { std::vector<Dot::Attr> Value::dot_attrs() const {
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"), return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
Dot::Attr("shape", "box"), Dot::Attr("shape", "box"),
......
...@@ -29,6 +29,7 @@ limitations under the License. */ ...@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/device.h" #include "paddle/fluid/inference/analysis/device.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -38,39 +39,35 @@ class NodeMap; ...@@ -38,39 +39,35 @@ class NodeMap;
// A helper class to maintain the status from Pass. // A helper class to maintain the status from Pass.
struct NodeAttr { struct NodeAttr {
using any_t =
boost::variant<bool, float, int32_t, int64_t, void *, std::string>;
// NOTE T should be a primary type or a struct combined by several primary // NOTE T should be a primary type or a struct combined by several primary
// types. // types.
// NOTE the STL containers should not use here. // NOTE the STL containers should not use here.
// Some usages // Some usages
// Attr attr; // Attr attr;
// attr.Bool() = true; // attr.Bool() = true;
bool &Bool() { return As<bool>(); } bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); } float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); } int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); } int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); } void *&Pointer() { return As<void *>(); }
std::string &String(); std::string &String() { return As<std::string>(); }
private: private:
template <typename T> template <typename T>
T &As() { T &As() {
// init storage in the first usage. if (type_index_ == typeid(NodeAttr)) {
if (data_.empty()) { type_index_ = typeid(T);
VLOG(4) << "resize data to " << sizeof(T); any_data_ = T();
type_index_ = std::type_index(typeid(T)); } else {
data_.resize(sizeof(T)); PADDLE_ENFORCE(type_index_ == typeid(T), "fetch error type");
} }
PADDLE_ENFORCE(framework::IsType<T>(type_index_), return boost::get<T>(any_data_);
"type not matched, origin is %s, want %s",
DataTypeNamer::Global().repr(type_index_),
DataTypeNamer::Global().repr<T>());
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
return *reinterpret_cast<T *>(&data_[0]);
} }
private: private:
std::string data_; any_t any_data_;
std::type_index type_index_{typeid(NodeAttr)}; std::type_index type_index_{typeid(NodeAttr)};
}; };
......
...@@ -20,6 +20,24 @@ namespace paddle { ...@@ -20,6 +20,24 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST(NodeAttr, bool) {
NodeAttr x;
x.Bool() = true;
ASSERT_EQ(x.Bool(), true);
}
TEST(NodeAttr, int32) {
NodeAttr x;
x.Int32() = 32;
ASSERT_EQ(x.Int32(), 32);
}
TEST(NodeAttr, string) {
NodeAttr x;
x.String() = "Hello";
ASSERT_EQ(x.String(), "Hello");
}
TEST(Node, Attr) { TEST(Node, Attr) {
// Node is an abstract class, use Value instead for they share the same Attr // Node is an abstract class, use Value instead for they share the same Attr
// logic. // logic.
...@@ -27,6 +45,9 @@ TEST(Node, Attr) { ...@@ -27,6 +45,9 @@ TEST(Node, Attr) {
auto* node = nodes.Create(Node::Type::kValue); auto* node = nodes.Create(Node::Type::kValue);
node->attr("v0").Int32() = 2008; node->attr("v0").Int32() = 2008;
ASSERT_EQ(node->attr("v0").Int32(), 2008); ASSERT_EQ(node->attr("v0").Int32(), 2008);
node->attr("str").String() = "hello world";
ASSERT_EQ(node->attr("str").String(), "hello world");
} }
} // namespace analysis } // namespace analysis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册