提交 c856ac87 编写于 作者: T Tao Luo

add OpHasAttr in node.h, update is_test_pass and mkldnn_placement_pass

test=develop
上级 44debca8
...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl( ...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
if (op->HasAttr("is_test")) { if (n->OpHasAttr("is_test")) {
op->SetAttr("is_test", true); op->SetAttr("is_test", true);
} else if (std::find(begin(op_list), end(op_list), op->Type()) != } else if (std::find(begin(op_list), end(op_list), op->Type()) !=
end(op_list)) { end(op_list)) {
......
...@@ -104,9 +104,9 @@ TEST(IsTestPass, basic) { ...@@ -104,9 +104,9 @@ TEST(IsTestPass, basic) {
auto* op = node->Op(); auto* op = node->Op();
auto op_name = boost::get<std::string>(op->GetAttr("name")); auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv3") { if (op_name == "conv3") {
ASSERT_FALSE(op->HasAttr("is_test")); ASSERT_FALSE(node->OpHasAttr("is_test"));
} else { } else {
ASSERT_TRUE(op->HasAttr("is_test")); ASSERT_TRUE(node->OpHasAttr("is_test"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test"))); EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test")));
} }
} }
......
...@@ -22,7 +22,7 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( ...@@ -22,7 +22,7 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy."; VLOG(3) << "Aplies MKL-DNN placement strategy.";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) { if (n->IsOp() && n->OpHasAttr("use_mkldnn")) {
n->Op()->SetAttr("use_mkldnn", true); n->Op()->SetAttr("use_mkldnn", true);
} }
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -24,10 +25,33 @@ constexpr char Node::kControlDepVarName[]; ...@@ -24,10 +25,33 @@ constexpr char Node::kControlDepVarName[];
const char Node::kControlDepVarName[] = "__control_var"; const char Node::kControlDepVarName[] = "__control_var";
#endif #endif
std::unique_ptr<Node> CreateNodeForTest(const std::string& name, std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
Node::Type type) { Node::Type type) {
return std::unique_ptr<Node>(new Node(name, type)); return std::unique_ptr<Node>(new Node(name, type));
} }
bool Node::OpHasAttr(const std::string &name) const {
if (Op()->HasAttr(name)) {
return true;
} else {
auto &op_info = OpInfoMap::Instance();
auto op_type = Op()->Type();
if (op_info.Has(op_type)) {
auto op_info_ptr = op_info.Get(op_type);
if (op_info_ptr.HasOpProtoAndChecker()) {
const proto::OpProto &proto = op_info_ptr.Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return true;
}
}
}
}
}
return false;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -108,6 +108,8 @@ class Node { ...@@ -108,6 +108,8 @@ class Node {
Name().find(ir::Node::kControlDepVarName) != std::string::npos; Name().find(ir::Node::kControlDepVarName) != std::string::npos;
} }
bool OpHasAttr(const std::string& name) const;
std::vector<Node*> inputs; std::vector<Node*> inputs;
std::vector<Node*> outputs; std::vector<Node*> outputs;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册