diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index bfbb6b092cae55a7dd684a62ae20be931aeb5436..899ba73db755bb61afb0bdff20865fd2a290344e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -32,12 +32,13 @@ namespace parallel { class GatherV2PInfo : public OperatorInfo { public: GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) + const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), axis_(0), bias_(0), index_offset_(0), - slice_size_(0) {} + slice_size_(0), + replace_op_name_(replace_op_name) {} ~GatherV2PInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; @@ -69,10 +70,10 @@ class GatherV2PInfo : public OperatorInfo { int32_t axis_; std::string target_ = DEVICE; - std::string replace_op_name_ = GATHERV2; int64_t bias_; int64_t index_offset_; int64_t slice_size_; + std::string replace_op_name_ = GATHERV2; Shape out_dev_matrix_shape_; Group group_; bool manual_split_ = false; @@ -83,12 +84,9 @@ class GatherV2PInfo : public OperatorInfo { class SparseGatherV2Info : public GatherV2PInfo { public: SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} ~SparseGatherV2Info() override = default; - - private: - std::string replace_op_name_ = SPARSE_GATHERV2; }; class EmbeddingLookupInfo : public GatherV2PInfo { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 5bdc94e12dce7175e7b5271f48b9195af3684500..3adf234eeb228d6f596ed464d0f7533696ade146 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -197,6 +197,7 @@ constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue"; constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; constexpr char CONV2D[] = "Conv2D"; constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; +constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx"; constexpr char BATCH_NORM[] = "BatchNorm"; constexpr char LAYER_NORM[] = "LayerNorm"; constexpr char POOLING[] = "Pooling"; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 8944fa011eb2062dd9012b1a7e000ab52b933d0e..295bc27333c409ca0afd97397fa79d602c95817f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -263,7 +263,7 @@ bool IsSplittableOperator(const std::string &op_name) { LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, - EMBEDDING_LOOKUP}; + EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index b69b83f257607b8a40a339f6c762666292187bab..b1cf7793b6e3a4ac5918ec054a09fb38d1c87a0e 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -570,8 +570,7 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; - auto prim = GetValueNode(node->input(0)); - if (prim->name() == EMBEDDING_LOOKUP) { + if (replace_op.first == EMBEDDING_LOOKUP) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } if (!params.empty()) { diff --git a/mindspore/ccsrc/utils/comm_manager.cc b/mindspore/ccsrc/utils/comm_manager.cc index de165c4aac681626ebfcdc304ac73dc0a0945e55..6411479f77c5f6180f99b8493d4c6a0e93f434d1 100644 --- a/mindspore/ccsrc/utils/comm_manager.cc +++ b/mindspore/ccsrc/utils/comm_manager.cc @@ -40,7 +40,7 @@ CommManager &CommManager::GetInstance() noexcept { #define HCCL_RUN_CHECK(op_name, group, op) \ do { \ auto hccl_result = (op); \ - if (hccl_result != tagHcclResult::HCCL_SUCCESS) { \ + if (hccl_result != 0) { \ MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \ return false; \ } \ diff --git a/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py b/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..4655ec89acab7683cf0f6462f4f87c5c32109f4f --- /dev/null +++ b/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.common.parameter import Parameter +from mindspore.ops import composite as C +from mindspore.ops import operations as P +import mindspore.nn as nn +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y, b): + predict = self.network(x, y, b) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y, b): + return grad_all(self.network)(x, y, b) + + +# model_parallel test +def test_two_matmul_batchnorm_ex(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul1 = P.MatMul().set_strategy(strategy1) + self.norm = P.FusedBatchNormEx() + self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma") + self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta") + self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean") + self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var") + self.matmul2 = P.MatMul().set_strategy(strategy2) + + def construct(self, x, y, b): + out = self.matmul1(x, y) + out = self.norm(out, self.gamma, self.beta, self.mean, self.var)[0] + out = self.matmul2(out, b) + return out + + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8) + strategy1 = ((4, 2), (2, 1)) + strategy2 = ((1, 8), (8, 1)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + net.set_auto_parallel() + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + b = Tensor(np.ones([64, 64]), dtype=ms.float32) + _executor.compile(net, x, y, b) diff --git a/tests/ut/python/parallel/test_sparse_gather_v2.py b/tests/ut/python/parallel/test_sparse_gather_v2.py index 30957dcf9dc3c3384b70dd1e81de27caac7c84ad..da1c5891792820fadbe9b01f598d3ac272ab1753 100644 --- a/tests/ut/python/parallel/test_sparse_gather_v2.py +++ b/tests/ut/python/parallel/test_sparse_gather_v2.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ import numpy as np -import pytest import mindspore as ms import mindspore.nn as nn @@ -158,18 +157,6 @@ def test_gatherv2_semi_auto7(): _executor.compile(net, x, y) -def test_gatherv2_semi_auto8(): - context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((8,), (1, 1)) - strategy2 = ((4, 2), (4, 2)) - net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) - net.set_auto_parallel() - - x = Tensor(np.ones([64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64]), dtype=ms.float32) - _executor.compile(net, x, y) - - def test_gatherv2_auto0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") net = GradWrap(NetWithLoss(Net(0))) @@ -188,7 +175,6 @@ def test_gatherv2_auto1(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((8, 1), (1, 1)) @@ -201,7 +187,6 @@ def test_gatherv2_cpu0(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu1(): context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((16, 1), (1, 1)) @@ -214,7 +199,6 @@ def test_gatherv2_cpu1(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((1, 8), (1, 1))