diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index c3e3f5893ea883a9bceae1fd5bdd99665ec6c733..400e985f18445b30e427af18bc449465ddcfa7df 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -346,6 +346,8 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { } OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(prim); + MS_EXCEPTION_IF_NULL(cnode); auto attrs = prim->attrs(); std::vector shape_list = ExtractShape(cnode); if (shape_list.empty()) { @@ -381,8 +383,8 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_cnode(cnode); // If no strategy has been configured for this operator, then candidate strategies are generated for - // auto-strategy searching - if (!StrategyFound(attrs)) { + // auto-strategy searchingm if this primitive is Cast, we ignore the user-specified strategy + if (!StrategyFound(attrs) || prim->name() == CAST) { // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // BatchParallelInfo operator operator_info->ComputeBatchSplitFlagList(); diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index ad194a7961aa31a9cd8f0e210f607753c5b67bba..e2b4d55aadae1738ef2d6907d95f3d7d6b3db453 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -371,7 +371,6 @@ bool IsParallelCareNode(const CNodePtr& cnode) { if (prim == nullptr) { return false; } - auto attrs = prim->attrs(); if (IsInBlackList(prim)) { MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); return false; @@ -380,10 +379,8 @@ bool IsParallelCareNode(const CNodePtr& cnode) { if (prim->name() == GET_NEXT) { return true; } - if ((prim->name() == CAST)) { - if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) { - return false; - } + if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { + return false; } return cnode->in_forward_flag(); @@ -654,6 +651,14 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { LossNodeInfo node_info; + // return -> cast + auto pre_cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + auto pre_prim = GetValueNode(pre_cnode->input(0)); + if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_node = pre_cnode->input(1); + } + // return -> loss if (pre_node == loss_node) { node_info.has_tuple_getitem = false; @@ -1948,6 +1953,14 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { MS_EXCEPTION_IF_NULL(current_value); PrimitivePtr current_prim = current_value->value()->cast(); MS_EXCEPTION_IF_NULL(current_prim); + + // return -> cast + if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_cnode = pre_cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + current_prim = GetValueNode(pre_cnode->input(0)); + } + // notice: the GetNext op has not input if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { MS_LOG(INFO) << "The loss is: " << current_prim->name(); diff --git a/tests/ut/python/parallel/test_element_wise_function.py b/tests/ut/python/parallel/test_element_wise_function.py index dfcebdc5abc20a85e063517267d2b4ac1e823d63..2eb3a22ed2a6d1e45c8b45adf402e659dd924b18 100644 --- a/tests/ut/python/parallel/test_element_wise_function.py +++ b/tests/ut/python/parallel/test_element_wise_function.py @@ -192,7 +192,6 @@ def test_cast_before_mirror(): net = GradWrap(NetWithLoss(Net(strategy1))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - x = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float16) @@ -217,7 +216,6 @@ def test_cast_before_mirror1(): net = GradWrap(NetWithLoss(Net(strategy1))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - x = Tensor(np.ones([128, 32]), dtype=ms.float16) y = Tensor(np.ones([32, 64]), dtype=ms.float16) b = Tensor(np.ones([64, 64]), dtype=ms.float32) @@ -242,7 +240,6 @@ def test_cast_before_mirror2(): net = GradWrap(NetWithLoss(Net(strategy1))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - x = Tensor(np.ones([128, 32]), dtype=ms.float16) y = Tensor(np.ones([32, 64]), dtype=ms.float16) b = Tensor(np.ones([64, 64]), dtype=ms.float32) @@ -267,8 +264,36 @@ def test_cast_before_mirror3(): net = GradWrap(NetWithLoss(Net(strategy1))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - x = Tensor(np.ones([128, 32]), dtype=ms.float16) y = Tensor(np.ones([32, 64]), dtype=ms.float16) b = Tensor(np.ones([64, 64]), dtype=ms.float32) _executor.compile(net, x, y, b) + + +def test_mul_two_cast(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2, strategy3): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.mul2 = P.Mul().set_strategy(strategy2) + self.cast = P.Cast().set_strategy(strategy3) + self.cast2 = P.Cast().set_strategy(strategy3) + + def construct(self, x, y, b): + out = self.mul(x, y) + out = self.mul2(out, b) + out = self.cast(out, ms.int32) + out = self.cast2(out, ms.bool_) + return out + + context.set_auto_parallel_context(device_num=8, global_rank=0) + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((8, 1), (8, 1)) + strategy3 = ((8, 1), ) + net = GradWrap(Net(strategy1, strategy2, strategy3)) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([128, 32]), dtype=ms.float32) + b = Tensor(np.ones([128, 32]), dtype=ms.float32) + _executor.compile(net, x, y, b)