未验证 提交 071a7020 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix the error misjudgment when there are control nodes in graph. (#23943)

上级 490db7f3
......@@ -88,38 +88,52 @@ bool GroupDetector::CheckPrecondition(const Node* n) {
return true;
};
return n && n->IsOp() && n->Op() && check_data_type(n->inputs) &&
check_data_type(n->outputs);
auto check_running_on_cpu = [&](const Node* n) -> bool {
if (n && n->IsOp() && n->Op()) {
auto* op = n->Op();
bool is_run_on_cpu = false;
if (op->HasAttr("force_cpu") &&
op->GetAttrType("force_cpu") == proto::AttrType::BOOLEAN) {
is_run_on_cpu = op->GetAttrIfExists<bool>("force_cpu");
}
if (op->HasAttr("op_device")) {
is_run_on_cpu = op->GetAttrIfExists<std::string>("op_device") == "cpu";
}
return is_run_on_cpu;
}
return false;
};
return n && n->IsOp() && n->Op() && !check_running_on_cpu(n) &&
check_data_type(n->inputs) && check_data_type(n->outputs);
}
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
// Check whether all inputs have the same shape.
bool is_first = true;
std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i];
if (!(in_i && in_i->IsVar() && in_i->Var())) {
return false;
}
std::vector<int64_t> shape_i = in_i->Var()->GetShape();
if (i == 0U) {
shape_0 = shape_i;
} else {
if (!IsEqualAndNotEmpty(shape_0, shape_i)) {
return false;
for (auto* in_i : n->inputs) {
if (in_i && in_i->IsVar() && in_i->Var()) {
std::vector<int64_t> shape_i = in_i->Var()->GetShape();
if (is_first) {
shape_0 = shape_i;
is_first = false;
} else {
if (!IsEqualAndNotEmpty(shape_0, shape_i)) {
return false;
}
}
}
}
auto op = n->Op();
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
if (op->Output(name).size() != 1) return false;
if (op->Output(name).size() < 1U) {
return false;
}
}
return true;
}
return false;
......
......@@ -170,7 +170,7 @@ class SubGraph {
}
for (auto* n : nodes_set_) {
if (n && n->IsVar() && n->Var()) {
if (n && ((n->IsVar() && n->Var()) || n->IsCtrlVar())) {
// Set the input of subgraph's input var node to null.
std::vector<Node*> inputs;
for (auto* in : n->inputs) {
......
......@@ -484,7 +484,7 @@ static std::string DebugString(OpDesc* op) {
return os.str();
}
static std::string DebugString(Node* node) {
static std::string DebugString(const Node* node) {
std::ostringstream os;
if (node->IsOp() && node->Op()) {
OpDesc* op = node->Op();
......@@ -553,7 +553,7 @@ static std::string DebugString(const std::vector<Node*>& nodes) {
for (auto* node : nodes) {
if (node->IsOp() && node->Op()) {
os << " ";
} else if (node->IsVar() && node->Var()) {
} else if ((node->IsVar() && node->Var()) || node->IsCtrlVar()) {
os << " ";
}
os << DebugString(node) << "\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册