From 7a7e499475714703889f6ab51b90eb18f6f9e082 Mon Sep 17 00:00:00 2001 From: BowenK Date: Fri, 21 Aug 2020 20:17:36 +0800 Subject: [PATCH] Revert "Eliminate AllReduce when the input is a constant" This reverts commit f3a9fbdd788b527da46c2fb49aab8870e1569cf4. --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 2 - mindspore/ccsrc/frontend/optimizer/irpass.h | 1 - .../optimizer/irpass/special_op_eliminate.h | 53 ------------------- mindspore/ccsrc/pipeline/jit/pass.cc | 1 - 4 files changed, 57 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index f7e702766..6f6350367 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -81,8 +81,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { reset_defer_inline_ = MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); - all_reduce_const_elim_ = - MakeSubstitution(std::make_shared(), "reduce_all_const_elim", prim::kPrimAllReduce); // Env Item Eliminate env_get_item_eliminate_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index afb485ead..e703f18e7 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -50,7 +50,6 @@ class OptimizeIRPassLib { SubstitutionPtr check_bprop_eliminate_; SubstitutionPtr reset_defer_inline_; SubstitutionPtr depend_value_elim_; - SubstitutionPtr all_reduce_const_elim_; // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index a25e6a104..9ad50b8b3 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -29,8 +29,6 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/prim_eliminate.h" #include "frontend/optimizer/optimizer.h" -#include "utils/comm_manager.h" -#include "frontend/parallel/context.h" namespace mindspore { namespace opt { @@ -205,57 +203,6 @@ class DependValueElim : public OptimizerCaller { return nullptr; } }; - -class AllReduceConstElim : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode x; - auto pattern = PPrimitive(prim::kPrimAllReduce, x); - // If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode) - if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) && - (pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) || - pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) { - auto cur_func_graph = pattern.GetFuncGraph(); - // If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant - auto prim_cnode = pattern.GetOriginalNode(); - MS_EXCEPTION_IF_NULL(prim_cnode); - auto primitive = GetCNodePrimitive(prim_cnode); - auto reduce_op = primitive->GetAttr("op"); - auto group = primitive->GetAttr("group")->ToString(); - // For sum operation, multiply constant tensor by number of devices - if (reduce_op->ToString() == "sum") { - unsigned int num_of_devices; - // Get number of devices - if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) { - MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]"; - } - // Multiply constant by number of devices then return - std::vector mul_inputs; - auto constant_node = x.GetNode(node); - MS_EXCEPTION_IF_NULL(constant_node); - auto constant_value_node = constant_node->cast(); - MS_EXCEPTION_IF_NULL(constant_value_node); - if (!constant_value_node->value()->isa()) { - MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " + - constant_value_node->value()->ToString(); - } - auto constant_tensor = constant_value_node->value()->cast(); - auto tensor_dtype = constant_tensor->Dtype(); - auto num_of_device_node = NewValueNode(std::make_shared((int64_t)num_of_devices, tensor_dtype)); - // Multiply nodes - auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional"); - MS_EXCEPTION_IF_NULL(mul_prim); - mul_inputs.push_back(NewValueNode(mul_prim)); - mul_inputs.push_back(constant_node); - mul_inputs.push_back(num_of_device_node); - return cur_func_graph->NewCNode(mul_inputs); - } else { - return x.GetNode(node); - } - } - return nullptr; - } -}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 113545491..9131f3d18 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -128,7 +128,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.incorporate_env_getitem_switch_, irpass.new_env_get_item_, irpass.depend_value_elim_, - irpass.all_reduce_const_elim_, }); opt::OptPassConfig a_after_grad = opt::OptPassConfig({ irpass.inline_without_move_, -- GitLab