diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 73941c1058d891cebf6a8296aa9ade0bb9b1117b..08c6b7dc9bbec3c05b36d0bb130f50046832364b 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -52,7 +52,11 @@ class Primitive : public Named { : Named(name), signatures_(), prim_type_(prim_type) {} Primitive(const Primitive &prim) - : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} + : Named(prim), + attrs_(prim.attrs_), + signatures_(prim.signatures_), + instance_name_(prim.instance_name_), + prim_type_(prim.prim_type_) {} MS_DECLARE_PARENT(Primitive, Named); diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index 4eb79772ddbf47e993be1994917ee9fa1307102b..9ba7efd60f8dffb4e9e201a27af45820b64731c7 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -56,6 +56,8 @@ void ParallelContext::Reset() { parameter_broadcast_ = false; parameter_broadcast_is_set_ = false; enable_all_reduce_fusion_ = false; + strategy_ckpt_load_file_ = ""; + strategy_ckpt_save_file_ = ""; } void ParallelContext::set_device_num(int32_t device_num) { @@ -103,6 +105,14 @@ void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) { parameter_broadcast_is_set_ = true; } +void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) { + strategy_ckpt_load_file_ = strategy_ckpt_load_file; +} + +void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) { + strategy_ckpt_save_file_ = strategy_ckpt_save_file; +} + void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector indices) { all_reduce_fusion_split_indices_ = indices; } diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 095a50f7b376c3215cc391e3719d8e1d6efa0d15..0e007c92c648abceb1796c995602a33499bbb03d 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -85,6 +85,11 @@ class ParallelContext { } bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } + void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); + std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } + void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); + std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } + void Reset(); private: @@ -105,6 +110,8 @@ class ParallelContext { bool enable_all_reduce_fusion_; std::vector all_reduce_fusion_split_indices_; std::vector all_reduce_fusion_split_sizes_; + std::string strategy_ckpt_load_file_; + std::string strategy_ckpt_save_file_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 269e624efa77579edf1a2290e13cd02ef374c28d..f0be47642e5664acb1a29ea20cbd5799f98f997f 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -40,6 +40,7 @@ #include "parallel/context.h" #include "parallel/ops_info/tmp_identity_info.h" #include "parallel/step_parallel.h" +#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "pipeline/parse/python_adapter.h" #include "pipeline/pipeline.h" @@ -339,7 +340,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); } -OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { +OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(cnode); auto attrs = prim->attrs(); @@ -385,9 +386,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & operator_info->set_input_value(input_value); operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_cnode(cnode); + // key of strategy map + std::string instance_name = prim->instance_name(); + std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; + bool load_strategy_from_ckpt = + StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); // If no strategy has been configured for this operator, then candidate strategies are generated for - // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy - if (!StrategyFound(attrs) || prim->name() == CAST) { + // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. + // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . + if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // BatchParallelInfo operator operator_info->ComputeBatchSplitFlagList(); @@ -397,7 +404,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & } } else { // In this case, the configured strategy should be extracted to help setting cost - StrategyPtr strategyPtr = parallel::ExtractStrategy(attrs); + StrategyPtr strategyPtr; + if (load_strategy_from_ckpt) { + strategyPtr = (*stra_map)[strategy_key_name]; + } else { + strategyPtr = parallel::ExtractStrategy(attrs); + } if (strategyPtr != nullptr) { if (prim->name() == RESHAPE) { MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; @@ -433,7 +445,13 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueId to its operatorInfo std::map from_cnode_to_info; - + // extract strategy from checkpoint for multi-train + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } // Step 1 for (auto &node : all_nodes) { // NOTE: we only care about splittable Primitive operators @@ -451,7 +469,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); if (search_cnode == from_cnode_to_info.end()) { - auto operator_info = CreateTheOperatorInfo(prim, cnode); + auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); if (operator_info == nullptr) { return FAILED; } @@ -486,7 +504,13 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueIdThroughCopy to its operatorInfo std::map from_cnode_to_info; - + // extract strategy from checkpoint for multi-train + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } for (auto &node : all_nodes) { // NOTE: we only care about splittable Primitive operators auto cnode = node->cast(); @@ -504,7 +528,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); if (search_cnode == from_cnode_to_info.end()) { // In this case, the corresponding OperatorInfo is not created, create the new one. - auto operator_info = CreateTheOperatorInfo(prim, cnode); + auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); if (operator_info == nullptr) { return FAILED; } diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 17a622855258ed0da6b664c6760c13b9f4a5ba40..62fb96c2979a0680b2b88c52785cef3f334036aa 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { } void ExtractInformation(const std::vector &all_nodes) { + // load strategy map from checkpoint + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector &all_nodes) { (void)cnode->set_operator_info(operator_); continue; } - if (!StrategyFound(attrs)) { + // load strategy checkpoint + // key of strategy map + std::string instance_name = prim->instance_name(); + std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; + bool load_strategy_from_ckpt = + StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); + + if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() << " is empty, using batch parallel"; std::shared_ptr> strategy_v_ptr = operator_->GenerateBatchStrategies(); @@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector &all_nodes) { MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString(); strategyPtr = NewStrategy(0, *strategy_v_ptr); + } else if (load_strategy_from_ckpt) { + strategyPtr = stra_map[strategy_key_name]; } else { strategyPtr = ExtractStrategy(attrs); } @@ -2022,53 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vectorget_return(); - auto all_nodes = DeepScopedGraphSearch(ret); - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - PrimitivePtr prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info) { - if (prim->instance_name().empty()) { - continue; +bool NodeWithParameter(const CNodePtr &node) { + std::vector node_inputs{node->inputs()}; + for (auto input : node_inputs) { + if (input->isa()) { + auto input_parameter = input->cast(); + if (input_parameter->has_default()) { + return py::cast(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad")); } - std::string instance_name = prim->instance_name(); - StrategyPtr strategyPtr = operator_info->strategy(); - MS_EXCEPTION_IF_NULL(node->scope()); - std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; - straMap[node_name] = strategyPtr; } } - if (StrategyCheckpoint::GetInstance().Save(straMap) != SUCCESS) { - MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; - } + return false; } -void RestoreStrategy(const FuncGraphPtr &func_graph) { +void CheckpointStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(INFO) << "Extract strategy from checkpoint begin"; - StrategyMap straMap; - if (StrategyCheckpoint::GetInstance().Load(&straMap) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } - if (StrategyCheckpoint::GetInstance().RemoveCheckPoint() != SUCCESS) { - MS_LOG(EXCEPTION) << "Remove strategy checkpoint failed"; - } + MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; + StrategyMap stra_map; auto ret = func_graph->get_return(); auto all_nodes = DeepScopedGraphSearch(ret); for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0)) || !NodeWithParameter(cnode)) { continue; } PrimitivePtr prim = GetValueNode(cnode->input(0)); @@ -2076,18 +2068,18 @@ void RestoreStrategy(const FuncGraphPtr &func_graph) { OperatorInfoPtr operator_info = cnode->operator_info(); if (operator_info) { if (prim->instance_name().empty()) { - continue; + MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name"; } std::string instance_name = prim->instance_name(); + StrategyPtr strategyPtr = operator_info->strategy(); MS_EXCEPTION_IF_NULL(node->scope()); std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; - MS_LOG(INFO) << "Node name is " << node_name; - if (straMap.find(node_name) != straMap.end()) { - StrategyPtr strategyPtr = straMap[node_name]; - operator_info->set_strategy(strategyPtr); - } + stra_map[node_name] = strategyPtr; } } + if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; + } } void SetForwardFlag(const std::vector &all_nodes) { @@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) // extract shape and strategy, set operator_info ExtractInformation(all_nodes); ReshapeInit(all_nodes); - // extract strategy from checkpoint for multi-train - if (StrategyCheckpoint::GetInstance().CheckPointOn() && StrategyCheckpoint::GetInstance().CheckPointExit()) { - RestoreStrategy(root); - } } // save strategy as checkpoint for multi-train - if (StrategyCheckpoint::GetInstance().CheckPointOn() && - StrategyCheckpoint::GetInstance().GetCurrentTrainTime() < StrategyCheckpoint::GetInstance().GetTrainTimes()) { + if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { CheckpointStrategy(root); } diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index 745794912b9a6f688d450a74e48b8bcdf5a2b05f..c26f65ec656425107737eb39bdb5a709b680606a 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector &all_nodes); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -void RestoreStrategy(const FuncGraphPtr &func_graph); +bool NodeWithParameter(const CNodePtr &node); void CheckpointStrategy(const FuncGraphPtr &func_graph); diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 981cf8a11519b14ae8e3e5a78de5ba51306c9b55..de10f4beb405c9717c6f503c33e4bdd670f637ca 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -29,30 +29,32 @@ namespace mindspore { namespace parallel { StrategyCheckpoint &StrategyCheckpoint::GetInstance() { static StrategyCheckpoint instance = StrategyCheckpoint(); + if (ParallelContext::GetInstance() != nullptr) { + instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file(); + instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); + instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); + instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); + } return instance; } -bool StrategyCheckpoint::CheckPointExit() const { - std::ifstream fin(path_); +bool StrategyCheckpoint::CheckPointExit(const std::string path) const { + std::ifstream fin(path); if (fin) { return true; } return false; } -Status StrategyCheckpoint::RemoveCheckPoint() const { - if (std::remove(common::SafeCStr(path_)) == 0) { - return SUCCESS; - } - return FAILED; -} - Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { if (strategy_map == nullptr) { MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; } + if (!CheckPointExit(load_file_)) { + MS_LOG(EXCEPTION) << "CheckPoint file is not found"; + } straspb::ParallelStrategyMap parallel_strategy_map; - std::fstream input(path_, std::ios::in | std::ios::binary); + std::fstream input(load_file_, std::ios::in | std::ios::binary); if (!parallel_strategy_map.ParseFromIstream(&input)) { MS_LOG(ERROR) << "Load strategy file failed"; return FAILED; @@ -77,14 +79,14 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { StrategyPtr strategy = NewStrategy(stage, strategy_inputs); (*strategy_map)[node_name] = strategy; - current_train_time_ = (int32_t)parallel_strategy_map.train_time(); + current_stage_ = (int32_t)parallel_strategy_map.current_stage(); } return SUCCESS; } Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { straspb::ParallelStrategyMap parallel_strategy_map; - parallel_strategy_map.set_train_time(IntToUint(++current_train_time_)); + parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); for (auto &node_stra : strategy_map) { straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); MS_EXCEPTION_IF_NULL(parallel_strategy_item); @@ -100,7 +102,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { } } } - std::fstream output(path_, std::ios::out | std::ios::trunc | std::ios::binary); + std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); if (!parallel_strategy_map.SerializeToOstream(&output)) { MS_LOG(ERROR) << "Save strategy file failed"; return FAILED; diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index c871ea6eef167374a4937856d4b9c88a3bece339..0cf6229fa363978c86aff242937e459f1c411153 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -21,43 +21,37 @@ #include #include "parallel/ops_info/ops_utils.h" #include "parallel/strategy.h" +#include "parallel/context.h" namespace mindspore { namespace parallel { -constexpr char DEFAULT_CHECKPOINT_PATH[] = "./strategys.ckpt"; using StrategyMap = std::unordered_map; class StrategyCheckpoint { public: - StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) { - train_times_ = 1; - checkpoint_on_ = false; - const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); - if (train_times_str != nullptr && std::stoi(train_times_str) > 0) { - train_times_ = std::stoi(train_times_str); - } - const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); - if (checkpoint_on_str != nullptr) { - checkpoint_on_ = (std::string(checkpoint_on_str) == "on"); - } + StrategyCheckpoint() { + current_stage_ = 0; + load_file_ = ""; + load_checkpoint_on_ = false; + save_file_ = ""; + save_checkpoint_on_ = false; } ~StrategyCheckpoint() = default; - bool CheckPointExit() const; - Status RemoveCheckPoint() const; + Status Load(StrategyMap *strategy_map); Status Save(const StrategyMap &strategy_map); static StrategyCheckpoint &GetInstance(); - int32_t GetTrainTimes() const { return train_times_; } - int32_t GetCurrentTrainTime() const { return current_train_time_; } - bool CheckPointOn() const { return checkpoint_on_; } + bool LoadCheckPointOn() const { return load_checkpoint_on_; } + bool SaveCheckPointOn() const { return save_checkpoint_on_; } private: - std::string path_; - bool checkpoint_on_; - // total train times for a train, get from Environmental variable:TRAIN_TIME, please export it - int32_t train_times_; - int32_t current_train_time_; + std::string load_file_; + std::string save_file_; + bool load_checkpoint_on_; + bool save_checkpoint_on_; + bool CheckPointExit(const std::string path) const; + int32_t current_stage_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index f5cacc7ed5cba087b46c41b15f82ca7b77671340..5c6727670b1a9459625711d68726c1379c8e9d44 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -191,6 +191,12 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, "Get parameter broadcast is set.") .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.") + .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file, + "Set strategy checkpoint load file.") + .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file, + "Set strategy checkpoint save file.") + .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") + .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); (void)py::class_>(m, "CostModelContext") diff --git a/mindspore/ccsrc/utils/node_strategy.proto b/mindspore/ccsrc/utils/node_strategy.proto index dc06482ba1d5cbb9b282f2a00c89fa916d29ec1f..8ec25f21a6a6315648be4a693cbca42d8b9c7b99 100644 --- a/mindspore/ccsrc/utils/node_strategy.proto +++ b/mindspore/ccsrc/utils/node_strategy.proto @@ -33,6 +33,6 @@ message ParallelStrategyItem { } message ParallelStrategyMap { - required uint32 train_time = 1; + required uint32 current_stage = 1; repeated ParallelStrategyItem parallel_strategy_item = 2; } \ No newline at end of file diff --git a/mindspore/context.py b/mindspore/context.py index 237b2143ed30908c262a202526966649c19f766a..7341db620aa92aa35ca20bac66d0a6594e428bd2 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -404,7 +404,7 @@ def _context(): @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, - parameter_broadcast=bool) + parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) def set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -436,6 +436,8 @@ def set_auto_parallel_context(**kwargs): parameter_broadcast (bool): Indicating whether to broadcast parameters before training. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter broadcast. Default: False. + strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' + strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' Raises: ValueError: If input key is not attribute in auto parallel context. @@ -447,6 +449,8 @@ def set_auto_parallel_context(**kwargs): >>> context.set_auto_parallel_context(cast_before_mirror=False) >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") >>> context.set_auto_parallel_context(parameter_broadcast=False) + >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") + >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") """ _set_auto_parallel_context(**kwargs) @@ -477,6 +481,8 @@ def reset_auto_parallel_context(): - cast_before_mirror: True. - parallel_mode: "stand_alone". - parameter_broadcast: False. + - strategy_ckpt_load_file: "". + - strategy_ckpt_save_file: "". """ _reset_auto_parallel_context() diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index d281b4f76ca49b525463d84d7dd5586dfa193daf..24c81003bd252a2d0fd0f4c69d43d71d2a90c9ea 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -88,6 +88,8 @@ class Primitive(Primitive_): for name in self.attrs: value = self.attrs[name] cloned.add_prim_attr(name, value) + if hasattr(self, 'instance_name'): + cloned.set_prim_instance_name(self.instance_name) return cloned def add_prim_attr(self, name, value): diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 0608989d944603d35c4fc258b8664f493d0bf8c8..f3f8d443e9ee184c1d86b579cfafd411df62c5fa 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -208,6 +208,36 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_parameter_broadcast() + def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file): + """ + Set strategy checkpoint load path. + + Args: + strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint. + """ + self.check_context_handle() + self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file) + + def get_strategy_ckpt_load_file(self): + """Get strategy checkpoint load path.""" + self.check_context_handle() + return self._context_handle.get_strategy_ckpt_load_file() + + def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): + """ + Set strategy checkpoint save path. + + Args: + strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. + """ + self.check_context_handle() + self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) + + def get_strategy_ckpt_save_file(self): + """Get strategy checkpoint save path.""" + self.check_context_handle() + return self._context_handle.get_strategy_ckpt_save_file() + def get_parameter_broadcast_is_set(self): """Get parameter broadcast is set or not.""" self.check_context_handle() @@ -315,7 +345,9 @@ _set_auto_parallel_context_func_map = { "cast_before_mirror": auto_parallel_context().set_cast_before_mirror, "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, "parallel_mode": auto_parallel_context().set_parallel_mode, - "parameter_broadcast": auto_parallel_context().set_parameter_broadcast} + "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, + "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, + "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file} _get_auto_parallel_context_func_map = { @@ -325,11 +357,14 @@ _get_auto_parallel_context_func_map = { "cast_before_mirror": auto_parallel_context().get_cast_before_mirror, "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, "parallel_mode": auto_parallel_context().get_parallel_mode, - "parameter_broadcast": auto_parallel_context().get_parameter_broadcast} + "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, + "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, + "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file} @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, - loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool) + loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool, + strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) def _set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -360,6 +395,8 @@ def _set_auto_parallel_context(**kwargs): parameter_broadcast (bool): Indicating whether to broadcast parameters before training. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter broadcast. Default: False. + strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' + strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' Raises: ValueError: If input key is not attribute in auto parallel context. @@ -400,5 +437,7 @@ def _reset_auto_parallel_context(): - cast_before_mirror: True. - parallel_mode: "stand_alone". - parameter_broadcast: False. + - strategy_ckpt_load_file: "" + - strategy_ckpt_save_file: "" """ auto_parallel_context().reset() diff --git a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc index 73de5071cd6af3eb5723ef151fa27e060faf86c9..43d0dd4b3fbe681d522267d0bb305ba50b65d967 100644 --- a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc +++ b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc @@ -25,9 +25,7 @@ StrategyCheckpoint& StrategyCheckpoint::GetInstance() { return instance; } -bool StrategyCheckpoint::CheckPointExit() const { return false; } - -Status StrategyCheckpoint::RemoveCheckPoint() const { return SUCCESS; } +bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return false; } Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } diff --git a/tests/ut/python/parallel/test_strategy_checkpoint.py b/tests/ut/python/parallel/test_strategy_checkpoint.py index 09f4a54cbf92706ce9af4769d95019cada2b0bff..89b6dd1dbbb1f496725d8874f85726e6343434f7 100644 --- a/tests/ut/python/parallel/test_strategy_checkpoint.py +++ b/tests/ut/python/parallel/test_strategy_checkpoint.py @@ -14,10 +14,10 @@ import numpy as np from mindspore import context -from mindspore.context import set_auto_parallel_context +from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context import mindspore.nn as nn from mindspore.ops import operations as P -from mindspore import Tensor +from mindspore import Tensor, Parameter from tests.ut.python.ops.test_math_ops import VirtualLoss import mindspore as ms from mindspore.common.api import _executor @@ -25,17 +25,15 @@ from mindspore.ops import composite as C # model_parallel test -# export PARALLEL_CHECKPOINT_ON=on -# export PARALLEL_TRAIN_TIMES=4 -def test_six_matmul(): +def test_six_matmul_save(): class NetWithLoss(nn.Cell): def __init__(self, network): super(NetWithLoss, self).__init__() self.loss = VirtualLoss() self.network = network - def construct(self, x1, x2, x3, x4, x5, x6, x7): - predict = self.network(x1, x2, x3, x4, x5, x6, x7) + def construct(self, x1, x6): + predict = self.network(x1, x6) return self.loss(predict) @@ -44,8 +42,8 @@ def test_six_matmul(): super(GradWrap, self).__init__() self.network = network - def construct(self, x1, x2, x3, x4, x5, x6, x7): - return C.grad_all(self.network)(x1, x2, x3, x4, x5, x6, x7) + def construct(self, x1, x6): + return C.grad_all(self.network)(x1, x6) class Net(nn.Cell): def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6): @@ -56,45 +54,46 @@ def test_six_matmul(): self.matmul4 = P.MatMul().set_strategy(strategy4) self.matmul5 = P.MatMul().set_strategy(strategy5) self.matmul6 = P.MatMul().set_strategy(strategy6) - - def construct(self, x1, x2, x3, x4, x5, x6, x7): - out = self.matmul1(x1, x2) - out = self.matmul2(out, x3) - out = self.matmul3(out, x4) - out = self.matmul4(out, x5) - out = self.matmul5(out, x6) - out = self.matmul6(out, x7) + self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") + self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2") + self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") + self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") + self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + + def construct(self, x1, x6): + out = self.matmul1(x1, self.weight1) + out = self.matmul2(out, self.weight2) + out = self.matmul3(out, self.weight3) + out = self.matmul4(out, self.weight4) + out = self.matmul5(out, self.weight5) + out = self.matmul6(out, x6) return out - set_auto_parallel_context(device_num=512, global_rank=0) + reset_auto_parallel_context() + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt") strategy1 = ((8, 1), (1, 1)) strategy2 = ((1, 8), (8, 1)) strategy3 = ((2, 2), (2, 2)) - strategy4 = ((4, 2), (2, 4)) - strategy5 = ((2, 4), (4, 2)) - strategy6 = ((4, 4), (4, 4)) + strategy4 = ((1, 1), (1, 8)) + strategy5 = ((4, 2), (2, 1)) + strategy6 = ((4, 1), (1, 2)) net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) - x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) - x3 = Tensor(np.ones([64, 64]), dtype=ms.float32) - x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) - x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) - x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) - x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) - _executor.compile(net, x1, x2, x3, x4, x5, x6, x7) + x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) + x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) + _executor.compile(net, x1, x6) -# remove matmul2 -def test_six_matmul_repeated1(): +# remove matmul2, add matmul7 +def test_six_matmul_load(): class NetWithLoss(nn.Cell): def __init__(self, network): super(NetWithLoss, self).__init__() self.loss = VirtualLoss() self.network = network - def construct(self, x1, x2, x4, x5, x6, x7): - predict = self.network(x1, x2, x4, x5, x6, x7) + def construct(self, x1, x6, x7): + predict = self.network(x1, x6, x7) return self.loss(predict) @@ -103,53 +102,58 @@ def test_six_matmul_repeated1(): super(GradWrap, self).__init__() self.network = network - def construct(self, x1, x2, x4, x5, x6, x7): - return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7) + def construct(self, x1, x6, x7): + return C.grad_all(self.network)(x1, x6, x7) class Net(nn.Cell): - def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6): + def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): super().__init__() self.matmul1 = P.MatMul().set_strategy(strategy1) self.matmul3 = P.MatMul().set_strategy(strategy3) self.matmul4 = P.MatMul().set_strategy(strategy4) self.matmul5 = P.MatMul().set_strategy(strategy5) self.matmul6 = P.MatMul().set_strategy(strategy6) - - def construct(self, x1, x2, x4, x5, x6, x7): - out = self.matmul1(x1, x2) - out = self.matmul3(out, x4) - out = self.matmul4(out, x5) - out = self.matmul5(out, x6) - out = self.matmul6(out, x7) + self.matmul7 = P.MatMul().set_strategy(strategy7) + self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") + self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") + self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") + self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + + def construct(self, x1, x6, x7): + out = self.matmul1(x1, self.weight1) + out = self.matmul3(out, self.weight3) + out = self.matmul4(out, self.weight4) + out = self.matmul5(out, self.weight5) + out = self.matmul6(out, x6) + out = self.matmul7(out, x7) return out - set_auto_parallel_context(device_num=512, global_rank=0) + reset_auto_parallel_context() + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt") strategy1 = ((8, 1), (1, 1)) strategy3 = ((8, 1), (1, 1)) strategy4 = ((8, 1), (1, 1)) strategy5 = ((8, 1), (1, 1)) strategy6 = ((8, 1), (1, 1)) - net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6))) + strategy7 = ((8, 1), (1, 1)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) - x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) - x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) - x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) - x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) + x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) + x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) - _executor.compile(net, x1, x2, x4, x5, x6, x7) + _executor.compile(net, x1, x6, x7) -# add matmul7 -def test_six_matmul_repeated2(): +# model_parallel test +def test_six_matmul_save_auto(): class NetWithLoss(nn.Cell): def __init__(self, network): super(NetWithLoss, self).__init__() self.loss = VirtualLoss() self.network = network - def construct(self, x1, x2, x4, x5, x6, x7, x8): - predict = self.network(x1, x2, x4, x5, x6, x7, x8) + def construct(self, x1, x6): + predict = self.network(x1, x6) return self.loss(predict) @@ -158,60 +162,52 @@ def test_six_matmul_repeated2(): super(GradWrap, self).__init__() self.network = network - def construct(self, x1, x2, x4, x5, x6, x7, x8): - return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8) + def construct(self, x1, x6): + return C.grad_all(self.network)(x1, x6) class Net(nn.Cell): - def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): + def __init__(self): super().__init__() - self.matmul1 = P.MatMul().set_strategy(strategy1) - self.matmul3 = P.MatMul().set_strategy(strategy3) - self.matmul4 = P.MatMul().set_strategy(strategy4) - self.matmul5 = P.MatMul().set_strategy(strategy5) - self.matmul6 = P.MatMul().set_strategy(strategy6) - self.matmul7 = P.MatMul().set_strategy(strategy7) - - def construct(self, x1, x2, x4, x5, x6, x7, x8): - out = self.matmul1(x1, x2) - out = self.matmul3(out, x4) - out = self.matmul4(out, x5) - out = self.matmul5(out, x6) - out = self.matmul6(out, x7) - out = self.matmul7(out, x8) + self.matmul1 = P.MatMul() + self.matmul2 = P.MatMul() + self.matmul3 = P.MatMul() + self.matmul4 = P.MatMul() + self.matmul5 = P.MatMul() + self.matmul6 = P.MatMul() + self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") + self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2") + self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") + self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") + self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + + def construct(self, x1, x6): + out = self.matmul1(x1, self.weight1) + out = self.matmul2(out, self.weight2) + out = self.matmul3(out, self.weight3) + out = self.matmul4(out, self.weight4) + out = self.matmul5(out, self.weight5) + out = self.matmul6(out, x6) return out - set_auto_parallel_context(device_num=512, global_rank=0) - strategy1 = ((8, 1), (1, 1)) - strategy3 = ((8, 1), (1, 1)) - strategy4 = ((8, 1), (1, 1)) - strategy5 = ((8, 1), (1, 1)) - strategy6 = ((8, 1), (1, 1)) - strategy7 = ((8, 1), (1, 1)) - net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7))) - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - - x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) - x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) - x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) - x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) - x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) - x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) - x8 = Tensor(np.ones([32, 128]), dtype=ms.float32) - _executor.compile(net, x1, x2, x4, x5, x6, x7, x8) + reset_auto_parallel_context() + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt") + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) + x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) + _executor.compile(net, x1, x6) -# add scope2 -def test_six_matmul_repeated3(): +# remove matmul2, add matmul7 +def test_six_matmul_load_auto(): class NetWithLoss(nn.Cell): - def __init__(self, network1, network2): + def __init__(self, network): super(NetWithLoss, self).__init__() self.loss = VirtualLoss() - self.network = network1 - self.network2 = network2 + self.network = network - def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10): - predict = self.network(x1, x2, x4, x5, x6, x7, x8) - predict = self.network2(predict, x9, x10) + def construct(self, x1, x6, x7): + predict = self.network(x1, x6, x7) return self.loss(predict) @@ -220,62 +216,42 @@ def test_six_matmul_repeated3(): super(GradWrap, self).__init__() self.network = network - def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10): - return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8, x9, x10) + def construct(self, x1, x6, x7): + return C.grad_all(self.network)(x1, x6, x7) class Net(nn.Cell): - def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): + def __init__(self, strategy1, strategy3, strategy4, strategy5): super().__init__() self.matmul1 = P.MatMul().set_strategy(strategy1) self.matmul3 = P.MatMul().set_strategy(strategy3) self.matmul4 = P.MatMul().set_strategy(strategy4) self.matmul5 = P.MatMul().set_strategy(strategy5) - self.matmul6 = P.MatMul().set_strategy(strategy6) - self.matmul7 = P.MatMul().set_strategy(strategy7) - - def construct(self, x1, x2, x4, x5, x6, x7, x8): - out = self.matmul1(x1, x2) - out = self.matmul3(out, x4) - out = self.matmul4(out, x5) - out = self.matmul5(out, x6) - out = self.matmul6(out, x7) - out = self.matmul7(out, x8) - return out - - class Net1(nn.Cell): - def __init__(self, strategy1, strategy2): - super().__init__() - self.matmul1 = P.MatMul().set_strategy(strategy1) - self.matmul2 = P.MatMul().set_strategy(strategy2) - - def construct(self, x1, x2, x3): - out = self.matmul1(x1, x2) - out = self.matmul2(out, x3) + self.matmul6 = P.MatMul() + self.matmul7 = P.MatMul() + self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") + self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") + self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") + self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + + def construct(self, x1, x6, x7): + out = self.matmul1(x1, self.weight1) + out = self.matmul3(out, self.weight3) + out = self.matmul4(out, self.weight4) + out = self.matmul5(out, self.weight5) + out = self.matmul6(out, x6) + out = self.matmul7(out, x7) return out + reset_auto_parallel_context() + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt") + strategy1 = ((2, 2), (2, 2)) + strategy3 = ((2, 2), (2, 2)) + strategy4 = ((2, 2), (2, 2)) + strategy5 = ((2, 2), (2, 2)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5))) + context.set_auto_parallel_context(parallel_mode="auto_parallel") - set_auto_parallel_context(device_num=512, global_rank=0) - strategy1 = ((8, 1), (1, 1)) - strategy3 = ((8, 1), (1, 1)) - strategy4 = ((8, 1), (1, 1)) - strategy5 = ((8, 1), (1, 1)) - strategy6 = ((8, 1), (1, 1)) - strategy7 = ((8, 1), (1, 1)) - strategy8 = ((8, 1), (1, 1)) - strategy9 = ((8, 1), (1, 1)) - net1 = Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7) - net2 = Net1(strategy8, strategy9) - net = GradWrap(NetWithLoss(net1, net2)) - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - - x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) - x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) - x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) - x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) - x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) + x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) + x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) - x8 = Tensor(np.ones([32, 128]), dtype=ms.float32) - x9 = Tensor(np.ones([128, 64]), dtype=ms.float32) - x10 = Tensor(np.ones([64, 64]), dtype=ms.float32) - _executor.compile(net, x1, x2, x4, x5, x6, x7, x8, x9, x10) - + _executor.compile(net, x1, x6, x7) \ No newline at end of file