diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b7f732a6a29934b4a04801689e1a47085fc98ff --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include +#include +#include +#include +#include "session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "pipeline/static_analysis/abstract_value.h" +#include "pre_activate/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kConfusionMulGradOutputNum = 2; + +CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf, + const AnfNodePtr &input3) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(reduce_sum); + MS_EXCEPTION_IF_NULL(mul0_anf); + MS_EXCEPTION_IF_NULL(input3); + auto mul0 = mul0_anf->cast(); + MS_EXCEPTION_IF_NULL(mul0); + + auto prim = std::make_shared(kConfusionMulGradOpName); + std::vector inputs = {NewValueNode(prim), mul0->input(1), mul0->input(2), input3}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(reduce_sum->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); + auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + return fusion_node; +} + +AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const AnfNodePtr &mul1) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input2); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(input2) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + + AnfNodePtr mul0 = nullptr; + const AnfNodeIndexSet &outputs_set = manager->node_users()[input2]; + // input2 must be the 2rd input of mul0 + auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul1](const std::pair &node_index) { + return node_index.first != mul1 && node_index.second == 2; + }); + if (it != outputs_set.end() && AnfAlgo::GetCNodeName(it->first) == prim::kPrimMul->name()) { + mul0 = it->first; + } + return mul0; +} +} // namespace + +const BaseRef ConfusionMulGradFusion::DefinePattern() const { + VectorRef mul1({prim::kPrimMul, input3_, input2_}); + VectorRef reduce_sum({prim::kPrimReduceSum, mul1}); + return reduce_sum; +} + +const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto input2 = utils::cast((*equiv)[input2_]); + auto input3 = utils::cast((*equiv)[input3_]); + auto reduce_sum = node->cast(); + MS_EXCEPTION_IF_NULL(reduce_sum); + auto mul1 = reduce_sum->input(1); + if (IsUsedByOthers(graph, mul1)) { + MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; + return nullptr; + } + auto mul0 = GetMul0(graph, input2, mul1); + if (mul0 == nullptr) { + MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; + return nullptr; + } + + auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3); + std::vector fusion_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, fusion_node, kConfusionMulGradOutputNum, &fusion_node_outputs); + + auto manage = graph->manager(); + MS_EXCEPTION_IF_NULL(manage); + manage->Replace(mul0, fusion_node_outputs[0]); + return fusion_node_outputs[1]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..170df5b0e429982fb3225308bd75ed074d608151 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ + +#include +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConfusionMulGradFusion : public PatternProcessPass { + public: + explicit ConfusionMulGradFusion(bool multigraph = true) + : PatternProcessPass("confusion_mul_grad_fusion", multigraph) { + input2_ = std::make_shared(); + input3_ = std::make_shared(); + } + ~ConfusionMulGradFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input2_; + VarPtr input3_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 2b35168ec350a58d3b33e08c9406913ddacafc74..39b4b7a16001400e492a7cf2fe3a421bb11eeac1 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -111,6 +111,7 @@ constexpr auto kFusedMulAddOpName = "FusedMulAdd"; constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; constexpr auto kBiasAddOpName = "BiasAdd"; +constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; // attr key name constexpr auto kAttrInputNames = "input_names"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3bf09d2cbd594face80e492367c386ae1199b1a --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +class TestHWOptimizeConfusionMulGradFusion : public BackendCommon { + public: + TestHWOptimizeConfusionMulGradFusion() : get_py_fun_("gtest_input.pre_activate.confusion_mul_grad_fusion", true) {} + ~TestHWOptimizeConfusionMulGradFusion() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{1, 1, 1, 1}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 3; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_mul_grad_fusion.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_mul_grad_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f7bcc9960ad77dafa268cf6d8b71e83ac590ed --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_mul_grad_fusion.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import operations as P +from mindspore.ops import Primitive + +mul = P.Mul() +reduce_sum = P.ReduceSum() +confusion_mul_grad = Primitive('ConfusionMulGrad') +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') +axis = 2 + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + +def test_confusion_mul_grad_fusion(tag): + fns = FnDict() + + @fns + def before(input1, input2, input3): + output1 = mul(input1, input2) + mul1 = mul(input3, input2) + # input axis will be convert to attr in step ConstructKernelGraph + output2 = reduce_sum(mul1, axis) + res = make_tuple(output1, output2) + return res + + @fns + def after(input1, input2, input3): + res = confusion_mul_grad(input1, input2, input3) + item0 = tuple_getitem(res, 0) + item1 = tuple_getitem(res, 1) + res = make_tuple(item0, item1) + return make_tuple(res) + + return fns[tag]