From c856ac8721d6ce41e4cb89c9d6fb369fb0f8d739 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Thu, 29 Nov 2018 20:30:44 +0800 Subject: [PATCH] add OpHasAttr in node.h, update is_test_pass and mkldnn_placement_pass test=develop --- paddle/fluid/framework/ir/is_test_pass.cc | 2 +- .../fluid/framework/ir/is_test_pass_tester.cc | 4 +-- .../framework/ir/mkldnn_placement_pass.cc | 2 +- paddle/fluid/framework/ir/node.cc | 26 ++++++++++++++++++- paddle/fluid/framework/ir/node.h | 2 ++ 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 292f232ffce..a61bd5f2913 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -38,7 +38,7 @@ std::unique_ptr IsTestPass::ApplyImpl( for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); - if (op->HasAttr("is_test")) { + if (n->OpHasAttr("is_test")) { op->SetAttr("is_test", true); } else if (std::find(begin(op_list), end(op_list), op->Type()) != end(op_list)) { diff --git a/paddle/fluid/framework/ir/is_test_pass_tester.cc b/paddle/fluid/framework/ir/is_test_pass_tester.cc index 9696441a216..a5fb0abb3c2 100644 --- a/paddle/fluid/framework/ir/is_test_pass_tester.cc +++ b/paddle/fluid/framework/ir/is_test_pass_tester.cc @@ -104,9 +104,9 @@ TEST(IsTestPass, basic) { auto* op = node->Op(); auto op_name = boost::get(op->GetAttr("name")); if (op_name == "conv3") { - ASSERT_FALSE(op->HasAttr("is_test")); + ASSERT_FALSE(node->OpHasAttr("is_test")); } else { - ASSERT_TRUE(op->HasAttr("is_test")); + ASSERT_TRUE(node->OpHasAttr("is_test")); EXPECT_TRUE(boost::get(op->GetAttr("is_test"))); } } diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc index 65be69b7f5b..366057b01e7 100644 --- a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc @@ -22,7 +22,7 @@ std::unique_ptr MKLDNNPlacementPass::ApplyImpl( std::unique_ptr graph) const { VLOG(3) << "Aplies MKL-DNN placement strategy."; for (const Node* n : graph->Nodes()) { - if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) { + if (n->IsOp() && n->OpHasAttr("use_mkldnn")) { n->Op()->SetAttr("use_mkldnn", true); } } diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 50d91130889..4c4da10b04d 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_info.h" namespace paddle { namespace framework { @@ -24,10 +25,33 @@ constexpr char Node::kControlDepVarName[]; const char Node::kControlDepVarName[] = "__control_var"; #endif -std::unique_ptr CreateNodeForTest(const std::string& name, +std::unique_ptr CreateNodeForTest(const std::string &name, Node::Type type) { return std::unique_ptr(new Node(name, type)); } + +bool Node::OpHasAttr(const std::string &name) const { + if (Op()->HasAttr(name)) { + return true; + } else { + auto &op_info = OpInfoMap::Instance(); + auto op_type = Op()->Type(); + if (op_info.Has(op_type)) { + auto op_info_ptr = op_info.Get(op_type); + if (op_info_ptr.HasOpProtoAndChecker()) { + const proto::OpProto &proto = op_info_ptr.Proto(); + for (int i = 0; i != proto.attrs_size(); ++i) { + const proto::OpProto::Attr &attr = proto.attrs(i); + if (attr.name() == name) { + return true; + } + } + } + } + } + return false; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d2a393b3f19..ac08006a495 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -108,6 +108,8 @@ class Node { Name().find(ir::Node::kControlDepVarName) != std::string::npos; } + bool OpHasAttr(const std::string& name) const; + std::vector inputs; std::vector outputs; -- GitLab