From 981b013f818578d96a355f7d20d540ae8fffec4e Mon Sep 17 00:00:00 2001 From: seatea Date: Thu, 7 May 2020 18:58:56 +0800 Subject: [PATCH] Fix CSE bug for some operations like `DropoutGenMask` which should not be optimized as it will generate different values each time. --- mindspore/ccsrc/optimizer/cse.cc | 18 +++++++++++++++++- mindspore/ccsrc/optimizer/cse.h | 2 ++ mindspore/ccsrc/pybind_api/export_flags.cc | 1 + mindspore/ccsrc/pybind_api/export_flags.h | 1 + mindspore/ops/operations/nn_ops.py | 1 + .../networks/models/bert/bert_tdt_lossscale.py | 2 +- 6 files changed, 23 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc index ff7291cd8..1af08ea3e 100644 --- a/mindspore/ccsrc/optimizer/cse.cc +++ b/mindspore/ccsrc/optimizer/cse.cc @@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } +bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { + bool has_random_effect = false; + auto prim_main = GetCNodePrimitive(main); + auto prim_node = GetCNodePrimitive(node); + if (prim_main == prim_node) { + return false; + } + if (prim_main != nullptr) { + auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); + if (effect_val != nullptr && effect_val->isa()) { + has_random_effect = GetValue(effect_val); + } + } + return has_random_effect; +} + bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); @@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { break; } } - if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) { + if (CheckRandomEffect(c_main, c_node)) { appsame = false; } replace = appsame; diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h index 544e6cb6a..fd90f61ee 100644 --- a/mindspore/ccsrc/optimizer/cse.h +++ b/mindspore/ccsrc/optimizer/cse.h @@ -43,6 +43,8 @@ class CSE { virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; + virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; + bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; private: diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 931e9e17b..83392784f 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; +const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index ed68da17d..74c27ff35 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; +extern const char GRAPH_FLAG_RANDOM_EFFECT[]; } // namespace mindspore diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a80500c0e..545e1e9f2 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1877,6 +1877,7 @@ class DropoutGenMask(Primitive): self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) validator.check_value_type("Seed0", Seed0, [int], self.name) validator.check_value_type("Seed1", Seed1, [int], self.name) + self.add_prim_attr("_random_effect", True) class DropoutDoMask(PrimitiveWithInfer): diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py index 1bd72d022..e6578af74 100644 --- a/tests/st/networks/models/bert/bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py @@ -162,7 +162,7 @@ def test_bert_tdt(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661] + expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, 12.40294, 12.621653] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) -- GitLab