From b8d7f6d77f99ab414f6e6912dc35eab48baf658a Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 23 Jun 2020 21:16:05 +0800 Subject: [PATCH] add UnsortedSegmentSum fission pass --- .../ascend/ascend_backend_optimization.cc | 2 + .../unsorted_segment_sum_fission.cc | 118 ++++++++++++++++++ .../ir_fission/unsorted_segment_sum_fission.h | 37 ++++++ .../ccsrc/backend/optimizer/common/helper.h | 1 + mindspore/ccsrc/utils/utils.h | 5 + .../test_unsorted_segment_sum_fission.py | 47 +++++++ .../unsorted_segment_sum_fission_test.cc | 68 ++++++++++ .../unsorted_segment_sum_fission.py | 63 ++++++++++ 8 files changed, 341 insertions(+) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h create mode 100644 tests/st/fusion/test_unsorted_segment_sum_fission.py create mode 100644 tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc create mode 100644 tests/ut/cpp/python_input/gtest_input/pre_activate/unsorted_segment_sum_fission.py diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index dcca95fbc..41f271943 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -26,6 +26,7 @@ #include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" #include "backend/optimizer/pass/communication_op_fusion.h" #include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" #include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" @@ -172,6 +173,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } } // namespace diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc new file mode 100644 index 000000000..6fd81b537 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(origin_node); + std::vector padding_inputs = {NewValueNode(std::make_shared(kPaddingOpName)), + origin_node->input(1)}; + auto padding = graph->NewCNode(padding_inputs); + MS_EXCEPTION_IF_NULL(padding); + padding->set_scope(origin_node->scope()); + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); + shape[shape.size() - 1] = pad_dim_size; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)}, {shape}, + padding.get()); + AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToInt(pad_dim_size)), padding); + return padding; +} + +CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding, + const size_t &pad_dim_size) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(origin_node); + MS_EXCEPTION_IF_NULL(padding); + std::vector unsorted_segment_sum8_inputs = { + NewValueNode(std::make_shared(prim::kPrimUnsortedSegmentSum->name())), padding, origin_node->input(2)}; + auto unsorted_segment_sum = graph->NewCNode(unsorted_segment_sum8_inputs); + MS_EXCEPTION_IF_NULL(unsorted_segment_sum); + unsorted_segment_sum->set_scope(origin_node->scope()); + auto shape = AnfAlgo::GetOutputInferShape(origin_node, 0); + shape[shape.size() - 1] = pad_dim_size; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape}, + unsorted_segment_sum.get()); + AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(SizeToInt(shape[0])), unsorted_segment_sum); + return unsorted_segment_sum; +} + +CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum, + const CNodePtr &unsorted_segment_sum8) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(unsort_segment_sum); + MS_EXCEPTION_IF_NULL(unsorted_segment_sum8); + std::vector slice_inputs = {NewValueNode(std::make_shared(kSliceOpName)), + unsorted_segment_sum8}; + auto slice = graph->NewCNode(slice_inputs); + MS_EXCEPTION_IF_NULL(slice); + slice->set_scope(unsort_segment_sum->scope()); + slice->set_abstract(unsort_segment_sum->abstract()); + auto unsort_segment_sum_shape = AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0); + std::vector offsets(unsort_segment_sum_shape.size(), 0); + AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Int(offsets)), slice); + AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Int(unsort_segment_sum_shape)), slice); + return slice; +} +} // namespace + +const BaseRef UnsortSegmentSumFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VectorRef pattern({prim::kPrimUnsortedSegmentSum, Xs}); + return pattern; +} + +const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto origin_node = node->cast(); + MS_EXCEPTION_IF_NULL(origin_node); + if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { + MS_LOG(INFO) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum + << ". CNode= " << origin_node->DebugString(); + return nullptr; + } + auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); + if (input0_shape[input0_shape.size() - 1] != 1) { + MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is " + << input0_shape[input0_shape.size() - 1]; + return nullptr; + } + size_t pad_dim_size; + auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0); + if (input_dtype == kNumberTypeFloat32) { + pad_dim_size = 8; + } else if (input_dtype == kNumberTypeFloat16) { + pad_dim_size = 16; + } else { + MS_LOG(INFO) << "UnsortedSegmentSum data type not in (float21, float16), no need change"; + return nullptr; + } + + auto padding = CreatePadding(graph, origin_node, pad_dim_size); + auto unsorted_segment_sum8 = CreateUnsortedSegmentSum(graph, origin_node, padding, pad_dim_size); + return CreateSlice(graph, origin_node, unsorted_segment_sum8); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h new file mode 100644 index 000000000..6d47a3f40 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class UnsortSegmentSumFission : public PatternProcessPass { + public: + explicit UnsortSegmentSumFission(bool multigraph = true) + : PatternProcessPass("unsorted_segment_sum_fission", multigraph) {} + ~UnsortSegmentSumFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 16bdeb79f..f21637a19 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -98,6 +98,7 @@ constexpr size_t kTopkInputNum = 3; constexpr size_t kLarsV2InputNum = 5; constexpr size_t kFusedMulApplyMomentumOutputNum = 2; constexpr size_t kSplitInputNum = 2; +constexpr size_t kUnsortedSegmentSumInputNum = 2; enum FusedBatchNormInput { kX = 1, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index c1f551258..c8c9fc229 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -182,6 +182,7 @@ constexpr auto kPushOpName = "Push"; constexpr auto kPullOpName = "Pull"; constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; +constexpr auto kPaddingOpName = "Padding"; // attr key name constexpr auto kAttrInputNames = "input_names"; @@ -253,6 +254,10 @@ constexpr auto kAttrInputNums = "inputNums"; constexpr auto kAttrT = "T"; constexpr auto kAttrNum = "num"; constexpr auto kAttrRankSize = "rank_size"; +constexpr auto kAttrPadDimSize = "pad_dim_size"; +constexpr auto kAttrNumSegments = "num_segments"; +constexpr auto kAttrBegin = "begin"; +constexpr auto kAttrSize = "size"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/st/fusion/test_unsorted_segment_sum_fission.py b/tests/st/fusion/test_unsorted_segment_sum_fission.py new file mode 100644 index 000000000..628403b76 --- /dev/null +++ b/tests/st/fusion/test_unsorted_segment_sum_fission.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +context.set_context(save_graphs=True) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.unsorted_segment_sum = P.UnsortedSegmentSum() + self.num_segments = 3 + + def construct(self, x, segment_ids): + x = self.unsorted_segment_sum(x, segment_ids, self.num_segments) + return x + + +def test_net(): + input_x = np.random.randn(3, 39, 1).astype(np.float32) + segment_ids = Tensor([0, 1, 2], mindspore.int32) + net = Net() + output = net(Tensor(input_x), segment_ids) + print("result", output.asnumpy()) + + +if __name__ == "__main__": + test_net() diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc new file mode 100644 index 000000000..02a4aa69f --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +class TestHWUnsortedSegmentSumFission : public BackendCommon { + public: + TestHWUnsortedSegmentSumFission() : get_py_fun_("gtest_input.pre_activate.unsorted_segment_sum_fission", true) {} + ~TestHWUnsortedSegmentSumFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWUnsortedSegmentSumFission, test_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before1"); + EXPECT_NE(g, nullptr); + std::vector shp_x{16, 1}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "after1"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWUnsortedSegmentSumFission, test_no_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before2"); + EXPECT_NE(g, nullptr); + std::vector shp_x{16, 2}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "after2"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/unsorted_segment_sum_fission.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/unsorted_segment_sum_fission.py new file mode 100644 index 000000000..c7d6b946f --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/unsorted_segment_sum_fission.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') +unsorted_segment_sum = P.UnsortedSegmentSum() +num_segments = 4 +padding = Primitive('Padding') +op_slice = Primitive('Slice') +op_unsorted_segment_sum = Primitive('UnsortedSegmentSum') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_unsorted_segment_sum_fission(tag): + fns = FnDict() + + @fns + def before1(input0, input1): + x = unsorted_segment_sum(input0, input1, num_segments) + return x + + @fns + def after1(input0, input1): + x = padding(input0) + x = op_unsorted_segment_sum(x, input1) + x = op_slice(x) + return make_tuple(x) + + @fns + def before2(input0, input1): + x = unsorted_segment_sum(input0, input1, num_segments) + return x + + @fns + def after2(input0, input1): + x = op_unsorted_segment_sum(input0, input1) + return make_tuple(x) + + return fns[tag] -- GitLab