From 8463731bcc9bd34d3e6c8ab6615a2a4d6cba9777 Mon Sep 17 00:00:00 2001 From: huanghui Date: Thu, 18 Jun 2020 19:34:15 +0800 Subject: [PATCH] make those AdamXX and LambXX fusion pass not work for unexpect data type --- .../ascend/ir_fusion/adam_apply_one_fusion.cc | 3 +++ .../ascend/ir_fusion/adam_apply_one_with_decay_rule.cc | 4 +++- .../pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc | 3 +++ .../ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc | 3 +++ .../ir_fusion/lamb_next_mv_with_decay_v1_rule.cc | 3 +++ .../ascend/ir_fusion/lamb_next_right_rule.cc | 3 +++ .../ir_fusion/lamb_update_with_lr_rule_fusion.cc | 3 +++ .../ascend/ir_fusion/lamb_update_with_lr_v2.cc | 3 +++ mindspore/ccsrc/pre_activate/common/helper.cc | 10 ++++++++++ mindspore/ccsrc/pre_activate/common/helper.h | 4 ++++ mindspore/ccsrc/utils/utils.h | 3 +++ 11 files changed, 41 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc index 464516719..59be003b1 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -109,6 +109,9 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto new_node = CreateAdamApplyOneNode(func_graph, equiv); MS_EXCEPTION_IF_NULL(new_node); new_node->set_scope(node->scope()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc index 7dc13ee7a..f6077c95f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc @@ -146,7 +146,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c if (graph == nullptr || node == nullptr || equiv == nullptr) { return nullptr; } - + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } std::vector inputs = GetFusionNodeInputs(equiv); auto fusion_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(fusion_node); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc index 5f0b86964..42e37df3e 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc @@ -108,6 +108,9 @@ bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2 const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } std::vector old_pattern_outputs; if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) { return nullptr; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc index e0389309a..0e3cd28a6 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc @@ -88,6 +88,9 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); MS_EXCEPTION_IF_NULL(mul4); // Get add3 and match the add3 pattern diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc index 9efd50336..26828f213 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc @@ -153,6 +153,9 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra if (func_graph == nullptr || node == nullptr || equiv == nullptr) { return nullptr; } + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } AnfNodePtr mul4 = nullptr; AnfNodePtr real_div0 = nullptr; AnfNodePtr real_div1 = nullptr; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc index 68baeeed9..5065c4c5b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc @@ -61,6 +61,9 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto new_node = CreateLambNextRightNode(func_graph, equiv); MS_EXCEPTION_IF_NULL(new_node); // Set abstract of new node diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc index 16a43e207..b5b6d2bb0 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc @@ -50,6 +50,9 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph, MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto input0 = utils::cast((*equiv)[input0_]); auto input1 = utils::cast((*equiv)[input1_]); auto input2 = utils::cast((*equiv)[input2_]); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc index 069581b6e..43e187216 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc @@ -42,6 +42,9 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } auto prim = std::make_shared(kLambUpdateWithLrV2OpName); std::vector inputs = {NewValueNode(prim)}; (void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs), diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 290fd24b5..c59260564 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -765,5 +765,15 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { MS_EXCEPTION_IF_NULL(cnode); return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); } + +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set) { + MS_EXCEPTION_IF_NULL(node); + TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0); + if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) { + return true; + } + MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); + return false; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index ead338d0a..59fba21d5 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "ir/func_graph.h" #include "session/kernel_graph.h" @@ -189,6 +190,9 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); // Get attr which is bool from cnode bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); + +// Check node's data type is in supported data type set +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index f70527b39..53afb150c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -25,6 +25,7 @@ #include #include "utils/log_adapter.h" +#include "ir/dtype/type.h" namespace mindspore { // op name. Op which not exists in operator/ops.h, so define it's name here @@ -270,6 +271,8 @@ const std::set kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFo kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; +const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; + static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { try { if (chmod(file_name.c_str(), mode) != 0) { -- GitLab