提交 60a9fb00 编写于 作者: Y yao_yf

add_tensor_layout_in_stra_ckpt

上级 57fd31b2
...@@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo { ...@@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override; std::shared_ptr<Strategys> GenerateBatchStrategies() override;
const std::vector<int64_t> &param_split_shapes() const { return param_split_shapes_; }
const std::vector<int64_t> &index_offsets() const { return index_offsets_; }
protected: protected:
Status CheckStrategy(const StrategyPtr &strategy) override; Status CheckStrategy(const StrategyPtr &strategy) override;
......
...@@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & ...@@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode); operator_info->set_cnode(cnode);
// key of strategy map // key of strategy map
std::string strategy_key_name = NodeParameterName(cnode); std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
if (!param_names.empty()) {
strategy_key_name = param_names[0].first;
}
bool load_strategy_from_ckpt = bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
// If no strategy has been configured for this operator, then candidate strategies are generated for // If no strategy has been configured for this operator, then candidate strategies are generated for
......
...@@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
} }
// load strategy checkpoint // load strategy checkpoint
// key of strategy map // key of strategy map
std::string strategy_key_name = NodeParameterName(cnode); std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
if (!param_names.empty()) {
strategy_key_name = param_names[0].first;
}
bool load_strategy_from_ckpt = bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
...@@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo ...@@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
} }
} }
std::string NodeParameterName(const CNodePtr &node) { std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) {
std::vector<AnfNodePtr> node_inputs{node->inputs()}; std::vector<AnfNodePtr> node_inputs{node->inputs()};
for (auto input : node_inputs) { std::vector<std::pair<std::string, int>> param_names;
for (int i = 0; i < UintToInt(node_inputs.size()); ++i) {
auto input = node_inputs[i];
if (input->isa<Parameter>()) { if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>(); auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) { if (input_parameter->has_default()) {
input_parameter->name(); if (ParameterRequireGrad(input_parameter)) {
param_names.push_back({input_parameter->name(), i});
}
} }
} }
} }
return ""; return param_names;
} }
void CheckpointStrategy(const FuncGraphPtr &func_graph) { void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
StrategyMap stra_map; StrategyMap stra_map;
TensorInfoMap tensor_info_map;
ManualShapeMap manual_shape_map;
auto ret = func_graph->get_return(); auto ret = func_graph->get_return();
auto all_nodes = DeepScopedGraphSearch(ret); auto all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
...@@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { ...@@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue; continue;
} }
std::string param_name = NodeParameterName(cnode); auto param_names = NodeParameterName(cnode);
if (param_name.empty()) { if (param_names.empty()) {
continue; continue;
} }
string param_name = param_names[0].first;
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
...@@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { ...@@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue; continue;
} }
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
StrategyPtr strategyPtr = operator_info->strategy(); StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope()); MS_EXCEPTION_IF_NULL(node->scope());
stra_map[param_name] = strategyPtr; stra_map[param_name] = strategyPtr;
for (auto param_name_pair : param_names) {
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
continue;
}
tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1];
}
if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos ||
operator_info->name().find(GATHERV2) != std::string::npos) {
auto gatherv2_info = std::dynamic_pointer_cast<GatherV2PInfo>(operator_info);
auto param_split_shapes = gatherv2_info->param_split_shapes();
auto index_offsets = gatherv2_info->index_offsets();
if (param_split_shapes.size() != index_offsets.size()) {
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same.";
}
std::vector<std::pair<int32_t, int32_t>> manual_shape;
for (int i = 0; i < UintToInt(param_split_shapes.size()); ++i) {
manual_shape.push_back({param_split_shapes[i], index_offsets[i]});
}
manual_shape_map[param_name] = manual_shape;
}
} }
} }
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
} }
} }
......
...@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); ...@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager); const FuncGraphManagerPtr &manager);
std::string NodeParameterName(const CNodePtr &node); std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node);
void CheckpointStrategy(const FuncGraphPtr &func_graph); void CheckpointStrategy(const FuncGraphPtr &func_graph);
......
...@@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { ...@@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
return SUCCESS; return SUCCESS;
} }
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
ManualShapeMap *manual_shape_map) {
straspb::ParallelStrategyMap parallel_strategy_map; straspb::ParallelStrategyMap parallel_strategy_map;
parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); parallel_strategy_map.set_current_stage(IntToUint(++current_stage_));
for (auto &node_stra : strategy_map) { for (auto &node_stra : strategy_map) {
...@@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { ...@@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
} }
} }
} }
for (auto &node_tensor_info : tensor_info_map) {
TensorInfo tensor_info = node_tensor_info.second;
TensorLayout tensor_layout = tensor_info.tensor_layout();
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
MS_EXCEPTION_IF_NULL(parallel_layout_item);
parallel_layout_item->set_param_name(node_tensor_info.first);
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
MS_EXCEPTION_IF_NULL(dev_matrix);
for (auto dim : tensor_layout.device_arrangement().array()) {
dev_matrix->add_dim(IntToUint(dim));
}
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
MS_EXCEPTION_IF_NULL(tensor_map);
for (auto dim : tensor_layout.tensor_map().array()) {
tensor_map->add_dim(dim);
}
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
MS_EXCEPTION_IF_NULL(manual_shape_map);
auto manual_shape = (*manual_shape_map)[node_tensor_info.first];
for (auto dim_pair : manual_shape) {
param_split_shape->add_dim(dim_pair.first);
indices_offset->add_dim(dim_pair.second);
}
}
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_strategy_map.SerializeToOstream(&output)) { if (!parallel_strategy_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed"; MS_LOG(ERROR) << "Save strategy file failed";
......
...@@ -19,13 +19,19 @@ ...@@ -19,13 +19,19 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector>
#include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/strategy.h" #include "frontend/parallel/strategy.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#include "frontend/parallel/tensor_layout/tensor_layout.h"
#include "frontend/parallel/tensor_layout/tensor_info.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
using StrategyMap = std::unordered_map<std::string, StrategyPtr>; using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int32_t, int32_t>>>;
class StrategyCheckpoint { class StrategyCheckpoint {
public: public:
StrategyCheckpoint() { StrategyCheckpoint() {
...@@ -38,7 +44,7 @@ class StrategyCheckpoint { ...@@ -38,7 +44,7 @@ class StrategyCheckpoint {
~StrategyCheckpoint() = default; ~StrategyCheckpoint() = default;
Status Load(StrategyMap *strategy_map); Status Load(StrategyMap *strategy_map);
Status Save(const StrategyMap &strategy_map); Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
static StrategyCheckpoint &GetInstance(); static StrategyCheckpoint &GetInstance();
bool LoadCheckPointOn() const { return load_checkpoint_on_; } bool LoadCheckPointOn() const { return load_checkpoint_on_; }
......
...@@ -32,7 +32,36 @@ message ParallelStrategyItem { ...@@ -32,7 +32,36 @@ message ParallelStrategyItem {
required ParallelStrategys parallel_strategys = 2; required ParallelStrategys parallel_strategys = 2;
} }
message DevMatrix {
repeated uint32 dim = 1;
}
message TensorMap {
repeated int32 dim = 1;
}
message ParamSplitShape {
repeated int64 dim = 1;
}
message IndicesOffset {
repeated int64 dim = 1;
}
message ParallelLayouts {
repeated DevMatrix dev_matrix = 1;
repeated TensorMap tensor_map = 2;
repeated ParamSplitShape param_split_shape = 3;
repeated IndicesOffset indices_offset = 4;
}
message ParallelLayoutItem {
required string param_name = 1;
required ParallelLayouts parallel_layouts = 2;
}
message ParallelStrategyMap { message ParallelStrategyMap {
required uint32 current_stage = 1; required uint32 current_stage = 1;
repeated ParallelStrategyItem parallel_strategy_item = 2; repeated ParallelStrategyItem parallel_strategy_item = 2;
repeated ParallelLayoutItem parallel_layout_item = 3;
} }
\ No newline at end of file
...@@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f ...@@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f
Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { return SUCCESS; } Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
ManualShapeMap *manual_shape_map) { return SUCCESS; }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册