diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index 530f67ba4534fa7cd32db7d086279aed90ef16f7..31de9f4456f0d39d37351c0fa6a8c07d160ae03f 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -44,6 +44,7 @@ namespace parallel { #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 #define DEFAULT_FULLY_USE_DEVICES true #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false +#define DEFAULT_IS_MULTI_SUBGRAPHS false class CostGraph; using CostGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/parallel/costmodel_context.cc b/mindspore/ccsrc/parallel/costmodel_context.cc index 82b260f96700209ed021c4c384ee82532a211f54..591fa737aa1cd9481ff111ec49abc51305a40e09 100644 --- a/mindspore/ccsrc/parallel/costmodel_context.cc +++ b/mindspore/ccsrc/parallel/costmodel_context.cc @@ -46,6 +46,7 @@ void CostModelContext::ResetCostModel() { costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; + is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; @@ -84,6 +85,7 @@ void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } +void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { costmodel_allreduce_fusion_algorithm_ = algorithm; } diff --git a/mindspore/ccsrc/parallel/costmodel_context.h b/mindspore/ccsrc/parallel/costmodel_context.h index 99374830517d415082a74c516316fd9b3bae7faf..ebb0d00008ad382490e60d42c29246ed5d694803 100644 --- a/mindspore/ccsrc/parallel/costmodel_context.h +++ b/mindspore/ccsrc/parallel/costmodel_context.h @@ -67,6 +67,9 @@ class CostModelContext { void set_costmodel_communi_bias(double); double costmodel_communi_bias() const { return costmodel_communi_bias_; } + void set_multi_subgraphs(bool); + bool is_multi_subgraphs() const { return is_multi_subgraphs_; } + void set_costmodel_allreduce_fusion_algorithm(int32_t); int32_t costmodel_allreduce_fusion_algorithm() const { return costmodel_allreduce_fusion_algorithm_; } @@ -138,6 +141,8 @@ class CostModelContext { // COST_MODEL_COMMUNI_BIAS double costmodel_communi_bias_; + bool is_multi_subgraphs_; + int32_t costmodel_allreduce_fusion_algorithm_; int32_t costmodel_allreduce_fusion_times_; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 7d37bafe985bf605612e22731c83a025773d9f16..269e624efa77579edf1a2290e13cd02ef374c28d 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -426,13 +426,13 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & return operator_info; } -Status ConstructCostGraphNodes(const std::vector &all_nodes, const FuncGraphPtr &) { +// Using CNode's UniqueIds to construct nodes +Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; entire_costgraph = std::make_shared(); entire_costgraph->SetDeviceMemoryAndCostParameter(); - bool new_operator = true, first_operator = true; - std::string first_operator_cnode; - size_t current_op_index = 0; + // The map from CNode's UniqueId to its operatorInfo + std::map from_cnode_to_info; // Step 1 for (auto &node : all_nodes) { @@ -449,12 +449,8 @@ Status ConstructCostGraphNodes(const std::vector &all_nodes, const F PrimitivePtr prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); - // When visiting the second subgraph, use the corresponding operatorInfo which already created - bool modify_new_operator = (new_operator) && (!first_operator) && (cnode->UniqueId() == first_operator_cnode); - if (modify_new_operator) { - new_operator = false; - } - if (new_operator) { + auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); + if (search_cnode == from_cnode_to_info.end()) { auto operator_info = CreateTheOperatorInfo(prim, cnode); if (operator_info == nullptr) { return FAILED; @@ -465,14 +461,67 @@ Status ConstructCostGraphNodes(const std::vector &all_nodes, const F entire_costgraph->AddOperator(operator_info); (void)cnode->set_operator_info(operator_info); - if (first_operator) { - first_operator_cnode = cnode->UniqueId(); - first_operator = false; + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); + (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); + // Needed by rec_parser + entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); + } else { + // Two CNODEs' UniqueIds should not be equal + MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name(); + } + } + + MS_LOG(INFO) << "Constructing nodes for cost graph ends."; + return SUCCESS; +} + +// Using CNode's UniqueIdThroughCopys to construct nodes +Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &) { + MS_LOG(INFO) << "Constructing nodes for cost graph begins."; + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); + // The map from CNode's UniqueIdThroughCopy to its operatorInfo + std::map from_cnode_to_info; + + for (auto &node : all_nodes) { + // NOTE: we only care about splittable Primitive operators + auto cnode = node->cast(); + bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); + if (bool_result) { + continue; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsAutoParallelCareNode(cnode)) { + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + + // Find the operatorInfo if it exists + auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); + if (search_cnode == from_cnode_to_info.end()) { + // In this case, the corresponding OperatorInfo is not created, create the new one. + auto operator_info = CreateTheOperatorInfo(prim, cnode); + if (operator_info == nullptr) { + return FAILED; } // Needed by rec_parser + operator_info->set_type(prim->name()); + std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); + + entire_costgraph->AddOperator(operator_info); + (void)cnode->set_operator_info(operator_info); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); + (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); + // Needed by rec_parser entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); } else { - auto current_op_ptr = entire_costgraph->FindOperatorByIndex(current_op_index); + auto current_op_ptr = search_cnode->second; if (current_op_ptr == nullptr) { MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; } else { @@ -484,14 +533,12 @@ Status ConstructCostGraphNodes(const std::vector &all_nodes, const F << " does not match the Prim: " << prim->name(); } (void)cnode->set_operator_info(current_op_ptr); - current_op_index++; + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); } } } - if ((!new_operator) && (current_op_index != entire_costgraph->GetOperators().size())) { - MS_LOG(EXCEPTION) << "The second subgraph's operator number: " << current_op_index - << " does not match the first ones: " << entire_costgraph->GetOperators().size(); - } MS_LOG(INFO) << "Constructing nodes for cost graph ends."; return SUCCESS; @@ -844,11 +891,20 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu // OUTPUT: the determined strategy for each operator. // Step 1 - if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() - << " operators."; + if (CostModelContext::GetInstance()->is_multi_subgraphs()) { + if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } } // Step 2 @@ -916,7 +972,7 @@ std::vector> RecInputTensorNames(const std::map &all_nodes, const FuncGraphPtr &root) { - if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) { + if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() << " operators."; } else { diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.h b/mindspore/ccsrc/parallel/step_auto_parallel.h index f120edcc61685028b730bbd5b932fc520b665920..fff9dfa4c37b3077f4dba7a61dda18ffdb13b704 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.h +++ b/mindspore/ccsrc/parallel/step_auto_parallel.h @@ -43,7 +43,9 @@ std::vector ExtractInputTypeLengthByNode(const CNodePtr &node); std::vector ExtractOutputTypeByNode(const CNodePtr &node); -Status ConstructCostGraphNodes(const std::vector &all_nodes, const FuncGraphPtr &root); +Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root); + +Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root); void ConstructCostGraphEdges(const std::vector &all_nodes); diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index e8723e66a4bca7f4cc3bbe5638de77c37ab9394d..778600dc0a91d92fecfb2dd17e309e57b6775581 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -24,6 +24,7 @@ #include #include "ir/func_graph_cloner.h" +#include "parallel/costmodel_context.h" #include "pipeline/pass.h" #include "pipeline/parse/parse_base.h" #include "pipeline/parse/data_converter.h" @@ -341,7 +342,10 @@ static std::vector CommonPipeline() { // Resolve the python func actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); - actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); + auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); + if (!multi_graphs) { + actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); + } actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); // Evaluate type and shape, and specialize actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 04e6edc5c846d1faa6603822c8715f37f85c534c..868255a359821590b967f35a72e742ca6ea30509 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -222,6 +222,8 @@ PYBIND11_MODULE(_c_expression, m) { "Set the parameter cost_model_communi_bias of the DP algorithm.") .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, "Get the parameter cost_model_communi_bias of the DP algorithm.") + .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") + .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, "Set the parameter gradient AllReduce fusion algorithm.") .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, diff --git a/mindspore/parallel/_cost_model_context.py b/mindspore/parallel/_cost_model_context.py index 54cca5516b1e935e47ab40c4f50c6c1a91518488..2790aed855d8737171796ae03351cf006b3e9095 100644 --- a/mindspore/parallel/_cost_model_context.py +++ b/mindspore/parallel/_cost_model_context.py @@ -214,6 +214,31 @@ class _CostModelContext: raise ValueError("Context handle is none in context!!!") return self._context_handle.get_costmodel_communi_bias() + def set_multi_subgraphs(self, multi_subgraph): + """ + Set the flag of ANF graph containing multiple subgraphs. + + Args: + multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag. + + Raises: + ValueError: If context handle is none. + """ + if self._context_handle is None: + raise ValueError("Context handle is none in context!!!") + self._context_handle.set_multi_subgraphs(multi_subgraph) + + def get_multi_subgraphs(self): + """ + Get the flag of ANF graph containing multiple subgraphs. + + Raises: + ValueError: If context handle is none. + """ + if self._context_handle is None: + raise ValueError("Context handle is none in context!!!") + return self._context_handle.get_multi_subgraphs() + def set_costmodel_allreduce_fusion_algorithm(self, algorithm): """ Set costmodel allreduce fusion algorithm. @@ -427,6 +452,7 @@ set_cost_model_context_func_map = { "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold, "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, + "multi_subgraphs": cost_model_context().set_multi_subgraphs, "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, @@ -447,6 +473,7 @@ get_cost_model_context_func_map = { "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, + "multi_subgraphs": cost_model_context().get_multi_subgraphs(), "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, @@ -461,6 +488,7 @@ get_cost_model_context_func_map = { @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, + multi_subgraphs=bool, costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, costmodel_allreduce_fusion_allreduce_inherent_time=float, @@ -481,6 +509,7 @@ def set_cost_model_context(**kwargs): costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice. costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. + multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. 0: bypass allreduce fusion; 1: only use backward computation time to group allreduce; diff --git a/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8c89de25519f4e95bed42683b29b9316027029 --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py @@ -0,0 +1,101 @@ +import numpy as np +from mindspore import context +import mindspore as ms +import mindspore.nn as nn +from mindspore.nn.optim import Adam, FTRL +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor, Parameter, ParameterTuple +from mindspore.ops import composite as C +from mindspore.parallel import _cost_model_context as cost_model_context +from mindspore.common.api import _executor +from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters +from mindspore.parallel._utils import _reset_op_id as reset_op_id + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mul = P.Mul() + self.relu = P.ReLU() + self.wd = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide") + self.wt = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="l") + def construct(self, x): + out = self.mul(x, self.wd) + out = self.mul(out, self.wt) + out = self.relu(out) + return out + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.sum = P.ReduceSum() + self.mean = P.ReduceMean() + self.net = network + + def construct(self, x): + predict = self.net(x) + loss1 = self.sum(predict, -1) + loss2 = self.mean(predict, -1) + return loss1, loss2 + +class IthOutputCell(nn.Cell): + def __init__(self, network, output_index): + super(IthOutputCell, self).__init__() + self.network = network + self.output_index = output_index + + def construct(self, x): + predict = self.network(x)[self.output_index] + return predict + +class TrainStepWarp(nn.Cell): + def __init__(self, network, sens=1000.0): + super(TrainStepWarp, self).__init__() + self.network = network + self.network.set_train() + self.trainable_params = network.trainable_params() + weights_w = [] + weights_d = [] + for params in self.trainable_params: + weights_w.append(params) + weights_d.append(params) + self.weights_w = ParameterTuple(weights_w) + self.weights_d = ParameterTuple(weights_d) + self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, l1=1e-8, + l2=1e-8, initial_accum=1.0) + self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8, + loss_scale=sens) + self.hyper_map = C.HyperMap() + self.grad_w = C.GradOperation('grad_w', get_by_list=True, sens_param=True) + self.grad_d = C.GradOperation('grad_d', get_by_list=True, sens_param=True) + self.sens = sens + self.loss_net_w = IthOutputCell(network, output_index=0) + self.loss_net_d = IthOutputCell(network, output_index=1) + + def construct(self, x): + weights_w = self.weights_w + weights_d = self.weights_d + loss_w, loss_d = self.network(x) + sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) + sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) + grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w) + grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d) + return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d)) + +def test_double_subgraphs(): + cost_model_context.set_cost_model_context(multi_subgraphs=True) + context.set_context(save_graphs=True) + context.set_auto_parallel_context(device_num=8, global_rank=0) + net = TrainStepWarp(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + + x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32) + reset_op_id() + _executor.compile(net, x, phase='train') + strategies = _executor._get_strategy(net) + expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]], + 'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]], + 'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]], + 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], + 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} + assert strategies == expected_strategies diff --git a/tests/ut/python/parallel/test_auto_parallel_two_bn.py b/tests/ut/python/parallel/test_auto_parallel_two_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb6074f9fc3aef8e1c32c3caf18d82951fb4fdd --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_two_bn.py @@ -0,0 +1,70 @@ +import numpy as np +from mindspore import context +import mindspore as ms +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore.common.api import _executor +from tests.ut.python.ops.test_math_ops import VirtualLoss +from mindspore.parallel import set_algo_parameters +from mindspore.parallel._utils import _reset_op_id as reset_op_id +import re + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + +class Blockcell(nn.Cell): + def __init__(self): + super(Blockcell, self).__init__() + self.bn = nn.BatchNorm2d(64, momentum=0.9) + + def construct(self, x): + out = self.bn(x) + return out + +def getBlock(): + return Blockcell() + +def test_two_bn(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.block1 = getBlock() + self.block2 = getBlock() + self.relu = P.ReLU() + self.add = P.TensorAdd() + self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32) + + def construct(self, x): + out = self.block1(x) + out = self.relu(out) + out = self.add(out, self.bias) + out = self.block2(out) + return out + + net = NetWithLoss(Net()) + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + context.set_context(save_graphs=True) + context.set_auto_parallel_context(device_num=8, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + set_algo_parameters(elementwise_op_strategy_follow=True) + reset_op_id() + + _executor.compile(net, x, phase='train') + strategies = _executor._get_strategy(net) + assert len(strategies) == 4 + + for (k, v) in strategies.items(): + if re.search('BatchNorm-op', k) is not None: + assert v == [[8, 1], [1], [1], [1], [1]] + elif re.search('TensorAdd-op', k) is not None: + assert v == [[8, 1], [8, 1]] + elif re.search('ReLU-op', k) is not None: + assert v == [[8, 1]]