diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 4fd5f34cf250cf39a7db74529be4b4d8b50c7683..f8e1d62d0ab3252bbf213bbfcbddae8d63792e10 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -121,6 +121,7 @@ REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); REGISTER(AssignSubInfo); REGISTER(ReLUInfo); REGISTER(GatherV2Info); +REGISTER(SparseGatherV2Info); REGISTER(SqrtInfo); REGISTER(SigmoidInfo); REGISTER(GetNextInfo); diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index 3d9470e7d8ee91c3db50dff250e69adf3bf55dab..1c40350e6ad7ca0a333df8c0dce61949f7aa122e 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -399,7 +399,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); auto gather_v2 = - gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); + gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index 22aff16b493c2837fc396ca88bf4b9706fa12c8a..b139ee215cea3a04fa3282ce140ed6cba28902f3 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -63,6 +63,7 @@ class GatherV2PInfo : public OperatorInfo { int32_t axis_; std::string target_; + std::string replace_op_name_ = GATHERV2; int32_t bias_; int32_t slice_size_; Shape out_dev_matrix_shape_; @@ -70,6 +71,17 @@ class GatherV2PInfo : public OperatorInfo { bool reduce_scatter_flag_ = false; int32_t split_num_ = 1; }; + +class SparseGatherV2Info : public GatherV2PInfo { + public: + SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + ~SparseGatherV2Info() override = default; + + private: + std::string replace_op_name_ = SPARSE_GATHERV2; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 1110bedc3fb43620506756f045cbf0af42234348..bc0d669baa4c93dfa33effa0eac173a07f12442b 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -205,6 +205,7 @@ constexpr char EQUAL[] = "Equal"; constexpr char NOT_EQUAL[] = "NotEqual"; constexpr char LOGICALNOT[] = "LogicalNot"; constexpr char GATHERV2[] = "GatherV2"; +constexpr char SPARSE_GATHERV2[] = "SparseGatherV2"; constexpr char STRIDEDSLICE[] = "StridedSlice"; constexpr char BROADCAST[] = "Broadcast"; constexpr char SQRT[] = "Sqrt"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 7d1ff623b9e77398f08fc9ab242bc10b07e3d389..429241c8b73521ee7225290573aa4229b75b2195 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -261,7 +261,7 @@ bool IsSplittableOperator(const std::string &op_name) { REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, - STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, + STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; // clang-format on diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 166ce6b03827356863f59c634e6588a9f4969574..4528ff8639c7e5ea6cd412a7fecd847c3b1fa11d 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -535,7 +535,7 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; auto prim = GetValueNode(node->input(0)); - if (prim->name() == GATHERV2) { + if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)}; } if (!params.empty()) { diff --git a/tests/ut/python/parallel/test_sparse_gather_v2.py b/tests/ut/python/parallel/test_sparse_gather_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0517a08ecfcb43323bfca6874b04560f058c90 --- /dev/null +++ b/tests/ut/python/parallel/test_sparse_gather_v2.py @@ -0,0 +1,220 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return C.grad_all(self.network)(x, y) + + +class Net(nn.Cell): + def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""): + super().__init__() + if shape is None: + shape = [64, 64] + self.gatherv2 = P.SparseGatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target) + self.mul = P.Mul().set_strategy(strategy2) + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.axis = axis + + def construct(self, x, y): + out = self.gatherv2(x, self.index, self.axis) + out = self.mul(out, y) + return out + + +def test_gatherv2_semi_auto0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto3(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto4(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto5(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto6(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, None, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto7(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, None, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto8(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8,), (1, 1)) + strategy2 = ((4, 2), (4, 2)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_auto0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss(Net(0))) + net.set_auto_parallel() + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss(Net(1))) + net.set_auto_parallel() + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu1(): + context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((16, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y)