diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index f7e7027664d2d2e067c5e9fa96e2edd3a18b2a99..6f63503678b5f246130929bc2f22bfe69ffff72d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -81,8 +81,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { reset_defer_inline_ = MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); - all_reduce_const_elim_ = - MakeSubstitution(std::make_shared(), "reduce_all_const_elim", prim::kPrimAllReduce); // Env Item Eliminate env_get_item_eliminate_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index afb485ead89ef17ebe1d90129b89c978614cad81..e703f18e7d7ef1793e23c9cf003151080c187f84 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -50,7 +50,6 @@ class OptimizeIRPassLib { SubstitutionPtr check_bprop_eliminate_; SubstitutionPtr reset_defer_inline_; SubstitutionPtr depend_value_elim_; - SubstitutionPtr all_reduce_const_elim_; // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index a25e6a104bd08edb052daf1adac7111e6931443e..9ad50b8b33f63e983f0e6fbcb5b095cd33b9fe79 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -29,8 +29,6 @@ #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass/prim_eliminate.h" #include "frontend/optimizer/optimizer.h" -#include "utils/comm_manager.h" -#include "frontend/parallel/context.h" namespace mindspore { namespace opt { @@ -205,57 +203,6 @@ class DependValueElim : public OptimizerCaller { return nullptr; } }; - -class AllReduceConstElim : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode x; - auto pattern = PPrimitive(prim::kPrimAllReduce, x); - // If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode) - if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) && - (pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) || - pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) { - auto cur_func_graph = pattern.GetFuncGraph(); - // If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant - auto prim_cnode = pattern.GetOriginalNode(); - MS_EXCEPTION_IF_NULL(prim_cnode); - auto primitive = GetCNodePrimitive(prim_cnode); - auto reduce_op = primitive->GetAttr("op"); - auto group = primitive->GetAttr("group")->ToString(); - // For sum operation, multiply constant tensor by number of devices - if (reduce_op->ToString() == "sum") { - unsigned int num_of_devices; - // Get number of devices - if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) { - MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]"; - } - // Multiply constant by number of devices then return - std::vector mul_inputs; - auto constant_node = x.GetNode(node); - MS_EXCEPTION_IF_NULL(constant_node); - auto constant_value_node = constant_node->cast(); - MS_EXCEPTION_IF_NULL(constant_value_node); - if (!constant_value_node->value()->isa()) { - MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " + - constant_value_node->value()->ToString(); - } - auto constant_tensor = constant_value_node->value()->cast(); - auto tensor_dtype = constant_tensor->Dtype(); - auto num_of_device_node = NewValueNode(std::make_shared((int64_t)num_of_devices, tensor_dtype)); - // Multiply nodes - auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional"); - MS_EXCEPTION_IF_NULL(mul_prim); - mul_inputs.push_back(NewValueNode(mul_prim)); - mul_inputs.push_back(constant_node); - mul_inputs.push_back(num_of_device_node); - return cur_func_graph->NewCNode(mul_inputs); - } else { - return x.GetNode(node); - } - } - return nullptr; - } -}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 113545491f8a4c072ea63f8d10e342f21d61e379..9131f3d1858e0e13294d348fc3e35a55afacee01 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -128,7 +128,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.incorporate_env_getitem_switch_, irpass.new_env_get_item_, irpass.depend_value_elim_, - irpass.all_reduce_const_elim_, }); opt::OptPassConfig a_after_grad = opt::OptPassConfig({ irpass.inline_without_move_,