提交 172a4687 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3968 Ignore node with _side_effect attr in cse

Merge pull request !3968 from YuJianfeng/cse
......@@ -16,6 +16,8 @@
#include "backend/optimizer/pass/common_subexpression_elimination.h"
#include <memory>
#include "runtime/device/kernel_info.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/flags.h"
namespace mindspore {
namespace opt {
......@@ -33,48 +35,60 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
}
return false;
}
bool HasSideEffectAttr(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(GRAPH_FLAG_SIDE_EFFECT, cnode)) {
return false;
}
return AnfAlgo::GetNodeAttr<bool>(cnode, GRAPH_FLAG_SIDE_EFFECT);
}
} // namespace
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const {
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
bool replace = false;
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
auto main_value = GetValueNode(main);
auto node_value = GetValueNode(node);
if (main_value->isa<Primitive>() && node_value->isa<Primitive>()) {
replace = false;
return false;
} else if (main_value->isa<tensor::Tensor>() && node_value->isa<tensor::Tensor>()) {
replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node);
return (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node);
} else {
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
}
} else if (main->isa<CNode>() && node->isa<CNode>()) {
if (check_side_effect && HasSideEffectAttr(main)) {
return false;
}
if (!CheckEqualKernelBuildInfo(main, node)) {
replace = false;
} else {
auto c_main = main->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_main);
auto c_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_node);
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() == inp2.size()) {
bool appsame = true;
for (size_t j = 0; j < inp1.size(); j++) {
MS_EXCEPTION_IF_NULL(inp1[j]);
MS_EXCEPTION_IF_NULL(inp2[j]);
if (!(*inp1[j] == *inp2[j])) {
appsame = false;
break;
}
}
replace = appsame;
return false;
}
auto c_main = main->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_main);
auto c_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_node);
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() != inp2.size()) {
return false;
}
for (size_t j = 0; j < inp1.size(); j++) {
auto inp1_j = inp1[j];
auto inp2_j = inp2[j];
MS_EXCEPTION_IF_NULL(inp1_j);
MS_EXCEPTION_IF_NULL(inp2_j);
if (!(*inp1_j == *inp2_j)) {
return false;
}
}
return true;
}
return replace;
return false;
}
bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册