diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index f5b2ca325381f632ba3555ddcf0d01372fc6e133..f1e827d6dd512def90594e8f02eb477c3a960d2c 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -90,6 +90,7 @@ static std::map tbe_func_adapter_map = { {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, {"lamb_next_mv", "lamb_next_m_v"}, {"split", "split_d"}, + {"split_v", "split_v_d"}, {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, {"pad", "pad_d"}, diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 6705cd4f8f77a76a815cd1f90099e23298f60311..f4f9d8da14bc7e4af794f020f99dd495ac967b8b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -87,6 +87,7 @@ #include "pre_activate/ascend/ir_fission/addn_fission.h" #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include "pre_activate/ascend/ir_fission/split_fission.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -141,6 +142,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } } // namespace diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc new file mode 100644 index 0000000000000000000000000000000000000000..c39a5e01e692638e38e6848f78379a4ecee44614 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc @@ -0,0 +1,191 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fission/split_fission.h" +#include +#include +#include "session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input_node); + std::vector splitv_inputs{NewValueNode(std::make_shared(kSplitVOpName)), input_node}; + CNodePtr splitv = func_graph->NewCNode(splitv_inputs); + MS_EXCEPTION_IF_NULL(splitv); + splitv->set_scope(input_node->scope()); + return splitv; +} + +CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { + MS_EXCEPTION_IF_NULL(origin_cnode); + if (origin_cnode->inputs().size() < kSplitInputNum) { + MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " + << kSplitInputNum - 1; + } + return CreateSplitVNode(func_graph, origin_cnode->input(1)); +} + +void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector &size_splits, int split_dim, int num_split) { + AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv); +} + +size_t GetSmallSplitSize(const AnfNodePtr &split_node, int split_dim, int num_split) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0); + if (split_dim < 0) { + split_dim += input_shape.size(); + } + if (IntToSize(split_dim) >= input_shape.size()) { + MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0"; + } + return input_shape[split_dim] / num_split; +} + +void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int outputs_num, + std::vector *inputs) { + MS_EXCEPTION_IF_NULL(inputs); + std::vector new_splitv_output; + CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, outputs_num, &new_splitv_output); + inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end()); +} + +AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { + MS_EXCEPTION_IF_NULL(func_graph); + auto idx = NewValueNode(SizeToInt(index)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(index)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); + return tuple_getitem; +} + +void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int split_size, int num_split, + std::vector *new_type_ids, + std::vector> *new_output_shapes) { + MS_EXCEPTION_IF_NULL(new_type_ids); + MS_EXCEPTION_IF_NULL(new_output_shapes); + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + output_shape[split_dim] = split_size; + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + for (int i = 0; i < num_split; ++i) { + new_type_ids->emplace_back(type_id); + new_output_shapes->emplace_back(output_shape); + } +} + +void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, + const std::vector &size_splits_base, int split_dim, int num_split) { + SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); + std::vector base_type_ids; + std::vector> base_output_shapes_base; + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + for (int i = 0; i < num_split; ++i) { + output_shape[split_dim] = size_splits_base[i]; + base_output_shapes_base.emplace_back(output_shape); + base_type_ids.emplace_back(type_id); + } + AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); +} + +AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int num_split, int divisor) { + MS_EXCEPTION_IF_NULL(func_graph); + auto split_dim = AnfAlgo::GetNodeAttr(cnode, kAttrAxis); + CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode); + + // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. + auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); + std::vector size_splits_new; + for (int i = 0; i < divisor; ++i) { + size_splits_new.emplace_back(small_split_size); + } + // Create new output shape and new output type id for each new Splitv node which has full inputs. + std::vector new_type_ids; + std::vector> new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); + + // Create make_tuple input to create a make_tuple for replacing the old Split node. + std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; + // Start to divide the outputs of Split. + std::vector size_splits_base; + const auto base_split_size = divisor * small_split_size; + int nodes_num = 0; + int cur_output_index = 0; + while (num_split - cur_output_index > divisor) { + CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); + AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); + cur_output_index += divisor; + size_splits_base.emplace_back(base_split_size); + nodes_num++; + } + if (cur_output_index < num_split) { + auto last_node_num_split = num_split - cur_output_index; + if (last_node_num_split > 1) { + CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + std::vector size_splits_new_last; + for (int i = 0; i < last_node_num_split; ++i) { + size_splits_new_last.emplace_back(small_split_size); + } + SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); + // Create new output shape and new output type id for the last Splitv node + std::vector last_new_type_ids; + std::vector> last_new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids, + &last_new_output_shapes); + AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); + size_splits_base.emplace_back(last_node_num_split * small_split_size); + } else { + make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + size_splits_base.emplace_back(small_split_size); + } + nodes_num++; + } + // Set Attr and abstract for the base splitv + SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num); + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace + +const BaseRef SplitFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto split_prim = std::make_shared(kSplitOpName); + return VectorRef({split_prim, Xs}); +} + +const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Check output num + if (!AnfAlgo::HasNodeAttr(kAttrOutputNum, cnode)) { + return nullptr; + } + auto num_split = AnfAlgo::GetNodeAttr(cnode, kAttrOutputNum); + if (num_split <= outputs_divisor_) { + return nullptr; + } + return DoFission(func_graph, cnode, num_split, outputs_divisor_); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h new file mode 100644 index 0000000000000000000000000000000000000000..c2763bb7141a48b31333fdf08bcb0f287f78ecdc --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ + +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr int kSplitOutputsDivisor = 63; +class SplitFission : public PatternProcessPass { + public: + explicit SplitFission(bool multigraph = true) + : PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {} + ~SplitFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + int outputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 59fba21d55bcec5cfe74807e343669f9b52b4cc0..49a1d47d0c3c6e45f74d9d2a63a898e52a52e189 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -97,6 +97,7 @@ constexpr size_t kBiasAddInputNum = 3; constexpr size_t kTopkInputNum = 3; constexpr size_t kLarsV2InputNum = 5; constexpr size_t kFusedMulApplyMomentumOutputNum = 2; +constexpr size_t kSplitInputNum = 2; enum FusedBatchNormInput { kX = 1, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index b056b2ccdc0ab1c743955947779e9cac5a1796de..97ffd739bbffd66c3917fa162154e71708a89c54 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -72,6 +72,7 @@ constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin"; constexpr auto kFlattenGradOpName = "FlattenGrad"; constexpr auto kExpandDimsOpName = "ExpandDims"; constexpr auto kSplitOpName = "Split"; +constexpr auto kSplitVOpName = "SplitV"; constexpr auto kSparseApplyAdagradOpName = "SparseApplyAdagrad"; constexpr auto kMomentumOpName = "Momentum"; constexpr auto kApplyMomentumOpName = "ApplyMomentum"; @@ -211,6 +212,10 @@ constexpr auto kAttrWaitEvent = "wait_event"; constexpr auto kAttrRecordEventStream = "record_event_stream"; constexpr auto kAttrWaitEventStream = "wait_event_stream"; constexpr auto kAttrIndex = "index"; +constexpr auto kAttrSplitDim = "split_dim"; +constexpr auto kAttrNumSplit = "num_split"; +constexpr auto kAttrOutputNum = "output_num"; +constexpr auto kAttrSizeSplits = "size_splits"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ab70e83480a790dc7f84614329dd86e549cd9c04 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#define private public +#define protected public +#include "pre_activate/ascend/ir_fission/split_fission.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWSplitFission : public BackendCommon { + public: + TestHWSplitFission() : get_py_fun_("gtest_input.pre_activate.split_fission_test", true) {} + ~TestHWSplitFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWSplitFission, test_split_fission_divided_by_3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_split_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{512, 3, 1}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + args_spec_list.push_back(x_abstract); + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto split_fission = std::make_shared(); + split_fission->outputs_divisor_ = 3; + pm->AddPass(split_fission); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_split_fission", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/split_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/split_fission_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b25fa1f5d00905dad1ba6fd315ee8642e15ab607 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/split_fission_test.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +split = P.Split(0, 8) +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') +splitv = Primitive('SplitV') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_split_fission(tag): + """ test_adam_apply_one_with_decay_rule """ + fns = FnDict() + + @fns + def before(x): + return split(x) + + @fns + def after(x): + splitv0 = splitv(x) + splitv1 = splitv(tuple_getitem(splitv0, 0)) + splitv2 = splitv(tuple_getitem(splitv0, 1)) + splitv3 = splitv(tuple_getitem(splitv0, 2)) + make_tuple0 = make_tuple(tuple_getitem(splitv1, 0), tuple_getitem(splitv1, 1), tuple_getitem(splitv1, 2), + tuple_getitem(splitv2, 0), tuple_getitem(splitv2, 1), tuple_getitem(splitv2, 2), + tuple_getitem(splitv3, 0), tuple_getitem(splitv3, 1)) + return make_tuple( + make_tuple(tuple_getitem(make_tuple0, 0), tuple_getitem(make_tuple0, 1), tuple_getitem(make_tuple0, 2), + tuple_getitem(make_tuple0, 3), tuple_getitem(make_tuple0, 4), tuple_getitem(make_tuple0, 5), + tuple_getitem(make_tuple0, 6), tuple_getitem(make_tuple0, 7))) + + return fns[tag]