提交 6cde5f6d 编写于 作者: Y yao_yf

auto parallel strategy checkpoint

上级 420ef2a3
......@@ -52,7 +52,11 @@ class Primitive : public Named {
: Named(name), signatures_(), prim_type_(prim_type) {}
Primitive(const Primitive &prim)
: Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {}
: Named(prim),
attrs_(prim.attrs_),
signatures_(prim.signatures_),
instance_name_(prim.instance_name_),
prim_type_(prim.prim_type_) {}
MS_DECLARE_PARENT(Primitive, Named);
......
......@@ -56,6 +56,8 @@ void ParallelContext::Reset() {
parameter_broadcast_ = false;
parameter_broadcast_is_set_ = false;
enable_all_reduce_fusion_ = false;
strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = "";
}
void ParallelContext::set_device_num(int32_t device_num) {
......@@ -103,6 +105,14 @@ void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
parameter_broadcast_is_set_ = true;
}
void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) {
strategy_ckpt_load_file_ = strategy_ckpt_load_file;
}
void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) {
strategy_ckpt_save_file_ = strategy_ckpt_save_file;
}
void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) {
all_reduce_fusion_split_indices_ = indices;
}
......
......@@ -85,6 +85,11 @@ class ParallelContext {
}
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void Reset();
private:
......@@ -105,6 +110,8 @@ class ParallelContext {
bool enable_all_reduce_fusion_;
std::vector<uint32_t> all_reduce_fusion_split_indices_;
std::vector<uint32_t> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
};
} // namespace parallel
} // namespace mindspore
......
......@@ -40,6 +40,7 @@
#include "parallel/context.h"
#include "parallel/ops_info/tmp_identity_info.h"
#include "parallel/step_parallel.h"
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/pipeline.h"
......@@ -339,7 +340,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
}
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) {
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(cnode);
auto attrs = prim->attrs();
......@@ -385,9 +386,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_input_value(input_value);
operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode);
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
// If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy
if (!StrategyFound(attrs) || prim->name() == CAST) {
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList();
......@@ -397,7 +404,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
}
} else {
// In this case, the configured strategy should be extracted to help setting cost
StrategyPtr strategyPtr = parallel::ExtractStrategy(attrs);
StrategyPtr strategyPtr;
if (load_strategy_from_ckpt) {
strategyPtr = (*stra_map)[strategy_key_name];
} else {
strategyPtr = parallel::ExtractStrategy(attrs);
}
if (strategyPtr != nullptr) {
if (prim->name() == RESHAPE) {
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
......@@ -433,7 +445,13 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueId to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// extract strategy from checkpoint for multi-train
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
// Step 1
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
......@@ -451,7 +469,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
if (search_cnode == from_cnode_to_info.end()) {
auto operator_info = CreateTheOperatorInfo(prim, cnode);
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
if (operator_info == nullptr) {
return FAILED;
}
......@@ -486,7 +504,13 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// extract strategy from checkpoint for multi-train
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
......@@ -504,7 +528,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
if (search_cnode == from_cnode_to_info.end()) {
// In this case, the corresponding OperatorInfo is not created, create the new one.
auto operator_info = CreateTheOperatorInfo(prim, cnode);
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
if (operator_info == nullptr) {
return FAILED;
}
......
......@@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
}
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
// load strategy map from checkpoint
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
......@@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(void)cnode->set_operator_info(operator_);
continue;
}
if (!StrategyFound(attrs)) {
// load strategy checkpoint
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
std::shared_ptr<std::vector<Dimensions>> strategy_v_ptr = operator_->GenerateBatchStrategies();
......@@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is "
<< attrs[GEN_STRATEGY]->ToString();
strategyPtr = NewStrategy(0, *strategy_v_ptr);
} else if (load_strategy_from_ckpt) {
strategyPtr = stra_map[strategy_key_name];
} else {
strategyPtr = ExtractStrategy(attrs);
}
......@@ -2022,53 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
}
}
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "Save strategy to checkpoint begin";
StrategyMap straMap;
auto ret = func_graph->get_return();
auto all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) {
if (prim->instance_name().empty()) {
continue;
bool NodeWithParameter(const CNodePtr &node) {
std::vector<AnfNodePtr> node_inputs{node->inputs()};
for (auto input : node_inputs) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
}
std::string instance_name = prim->instance_name();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
straMap[node_name] = strategyPtr;
}
}
if (StrategyCheckpoint::GetInstance().Save(straMap) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
}
return false;
}
void RestoreStrategy(const FuncGraphPtr &func_graph) {
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "Extract strategy from checkpoint begin";
StrategyMap straMap;
if (StrategyCheckpoint::GetInstance().Load(&straMap) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
if (StrategyCheckpoint::GetInstance().RemoveCheckPoint() != SUCCESS) {
MS_LOG(EXCEPTION) << "Remove strategy checkpoint failed";
}
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
StrategyMap stra_map;
auto ret = func_graph->get_return();
auto all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
......@@ -2076,18 +2068,18 @@ void RestoreStrategy(const FuncGraphPtr &func_graph) {
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) {
if (prim->instance_name().empty()) {
continue;
MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name";
}
std::string instance_name = prim->instance_name();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
MS_LOG(INFO) << "Node name is " << node_name;
if (straMap.find(node_name) != straMap.end()) {
StrategyPtr strategyPtr = straMap[node_name];
operator_info->set_strategy(strategyPtr);
}
stra_map[node_name] = strategyPtr;
}
}
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
}
}
void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
......@@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// extract shape and strategy, set operator_info
ExtractInformation(all_nodes);
ReshapeInit(all_nodes);
// extract strategy from checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().CheckPointOn() && StrategyCheckpoint::GetInstance().CheckPointExit()) {
RestoreStrategy(root);
}
}
// save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().CheckPointOn() &&
StrategyCheckpoint::GetInstance().GetCurrentTrainTime() < StrategyCheckpoint::GetInstance().GetTrainTimes()) {
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
CheckpointStrategy(root);
}
......
......@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);
void RestoreStrategy(const FuncGraphPtr &func_graph);
bool NodeWithParameter(const CNodePtr &node);
void CheckpointStrategy(const FuncGraphPtr &func_graph);
......
......@@ -29,30 +29,32 @@ namespace mindspore {
namespace parallel {
StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
static StrategyCheckpoint instance = StrategyCheckpoint();
if (ParallelContext::GetInstance() != nullptr) {
instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file();
instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
}
return instance;
}
bool StrategyCheckpoint::CheckPointExit() const {
std::ifstream fin(path_);
bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
std::ifstream fin(path);
if (fin) {
return true;
}
return false;
}
Status StrategyCheckpoint::RemoveCheckPoint() const {
if (std::remove(common::SafeCStr(path_)) == 0) {
return SUCCESS;
}
return FAILED;
}
Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
if (strategy_map == nullptr) {
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
}
if (!CheckPointExit(load_file_)) {
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
}
straspb::ParallelStrategyMap parallel_strategy_map;
std::fstream input(path_, std::ios::in | std::ios::binary);
std::fstream input(load_file_, std::ios::in | std::ios::binary);
if (!parallel_strategy_map.ParseFromIstream(&input)) {
MS_LOG(ERROR) << "Load strategy file failed";
return FAILED;
......@@ -77,14 +79,14 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
(*strategy_map)[node_name] = strategy;
current_train_time_ = (int32_t)parallel_strategy_map.train_time();
current_stage_ = (int32_t)parallel_strategy_map.current_stage();
}
return SUCCESS;
}
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
straspb::ParallelStrategyMap parallel_strategy_map;
parallel_strategy_map.set_train_time(IntToUint(++current_train_time_));
parallel_strategy_map.set_current_stage(IntToUint(++current_stage_));
for (auto &node_stra : strategy_map) {
straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
MS_EXCEPTION_IF_NULL(parallel_strategy_item);
......@@ -100,7 +102,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
}
}
}
std::fstream output(path_, std::ios::out | std::ios::trunc | std::ios::binary);
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_strategy_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed";
return FAILED;
......
......@@ -21,43 +21,37 @@
#include <unordered_map>
#include "parallel/ops_info/ops_utils.h"
#include "parallel/strategy.h"
#include "parallel/context.h"
namespace mindspore {
namespace parallel {
constexpr char DEFAULT_CHECKPOINT_PATH[] = "./strategys.ckpt";
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
class StrategyCheckpoint {
public:
StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) {
train_times_ = 1;
checkpoint_on_ = false;
const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES");
if (train_times_str != nullptr && std::stoi(train_times_str) > 0) {
train_times_ = std::stoi(train_times_str);
}
const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON");
if (checkpoint_on_str != nullptr) {
checkpoint_on_ = (std::string(checkpoint_on_str) == "on");
}
StrategyCheckpoint() {
current_stage_ = 0;
load_file_ = "";
load_checkpoint_on_ = false;
save_file_ = "";
save_checkpoint_on_ = false;
}
~StrategyCheckpoint() = default;
bool CheckPointExit() const;
Status RemoveCheckPoint() const;
Status Load(StrategyMap *strategy_map);
Status Save(const StrategyMap &strategy_map);
static StrategyCheckpoint &GetInstance();
int32_t GetTrainTimes() const { return train_times_; }
int32_t GetCurrentTrainTime() const { return current_train_time_; }
bool CheckPointOn() const { return checkpoint_on_; }
bool LoadCheckPointOn() const { return load_checkpoint_on_; }
bool SaveCheckPointOn() const { return save_checkpoint_on_; }
private:
std::string path_;
bool checkpoint_on_;
// total train times for a train, get from Environmental variable:TRAIN_TIME, please export it
int32_t train_times_;
int32_t current_train_time_;
std::string load_file_;
std::string save_file_;
bool load_checkpoint_on_;
bool save_checkpoint_on_;
bool CheckPointExit(const std::string path) const;
int32_t current_stage_;
};
} // namespace parallel
} // namespace mindspore
......
......@@ -191,6 +191,12 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
"Get parameter broadcast is set.")
.def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.")
.def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file,
"Set strategy checkpoint load file.")
.def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file,
"Set strategy checkpoint save file.")
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
......
......@@ -33,6 +33,6 @@ message ParallelStrategyItem {
}
message ParallelStrategyMap {
required uint32 train_time = 1;
required uint32 current_stage = 1;
repeated ParallelStrategyItem parallel_strategy_item = 2;
}
\ No newline at end of file
......@@ -404,7 +404,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
parameter_broadcast=bool)
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
def set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
......@@ -436,6 +436,8 @@ def set_auto_parallel_context(**kwargs):
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
Raises:
ValueError: If input key is not attribute in auto parallel context.
......@@ -447,6 +449,8 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(cast_before_mirror=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
"""
_set_auto_parallel_context(**kwargs)
......@@ -477,6 +481,8 @@ def reset_auto_parallel_context():
- cast_before_mirror: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: "".
- strategy_ckpt_save_file: "".
"""
_reset_auto_parallel_context()
......
......@@ -88,6 +88,8 @@ class Primitive(Primitive_):
for name in self.attrs:
value = self.attrs[name]
cloned.add_prim_attr(name, value)
if hasattr(self, 'instance_name'):
cloned.set_prim_instance_name(self.instance_name)
return cloned
def add_prim_attr(self, name, value):
......
......@@ -208,6 +208,36 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_parameter_broadcast()
def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
"""
Set strategy checkpoint load path.
Args:
strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
"""
self.check_context_handle()
self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
def get_strategy_ckpt_load_file(self):
"""Get strategy checkpoint load path."""
self.check_context_handle()
return self._context_handle.get_strategy_ckpt_load_file()
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
"""
Set strategy checkpoint save path.
Args:
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
"""
self.check_context_handle()
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
def get_strategy_ckpt_save_file(self):
"""Get strategy checkpoint save path."""
self.check_context_handle()
return self._context_handle.get_strategy_ckpt_save_file()
def get_parameter_broadcast_is_set(self):
"""Get parameter broadcast is set or not."""
self.check_context_handle()
......@@ -315,7 +345,9 @@ _set_auto_parallel_context_func_map = {
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"parallel_mode": auto_parallel_context().set_parallel_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast}
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file}
_get_auto_parallel_context_func_map = {
......@@ -325,11 +357,14 @@ _get_auto_parallel_context_func_map = {
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"parallel_mode": auto_parallel_context().get_parallel_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast}
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool)
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
def _set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
......@@ -360,6 +395,8 @@ def _set_auto_parallel_context(**kwargs):
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
Raises:
ValueError: If input key is not attribute in auto parallel context.
......@@ -400,5 +437,7 @@ def _reset_auto_parallel_context():
- cast_before_mirror: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
"""
auto_parallel_context().reset()
......@@ -25,9 +25,7 @@ StrategyCheckpoint& StrategyCheckpoint::GetInstance() {
return instance;
}
bool StrategyCheckpoint::CheckPointExit() const { return false; }
Status StrategyCheckpoint::RemoveCheckPoint() const { return SUCCESS; }
bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return false; }
Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
......
......@@ -14,10 +14,10 @@
import numpy as np
from mindspore import context
from mindspore.context import set_auto_parallel_context
from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import Tensor, Parameter
from tests.ut.python.ops.test_math_ops import VirtualLoss
import mindspore as ms
from mindspore.common.api import _executor
......@@ -25,17 +25,15 @@ from mindspore.ops import composite as C
# model_parallel test
# export PARALLEL_CHECKPOINT_ON=on
# export PARALLEL_TRAIN_TIMES=4
def test_six_matmul():
def test_six_matmul_save():
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x1, x2, x3, x4, x5, x6, x7):
predict = self.network(x1, x2, x3, x4, x5, x6, x7)
def construct(self, x1, x6):
predict = self.network(x1, x6)
return self.loss(predict)
......@@ -44,8 +42,8 @@ def test_six_matmul():
super(GradWrap, self).__init__()
self.network = network
def construct(self, x1, x2, x3, x4, x5, x6, x7):
return C.grad_all(self.network)(x1, x2, x3, x4, x5, x6, x7)
def construct(self, x1, x6):
return C.grad_all(self.network)(x1, x6)
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
......@@ -56,45 +54,46 @@ def test_six_matmul():
self.matmul4 = P.MatMul().set_strategy(strategy4)
self.matmul5 = P.MatMul().set_strategy(strategy5)
self.matmul6 = P.MatMul().set_strategy(strategy6)
def construct(self, x1, x2, x3, x4, x5, x6, x7):
out = self.matmul1(x1, x2)
out = self.matmul2(out, x3)
out = self.matmul3(out, x4)
out = self.matmul4(out, x5)
out = self.matmul5(out, x6)
out = self.matmul6(out, x7)
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1)
out = self.matmul2(out, self.weight2)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = self.matmul6(out, x6)
return out
set_auto_parallel_context(device_num=512, global_rank=0)
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((1, 8), (8, 1))
strategy3 = ((2, 2), (2, 2))
strategy4 = ((4, 2), (2, 4))
strategy5 = ((2, 4), (4, 2))
strategy6 = ((4, 4), (4, 4))
strategy4 = ((1, 1), (1, 8))
strategy5 = ((4, 2), (2, 1))
strategy6 = ((4, 1), (1, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
x3 = Tensor(np.ones([64, 64]), dtype=ms.float32)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_executor.compile(net, x1, x2, x3, x4, x5, x6, x7)
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x1, x6)
# remove matmul2
def test_six_matmul_repeated1():
# remove matmul2, add matmul7
def test_six_matmul_load():
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x1, x2, x4, x5, x6, x7):
predict = self.network(x1, x2, x4, x5, x6, x7)
def construct(self, x1, x6, x7):
predict = self.network(x1, x6, x7)
return self.loss(predict)
......@@ -103,53 +102,58 @@ def test_six_matmul_repeated1():
super(GradWrap, self).__init__()
self.network = network
def construct(self, x1, x2, x4, x5, x6, x7):
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7)
def construct(self, x1, x6, x7):
return C.grad_all(self.network)(x1, x6, x7)
class Net(nn.Cell):
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6):
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul3 = P.MatMul().set_strategy(strategy3)
self.matmul4 = P.MatMul().set_strategy(strategy4)
self.matmul5 = P.MatMul().set_strategy(strategy5)
self.matmul6 = P.MatMul().set_strategy(strategy6)
def construct(self, x1, x2, x4, x5, x6, x7):
out = self.matmul1(x1, x2)
out = self.matmul3(out, x4)
out = self.matmul4(out, x5)
out = self.matmul5(out, x6)
out = self.matmul6(out, x7)
self.matmul7 = P.MatMul().set_strategy(strategy7)
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = self.matmul6(out, x6)
out = self.matmul7(out, x7)
return out
set_auto_parallel_context(device_num=512, global_rank=0)
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
strategy1 = ((8, 1), (1, 1))
strategy3 = ((8, 1), (1, 1))
strategy4 = ((8, 1), (1, 1))
strategy5 = ((8, 1), (1, 1))
strategy6 = ((8, 1), (1, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6)))
strategy7 = ((8, 1), (1, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_executor.compile(net, x1, x2, x4, x5, x6, x7)
_executor.compile(net, x1, x6, x7)
# add matmul7
def test_six_matmul_repeated2():
# model_parallel test
def test_six_matmul_save_auto():
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x1, x2, x4, x5, x6, x7, x8):
predict = self.network(x1, x2, x4, x5, x6, x7, x8)
def construct(self, x1, x6):
predict = self.network(x1, x6)
return self.loss(predict)
......@@ -158,60 +162,52 @@ def test_six_matmul_repeated2():
super(GradWrap, self).__init__()
self.network = network
def construct(self, x1, x2, x4, x5, x6, x7, x8):
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8)
def construct(self, x1, x6):
return C.grad_all(self.network)(x1, x6)
class Net(nn.Cell):
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
def __init__(self):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul3 = P.MatMul().set_strategy(strategy3)
self.matmul4 = P.MatMul().set_strategy(strategy4)
self.matmul5 = P.MatMul().set_strategy(strategy5)
self.matmul6 = P.MatMul().set_strategy(strategy6)
self.matmul7 = P.MatMul().set_strategy(strategy7)
def construct(self, x1, x2, x4, x5, x6, x7, x8):
out = self.matmul1(x1, x2)
out = self.matmul3(out, x4)
out = self.matmul4(out, x5)
out = self.matmul5(out, x6)
out = self.matmul6(out, x7)
out = self.matmul7(out, x8)
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
self.matmul3 = P.MatMul()
self.matmul4 = P.MatMul()
self.matmul5 = P.MatMul()
self.matmul6 = P.MatMul()
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1)
out = self.matmul2(out, self.weight2)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = self.matmul6(out, x6)
return out
set_auto_parallel_context(device_num=512, global_rank=0)
strategy1 = ((8, 1), (1, 1))
strategy3 = ((8, 1), (1, 1))
strategy4 = ((8, 1), (1, 1))
strategy5 = ((8, 1), (1, 1))
strategy6 = ((8, 1), (1, 1))
strategy7 = ((8, 1), (1, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x8 = Tensor(np.ones([32, 128]), dtype=ms.float32)
_executor.compile(net, x1, x2, x4, x5, x6, x7, x8)
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x1, x6)
# add scope2
def test_six_matmul_repeated3():
# remove matmul2, add matmul7
def test_six_matmul_load_auto():
class NetWithLoss(nn.Cell):
def __init__(self, network1, network2):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network1
self.network2 = network2
self.network = network
def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10):
predict = self.network(x1, x2, x4, x5, x6, x7, x8)
predict = self.network2(predict, x9, x10)
def construct(self, x1, x6, x7):
predict = self.network(x1, x6, x7)
return self.loss(predict)
......@@ -220,62 +216,42 @@ def test_six_matmul_repeated3():
super(GradWrap, self).__init__()
self.network = network
def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10):
return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8, x9, x10)
def construct(self, x1, x6, x7):
return C.grad_all(self.network)(x1, x6, x7)
class Net(nn.Cell):
def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
def __init__(self, strategy1, strategy3, strategy4, strategy5):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul3 = P.MatMul().set_strategy(strategy3)
self.matmul4 = P.MatMul().set_strategy(strategy4)
self.matmul5 = P.MatMul().set_strategy(strategy5)
self.matmul6 = P.MatMul().set_strategy(strategy6)
self.matmul7 = P.MatMul().set_strategy(strategy7)
def construct(self, x1, x2, x4, x5, x6, x7, x8):
out = self.matmul1(x1, x2)
out = self.matmul3(out, x4)
out = self.matmul4(out, x5)
out = self.matmul5(out, x6)
out = self.matmul6(out, x7)
out = self.matmul7(out, x8)
return out
class Net1(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul2 = P.MatMul().set_strategy(strategy2)
def construct(self, x1, x2, x3):
out = self.matmul1(x1, x2)
out = self.matmul2(out, x3)
self.matmul6 = P.MatMul()
self.matmul7 = P.MatMul()
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = self.matmul6(out, x6)
out = self.matmul7(out, x7)
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
strategy1 = ((2, 2), (2, 2))
strategy3 = ((2, 2), (2, 2))
strategy4 = ((2, 2), (2, 2))
strategy5 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
set_auto_parallel_context(device_num=512, global_rank=0)
strategy1 = ((8, 1), (1, 1))
strategy3 = ((8, 1), (1, 1))
strategy4 = ((8, 1), (1, 1))
strategy5 = ((8, 1), (1, 1))
strategy6 = ((8, 1), (1, 1))
strategy7 = ((8, 1), (1, 1))
strategy8 = ((8, 1), (1, 1))
strategy9 = ((8, 1), (1, 1))
net1 = Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)
net2 = Net1(strategy8, strategy9)
net = GradWrap(NetWithLoss(net1, net2))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
x8 = Tensor(np.ones([32, 128]), dtype=ms.float32)
x9 = Tensor(np.ones([128, 64]), dtype=ms.float32)
x10 = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x1, x2, x4, x5, x6, x7, x8, x9, x10)
_executor.compile(net, x1, x6, x7)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册