From 071a7020606ef1d121974443c461eca7704cde86 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Sun, 19 Apr 2020 17:26:52 +0800 Subject: [PATCH] Fix the error misjudgment when there are control nodes in graph. (#23943) --- .../elementwise_group_detector.cc | 50 ++++++++++++------- .../framework/ir/fusion_group/subgraph.h | 2 +- .../fluid/framework/ir/pass_tester_helper.h | 4 +- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc index 6e61f10e091..5de253bb967 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -88,38 +88,52 @@ bool GroupDetector::CheckPrecondition(const Node* n) { return true; }; - return n && n->IsOp() && n->Op() && check_data_type(n->inputs) && - check_data_type(n->outputs); + auto check_running_on_cpu = [&](const Node* n) -> bool { + if (n && n->IsOp() && n->Op()) { + auto* op = n->Op(); + bool is_run_on_cpu = false; + if (op->HasAttr("force_cpu") && + op->GetAttrType("force_cpu") == proto::AttrType::BOOLEAN) { + is_run_on_cpu = op->GetAttrIfExists("force_cpu"); + } + if (op->HasAttr("op_device")) { + is_run_on_cpu = op->GetAttrIfExists("op_device") == "cpu"; + } + return is_run_on_cpu; + } + return false; + }; + + return n && n->IsOp() && n->Op() && !check_running_on_cpu(n) && + check_data_type(n->inputs) && check_data_type(n->outputs); } bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) { // Check whether all inputs have the same shape. + bool is_first = true; std::vector shape_0; - for (size_t i = 0; i < n->inputs.size(); ++i) { - auto* in_i = n->inputs[i]; - if (!(in_i && in_i->IsVar() && in_i->Var())) { - return false; - } - - std::vector shape_i = in_i->Var()->GetShape(); - if (i == 0U) { - shape_0 = shape_i; - } else { - if (!IsEqualAndNotEmpty(shape_0, shape_i)) { - return false; + for (auto* in_i : n->inputs) { + if (in_i && in_i->IsVar() && in_i->Var()) { + std::vector shape_i = in_i->Var()->GetShape(); + if (is_first) { + shape_0 = shape_i; + is_first = false; + } else { + if (!IsEqualAndNotEmpty(shape_0, shape_i)) { + return false; + } } } } - auto op = n->Op(); std::vector output_names = OperationMap::Instance().Get(op->Type()).output_names; - for (auto& name : output_names) { - if (op->Output(name).size() != 1) return false; + if (op->Output(name).size() < 1U) { + return false; + } } - return true; } return false; diff --git a/paddle/fluid/framework/ir/fusion_group/subgraph.h b/paddle/fluid/framework/ir/fusion_group/subgraph.h index 4cf2bf48d5d..029166cbe17 100644 --- a/paddle/fluid/framework/ir/fusion_group/subgraph.h +++ b/paddle/fluid/framework/ir/fusion_group/subgraph.h @@ -170,7 +170,7 @@ class SubGraph { } for (auto* n : nodes_set_) { - if (n && n->IsVar() && n->Var()) { + if (n && ((n->IsVar() && n->Var()) || n->IsCtrlVar())) { // Set the input of subgraph's input var node to null. std::vector inputs; for (auto* in : n->inputs) { diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 98cfcfa2000..ac438d368de 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -484,7 +484,7 @@ static std::string DebugString(OpDesc* op) { return os.str(); } -static std::string DebugString(Node* node) { +static std::string DebugString(const Node* node) { std::ostringstream os; if (node->IsOp() && node->Op()) { OpDesc* op = node->Op(); @@ -553,7 +553,7 @@ static std::string DebugString(const std::vector& nodes) { for (auto* node : nodes) { if (node->IsOp() && node->Op()) { os << " "; - } else if (node->IsVar() && node->Var()) { + } else if ((node->IsVar() && node->Var()) || node->IsCtrlVar()) { os << " "; } os << DebugString(node) << "\n"; -- GitLab