提交 067ed70f 编写于 作者: T Tao Luo

add HasProtoAttr function in op_desc.h, clean node.h

test=develop
上级 e99597d3
...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl( ...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
if (n->RuntimeHasAttr("is_test")) { if (op->HasAttr("is_test") || op->HasProtoAttr("is_test")) {
op->SetAttr("is_test", true); op->SetAttr("is_test", true);
} else if (std::find(begin(op_list), end(op_list), op->Type()) != } else if (std::find(begin(op_list), end(op_list), op->Type()) !=
end(op_list)) { end(op_list)) {
......
...@@ -104,9 +104,9 @@ TEST(IsTestPass, basic) { ...@@ -104,9 +104,9 @@ TEST(IsTestPass, basic) {
auto* op = node->Op(); auto* op = node->Op();
auto op_name = boost::get<std::string>(op->GetAttr("name")); auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv3") { if (op_name == "conv3") {
ASSERT_FALSE(node->RuntimeHasAttr("is_test")); ASSERT_FALSE(op->HasAttr("is_test"));
} else { } else {
ASSERT_TRUE(node->RuntimeHasAttr("is_test")); ASSERT_TRUE(op->HasAttr("is_test"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test"))); EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test")));
} }
} }
......
...@@ -25,12 +25,15 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( ...@@ -25,12 +25,15 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
const auto& op_types_list = const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types"); Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) { if (n->IsOp()) {
if (op_types_list.empty()) { auto* op = n->Op();
n->Op()->SetAttr("use_mkldnn", true); if (op->HasAttr("use_mkldnn") || op->HasProtoAttr("use_mkldnn")) {
} else if (std::find(op_types_list.begin(), op_types_list.end(), if (op_types_list.empty()) {
n->Name()) != op_types_list.end()) { op->SetAttr("use_mkldnn", true);
n->Op()->SetAttr("use_mkldnn", true); } else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
op->SetAttr("use_mkldnn", true);
}
} }
} }
} }
......
...@@ -30,28 +30,6 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name, ...@@ -30,28 +30,6 @@ std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
return std::unique_ptr<Node>(new Node(name, type)); return std::unique_ptr<Node>(new Node(name, type));
} }
bool Node::RuntimeHasAttr(const std::string &name) const {
if (Op()->HasAttr(name)) {
return true;
} else {
auto &op_info = OpInfoMap::Instance();
auto op_type = Op()->Type();
if (op_info.Has(op_type)) {
auto op_info_ptr = op_info.Get(op_type);
if (op_info_ptr.HasOpProtoAndChecker()) {
const proto::OpProto &proto = op_info_ptr.Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return true;
}
}
}
}
}
return false;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -108,18 +108,6 @@ class Node { ...@@ -108,18 +108,6 @@ class Node {
Name().find(ir::Node::kControlDepVarName) != std::string::npos; Name().find(ir::Node::kControlDepVarName) != std::string::npos;
} }
// RuntimeHasAttr is different with HasAttr now.
// 1. For Op()->HasAttr(), it judges whether a stored program_desc_ has attr,
// thus, if stored program_desc_ are old which don't have an attr, a new
// library which adds the attr already will fail on this function.
// Details:
// https://github.com/PaddlePaddle/Paddle/pull/14608#issuecomment-442309087
// 2. For Op()->RuntimeHasAttr, it judges the attr in runtime to avoid above
// problem.
// TODO(luotao): Maybe we should enhance HasAttr later, instead of adding
// RuntimeHasAttr.
bool RuntimeHasAttr(const std::string& name) const;
std::vector<Node*> inputs; std::vector<Node*> inputs;
std::vector<Node*> outputs; std::vector<Node*> outputs;
......
...@@ -237,20 +237,16 @@ void OpDesc::SetOutput(const std::string &param_name, ...@@ -237,20 +237,16 @@ void OpDesc::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args; this->outputs_[param_name] = args;
} }
bool OpDesc::HasAttr(const std::string &name) const { bool OpDesc::HasProtoAttr(const std::string &name) const {
if (attrs_.find(name) != attrs_.end()) { auto &op_info = OpInfoMap::Instance();
return true; if (op_info.Has(desc_.type())) {
} else { auto op_info_ptr = op_info.Get(desc_.type());
auto &op_info = OpInfoMap::Instance(); if (op_info_ptr.HasOpProtoAndChecker()) {
if (op_info.Has(desc_.type())) { const proto::OpProto &proto = op_info_ptr.Proto();
auto op_info_ptr = op_info.Get(desc_.type()); for (int i = 0; i != proto.attrs_size(); ++i) {
if (op_info_ptr.HasOpProtoAndChecker()) { const proto::OpProto::Attr &attr = proto.attrs(i);
const proto::OpProto &proto = op_info_ptr.Proto(); if (attr.name() == name) {
for (int i = 0; i != proto.attrs_size(); ++i) { return true;
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return true;
}
} }
} }
} }
......
...@@ -61,7 +61,11 @@ class OpDesc { ...@@ -61,7 +61,11 @@ class OpDesc {
void SetOutput(const std::string &param_name, void SetOutput(const std::string &param_name,
const std::vector<std::string> &args); const std::vector<std::string> &args);
bool HasAttr(const std::string &name) const; bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
bool HasProtoAttr(const std::string &name) const;
proto::AttrType GetAttrType(const std::string &name) const; proto::AttrType GetAttrType(const std::string &name) const;
......
...@@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray( ...@@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray(
paddle::platform::CPUPlace place) { paddle::platform::CPUPlace place) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (int i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int>(array.shape()[i]));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册