diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index 43857dddfd8a6bdfc9acb3ea93ba0b06ffbdde5d..c1cb308338d56fe7749d4ebd449a8277252115c1 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -37,7 +37,7 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { std::string op_name = AnfAlgo::GetCNodeName(cnode); auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); // deal ref op - if (op_info->is_ref()) { + if (op_info != nullptr && op_info->is_ref()) { auto ref_infos = op_info->ref_infos(); if (ref_infos.count(cur_out_index) != 0) { auto in_index = ref_infos.at(cur_out_index); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc index 5e265f2cf19d7ed11971b277b16567c7fa18a2b7..fa2815ff62e14f9d71d08203ff1274ea90f64c98 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc @@ -23,33 +23,33 @@ namespace mindspore { namespace opt { const BaseRef RemoveReshapePair::DefinePattern() const { - const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); - VectorRef reshape({prim_reshape, input_varptr_}); - - return VectorRef({prim::kPrimReshape, reshape}); + VarPtr X = std::make_shared(); + MS_EXCEPTION_IF_NULL(X); + return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})}); } const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(reshape_op_1); // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly - auto users = manager->node_users()[reshape_op_1]; - if (users.size() > 1) { + if (IsUsedByOthers(func_graph, reshape_op_1)) { return nullptr; } auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(reshape_op_2); - users = manager->node_users()[reshape_op_2]; - if (users.size() > 1) { + if (IsUsedByOthers(func_graph, reshape_op_2)) { return nullptr; } - auto input_node = reshape_op_2->input(1); - return input_node; + auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); + if (input_shape == output_shape) { + auto input_node = reshape_op_2->input(1); + return input_node; + } + return nullptr; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h index a284f4eaa95ab730d7a90ab2e8ed205778be982d..ddb25df70c650f74140e1b752c8e48be8c5173d0 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h @@ -28,15 +28,10 @@ namespace mindspore { namespace opt { class RemoveReshapePair : public PatternProcessPass { public: - explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) { - input_varptr_ = std::make_shared(); - } + explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {} ~RemoveReshapePair() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; }; } // namespace opt } // namespace mindspore