diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 91d1461803b739af4511c53b13b3d950d8c8cef2..17a622855258ed0da6b664c6760c13b9f4a5ba40 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -345,7 +345,6 @@ bool FindCommunicationOp(const std::vector &all_nodes) { continue; } auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { continue; } @@ -903,9 +902,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { } } -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node) { +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, + const std::vector> &sens_loss_pairs) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(node); + + bool is_loss_cnode = + std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), + [node](const std::pair &element) { return element.second == node; }); + MirrorOps mirror_ops = distribute_operator->mirror_ops(); VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); // insert mirror op @@ -914,7 +919,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo InsertMirrorOps(mirror_ops, node); } // insert virtual div op - if (!virtual_div_op.empty() && is_loss_node) { + if (!virtual_div_op.empty() && is_loss_cnode) { MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name(); InsertVirtualDivOp(virtual_div_op, node); } @@ -986,10 +991,6 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { Dimensions dim; if (elements[index]->isa()) { ValueTuplePtr value_tuple = elements[index]->cast(); - if (value_tuple == nullptr) { - MS_LOG(EXCEPTION) << "Failure:value_tuple is nullptr"; - } - std::vector value_vector = value_tuple->value(); (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), [](const ValuePtr &value) { return static_cast(GetValue(value)); }); @@ -1013,7 +1014,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) { BaseShapePtr base_shape_ptr = node->Shape(); if (node->isa()) { auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); if (IsValueNode(cnode->input(0))) { PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); @@ -1190,7 +1190,7 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNode continue; } CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast(); - if ((graph_cnode_inp0 == nullptr) || !IsValueNode(graph_cnode_inp0->input(1))) { + if (!IsValueNode(graph_cnode_inp0->input(1))) { continue; } FuncGraphPtr graph_sub = GetValueNode(graph_cnode_inp0->input(1)); @@ -1692,14 +1692,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { return pre_cnode; } -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { TensorLayouts ret; - if (!IsValueNode(cnode->input(1))) { - MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; - } - auto func_graph = GetValueNode(cnode->input(1)); - auto loss_cnode = FindLossCNode(func_graph); MS_EXCEPTION_IF_NULL(loss_cnode); AnfNodePtr node = loss_cnode->cast(); MS_EXCEPTION_IF_NULL(node); @@ -1735,16 +1729,16 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) { return ret; } -void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { +void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { MS_EXCEPTION_IF_NULL(grad_sens_node); - - auto cnode = grad_sens_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - AnfNodePtr sens_tensor_node = cnode->input(1); + if (grad_sens_node->size() <= 1) { + MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2"; + } + AnfNodePtr sens_tensor_node = grad_sens_node->input(1); MS_EXCEPTION_IF_NULL(sens_tensor_node); Shapes sens_shapes = GetNodeShape(sens_tensor_node); if (sens_shapes.size() != 1) { - MS_LOG(EXCEPTION) << "SplitSens: GetNodeShape for sens_tensor_node, output size is not 1"; + MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1"; } // If the shape of sens tensor is [] or [1], no need to split it. Shape sens_shape = sens_shapes[0]; @@ -1780,14 +1774,14 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); return; } - MS_LOG(EXCEPTION) << "SplitSens: the type of sens node is not Tensor or Parameter, it is unsupported now."; + MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; } // Use _GetTensorSlice operator to split the sens tensor - FuncGraphPtr func_graph = cnode->func_graph(); // only cnode can get the graph + FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph MS_EXCEPTION_IF_NULL(func_graph); Operator op = CreateGetTensorSliceOp(loss_grad_layout); - InsertGetTensorSliceOp(op, cnode, func_graph, 1, SPLIT_SENS); + InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS); } void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { @@ -1853,7 +1847,6 @@ std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no } auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { continue; } @@ -1870,55 +1863,12 @@ std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no return graph_set; } -// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -void StepSplitSens(const AnfNodePtr &node) { - if (!node->isa()) { - return; - } - - // cnode(sens)-->cnode(tuple_getitem) - auto cnode = node->cast(); - AnfNodePtr expect_tuple_getitem = cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem); - if (!expect_tuple_getitem->isa()) { - return; - } - auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); - if (!IsValueNode(expect_tuple_getitem_cnode->input(0))) { - return; - } - auto expect_tuple_getitem_prim = GetValueNode(expect_tuple_getitem_cnode->input(0)); - if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { - return; - } - - // cnode(sens)-->cnode(tuple_getitem)-->cnode - AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); - MS_EXCEPTION_IF_NULL(expect_anonymous); - if (!expect_anonymous->isa()) { - return; - } - - // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) - auto expect_anonymous_cnode = expect_anonymous->cast(); - MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); - AnfNodePtr expect_j = expect_anonymous_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_j); - if (!expect_j->isa()) { - return; - } - auto expect_j_cnode = expect_j->cast(); - MS_EXCEPTION_IF_NULL(expect_j_cnode); - if (!IsValueNode(expect_j_cnode->input(0))) { - return; - } - auto expect_j_prim = GetValueNode(expect_j_cnode->input(0)); - if (expect_j_prim->name() == J) { - auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode); - if (!loss_grad_layout.empty()) { - SplitSens(node, loss_grad_layout[0]); - } +void StepSplitSens(const std::pair &sens_loss_pair) { + CNodePtr sens_node = sens_loss_pair.first; + CNodePtr loss_node = sens_loss_pair.second; + auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); + if (!loss_grad_layout.empty()) { + SplitSens(sens_node, loss_grad_layout[0]); } } @@ -1937,26 +1887,77 @@ std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root) { return loss_node; } +// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) +std::vector> GetSensLossPairs(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + std::vector> sens_loss_pairs; + for (auto &node : root->nodes()) { + if (!node->isa()) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem) + auto sens_cnode = node->cast(); + AnfNodePtr expect_tuple_getitem = sens_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem); + if (!expect_tuple_getitem->isa()) { + continue; + } + + auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); + if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode + AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); + MS_EXCEPTION_IF_NULL(expect_anonymous); + if (!expect_anonymous->isa()) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) + auto expect_anonymous_cnode = expect_anonymous->cast(); + AnfNodePtr expect_j = expect_anonymous_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_j); + if (!expect_j->isa()) { + continue; + } + auto expect_j_cnode = expect_j->cast(); + if (!IsSomePrimitive(expect_j_cnode, J)) { + continue; + } + + if (!IsValueNode(expect_j_cnode->input(1))) { + MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; + } + auto func_graph = GetValueNode(expect_j_cnode->input(1)); + auto loss_cnode = FindLossCNode(func_graph); + std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); + sens_loss_pairs.push_back(sens_loss_pair); + } + return sens_loss_pairs; +} + void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(manager); TensorRedistribution tensor_redistribution; - AnfNodePtr grad_sens_node = nullptr; - std::vector loss_cnode = FindLossCNodeFromRoot(root); + std::vector> sens_loss_pairs = GetSensLossPairs(root); + bool has_backward = !sens_loss_pairs.empty(); // split sens must before inserting the operators. - for (auto &node : all_nodes) { + for (auto &pair : sens_loss_pairs) { // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. // If the type of sens node is not Tensor, it is unsupported now, do nothing default. - StepSplitSens(node); + StepSplitSens(pair); } for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { continue; } @@ -1965,11 +1966,6 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes) { continue; } auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { continue; } @@ -2117,7 +2114,6 @@ void SetForwardFlag(const AnfNodeSet &all_nodes) { continue; } auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { continue; } @@ -2146,7 +2142,6 @@ std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const An continue; } auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); auto root_node_id = node->UniqueIdThroughCopy(); if (loss_cnode_id == root_node_id) { root_forward_nodes = DeepLinkedGraphSearch(cnode); diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index b0d128f515158fe263fc1cc9a1618d396d8c9450..745794912b9a6f688d450a74e48b8bcdf5a2b05f 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -82,7 +82,8 @@ std::pair FindCNode(const AnfNodePtr &anode, const std::string & void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node); +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, + const std::vector> &sens_loss_pairs); // Generate and init parallel operator OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, diff --git a/tests/ut/python/parallel/test_forward_graph.py b/tests/ut/python/parallel/test_forward_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..76cd5b417895f9702280da9f7a378aecc85746e0 --- /dev/null +++ b/tests/ut/python/parallel/test_forward_graph.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell +from mindspore.ops import operations as P +from mindspore.common.api import _executor + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.neg = P.Neg().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.neg(out) + return out, b + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + + +def compile(net): + _executor.compile(net, _x, _b) + context.reset_auto_parallel_context() + + +def test_forward_graph_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((16, 1, 1), ) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_forward_graph_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 1, 16), (1, 1, 16)) + strategy2 = ((1, 1, 16), ) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_forward_graph_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((2, 2, 4), ) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_forward_graph_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile(net) + + +def test_forward_graph_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((1, 2, 2), ) + net = Net(_w1, strategy1, strategy2) + compile(net) +