diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc index ff7291cd86b4a534df33fa8493036a2b2d560163..1af08ea3e127f781781f798946f3b9a8c52ee1be 100644 --- a/mindspore/ccsrc/optimizer/cse.cc +++ b/mindspore/ccsrc/optimizer/cse.cc @@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } +bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { + bool has_random_effect = false; + auto prim_main = GetCNodePrimitive(main); + auto prim_node = GetCNodePrimitive(node); + if (prim_main == prim_node) { + return false; + } + if (prim_main != nullptr) { + auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); + if (effect_val != nullptr && effect_val->isa()) { + has_random_effect = GetValue(effect_val); + } + } + return has_random_effect; +} + bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); @@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { break; } } - if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) { + if (CheckRandomEffect(c_main, c_node)) { appsame = false; } replace = appsame; diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h index 544e6cb6a36f7c14dc99b781d1258522ea043c68..fd90f61eebc634258f82ca861cf329466d91211a 100644 --- a/mindspore/ccsrc/optimizer/cse.h +++ b/mindspore/ccsrc/optimizer/cse.h @@ -43,6 +43,8 @@ class CSE { virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; + virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; + bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; private: diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 931e9e17b11fbb2708e208c3aa2537e827aa3764..83392784f3e1d059f52b09cd0b37e522087a282d 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; +const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index ed68da17dad8f7231af679f6330dd3053d55c877..74c27ff35d20cecf6227db8fa4cb48be6d6c4213 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; +extern const char GRAPH_FLAG_RANDOM_EFFECT[]; } // namespace mindspore diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a80500c0e6414fc649eb3794e05914aea53b2200..545e1e9f2da669fe0bcd5bf49a2fd3737152e7b2 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1877,6 +1877,7 @@ class DropoutGenMask(Primitive): self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) validator.check_value_type("Seed0", Seed0, [int], self.name) validator.check_value_type("Seed1", Seed1, [int], self.name) + self.add_prim_attr("_random_effect", True) class DropoutDoMask(PrimitiveWithInfer): diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py index 1bd72d0221746514561bbdacc4e1abf7614841b3..e6578af74912bc0476f8c310e3490b41ccd03dac 100644 --- a/tests/st/networks/models/bert/bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py @@ -162,7 +162,7 @@ def test_bert_tdt(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661] + expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, 12.40294, 12.621653] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)