提交 981b013f 编写于 作者: S seatea

Fix CSE bug for some operations like `DropoutGenMask` which should not

be optimized as it will generate different values each time.
上级 0040764d
...@@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { ...@@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return changed; return changed;
} }
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool has_random_effect = false;
auto prim_main = GetCNodePrimitive(main);
auto prim_node = GetCNodePrimitive(node);
if (prim_main == prim_node) {
return false;
}
if (prim_main != nullptr) {
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
has_random_effect = GetValue<bool>(effect_val);
}
}
return has_random_effect;
}
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
...@@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { ...@@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
break; break;
} }
} }
if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) { if (CheckRandomEffect(c_main, c_node)) {
appsame = false; appsame = false;
} }
replace = appsame; replace = appsame;
......
...@@ -43,6 +43,8 @@ class CSE { ...@@ -43,6 +43,8 @@ class CSE {
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
private: private:
......
...@@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; ...@@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
} // namespace mindspore } // namespace mindspore
...@@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; ...@@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
} // namespace mindspore } // namespace mindspore
......
...@@ -1877,6 +1877,7 @@ class DropoutGenMask(Primitive): ...@@ -1877,6 +1877,7 @@ class DropoutGenMask(Primitive):
self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
validator.check_value_type("Seed0", Seed0, [int], self.name) validator.check_value_type("Seed0", Seed0, [int], self.name)
validator.check_value_type("Seed1", Seed1, [int], self.name) validator.check_value_type("Seed1", Seed1, [int], self.name)
self.add_prim_attr("_random_effect", True)
class DropoutDoMask(PrimitiveWithInfer): class DropoutDoMask(PrimitiveWithInfer):
......
...@@ -162,7 +162,7 @@ def test_bert_tdt(): ...@@ -162,7 +162,7 @@ def test_bert_tdt():
# assertion occurs while the loss value, overflow state or loss_scale value is wrong # assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list) loss_value = np.array(callback.loss_list)
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661] expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, 12.40294, 12.621653]
print("loss value: {}".format(loss_value)) print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册