diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 44aba2c3990595cf323681dfc4e705de9e7f3f12..c1bc8ec638a83b2f0fde37ef1c8b7fb45ba061f8 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -16,7 +16,6 @@ #include "backend/optimizer/ascend/ascend_backend_optimization.h" #include #include -#include #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/ascend/ir_fission/bn_split.h" #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" @@ -24,6 +23,7 @@ #include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" #include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" #include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h" +#include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" #include "backend/optimizer/pass/communication_op_fusion.h" @@ -111,18 +111,9 @@ namespace mindspore { namespace opt { namespace { -void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { +void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) { MS_EXCEPTION_IF_NULL(ir_fusion_pm); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -133,10 +124,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -146,6 +133,27 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); +} + +void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { + MS_EXCEPTION_IF_NULL(ir_fusion_pm); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -153,15 +161,12 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } } // namespace @@ -265,9 +270,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); - if (context_ptr->ir_fusion_flag()) { - AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); - } + AddAscendIRFusionRulesPass(ir_fusion_pm.get()); + AddAscendIRFusionPass(ir_fusion_pm.get()); if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc new file mode 100644 index 0000000000000000000000000000000000000000..2857e5e2b0af20294fb16015887fbfad2495ece7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &old_node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(old_node); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimReduceMin->name())), input}; + CNodePtr reduce_min = graph->NewCNode(inputs); + reduce_min->set_scope(old_node->scope()); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min); + return reduce_min; +} + +bool NeedOptmize(const TypeId &dtype, const std::vector &shape, const std::vector &axis) { + if (dtype != kNumberTypeFloat32) { + MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; + return false; + } + if (shape.size() == 0 || shape.size() == 1) { + MS_LOG(INFO) << "ReduceMin's input shape size is " << shape.size() << ", no need optimize!"; + return false; + } + if (axis.size() == 1) { + MS_LOG(INFO) << "ReduceMin axis size is 1, no need optimize!"; + return false; + } + int last_dim = SizeToInt(shape.size() - 1); + if (std::find(axis.begin(), axis.end(), -1) == axis.end() && + std::find(axis.begin(), axis.end(), last_dim) == axis.end()) { + MS_LOG(INFO) << "Attribute of axis does not contain the last axis, not match!"; + return false; + } + return true; +} + +std::vector CalFirstAxis(const std::vector &shape, const std::vector &axis) { + std::vector axis_fisrt; + int last_dim = SizeToInt(shape.size() - 1); + std::copy_if(axis.begin(), axis.end(), std::back_inserter(axis_fisrt), + [&last_dim](int v) { return v != -1 && v != last_dim; }); + + int dim_size = SizeToInt(shape.size()); + if (axis_fisrt.empty()) { + for (int i = 0; i < dim_size - 1; ++i) { + axis_fisrt.push_back(i); + } + } + + for (size_t i = 0; i < axis_fisrt.size(); ++i) { + if (axis_fisrt[i] < -dim_size || axis_fisrt[i] > dim_size - 1) { + MS_LOG(EXCEPTION) << "The axis of ReduceMin verify failed, quit optimizing"; + } + if (axis_fisrt[i] < 0) { + axis_fisrt[i] = dim_size + axis_fisrt[i]; + } + } + return axis_fisrt; +} + +std::vector GetInferShape(const std::vector &shape, const std::vector &axis_first, + bool keep_dims) { + std::vector shape_first; + for (size_t item = 0; item < shape.size(); ++item) { + if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { + if (keep_dims) { + // If keep_dims is true, curretn dimesion set to 1 + shape_first.push_back(1); + } + } else { + // item is not in ConstValueAxis + shape_first.push_back(shape[item]); + } + } + return shape_first; +} +} // namespace + +const BaseRef ReduceMinFission::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::kPrimReduceMin, X}); +} + +const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + if (graph == nullptr || node == nullptr) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, 2); + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); + if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) { + MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!"; + return nullptr; + } + auto axis = AnfAlgo::GetNodeAttr>(cnode, kAttrAxis); + if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { + MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!"; + return nullptr; + } + auto keep_dims = AnfAlgo::GetNodeAttr(cnode, kAttrKeepDims); + + if (!NeedOptmize(dtype, shape, axis)) { + MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); + return nullptr; + } + + // Create reduce_min1 + CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); + std::vector axis_fisrt = CalFirstAxis(shape, axis); + std::vector shape_first = GetInferShape(shape, axis_fisrt, keep_dims); + AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1); + + // Create reduce_min2 + CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode); + reduce_min2->set_abstract(cnode->abstract()); + std::vector axis_last = {-1}; + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_last), reduce_min2); + return reduce_min2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.h new file mode 100644 index 0000000000000000000000000000000000000000..66976cb0b53f7a3b157206e3264fb8dd5d4cadc2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class ReduceMinFission : public PatternProcessPass { + public: + explicit ReduceMinFission(bool multigraph = true) : PatternProcessPass("reduce_min_fission", multigraph) {} + ~ReduceMinFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/reduce_min_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/reduce_min_fission_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1cec41c96682cee4c5192e1730b43d61965b08a --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/reduce_min_fission_test.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "debug/anf_ir_dump.h" +#define private public +#define protected public +#include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWOptReduceMinFission : public BackendCommon { + public: + TestHWOptReduceMinFission() : get_py_fun_("gtest_input.pre_activate.reduce_min_fission_test", true) {} + ~TestHWOptReduceMinFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWOptReduceMinFission, test_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reduce_min_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{32, 32, 32, 32}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + args_spec_list.push_back(x_abstract); + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto split_fission = std::make_shared(); + pm->AddPass(split_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reduce_min_fission", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/reduce_min_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/reduce_min_fission_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7690023e011735d164794552baf19692008a711f --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/reduce_min_fission_test.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') +reduce_min = P.ReduceMin(keep_dims=False) +reduce_min1 = Primitive('ReduceMin') +reduce_min2 = Primitive('ReduceMin') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_reduce_min_fission(tag): + fns = FnDict() + + @fns + def before(x): + res = reduce_min(x, (2, 3)) + return res + + @fns + def after(x): + res = reduce_min1(x) + res = reduce_min2(res) + return make_tuple(res) + + return fns[tag]