提交 2031710d 编写于 作者: H hongxing

fix bug and optimize code

上级 26d05be8
...@@ -34,37 +34,38 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share ...@@ -34,37 +34,38 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
const size_t iter_ops); const std::vector<std::shared_ptr<OperatorInfo>> &ops,
std::vector<std::vector<int32_t>> PrepareScalarInputOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const size_t iter_ops);
const size_t iter_ops, std::vector<int32_t> s); std::vector<std::vector<int32_t>> PrepareBiasAdd(std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s); std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const size_t iter_ops); const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph, void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> index_list); const std::shared_ptr<std::vector<size_t>> index_list);
int FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, const size_t iter_ops); size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops);
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph, std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph); const size_t iter_ops, const size_t iter_graph);
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index); const size_t incoming_op_index);
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops); std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops); std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, const size_t iter_ops, const size_t iter_ops, const size_t incoming_op_index);
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s); const size_t iter_ops, std::vector<int32_t> s);
void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph, void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph,
...@@ -75,14 +76,17 @@ void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph, ...@@ -75,14 +76,17 @@ void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s); const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops); const size_t iter_ops);
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
...@@ -47,7 +47,8 @@ enum OperatorType { ...@@ -47,7 +47,8 @@ enum OperatorType {
kRecDiv, kRecDiv,
kRecSqueeze, kRecSqueeze,
kRecCast, kRecCast,
kRecReduce kRecReduce,
kRecPReLU
}; };
enum InfoType { kApplication, kConstant }; enum InfoType { kApplication, kConstant };
......
...@@ -199,7 +199,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph, ...@@ -199,7 +199,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp,
OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub,
OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce,
OperatorType::kRecCast}; OperatorType::kRecCast, OperatorType::kRecReshape};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type; auto type = graph->nodes[node_index].apply.op_type;
if (type_list.find(type) != type_list.end()) { if (type_list.find(type) != type_list.end()) {
......
...@@ -55,7 +55,8 @@ const std::map<std::string, OperatorType> DictOpType{ ...@@ -55,7 +55,8 @@ const std::map<std::string, OperatorType> DictOpType{
{"HSigmoid", OperatorType::kRecReLU}, {"HSigmoid", OperatorType::kRecReLU},
{GELU, OperatorType::kRecReLU}, {GELU, OperatorType::kRecReLU},
{TANH, OperatorType::kRecReLU}, {TANH, OperatorType::kRecReLU},
{PRELU, OperatorType::kRecReLU},
{PRELU, OperatorType::kRecPReLU},
{TENSOR_ADD, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},
......
...@@ -83,7 +83,7 @@ double GetWeights(const Graph::NodeType &node) { ...@@ -83,7 +83,7 @@ double GetWeights(const Graph::NodeType &node) {
auto cost_ptr = std::make_shared<CostCommon>(); auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecUnkownType) { } else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU) {
// For unknown type // For unknown type
return 0.0; return 0.0;
} else { } else {
...@@ -177,7 +177,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node, ...@@ -177,7 +177,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
auto cost_ptr = std::make_shared<CostCommon>(); auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecUnkownType) { } else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU) {
// For unknown type // For unknown type
StrategyRec default_strategy; StrategyRec default_strategy;
return default_strategy; return default_strategy;
......
...@@ -464,6 +464,11 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node ...@@ -464,6 +464,11 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
} }
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) { if (!IsAutoParallelCareNode(cnode)) {
// Needed by rec_parser
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == TUPLE_GETITEM) {
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
}
continue; continue;
} }
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
...@@ -522,6 +527,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no ...@@ -522,6 +527,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
} }
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) { if (!IsAutoParallelCareNode(cnode)) {
// Needed by rec_parser
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == TUPLE_GETITEM) {
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
}
continue; continue;
} }
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
...@@ -1153,6 +1163,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const ...@@ -1153,6 +1163,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
MS_LOG(ERROR) << "Constructing nodes for cost graph failed."; MS_LOG(ERROR) << "Constructing nodes for cost graph failed.";
return FAILED; return FAILED;
} }
auto ops = entire_costgraph->GetOperators(); auto ops = entire_costgraph->GetOperators();
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list(); std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list(); auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册