From d22f5064314dd80eaaca8bb6b041f19beea3b6d7 Mon Sep 17 00:00:00 2001 From: lichenever Date: Thu, 3 Sep 2020 19:06:42 +0800 Subject: [PATCH] add BatchNormEx op --- .../parallel/ops_info/gather_v2_p_info.h | 14 ++-- .../frontend/parallel/ops_info/ops_utils.h | 1 + .../frontend/parallel/step_auto_parallel.cc | 2 +- .../ccsrc/frontend/parallel/step_parallel.cc | 3 +- mindspore/ccsrc/utils/comm_manager.cc | 2 +- .../test_batchnorm_ex_batch_parallel.py | 76 +++++++++++++++++++ .../python/parallel/test_sparse_gather_v2.py | 16 ---- 7 files changed, 86 insertions(+), 28 deletions(-) create mode 100644 tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index bfbb6b092..899ba73db 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -32,12 +32,13 @@ namespace parallel { class GatherV2PInfo : public OperatorInfo { public: GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) + const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), axis_(0), bias_(0), index_offset_(0), - slice_size_(0) {} + slice_size_(0), + replace_op_name_(replace_op_name) {} ~GatherV2PInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; @@ -69,10 +70,10 @@ class GatherV2PInfo : public OperatorInfo { int32_t axis_; std::string target_ = DEVICE; - std::string replace_op_name_ = GATHERV2; int64_t bias_; int64_t index_offset_; int64_t slice_size_; + std::string replace_op_name_ = GATHERV2; Shape out_dev_matrix_shape_; Group group_; bool manual_split_ = false; @@ -83,12 +84,9 @@ class GatherV2PInfo : public OperatorInfo { class SparseGatherV2Info : public GatherV2PInfo { public: SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} ~SparseGatherV2Info() override = default; - - private: - std::string replace_op_name_ = SPARSE_GATHERV2; }; class EmbeddingLookupInfo : public GatherV2PInfo { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 5bdc94e12..3adf234ee 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -197,6 +197,7 @@ constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue"; constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; constexpr char CONV2D[] = "Conv2D"; constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; +constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx"; constexpr char BATCH_NORM[] = "BatchNorm"; constexpr char LAYER_NORM[] = "LayerNorm"; constexpr char POOLING[] = "Pooling"; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 8944fa011..295bc2733 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -263,7 +263,7 @@ bool IsSplittableOperator(const std::string &op_name) { LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, - EMBEDDING_LOOKUP}; + EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index b69b83f25..b1cf7793b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -570,8 +570,7 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; - auto prim = GetValueNode(node->input(0)); - if (prim->name() == EMBEDDING_LOOKUP) { + if (replace_op.first == EMBEDDING_LOOKUP) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } if (!params.empty()) { diff --git a/mindspore/ccsrc/utils/comm_manager.cc b/mindspore/ccsrc/utils/comm_manager.cc index de165c4aa..6411479f7 100644 --- a/mindspore/ccsrc/utils/comm_manager.cc +++ b/mindspore/ccsrc/utils/comm_manager.cc @@ -40,7 +40,7 @@ CommManager &CommManager::GetInstance() noexcept { #define HCCL_RUN_CHECK(op_name, group, op) \ do { \ auto hccl_result = (op); \ - if (hccl_result != tagHcclResult::HCCL_SUCCESS) { \ + if (hccl_result != 0) { \ MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \ return false; \ } \ diff --git a/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py b/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py new file mode 100644 index 000000000..4655ec89a --- /dev/null +++ b/tests/ut/python/parallel/test_batchnorm_ex_batch_parallel.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.common.parameter import Parameter +from mindspore.ops import composite as C +from mindspore.ops import operations as P +import mindspore.nn as nn +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y, b): + predict = self.network(x, y, b) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y, b): + return grad_all(self.network)(x, y, b) + + +# model_parallel test +def test_two_matmul_batchnorm_ex(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul1 = P.MatMul().set_strategy(strategy1) + self.norm = P.FusedBatchNormEx() + self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma") + self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta") + self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean") + self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var") + self.matmul2 = P.MatMul().set_strategy(strategy2) + + def construct(self, x, y, b): + out = self.matmul1(x, y) + out = self.norm(out, self.gamma, self.beta, self.mean, self.var)[0] + out = self.matmul2(out, b) + return out + + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8) + strategy1 = ((4, 2), (2, 1)) + strategy2 = ((1, 8), (8, 1)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + net.set_auto_parallel() + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + b = Tensor(np.ones([64, 64]), dtype=ms.float32) + _executor.compile(net, x, y, b) diff --git a/tests/ut/python/parallel/test_sparse_gather_v2.py b/tests/ut/python/parallel/test_sparse_gather_v2.py index 30957dcf9..da1c58917 100644 --- a/tests/ut/python/parallel/test_sparse_gather_v2.py +++ b/tests/ut/python/parallel/test_sparse_gather_v2.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ import numpy as np -import pytest import mindspore as ms import mindspore.nn as nn @@ -158,18 +157,6 @@ def test_gatherv2_semi_auto7(): _executor.compile(net, x, y) -def test_gatherv2_semi_auto8(): - context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((8,), (1, 1)) - strategy2 = ((4, 2), (4, 2)) - net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) - net.set_auto_parallel() - - x = Tensor(np.ones([64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64]), dtype=ms.float32) - _executor.compile(net, x, y) - - def test_gatherv2_auto0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") net = GradWrap(NetWithLoss(Net(0))) @@ -188,7 +175,6 @@ def test_gatherv2_auto1(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((8, 1), (1, 1)) @@ -201,7 +187,6 @@ def test_gatherv2_cpu0(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu1(): context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((16, 1), (1, 1)) @@ -214,7 +199,6 @@ def test_gatherv2_cpu1(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((1, 8), (1, 1)) -- GitLab