提交 21d936e6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!728 auto parallel strategy checkpoint full

Merge pull request !728 from yao_yf/strategy_checkpoint_extend
...@@ -52,7 +52,11 @@ class Primitive : public Named { ...@@ -52,7 +52,11 @@ class Primitive : public Named {
: Named(name), signatures_(), prim_type_(prim_type) {} : Named(name), signatures_(), prim_type_(prim_type) {}
Primitive(const Primitive &prim) Primitive(const Primitive &prim)
: Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} : Named(prim),
attrs_(prim.attrs_),
signatures_(prim.signatures_),
instance_name_(prim.instance_name_),
prim_type_(prim.prim_type_) {}
MS_DECLARE_PARENT(Primitive, Named); MS_DECLARE_PARENT(Primitive, Named);
......
...@@ -56,6 +56,8 @@ void ParallelContext::Reset() { ...@@ -56,6 +56,8 @@ void ParallelContext::Reset() {
parameter_broadcast_ = false; parameter_broadcast_ = false;
parameter_broadcast_is_set_ = false; parameter_broadcast_is_set_ = false;
enable_all_reduce_fusion_ = false; enable_all_reduce_fusion_ = false;
strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = "";
} }
void ParallelContext::set_device_num(int32_t device_num) { void ParallelContext::set_device_num(int32_t device_num) {
...@@ -103,6 +105,14 @@ void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) { ...@@ -103,6 +105,14 @@ void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
parameter_broadcast_is_set_ = true; parameter_broadcast_is_set_ = true;
} }
void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
strategy_ckpt_load_file_ = strategy_ckpt_load_file;
}
void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
strategy_ckpt_save_file_ = strategy_ckpt_save_file;
}
void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) { void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) {
all_reduce_fusion_split_indices_ = indices; all_reduce_fusion_split_indices_ = indices;
} }
......
...@@ -85,6 +85,11 @@ class ParallelContext { ...@@ -85,6 +85,11 @@ class ParallelContext {
} }
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void Reset(); void Reset();
private: private:
...@@ -105,6 +110,8 @@ class ParallelContext { ...@@ -105,6 +110,8 @@ class ParallelContext {
bool enable_all_reduce_fusion_; bool enable_all_reduce_fusion_;
std::vector<uint32_t> all_reduce_fusion_split_indices_; std::vector<uint32_t> all_reduce_fusion_split_indices_;
std::vector<uint32_t> all_reduce_fusion_split_sizes_; std::vector<uint32_t> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include "parallel/context.h" #include "parallel/context.h"
#include "parallel/ops_info/tmp_identity_info.h" #include "parallel/ops_info/tmp_identity_info.h"
#include "parallel/step_parallel.h" #include "parallel/step_parallel.h"
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "pipeline/parse/python_adapter.h" #include "pipeline/parse/python_adapter.h"
#include "pipeline/pipeline.h" #include "pipeline/pipeline.h"
...@@ -339,7 +340,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { ...@@ -339,7 +340,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
} }
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto attrs = prim->attrs(); auto attrs = prim->attrs();
...@@ -385,9 +386,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & ...@@ -385,9 +386,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_input_value(input_value); operator_info->set_input_value(input_value);
operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode); operator_info->set_cnode(cnode);
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
// If no strategy has been configured for this operator, then candidate strategies are generated for // If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
if (!StrategyFound(attrs) || prim->name() == CAST) { // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator // BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList(); operator_info->ComputeBatchSplitFlagList();
...@@ -397,7 +404,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & ...@@ -397,7 +404,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
} }
} else { } else {
// In this case, the configured strategy should be extracted to help setting cost // In this case, the configured strategy should be extracted to help setting cost
StrategyPtr strategyPtr = parallel::ExtractStrategy(attrs); StrategyPtr strategyPtr;
if (load_strategy_from_ckpt) {
strategyPtr = (*stra_map)[strategy_key_name];
} else {
strategyPtr = parallel::ExtractStrategy(attrs);
}
if (strategyPtr != nullptr) { if (strategyPtr != nullptr) {
if (prim->name() == RESHAPE) { if (prim->name() == RESHAPE) {
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
...@@ -433,7 +445,13 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node ...@@ -433,7 +445,13 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
entire_costgraph->SetDeviceMemoryAndCostParameter(); entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueId to its operatorInfo // The map from CNode's UniqueId to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info; std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// extract strategy from checkpoint for multi-train
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
// Step 1 // Step 1
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators // NOTE: we only care about splittable Primitive operators
...@@ -451,7 +469,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node ...@@ -451,7 +469,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
if (search_cnode == from_cnode_to_info.end()) { if (search_cnode == from_cnode_to_info.end()) {
auto operator_info = CreateTheOperatorInfo(prim, cnode); auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
if (operator_info == nullptr) { if (operator_info == nullptr) {
return FAILED; return FAILED;
} }
...@@ -486,7 +504,13 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no ...@@ -486,7 +504,13 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
entire_costgraph->SetDeviceMemoryAndCostParameter(); entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo // The map from CNode's UniqueIdThroughCopy to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info; std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// extract strategy from checkpoint for multi-train
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators // NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
...@@ -504,7 +528,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no ...@@ -504,7 +528,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
if (search_cnode == from_cnode_to_info.end()) { if (search_cnode == from_cnode_to_info.end()) {
// In this case, the corresponding OperatorInfo is not created, create the new one. // In this case, the corresponding OperatorInfo is not created, create the new one.
auto operator_info = CreateTheOperatorInfo(prim, cnode); auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
if (operator_info == nullptr) { if (operator_info == nullptr) {
return FAILED; return FAILED;
} }
......
...@@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { ...@@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
} }
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
// load strategy map from checkpoint
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
...@@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(void)cnode->set_operator_info(operator_); (void)cnode->set_operator_info(operator_);
continue; continue;
} }
if (!StrategyFound(attrs)) { // load strategy checkpoint
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel"; << " is empty, using batch parallel";
std::shared_ptr<std::vector<Dimensions>> strategy_v_ptr = operator_->GenerateBatchStrategies(); std::shared_ptr<std::vector<Dimensions>> strategy_v_ptr = operator_->GenerateBatchStrategies();
...@@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is "
<< attrs[GEN_STRATEGY]->ToString(); << attrs[GEN_STRATEGY]->ToString();
strategyPtr = NewStrategy(0, *strategy_v_ptr); strategyPtr = NewStrategy(0, *strategy_v_ptr);
} else if (load_strategy_from_ckpt) {
strategyPtr = stra_map[strategy_key_name];
} else { } else {
strategyPtr = ExtractStrategy(attrs); strategyPtr = ExtractStrategy(attrs);
} }
...@@ -2022,53 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo ...@@ -2022,53 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
} }
} }
void CheckpointStrategy(const FuncGraphPtr &func_graph) { bool NodeWithParameter(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph); std::vector<AnfNodePtr> node_inputs{node->inputs()};
MS_LOG(INFO) << "Save strategy to checkpoint begin"; for (auto input : node_inputs) {
StrategyMap straMap; if (input->isa<Parameter>()) {
auto ret = func_graph->get_return(); auto input_parameter = input->cast<ParameterPtr>();
auto all_nodes = DeepScopedGraphSearch(ret); if (input_parameter->has_default()) {
for (auto &node : all_nodes) { return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) {
if (prim->instance_name().empty()) {
continue;
} }
std::string instance_name = prim->instance_name();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
straMap[node_name] = strategyPtr;
} }
} }
if (StrategyCheckpoint::GetInstance().Save(straMap) != SUCCESS) { return false;
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
}
} }
void RestoreStrategy(const FuncGraphPtr &func_graph) { void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "Extract strategy from checkpoint begin"; MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
StrategyMap straMap; StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().Load(&straMap) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
if (StrategyCheckpoint::GetInstance().RemoveCheckPoint() != SUCCESS) {
MS_LOG(EXCEPTION) << "Remove strategy checkpoint failed";
}
auto ret = func_graph->get_return(); auto ret = func_graph->get_return();
auto all_nodes = DeepScopedGraphSearch(ret); auto all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) {
continue; continue;
} }
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
...@@ -2076,18 +2068,18 @@ void RestoreStrategy(const FuncGraphPtr &func_graph) { ...@@ -2076,18 +2068,18 @@ void RestoreStrategy(const FuncGraphPtr &func_graph) {
OperatorInfoPtr operator_info = cnode->operator_info(); OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) { if (operator_info) {
if (prim->instance_name().empty()) { if (prim->instance_name().empty()) {
continue; MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name";
} }
std::string instance_name = prim->instance_name(); std::string instance_name = prim->instance_name();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope()); MS_EXCEPTION_IF_NULL(node->scope());
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
MS_LOG(INFO) << "Node name is " << node_name; stra_map[node_name] = strategyPtr;
if (straMap.find(node_name) != straMap.end()) {
StrategyPtr strategyPtr = straMap[node_name];
operator_info->set_strategy(strategyPtr);
}
} }
} }
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
}
} }
void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) { void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
...@@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) ...@@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// extract shape and strategy, set operator_info // extract shape and strategy, set operator_info
ExtractInformation(all_nodes); ExtractInformation(all_nodes);
ReshapeInit(all_nodes); ReshapeInit(all_nodes);
// extract strategy from checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().CheckPointOn() && StrategyCheckpoint::GetInstance().CheckPointExit()) {
RestoreStrategy(root);
}
} }
// save strategy as checkpoint for multi-train // save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().CheckPointOn() && if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
StrategyCheckpoint::GetInstance().GetCurrentTrainTime() < StrategyCheckpoint::GetInstance().GetTrainTimes()) {
CheckpointStrategy(root); CheckpointStrategy(root);
} }
......
...@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); ...@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager); const FuncGraphManagerPtr &manager);
void RestoreStrategy(const FuncGraphPtr &func_graph); bool NodeWithParameter(const CNodePtr &node);
void CheckpointStrategy(const FuncGraphPtr &func_graph); void CheckpointStrategy(const FuncGraphPtr &func_graph);
......
...@@ -29,30 +29,32 @@ namespace mindspore { ...@@ -29,30 +29,32 @@ namespace mindspore {
namespace parallel { namespace parallel {
StrategyCheckpoint &StrategyCheckpoint::GetInstance() { StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
static StrategyCheckpoint instance = StrategyCheckpoint(); static StrategyCheckpoint instance = StrategyCheckpoint();
if (ParallelContext::GetInstance() != nullptr) {
instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file();
instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
}
return instance; return instance;
} }
bool StrategyCheckpoint::CheckPointExit() const { bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
std::ifstream fin(path_); std::ifstream fin(path);
if (fin) { if (fin) {
return true; return true;
} }
return false; return false;
} }
Status StrategyCheckpoint::RemoveCheckPoint() const {
if (std::remove(common::SafeCStr(path_)) == 0) {
return SUCCESS;
}
return FAILED;
}
Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
if (strategy_map == nullptr) { if (strategy_map == nullptr) {
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
} }
if (!CheckPointExit(load_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
}
straspb::ParallelStrategyMap parallel_strategy_map; straspb::ParallelStrategyMap parallel_strategy_map;
std::fstream input(path_, std::ios::in | std::ios::binary); std::fstream input(load_file_, std::ios::in | std::ios::binary);
if (!parallel_strategy_map.ParseFromIstream(&input)) { if (!parallel_strategy_map.ParseFromIstream(&input)) {
MS_LOG(ERROR) << "Load strategy file failed"; MS_LOG(ERROR) << "Load strategy file failed";
return FAILED; return FAILED;
...@@ -77,14 +79,14 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { ...@@ -77,14 +79,14 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
StrategyPtr strategy = NewStrategy(stage, strategy_inputs); StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
(*strategy_map)[node_name] = strategy; (*strategy_map)[node_name] = strategy;
current_train_time_ = (int32_t)parallel_strategy_map.train_time(); current_stage_ = (int32_t)parallel_strategy_map.current_stage();
} }
return SUCCESS; return SUCCESS;
} }
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
straspb::ParallelStrategyMap parallel_strategy_map; straspb::ParallelStrategyMap parallel_strategy_map;
parallel_strategy_map.set_train_time(IntToUint(++current_train_time_)); parallel_strategy_map.set_current_stage(IntToUint(++current_stage_));
for (auto &node_stra : strategy_map) { for (auto &node_stra : strategy_map) {
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
MS_EXCEPTION_IF_NULL(parallel_strategy_item); MS_EXCEPTION_IF_NULL(parallel_strategy_item);
...@@ -100,7 +102,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { ...@@ -100,7 +102,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
} }
} }
} }
std::fstream output(path_, std::ios::out | std::ios::trunc | std::ios::binary); std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_strategy_map.SerializeToOstream(&output)) { if (!parallel_strategy_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed"; MS_LOG(ERROR) << "Save strategy file failed";
return FAILED; return FAILED;
......
...@@ -21,43 +21,37 @@ ...@@ -21,43 +21,37 @@
#include <unordered_map> #include <unordered_map>
#include "parallel/ops_info/ops_utils.h" #include "parallel/ops_info/ops_utils.h"
#include "parallel/strategy.h" #include "parallel/strategy.h"
#include "parallel/context.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
constexpr char DEFAULT_CHECKPOINT_PATH[] = "./strategys.ckpt";
using StrategyMap = std::unordered_map<std::string, StrategyPtr>; using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
class StrategyCheckpoint { class StrategyCheckpoint {
public: public:
StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) { StrategyCheckpoint() {
train_times_ = 1; current_stage_ = 0;
checkpoint_on_ = false; load_file_ = "";
const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); load_checkpoint_on_ = false;
if (train_times_str != nullptr && std::stoi(train_times_str) > 0) { save_file_ = "";
train_times_ = std::stoi(train_times_str); save_checkpoint_on_ = false;
}
const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON");
if (checkpoint_on_str != nullptr) {
checkpoint_on_ = (std::string(checkpoint_on_str) == "on");
}
} }
~StrategyCheckpoint() = default; ~StrategyCheckpoint() = default;
bool CheckPointExit() const;
Status RemoveCheckPoint() const;
Status Load(StrategyMap *strategy_map); Status Load(StrategyMap *strategy_map);
Status Save(const StrategyMap &strategy_map); Status Save(const StrategyMap &strategy_map);
static StrategyCheckpoint &GetInstance(); static StrategyCheckpoint &GetInstance();
int32_t GetTrainTimes() const { return train_times_; } bool LoadCheckPointOn() const { return load_checkpoint_on_; }
int32_t GetCurrentTrainTime() const { return current_train_time_; } bool SaveCheckPointOn() const { return save_checkpoint_on_; }
bool CheckPointOn() const { return checkpoint_on_; }
private: private:
std::string path_; std::string load_file_;
bool checkpoint_on_; std::string save_file_;
// total train times for a train, get from Environmental variable:TRAIN_TIME, please export it bool load_checkpoint_on_;
int32_t train_times_; bool save_checkpoint_on_;
int32_t current_train_time_; bool CheckPointExit(const std::string path) const;
int32_t current_stage_;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -189,6 +189,12 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -189,6 +189,12 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
"Get parameter broadcast is set.") "Get parameter broadcast is set.")
.def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.") .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.")
.def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file,
"Set strategy checkpoint load file.")
.def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file,
"Set strategy checkpoint save file.")
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context."); .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
......
...@@ -33,6 +33,6 @@ message ParallelStrategyItem { ...@@ -33,6 +33,6 @@ message ParallelStrategyItem {
} }
message ParallelStrategyMap { message ParallelStrategyMap {
required uint32 train_time = 1; required uint32 current_stage = 1;
repeated ParallelStrategyItem parallel_strategy_item = 2; repeated ParallelStrategyItem parallel_strategy_item = 2;
} }
\ No newline at end of file
...@@ -396,7 +396,7 @@ def _context(): ...@@ -396,7 +396,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
parameter_broadcast=bool) parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
def set_auto_parallel_context(**kwargs): def set_auto_parallel_context(**kwargs):
""" """
Set auto parallel context. Set auto parallel context.
...@@ -428,6 +428,8 @@ def set_auto_parallel_context(**kwargs): ...@@ -428,6 +428,8 @@ def set_auto_parallel_context(**kwargs):
parameter_broadcast (bool): Indicating whether to broadcast parameters before training. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False. broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
...@@ -439,6 +441,8 @@ def set_auto_parallel_context(**kwargs): ...@@ -439,6 +441,8 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(cast_before_mirror=False) >>> context.set_auto_parallel_context(cast_before_mirror=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel") >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
""" """
_set_auto_parallel_context(**kwargs) _set_auto_parallel_context(**kwargs)
...@@ -469,6 +473,8 @@ def reset_auto_parallel_context(): ...@@ -469,6 +473,8 @@ def reset_auto_parallel_context():
- cast_before_mirror: True. - cast_before_mirror: True.
- parallel_mode: "stand_alone". - parallel_mode: "stand_alone".
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: "".
- strategy_ckpt_save_file: "".
""" """
_reset_auto_parallel_context() _reset_auto_parallel_context()
......
...@@ -88,6 +88,8 @@ class Primitive(Primitive_): ...@@ -88,6 +88,8 @@ class Primitive(Primitive_):
for name in self.attrs: for name in self.attrs:
value = self.attrs[name] value = self.attrs[name]
cloned.add_prim_attr(name, value) cloned.add_prim_attr(name, value)
if hasattr(self, 'instance_name'):
cloned.set_prim_instance_name(self.instance_name)
return cloned return cloned
def add_prim_attr(self, name, value): def add_prim_attr(self, name, value):
......
...@@ -208,6 +208,36 @@ class _AutoParallelContext: ...@@ -208,6 +208,36 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_parameter_broadcast() return self._context_handle.get_parameter_broadcast()
def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
"""
Set strategy checkpoint load path.
Args:
strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
"""
self.check_context_handle()
self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
def get_strategy_ckpt_load_file(self):
"""Get strategy checkpoint load path."""
self.check_context_handle()
return self._context_handle.get_strategy_ckpt_load_file()
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
"""
Set strategy checkpoint save path.
Args:
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
"""
self.check_context_handle()
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
def get_strategy_ckpt_save_file(self):
"""Get strategy checkpoint save path."""
self.check_context_handle()
return self._context_handle.get_strategy_ckpt_save_file()
def get_parameter_broadcast_is_set(self): def get_parameter_broadcast_is_set(self):
"""Get parameter broadcast is set or not.""" """Get parameter broadcast is set or not."""
self.check_context_handle() self.check_context_handle()
...@@ -315,7 +345,9 @@ _set_auto_parallel_context_func_map = { ...@@ -315,7 +345,9 @@ _set_auto_parallel_context_func_map = {
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror, "cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"parallel_mode": auto_parallel_context().set_parallel_mode, "parallel_mode": auto_parallel_context().set_parallel_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast} "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file}
_get_auto_parallel_context_func_map = { _get_auto_parallel_context_func_map = {
...@@ -325,11 +357,14 @@ _get_auto_parallel_context_func_map = { ...@@ -325,11 +357,14 @@ _get_auto_parallel_context_func_map = {
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror, "cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"parallel_mode": auto_parallel_context().get_parallel_mode, "parallel_mode": auto_parallel_context().get_parallel_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast} "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool) loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
def _set_auto_parallel_context(**kwargs): def _set_auto_parallel_context(**kwargs):
""" """
Set auto parallel context. Set auto parallel context.
...@@ -360,6 +395,8 @@ def _set_auto_parallel_context(**kwargs): ...@@ -360,6 +395,8 @@ def _set_auto_parallel_context(**kwargs):
parameter_broadcast (bool): Indicating whether to broadcast parameters before training. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False. broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
...@@ -400,5 +437,7 @@ def _reset_auto_parallel_context(): ...@@ -400,5 +437,7 @@ def _reset_auto_parallel_context():
- cast_before_mirror: True. - cast_before_mirror: True.
- parallel_mode: "stand_alone". - parallel_mode: "stand_alone".
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
""" """
auto_parallel_context().reset() auto_parallel_context().reset()
...@@ -25,9 +25,7 @@ StrategyCheckpoint& StrategyCheckpoint::GetInstance() { ...@@ -25,9 +25,7 @@ StrategyCheckpoint& StrategyCheckpoint::GetInstance() {
return instance; return instance;
} }
bool StrategyCheckpoint::CheckPointExit() const { return false; } bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return false; }
Status StrategyCheckpoint::RemoveCheckPoint() const { return SUCCESS; }
Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import numpy as np import numpy as np
from mindspore import context from mindspore import context
from mindspore.context import set_auto_parallel_context from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import Tensor from mindspore import Tensor, Parameter
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
import mindspore as ms import mindspore as ms
from mindspore.common.api import _executor from mindspore.common.api import _executor
...@@ -25,17 +25,15 @@ from mindspore.ops import composite as C ...@@ -25,17 +25,15 @@ from mindspore.ops import composite as C
# model_parallel test # model_parallel test
# export PARALLEL_CHECKPOINT_ON=on def test_six_matmul_save():
# export PARALLEL_TRAIN_TIMES=4
def test_six_matmul():
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
self.loss = VirtualLoss() self.loss = VirtualLoss()
self.network = network self.network = network
def construct(self, x1, x2, x3, x4, x5, x6, x7): def construct(self, x1, x6):
predict = self.network(x1, x2, x3, x4, x5, x6, x7) predict = self.network(x1, x6)
return self.loss(predict) return self.loss(predict)
...@@ -44,8 +42,8 @@ def test_six_matmul(): ...@@ -44,8 +42,8 @@ def test_six_matmul():
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
self.network = network self.network = network
def construct(self, x1, x2, x3, x4, x5, x6, x7): def construct(self, x1, x6):
return C.grad_all(self.network)(x1, x2, x3, x4, x5, x6, x7) return C.grad_all(self.network)(x1, x6)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6): def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
...@@ -56,45 +54,46 @@ def test_six_matmul(): ...@@ -56,45 +54,46 @@ def test_six_matmul():
self.matmul4 = P.MatMul().set_strategy(strategy4) self.matmul4 = P.MatMul().set_strategy(strategy4)
self.matmul5 = P.MatMul().set_strategy(strategy5) self.matmul5 = P.MatMul().set_strategy(strategy5)
self.matmul6 = P.MatMul().set_strategy(strategy6) self.matmul6 = P.MatMul().set_strategy(strategy6)
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
def construct(self, x1, x2, x3, x4, x5, x6, x7): self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
out = self.matmul1(x1, x2) self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
out = self.matmul2(out, x3) self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
out = self.matmul3(out, x4) self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
out = self.matmul4(out, x5)
out = self.matmul5(out, x6) def construct(self, x1, x6):
out = self.matmul6(out, x7) out = self.matmul1(x1, self.weight1)
out = self.matmul2(out, self.weight2)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = self.matmul6(out, x6)
return out return out
set_auto_parallel_context(device_num=512, global_rank=0) reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
strategy1 = ((8, 1), (1, 1)) strategy1 = ((8, 1), (1, 1))
strategy2 = ((1, 8), (8, 1)) strategy2 = ((1, 8), (8, 1))
strategy3 = ((2, 2), (2, 2)) strategy3 = ((2, 2), (2, 2))
strategy4 = ((4, 2), (2, 4)) strategy4 = ((1, 1), (1, 8))
strategy5 = ((2, 4), (4, 2)) strategy5 = ((4, 2), (2, 1))
strategy6 = ((4, 4), (4, 4)) strategy6 = ((4, 1), (1, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6))) net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x3 = Tensor(np.ones([64, 64]), dtype=ms.float32) _executor.compile(net, x1, x6)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_executor.compile(net, x1, x2, x3, x4, x5, x6, x7)
# remove matmul2 # remove matmul2, add matmul7
def test_six_matmul_repeated1(): def test_six_matmul_load():
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
self.loss = VirtualLoss() self.loss = VirtualLoss()
self.network = network self.network = network
def construct(self, x1, x2, x4, x5, x6, x7): def construct(self, x1, x6, x7):
predict = self.network(x1, x2, x4, x5, x6, x7) predict = self.network(x1, x6, x7)
return self.loss(predict) return self.loss(predict)
...@@ -103,53 +102,58 @@ def test_six_matmul_repeated1(): ...@@ -103,53 +102,58 @@ def test_six_matmul_repeated1():
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
self.network = network self.network = network
def construct(self, x1, x2, x4, x5, x6, x7): def construct(self, x1, x6, x7):
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7) return C.grad_all(self.network)(x1, x6, x7)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6): def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
super().__init__() super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1) self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul3 = P.MatMul().set_strategy(strategy3) self.matmul3 = P.MatMul().set_strategy(strategy3)
self.matmul4 = P.MatMul().set_strategy(strategy4) self.matmul4 = P.MatMul().set_strategy(strategy4)
self.matmul5 = P.MatMul().set_strategy(strategy5) self.matmul5 = P.MatMul().set_strategy(strategy5)
self.matmul6 = P.MatMul().set_strategy(strategy6) self.matmul6 = P.MatMul().set_strategy(strategy6)
self.matmul7 = P.MatMul().set_strategy(strategy7)
def construct(self, x1, x2, x4, x5, x6, x7): self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
out = self.matmul1(x1, x2) self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
out = self.matmul3(out, x4) self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
out = self.matmul4(out, x5) self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
out = self.matmul5(out, x6)
out = self.matmul6(out, x7) def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = self.matmul6(out, x6)
out = self.matmul7(out, x7)
return out return out
set_auto_parallel_context(device_num=512, global_rank=0) reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
strategy1 = ((8, 1), (1, 1)) strategy1 = ((8, 1), (1, 1))
strategy3 = ((8, 1), (1, 1)) strategy3 = ((8, 1), (1, 1))
strategy4 = ((8, 1), (1, 1)) strategy4 = ((8, 1), (1, 1))
strategy5 = ((8, 1), (1, 1)) strategy5 = ((8, 1), (1, 1))
strategy6 = ((8, 1), (1, 1)) strategy6 = ((8, 1), (1, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6))) strategy7 = ((8, 1), (1, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_executor.compile(net, x1, x2, x4, x5, x6, x7) _executor.compile(net, x1, x6, x7)
# add matmul7 # model_parallel test
def test_six_matmul_repeated2(): def test_six_matmul_save_auto():
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
self.loss = VirtualLoss() self.loss = VirtualLoss()
self.network = network self.network = network
def construct(self, x1, x2, x4, x5, x6, x7, x8): def construct(self, x1, x6):
predict = self.network(x1, x2, x4, x5, x6, x7, x8) predict = self.network(x1, x6)
return self.loss(predict) return self.loss(predict)
...@@ -158,60 +162,52 @@ def test_six_matmul_repeated2(): ...@@ -158,60 +162,52 @@ def test_six_matmul_repeated2():
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
self.network = network self.network = network
def construct(self, x1, x2, x4, x5, x6, x7, x8): def construct(self, x1, x6):
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8) return C.grad_all(self.network)(x1, x6)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): def __init__(self):
super().__init__() super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1) self.matmul1 = P.MatMul()
self.matmul3 = P.MatMul().set_strategy(strategy3) self.matmul2 = P.MatMul()
self.matmul4 = P.MatMul().set_strategy(strategy4) self.matmul3 = P.MatMul()
self.matmul5 = P.MatMul().set_strategy(strategy5) self.matmul4 = P.MatMul()
self.matmul6 = P.MatMul().set_strategy(strategy6) self.matmul5 = P.MatMul()
self.matmul7 = P.MatMul().set_strategy(strategy7) self.matmul6 = P.MatMul()
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
def construct(self, x1, x2, x4, x5, x6, x7, x8): self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
out = self.matmul1(x1, x2) self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
out = self.matmul3(out, x4) self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
out = self.matmul4(out, x5) self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
out = self.matmul5(out, x6)
out = self.matmul6(out, x7) def construct(self, x1, x6):
out = self.matmul7(out, x8) out = self.matmul1(x1, self.weight1)
out = self.matmul2(out, self.weight2)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = self.matmul6(out, x6)
return out return out
set_auto_parallel_context(device_num=512, global_rank=0) reset_auto_parallel_context()
strategy1 = ((8, 1), (1, 1)) set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
strategy3 = ((8, 1), (1, 1)) net = GradWrap(NetWithLoss(Net()))
strategy4 = ((8, 1), (1, 1)) context.set_auto_parallel_context(parallel_mode="auto_parallel")
strategy5 = ((8, 1), (1, 1))
strategy6 = ((8, 1), (1, 1))
strategy7 = ((8, 1), (1, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x8 = Tensor(np.ones([32, 128]), dtype=ms.float32)
_executor.compile(net, x1, x2, x4, x5, x6, x7, x8)
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x1, x6)
# add scope2 # remove matmul2, add matmul7
def test_six_matmul_repeated3(): def test_six_matmul_load_auto():
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network1, network2): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
self.loss = VirtualLoss() self.loss = VirtualLoss()
self.network = network1 self.network = network
self.network2 = network2
def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10): def construct(self, x1, x6, x7):
predict = self.network(x1, x2, x4, x5, x6, x7, x8) predict = self.network(x1, x6, x7)
predict = self.network2(predict, x9, x10)
return self.loss(predict) return self.loss(predict)
...@@ -220,62 +216,42 @@ def test_six_matmul_repeated3(): ...@@ -220,62 +216,42 @@ def test_six_matmul_repeated3():
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
self.network = network self.network = network
def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10): def construct(self, x1, x6, x7):
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8, x9, x10) return C.grad_all(self.network)(x1, x6, x7)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): def __init__(self, strategy1, strategy3, strategy4, strategy5):
super().__init__() super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1) self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul3 = P.MatMul().set_strategy(strategy3) self.matmul3 = P.MatMul().set_strategy(strategy3)
self.matmul4 = P.MatMul().set_strategy(strategy4) self.matmul4 = P.MatMul().set_strategy(strategy4)
self.matmul5 = P.MatMul().set_strategy(strategy5) self.matmul5 = P.MatMul().set_strategy(strategy5)
self.matmul6 = P.MatMul().set_strategy(strategy6) self.matmul6 = P.MatMul()
self.matmul7 = P.MatMul().set_strategy(strategy7) self.matmul7 = P.MatMul()
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
def construct(self, x1, x2, x4, x5, x6, x7, x8): self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
out = self.matmul1(x1, x2) self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
out = self.matmul3(out, x4) self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
out = self.matmul4(out, x5)
out = self.matmul5(out, x6) def construct(self, x1, x6, x7):
out = self.matmul6(out, x7) out = self.matmul1(x1, self.weight1)
out = self.matmul7(out, x8) out = self.matmul3(out, self.weight3)
return out out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
class Net1(nn.Cell): out = self.matmul6(out, x6)
def __init__(self, strategy1, strategy2): out = self.matmul7(out, x7)
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul2 = P.MatMul().set_strategy(strategy2)
def construct(self, x1, x2, x3):
out = self.matmul1(x1, x2)
out = self.matmul2(out, x3)
return out return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
strategy1 = ((2, 2), (2, 2))
strategy3 = ((2, 2), (2, 2))
strategy4 = ((2, 2), (2, 2))
strategy5 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
set_auto_parallel_context(device_num=512, global_rank=0) x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
strategy1 = ((8, 1), (1, 1)) x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
strategy3 = ((8, 1), (1, 1))
strategy4 = ((8, 1), (1, 1))
strategy5 = ((8, 1), (1, 1))
strategy6 = ((8, 1), (1, 1))
strategy7 = ((8, 1), (1, 1))
strategy8 = ((8, 1), (1, 1))
strategy9 = ((8, 1), (1, 1))
net1 = Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)
net2 = Net1(strategy8, strategy9)
net = GradWrap(NetWithLoss(net1, net2))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x8 = Tensor(np.ones([32, 128]), dtype=ms.float32) _executor.compile(net, x1, x6, x7)
x9 = Tensor(np.ones([128, 64]), dtype=ms.float32) \ No newline at end of file
x10 = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x1, x2, x4, x5, x6, x7, x8, x9, x10)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册