diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc index 4645167191407baa13b6733ed6abeeb3e9e72da0..59be003b150ff6c63b23aa8d78f6bb7c887cbda8 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -109,6 +109,9 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto new_node = CreateAdamApplyOneNode(func_graph, equiv); MS_EXCEPTION_IF_NULL(new_node); new_node->set_scope(node->scope()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc index 7dc13ee7a7d33dec98788400955d8c3bf850d103..f6077c95f25ac9c4498b2cd746abaecf83d63988 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc @@ -146,7 +146,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c if (graph == nullptr || node == nullptr || equiv == nullptr) { return nullptr; } - + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } std::vector inputs = GetFusionNodeInputs(equiv); auto fusion_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(fusion_node); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc index 5f0b86964476197badbdc6ea089fcf869d89a929..42e37df3e499dd677e2abbca028731ca09104a39 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc @@ -108,6 +108,9 @@ bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2 const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } std::vector old_pattern_outputs; if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) { return nullptr; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc index e0389309a10dac4b803f452b66b08ed2da234086..0e3cd28a665fb2a99bd098ebbe64da0489010a27 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc @@ -88,6 +88,9 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); MS_EXCEPTION_IF_NULL(mul4); // Get add3 and match the add3 pattern diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc index 9efd5033639b6d88e6d19278d4bfbe1e0f78a921..26828f2137c150c6cbf981a251969a6c99a9dc8d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc @@ -153,6 +153,9 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra if (func_graph == nullptr || node == nullptr || equiv == nullptr) { return nullptr; } + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } AnfNodePtr mul4 = nullptr; AnfNodePtr real_div0 = nullptr; AnfNodePtr real_div1 = nullptr; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc index 68baeeed992c055fd968f9a336e6326c65b1af7b..5065c4c5bab2ceaa4524a81c38e35043a4e18e83 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc @@ -61,6 +61,9 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto new_node = CreateLambNextRightNode(func_graph, equiv); MS_EXCEPTION_IF_NULL(new_node); // Set abstract of new node diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc index 16a43e207207f32fae2c8aed7dd6034bdd42ec30..b5b6d2bb085af34ee6471e7cc9fb6e880e2bda6f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc @@ -50,6 +50,9 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph, MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto input0 = utils::cast((*equiv)[input0_]); auto input1 = utils::cast((*equiv)[input1_]); auto input2 = utils::cast((*equiv)[input2_]); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc index 069581b6e4a391a80c0aa90352f50885c98580ea..43e18721630fd10d031da4e73aace51ae109e845 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc @@ -42,6 +42,9 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto prim = std::make_shared(kLambUpdateWithLrV2OpName); std::vector inputs = {NewValueNode(prim)}; (void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs), diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 290fd24b59d74d44f5e047960ced8edf19d4c9bf..c59260564abf708aee316c06b3e0ad9bd777329b 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -765,5 +765,15 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { MS_EXCEPTION_IF_NULL(cnode); return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); } + +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set) { + MS_EXCEPTION_IF_NULL(node); + TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0); + if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) { + return true; + } + MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); + return false; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index ead338d0af7abe5e244cad4334f980fe0f825fb7..59fba21d55bcec5cfe74807e343669f9b52b4cc0 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "ir/func_graph.h" #include "session/kernel_graph.h" @@ -189,6 +190,9 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); // Get attr which is bool from cnode bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); + +// Check node's data type is in supported data type set +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index f70527b391544aca100d9f92226d6ca7fed96c17..53afb150cafa761a5152c02432b3ad9af64e314a 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -25,6 +25,7 @@ #include #include "utils/log_adapter.h" +#include "ir/dtype/type.h" namespace mindspore { // op name. Op which not exists in operator/ops.h, so define it's name here @@ -270,6 +271,8 @@ const std::set kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFo kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; +const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; + static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { try { if (chmod(file_name.c_str(), mode) != 0) {