diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 292f232ffce48593e1827fe2dfe1b8472360054e..a61bd5f2913f87cfcdaed0fc97c7bdf9a6121d1b 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -38,7 +38,7 @@ std::unique_ptr IsTestPass::ApplyImpl( for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); - if (op->HasAttr("is_test")) { + if (n->OpHasAttr("is_test")) { op->SetAttr("is_test", true); } else if (std::find(begin(op_list), end(op_list), op->Type()) != end(op_list)) { diff --git a/paddle/fluid/framework/ir/is_test_pass_tester.cc b/paddle/fluid/framework/ir/is_test_pass_tester.cc index 9696441a21661db89146c448742a992d1f7df022..a5fb0abb3c23dd064a4c2906905583cc03105bd7 100644 --- a/paddle/fluid/framework/ir/is_test_pass_tester.cc +++ b/paddle/fluid/framework/ir/is_test_pass_tester.cc @@ -104,9 +104,9 @@ TEST(IsTestPass, basic) { auto* op = node->Op(); auto op_name = boost::get(op->GetAttr("name")); if (op_name == "conv3") { - ASSERT_FALSE(op->HasAttr("is_test")); + ASSERT_FALSE(node->OpHasAttr("is_test")); } else { - ASSERT_TRUE(op->HasAttr("is_test")); + ASSERT_TRUE(node->OpHasAttr("is_test")); EXPECT_TRUE(boost::get(op->GetAttr("is_test"))); } } diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc index 65be69b7f5b5e363d5d0753c45f9ff9e3f329fbe..366057b01e764be950cc3f3d9311216d87dca8f4 100644 --- a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc @@ -22,7 +22,7 @@ std::unique_ptr MKLDNNPlacementPass::ApplyImpl( std::unique_ptr graph) const { VLOG(3) << "Aplies MKL-DNN placement strategy."; for (const Node* n : graph->Nodes()) { - if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) { + if (n->IsOp() && n->OpHasAttr("use_mkldnn")) { n->Op()->SetAttr("use_mkldnn", true); } } diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 50d9113088903aa7681d6c6af5cc65f846d32787..4c4da10b04d6387a9d678fbfbf4c12677bb8f786 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_info.h" namespace paddle { namespace framework { @@ -24,10 +25,33 @@ constexpr char Node::kControlDepVarName[]; const char Node::kControlDepVarName[] = "__control_var"; #endif -std::unique_ptr CreateNodeForTest(const std::string& name, +std::unique_ptr CreateNodeForTest(const std::string &name, Node::Type type) { return std::unique_ptr(new Node(name, type)); } + +bool Node::OpHasAttr(const std::string &name) const { + if (Op()->HasAttr(name)) { + return true; + } else { + auto &op_info = OpInfoMap::Instance(); + auto op_type = Op()->Type(); + if (op_info.Has(op_type)) { + auto op_info_ptr = op_info.Get(op_type); + if (op_info_ptr.HasOpProtoAndChecker()) { + const proto::OpProto &proto = op_info_ptr.Proto(); + for (int i = 0; i != proto.attrs_size(); ++i) { + const proto::OpProto::Attr &attr = proto.attrs(i); + if (attr.name() == name) { + return true; + } + } + } + } + } + return false; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d2a393b3f19e9aab79098757dae663d030b0fa2b..ac08006a4953cf1335c8b37316afa36a5157c437 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -108,6 +108,8 @@ class Node { Name().find(ir::Node::kControlDepVarName) != std::string::npos; } + bool OpHasAttr(const std::string& name) const; + std::vector inputs; std::vector outputs;