diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc index 133a7e764a76e1db4ddfbb98c7389f6b4289571a..f2062b6f39bc37a5d98387eead9aef7a9709c346 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc @@ -16,6 +16,8 @@ #include "backend/optimizer/pass/common_subexpression_elimination.h" #include #include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/flags.h" namespace mindspore { namespace opt { @@ -33,48 +35,60 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { } return false; } + +bool HasSideEffectAttr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(GRAPH_FLAG_SIDE_EFFECT, cnode)) { + return false; + } + return AnfAlgo::GetNodeAttr(cnode, GRAPH_FLAG_SIDE_EFFECT); +} } // namespace -bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { +bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); - bool replace = false; if (main->isa() && node->isa()) { auto main_value = GetValueNode(main); auto node_value = GetValueNode(node); if (main_value->isa() && node_value->isa()) { - replace = false; + return false; } else if (main_value->isa() && node_value->isa()) { - replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); + return (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); } else { - replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); } } else if (main->isa() && node->isa()) { + if (check_side_effect && HasSideEffectAttr(main)) { + return false; + } if (!CheckEqualKernelBuildInfo(main, node)) { - replace = false; - } else { - auto c_main = main->cast(); - MS_EXCEPTION_IF_NULL(c_main); - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - const auto &inp1 = c_main->inputs(); - const auto &inp2 = c_node->inputs(); - if (inp1.size() == inp2.size()) { - bool appsame = true; - for (size_t j = 0; j < inp1.size(); j++) { - MS_EXCEPTION_IF_NULL(inp1[j]); - MS_EXCEPTION_IF_NULL(inp2[j]); - if (!(*inp1[j] == *inp2[j])) { - appsame = false; - break; - } - } - replace = appsame; + return false; + } + auto c_main = main->cast(); + MS_EXCEPTION_IF_NULL(c_main); + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() != inp2.size()) { + return false; + } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + return false; } } + return true; } - return replace; + return false; } bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {