提交 8463731b 编写于 作者: H huanghui

make those AdamXX and LambXX fusion pass not work for unexpect data type

上级 ef698a93
......@@ -109,6 +109,9 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto new_node = CreateAdamApplyOneNode(func_graph, equiv);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(node->scope());
......
......@@ -146,7 +146,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
if (graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv);
auto fusion_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fusion_node);
......
......@@ -108,6 +108,9 @@ bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2
const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
std::vector<AnfNodePtr> old_pattern_outputs;
if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
return nullptr;
......
......@@ -88,6 +88,9 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_);
MS_EXCEPTION_IF_NULL(mul4);
// Get add3 and match the add3 pattern
......
......@@ -153,6 +153,9 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
AnfNodePtr mul4 = nullptr;
AnfNodePtr real_div0 = nullptr;
AnfNodePtr real_div1 = nullptr;
......
......@@ -61,6 +61,9 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto new_node = CreateLambNextRightNode(func_graph, equiv);
MS_EXCEPTION_IF_NULL(new_node);
// Set abstract of new node
......
......@@ -50,6 +50,9 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]);
auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]);
auto input2 = utils::cast<AnfNodePtr>((*equiv)[input2_]);
......
......@@ -42,6 +42,9 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kLambUpdateWithLrV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs),
......
......@@ -765,5 +765,15 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
MS_EXCEPTION_IF_NULL(cnode);
return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
}
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
MS_EXCEPTION_IF_NULL(node);
TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
return true;
}
MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
return false;
}
} // namespace opt
} // namespace mindspore
......@@ -20,6 +20,7 @@
#include <memory>
#include <utility>
#include <string>
#include <set>
#include <unordered_set>
#include "ir/func_graph.h"
#include "session/kernel_graph.h"
......@@ -189,6 +190,9 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2);
// Get attr which is bool from cnode
bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);
// Check node's data type is in supported data type set
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
......@@ -25,6 +25,7 @@
#include <set>
#include "utils/log_adapter.h"
#include "ir/dtype/type.h"
namespace mindspore {
// op name. Op which not exists in operator/ops.h, so define it's name here
......@@ -270,6 +271,8 @@ const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFo
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04};
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
try {
if (chmod(file_name.c_str(), mode) != 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册