diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 981e2255f366249a98c2dd0328d507b3da7e8de7..a4555372821a824f5b0781b6a87a0262153403be 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -94,6 +94,7 @@ #include "pre_activate/ascend/ir_fission/split_fission.h" #include "pre_activate/ascend/format_type/modify_ops_attrs.h" #include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" +#include "pre_activate/ascend/ir_fusion/add_input_to_output.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -259,6 +260,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index ad48ca5291a5b2a9639aa44bb030cb6bcd74a4e6..dc88ca2e521cb4ce6b0799d07dc2590d341dc1c7 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -70,6 +70,21 @@ class KernelQuery { } }; using KernelQueryPtr = std::shared_ptr; + +class OpFinder { + public: + OpFinder() = default; + virtual ~OpFinder() = default; + virtual int GetOpRegisteredOutputNum(const std::string &op_name) { + auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE); + if (op_info == nullptr) { + return -1; + } + return op_info->outputs_ptr().size(); + } +}; +using OpFinderPtr = std::shared_ptr; + void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc new file mode 100644 index 0000000000000000000000000000000000000000..867f30b9d27a3571b1f35c01798469590cdfdf7e --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fusion/add_input_to_output.h" +#include +#include +#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h" +#include "session/anf_runtime_algorithm.h" +#include "kernel/oplib/oplib.h" + +namespace mindspore { +namespace opt { +namespace { +void GetInputOrOutputNames(const CNodePtr &cnode, const std::string &attr_name, std::vector *names_vec) { + MS_EXCEPTION_IF_NULL(names_vec); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr names_value = primitive->GetAttr(attr_name); + if (names_value == nullptr) { + return; + } + *names_vec = GetValue>(names_value); +} + +void AddOutputs(const CNodePtr &cnode, const std::vector &input_indices) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector input_names_vec; + GetInputOrOutputNames(cnode, kAttrInputNames, &input_names_vec); + std::vector output_names_vec; + GetInputOrOutputNames(cnode, kAttrOutputNames, &output_names_vec); + AbstractBasePtrList abstract_list; + auto origin_abstract = cnode->abstract(); + MS_EXCEPTION_IF_NULL(origin_abstract); + if (origin_abstract->isa()) { + auto origin_abstract_tuple = dyn_cast(origin_abstract); + MS_EXCEPTION_IF_NULL(origin_abstract_tuple); + AbstractBasePtrList origin_abstract_list = origin_abstract_tuple->elements(); + (void)std::copy(origin_abstract_list.begin(), origin_abstract_list.end(), std::back_inserter(abstract_list)); + } else { + abstract_list.emplace_back(origin_abstract); + } + + for (size_t i = 0; i < input_indices.size(); ++i) { + size_t index = input_indices[i]; + if (index + 1 >= cnode->inputs().size()) { + MS_LOG(INFO) << "The input index " << index << " for converting to output is out of range, " + << "node: " << cnode->DebugString(); + continue; + } + auto node_to_output = cnode->input(index + 1); + MS_EXCEPTION_IF_NULL(node_to_output); + abstract_list.emplace_back(node_to_output->abstract()); + if (!input_names_vec.empty() && !output_names_vec.empty() && index < input_names_vec.size()) { + output_names_vec.emplace_back(input_names_vec[index]); + } + } + if (!output_names_vec.empty()) { + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names_vec), cnode); + } + auto abstract_tuple = std::make_shared(abstract_list); + cnode->set_abstract(abstract_tuple); +} +} // namespace + +const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string op_name = AnfAlgo::GetCNodeName(cnode); + InputToOutputRegister reg; + if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) { + return nullptr; + } + int output_num = op_finder_->GetOpRegisteredOutputNum(op_name); + // No need add output when it is not a tbe op. + if (output_num == -1) { + return nullptr; + } + // No need add output if the output num matches the registered output num for tbe. + if (AnfAlgo::GetOutputTensorNum(cnode) >= IntToSize(output_num)) { + return nullptr; + } + bool is_origin_tuple_output = AnfAlgo::IsTupleOutput(cnode); + AddOutputs(cnode, reg.input_indices()); + // No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems + // pointed to the outputs. + if (is_origin_tuple_output) { + return nullptr; + } + std::vector new_outputs; + auto new_abstract_tuple = dyn_cast(cnode->abstract()); + MS_EXCEPTION_IF_NULL(new_abstract_tuple); + CreateMultipleOutputsOfAnfNode(func_graph, cnode, new_abstract_tuple->size(), &new_outputs); + if (new_outputs.size() != new_abstract_tuple->size()) { + MS_LOG(EXCEPTION) << "Failed to create outputs of " << cnode->DebugString(); + } + return new_outputs[0]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h new file mode 100644 index 0000000000000000000000000000000000000000..d57b32f370f4d6959f1f1c9cfc4f8a7f667d12b2 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ + +#include +#include +#include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class AddInputToOutput : public PatternProcessPass { + public: + explicit AddInputToOutput(bool multigraph = true) + : PatternProcessPass("add_input_to_output", multigraph), op_finder_(std::make_shared()) {} + ~AddInputToOutput() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + OpFinderPtr op_finder_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..b82efdf86a9051930ae81f819f44eaef0398bd4d --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h" +#include +#include "utils/utils.h" +#include "session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +bool ApplyRMSPropPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool FusedMulApplyMomentumPreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool SparseApplyRMSPropPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool ApplyAdagradV2PreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool ApplyKerasMomentumPreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool SparseApplyFtrlPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyFtrlV2PreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyAdagradV2PreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyAdadeltaPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} +} // namespace +InputToOutputRegistry::InputToOutputRegistry() { + Register(kApplyRMSPropOpName, {1, 2}, ApplyRMSPropPreCheck); + Register(kFusedMulApplyMomentumOpName, {1}, FusedMulApplyMomentumPreCheck); + Register(kApplyAdagradOpName, {1}); + Register(kApplyAdagradDAName, {1, 2}); + Register(kApplyAdadeltaOpName, {1, 2}); + Register(kApplyPowerSignOpName, {1}); + Register(kApplyProximalAdagradOpName, {1}); + Register(kApplyAdaMaxOpName, {1, 2}); + Register(kApplyAdagradV2OpName, {1}, ApplyAdagradV2PreCheck); + Register(kApplyKerasMomentumOpName, {1}, ApplyKerasMomentumPreCheck); + Register(kSparseApplyFtrlOpName, {1, 2}, SparseApplyFtrlPreCheck); + Register(kSparseApplyFtrlV2OpName, {1, 2}, SparseApplyFtrlV2PreCheck); + Register(kSparseApplyAdagradV2OpName, {1}, SparseApplyAdagradV2PreCheck); + Register(kSparseApplyProximalAdagradOpName, {1}); + Register(kSparseApplyAdagradOpName, {1}); + Register(kApplyFtrlV2OpName, {1, 2}); + Register(kApplyMomentumOpName, {1}); + Register(kApplyFtrlOpName, {1, 2}); + Register(kApplyAdamOpName, {1, 2}); + Register(kApplyCenteredRMSPropOpName, {1, 2, 3}); + Register(kApplyAddSignOpName, {1}); + Register(kSparseApplyRMSPropOpName, {1, 2}, SparseApplyRMSPropPreCheck); + Register(kSparseApplyAdadeltaOpName, {1, 2}, SparseApplyAdadeltaPreCheck); + Register(kApplyAdamWithAmsgradOpName, {1, 2}); +} + +InputToOutputRegistry &InputToOutputRegistry::Instance() { + static InputToOutputRegistry instance; + return instance; +} + +void InputToOutputRegistry::Register(const InputToOutputRegister ®) { + auto op_name = reg.op_name(); + if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { + (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " input2output register successfully!"; + } +} + +void InputToOutputRegistry::Register(const std::string &op_name, const std::vector &input_indices, + const PreCheckFunc &pre_check_func) { + if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { + InputToOutputRegister reg(op_name, pre_check_func); + reg.set_input_indices(input_indices); + (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " input2output register successfully!"; + } +} + +bool InputToOutputRegistry::GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const { + if (op_input_to_output_map_.find(op_name) != op_input_to_output_map_.end()) { + *reg = op_input_to_output_map_.at(op_name); + MS_LOG(DEBUG) << op_name << " input2output find in registry."; + return true; + } + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..45738c289c963085155d7e7afedfbb5db6c1bf58 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ +#include +#include +#include +#include +#include "ir/anf.h" +#include "common/utils.h" + +namespace mindspore { +namespace opt { +using PreCheckFunc = std::function; +class InputToOutputRegister { + public: + explicit InputToOutputRegister( + const std::string &op_name = "", const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; }) + : op_name_(op_name), pre_check_func_(pre_check_func) {} + virtual ~InputToOutputRegister() = default; + + void set_input_indices(const std::vector &input_indices) { input_indices_ = input_indices; } + + const std::vector &input_indices() const { return input_indices_; } + const std::string &op_name() const { return op_name_; } + + private: + std::string op_name_; + std::vector input_indices_; + PreCheckFunc pre_check_func_; +}; + +class InputToOutputRegistry { + public: + static InputToOutputRegistry &Instance(); + void Register(const InputToOutputRegister ®); + void Register( + const std::string &op_name, const std::vector &input_indices, + const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; }); + bool GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const; + + private: + InputToOutputRegistry(); + ~InputToOutputRegistry() = default; + DISABLE_COPY_AND_ASSIGN(InputToOutputRegistry) + std::unordered_map op_input_to_output_map_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index a5ec56cb2f9f34d39e7de6fffafb5da1bdc3c2a2..b3538a3d745633e32002ff24a5058cd55ee1347c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -164,6 +164,15 @@ constexpr auto kStridedReadOpName = "StridedRead"; constexpr auto kStridedWriteOpName = "StridedWrite"; constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; constexpr auto kFusedAdamName = "FusedAdam"; +constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2"; +constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2"; +constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl"; +constexpr auto kSparseApplyFtrlV2OpName = "SparseApplyFtrlV2"; +constexpr auto kApplyKerasMomentumOpName = "ApplyKerasMomentum"; +constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad"; +constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; +constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; +constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; // attr key name constexpr auto kAttrInputNames = "input_names"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b44fa6dc423441a148e988dacaaed2dc4749c7b --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "debug/anf_ir_dump.h" + +#define private public +#define protected public +#include "pre_activate/ascend/ir_fusion/add_input_to_output.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWAddInputToOutput : public BackendCommon { + public: + TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {} + ~TestHWAddInputToOutput() override = default; + + public: + UT::PyFuncGraphFetcher getPyFun_; +}; + +class MockOpFinder : public OpFinder { + public: + MockOpFinder() = default; + ~MockOpFinder() override = default; + int GetOpRegisteredOutputNum(const std::string &op_name) override { return 2; } +}; + +TEST_F(TestHWAddInputToOutput, test_add_input_to_output) { + FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 5; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + auto ret = kg->get_return(); + EXPECT_NE(ret, nullptr); + auto make_tuple = ret->input(1); + EXPECT_NE(make_tuple, nullptr); + auto momentum = make_tuple->cast()->input(1); + EXPECT_NE(momentum, nullptr); + EXPECT_NE(momentum->abstract(), nullptr); + EXPECT_FALSE(momentum->abstract()->isa()); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->op_finder_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kg); + EXPECT_TRUE(momentum->abstract()->isa()); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4fa1fe9630152aab6ab0a95c66e05c6bbe3e01 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import operations as P + +ApplyMomentum = P.ApplyMomentum() + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_add_input_to_output(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4): + return ApplyMomentum(input0, input1, input2, input3, input4) + + return fns[tag]