diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 6fd84dd3641eca4dc54b757bf73e48bf624aaf9a..9705862a4401e14064f7d9f4a640c9e405a9fb8f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -649,108 +649,13 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { MS_LOG(INFO) << "Constructing edges for cost graph ends."; } -std::pair> CNodeWithRefKeys(const AnfNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector refkeys; - if (cnode->isa()) { - auto cnode_ptr = cnode->cast(); - auto inputs = cnode_ptr->inputs(); - for (auto &one_input : inputs) { - if (IsValueNode(one_input)) { - refkeys.push_back(one_input); - } - } - if (refkeys.size() >= 1) { - return std::make_pair(cnode, refkeys); - } - } - return {nullptr, refkeys}; -} - void AugmentCostGraph(const std::vector &all_nodes) { // Step 3 for (auto &node : all_nodes) { - auto cnode_with_refkeys = CNodeWithRefKeys(node); - if ((!node->isa()) && (cnode_with_refkeys.first == nullptr)) { - continue; - } - std::string parameter_name; - AnfNodePtr target_parameter = nullptr; - AnfNodeIndexSet target_set; - - if (cnode_with_refkeys.first != nullptr) { - // Dealing with the RefKey case - auto refkeys = cnode_with_refkeys.second; - auto cnode = cnode_with_refkeys.first; - - auto cnode_ptr = cnode->cast(); - if (cnode_ptr == nullptr || !IsValueNode(cnode_ptr->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(cnode_ptr)) { - continue; - } - - if (refkeys.size() > 1) { - MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; - } - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - auto cnode_func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); - - // Find the RefKey being used - auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; - for (auto &candidate : candidate_set_by_refkey) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - target_set.add(candidate); - } - - // Find the corresponding Parameter being used - std::vector parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - parameter_name = parameters[0]->cast()->name(); - target_parameter = parameters[0]; - auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; - for (auto &candidate : candidate_set_by_para) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - (void)target_set.insert(candidate); - } - } else if (node->isa()) { - // Dealing with the Parameter case - MS_EXCEPTION_IF_NULL(node->func_graph()); - MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); - auto candidate_set = node->func_graph()->manager()->node_users()[node]; - for (auto &candidate : candidate_set) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - (void)target_set.insert(candidate); - } - // In this case, node is a Parameter - parameter_name = node->cast()->name(); - target_parameter = node; - } + ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode); + auto parameter_name = parameter_users_info.first; + auto target_parameter = parameter_users_info.second.first; + auto target_set = parameter_users_info.second.second; if (target_set.size() <= 1) { continue; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 38fdf4e55208ea98f4611ef69d3742fed667b0da..1d046c04b0c1e2fc29c42b72d78ad58da7ca6a04 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2499,6 +2499,149 @@ void HandleForwardMakeTuple(const std::vector &all_nodes) { } } +RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector refkeys; + if (cnode->isa()) { + auto cnode_ptr = cnode->cast(); + auto inputs = cnode_ptr->inputs(); + for (auto &one_input : inputs) { + if (IsValueNode(one_input)) { + refkeys.push_back(one_input); + } + } + if (refkeys.size() >= 1) { + return std::make_pair(cnode, refkeys); + } + } + return {nullptr, refkeys}; +} + +ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) { + // In this case, node is a Parameter + ParameterUsersInfo parameter_user_info; + MS_EXCEPTION_IF_NULL(node->func_graph()); + MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); + auto candidate_set = node->func_graph()->manager()->node_users()[node]; + for (auto &candidate : candidate_set) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if (c == nullptr || !IsValueNode(c->input(0)) || !IsCareNode(c)) { + continue; + } + (void)parameter_user_info.second.second.insert(candidate); + } + + parameter_user_info.first = node->cast()->name(); + parameter_user_info.second.first = node; + return parameter_user_info; +} + +ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) { + // Dealing with the RefKey case + ParameterUsersInfo parameter_user_info; + auto refkeys = ref_key_pair.second; + auto cnode = ref_key_pair.first; + + auto cnode_ptr = cnode->cast(); + if ((cnode_ptr == nullptr) || !IsValueNode(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) { + return parameter_user_info; + } + + if (refkeys.size() > 1) { + MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys"; + } + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + auto cnode_func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); + + // Find the RefKey being used + auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; + for (auto &candidate : candidate_set_by_refkey) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if ((c == nullptr) || !IsValueNode(c->input(0)) || !IsCareNode(c)) { + continue; + } + parameter_user_info.second.second.add(candidate); + } + + // Find the corresponding Parameter being used + std::vector parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + parameter_user_info.first = parameters[0]->cast()->name(); + parameter_user_info.second.first = parameters[0]; + auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; + for (auto &candidate : candidate_set_by_para) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if ((c == nullptr) || !IsValueNode(c->input(0)) || !IsCareNode(c)) { + continue; + } + (void)parameter_user_info.second.second.insert(candidate); + } + return parameter_user_info; +} + +ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) { + ParameterUsersInfo parameter_users_info; + auto cnode_with_refkeys = CNodeWithRefKeys(node); + + if (cnode_with_refkeys.first != nullptr) { + // the node is a ref key node + return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode); + } else if (node->isa()) { + // the node is a parameter node + return FindParameterNodeUsers(node, IsCareNode); + } + + return parameter_users_info; +} + +Shape ParameterSliceShape(const std::pair ¶m_info) { + auto user_cnode = param_info.first->cast(); + MS_EXCEPTION_IF_NULL(user_cnode); + auto user_input_index = param_info.second; + OperatorInfoPtr op_info = user_cnode->user_data(); + MS_EXCEPTION_IF_NULL(op_info); + + size_t input_tensor_info_size = op_info->inputs_tensor_info().size(); + if (SizeToInt(input_tensor_info_size) <= user_input_index - 1) { + MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size + << ", but the index is " << user_input_index - 1; + } + TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1]; + MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1 + << ", the slice shape is " << ShapeToString(tensor_info.slice_shape()) << ", the origin shape is " + << ShapeToString(tensor_info.shape()); + return tensor_info.slice_shape(); +} + +void CheckParameterSplit(const std::vector &all_nodes) { + for (auto &node : all_nodes) { + ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode); + auto users_set = parameter_users_info.second.second; + if (users_set.size() <= 1) { + continue; + } + + auto parameter_name = parameter_users_info.first; + MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users"; + auto first_user = users_set.pop(); + Shape first_user_slice_shape = ParameterSliceShape(first_user); + + for (auto &user : users_set) { + Shape user_slice_shape = ParameterSliceShape(user); + if (first_user_slice_shape != user_slice_shape) { + MS_LOG(EXCEPTION) << "The parameter: " << parameter_name + << " has multiple users, but the split strategies are different"; + } + } + } +} + bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(optimizer); @@ -2556,6 +2699,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) HandleForwardMakeTuple(all_nodes); + // if the input or parameter has multiple users, check whether its split strategies are consistent. + CheckParameterSplit(all_nodes); + // save strategy as checkpoint for multi-train if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { CheckpointStrategy(all_nodes); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index a90d740d58c83f874ba79c0439a705dc19c24eb1..ca049d1704dcbee258458b52bb8e8b91f1cd0ceb 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -150,6 +150,13 @@ std::vector ExtractInputsTensorName(const CNodePtr &node); std::set ForwardGraph(const FuncGraphPtr &root); bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); + +using RefKeyPair = std::pair>; +using ParameterUsersInfo = std::pair>; + +RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode); + +ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_auto_parallel_reshape.py b/tests/ut/python/parallel/test_auto_parallel_reshape.py index 2f4c4efb6e25e8c1d97756a18c354fbf3bd7434c..b2606505e65ffb5f5fe2854bca829f599d19220b 100644 --- a/tests/ut/python/parallel/test_auto_parallel_reshape.py +++ b/tests/ut/python/parallel/test_auto_parallel_reshape.py @@ -245,51 +245,3 @@ def test_reshape_auto_5(): context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() _executor.compile(net, x, y) - - -def test_reshape_auto_6(): - class NetWithLoss6(nn.Cell): - def __init__(self, network): - super(NetWithLoss6, self).__init__() - self.loss = VirtualLoss() - self.network = network - - def construct(self, x, y): - predict = self.network(x, y) - return self.loss(predict) - - class GradWrap6(nn.Cell): - def __init__(self, network): - super(GradWrap6, self).__init__() - self.network = network - - def construct(self, x, y): - return grad_all(self.network)(x, y) - - class Net(nn.Cell): - def __init__(self): - super().__init__() - self.relu = P.ReLU() - self.mul = P.Mul() - self.reshape = P.Reshape() - self.reduce_mean = P.ReduceMean() - self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight") - - def construct(self, x, y): - out1 = x + self.wide_w - w = self.reshape(self.wide_w, (4, 1024)) - out1 = self.reduce_mean(out1, 1) - out1 = out1 - w - out2 = self.mul(y, w) - out = out1 + out2 - return out - - size = 8 - context.set_auto_parallel_context(device_num=size, global_rank=0) - x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32) - y = Tensor(np.ones([4, 1024,]), dtype=ms.float32) - - net = GradWrap6(NetWithLoss6(Net())) - context.set_auto_parallel_context(parallel_mode="auto_parallel") - net.set_auto_parallel() - _executor.compile(net, x, y) diff --git a/tests/ut/python/parallel/test_parameter_multi_users.py b/tests/ut/python/parallel/test_parameter_multi_users.py new file mode 100644 index 0000000000000000000000000000000000000000..e977966eb172d812fc9b54c2fb0d50c55e57fe81 --- /dev/null +++ b/tests/ut/python/parallel/test_parameter_multi_users.py @@ -0,0 +1,94 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.mul2 = P.Mul().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.mul2(out, self.mul_weight) + return out + + +class Net2(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.mul2 = P.Mul().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.mul2(x, out) + return out + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_parameter_same_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((16, 1, 1), (16, 1, 1)) + net = Net(_w, strategy1, strategy2) + compile_net(net) + + +def test_parameter_different_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((4, 4, 1), (4, 4, 1)) + net = Net(_w, strategy1, strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_input_same_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((16, 1, 1), (16, 1, 1)) + net = Net(_w, strategy1, strategy2) + compile_net(net) + + +def test_input_different_split(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((4, 4, 1), (4, 4, 1)) + net = Net2(_w, strategy1, strategy2) + with pytest.raises(RuntimeError): + compile_net(net)