diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index ff864401b13a528328c7e31009493eed7551dbcf..01387b15da778cd1d0cc99882adfcaebbc8f5388 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -23,6 +23,7 @@ #include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" #include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" #include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" +#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" #include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" #include "pre_activate/pass/communication_op_fusion.h" @@ -149,6 +150,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); } } // namespace diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e6cea5ae555f851eb9199ad871a29e00f3cf22b --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" +#include +#include +#include "session/anf_runtime_algorithm.h" +#include "pre_activate/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(tensor_scatter_update); + std::vector inputs = {NewValueNode(std::make_shared(kTensorMoveOpName)), + tensor_scatter_update->input(1)}; + auto tensor_move = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(tensor_move); + tensor_move->set_scope(tensor_scatter_update->scope()); + tensor_move->set_abstract(tensor_scatter_update->abstract()); + AnfAlgo::SetNodeAttr(kAttrUseLocking, MakeValue(false), tensor_move); + return tensor_move; +} + +CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update, + const CNodePtr &tensor_move) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(tensor_scatter_update); + MS_EXCEPTION_IF_NULL(tensor_move); + std::vector inputs = {NewValueNode(std::make_shared(kScatterNdUpdateOpName)), tensor_move, + tensor_scatter_update->input(2), tensor_scatter_update->input(3)}; + auto scatter_nd_update = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(scatter_nd_update); + scatter_nd_update->set_scope(tensor_scatter_update->scope()); + scatter_nd_update->set_abstract(tensor_scatter_update->abstract()); + return scatter_nd_update; +} +} // namespace + +const BaseRef TensorScatterUpdateFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kTensorScatterUpdateOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr TensorScatterUpdateFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto tensor_scatter_update = node->cast(); + if (tensor_scatter_update == nullptr || tensor_scatter_update->size() != 4) { + return nullptr; + } + auto tensor_move = CreateTensorMove(func_graph, tensor_scatter_update); + return CreateScatterNdUpdate(func_graph, tensor_scatter_update, tensor_move); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h new file mode 100644 index 0000000000000000000000000000000000000000..0ada93ac7086c30fd1757a936ffcd7b3990d642e --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ + +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class TensorScatterUpdateFission : public PatternProcessPass { + public: + explicit TensorScatterUpdateFission(bool multigraph = true) + : PatternProcessPass("tensor_scatter_update_fission", multigraph) {} + ~TensorScatterUpdateFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 477ac350a86ab406fd48d6bc61de215638660f82..e28adb6e2164a62865772fbed200e319e6085c6c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -164,6 +164,18 @@ constexpr auto kStridedReadOpName = "StridedRead"; constexpr auto kStridedWriteOpName = "StridedWrite"; constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; constexpr auto kFusedAdamName = "FusedAdam"; +constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2"; +constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2"; +constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl"; +constexpr auto kSparseApplyFtrlV2OpName = "SparseApplyFtrlV2"; +constexpr auto kApplyKerasMomentumOpName = "ApplyKerasMomentum"; +constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad"; +constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; +constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; +constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; +constexpr auto kTensorMoveOpName = "TensorMove"; +constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate"; +constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate"; // attr key name constexpr auto kAttrInputNames = "input_names"; @@ -224,6 +236,9 @@ constexpr auto kAttrOutputNum = "output_num"; constexpr auto kAttrSizeSplits = "size_splits"; constexpr auto kAttrOutputDefault = "output_default"; constexpr auto kAttrPrimitiveTarget = "primitive_target"; +constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; +constexpr auto kAttrOffset = "offset"; +constexpr auto kAttrUseLocking = "use_locking"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..faebe0e4a0144babdcba0f83b154009a032b7763 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +class TestHWOptTensorScatterUpdateFission : public BackendCommon { + public: + TestHWOptTensorScatterUpdateFission() + : get_py_fun_("gtest_input.pre_activate.tensor_scatter_update_fission_test", true) {} + ~TestHWOptTensorScatterUpdateFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWOptTensorScatterUpdateFission, test_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp1{2, 3}; + std::vector shp2{2, 2}; + std::vector shp3{2}; + auto inputx = std::make_shared(kFloat32, shp1); + auto indices = std::make_shared(kInt32, shp2); + auto update = std::make_shared(kFloat32, shp3); + AbstractBasePtrList args_spec_list{inputx, indices, update}; + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "after"); + EXPECT_NE(g_after, nullptr); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/tensor_scatter_update_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/tensor_scatter_update_fission_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4a84f34607208cbc970fa91a47872bc533790331 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/tensor_scatter_update_fission_test.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +tensor_scatter_update = P.TensorScatterUpdate() +tensor_move = Primitive('TensorMove') +scatter_nd_update = Primitive('ScatterNdUpdate') +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_tensor_scatter_update_fission(tag): + fns = FnDict() + + @fns + def before(x, indices, updates): + res = tensor_scatter_update(x, indices, updates) + return res + + @fns + def after(x, indices, updates): + res = tensor_move(x) + res = scatter_nd_update(res, indices, updates) + return make_tuple(res) + + return fns[tag]