提交 78a29bba 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!972 Fix CSE bug for some operations like `DropoutGenMask` which may have random effect

Merge pull request !972 from seatea/fix-cse-bug-for-random-effect
......@@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return changed;
}
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool has_random_effect = false;
auto prim_main = GetCNodePrimitive(main);
auto prim_node = GetCNodePrimitive(node);
if (prim_main == prim_node) {
return false;
}
if (prim_main != nullptr) {
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
has_random_effect = GetValue<bool>(effect_val);
}
}
return has_random_effect;
}
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
......@@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
break;
}
}
if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) {
if (CheckRandomEffect(c_main, c_node)) {
appsame = false;
}
replace = appsame;
......
......@@ -43,6 +43,8 @@ class CSE {
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
private:
......
......@@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
} // namespace mindspore
......@@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
} // namespace mindspore
......
......@@ -1872,6 +1872,7 @@ class DropoutGenMask(Primitive):
self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
validator.check_value_type("Seed0", Seed0, [int], self.name)
validator.check_value_type("Seed1", Seed1, [int], self.name)
self.add_prim_attr("_random_effect", True)
class DropoutDoMask(PrimitiveWithInfer):
......
......@@ -162,7 +162,7 @@ def test_bert_tdt():
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661]
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, 12.40294, 12.621653]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册