提交 2031710d 编写于 作者: H hongxing

fix bug and optimize code

上级 26d05be8
......@@ -34,37 +34,38 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareScalarInputOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBiasAdd(std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops);
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> index_list);
int FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, const size_t iter_ops);
size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops);
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph);
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index);
const size_t incoming_op_index);
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s);
const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s);
const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, const size_t iter_ops,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
const size_t iter_ops, const size_t incoming_op_index);
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph,
......@@ -75,14 +76,17 @@ void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops);
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
......@@ -47,7 +47,8 @@ enum OperatorType {
kRecDiv,
kRecSqueeze,
kRecCast,
kRecReduce
kRecReduce,
kRecPReLU
};
enum InfoType { kApplication, kConstant };
......
......@@ -199,7 +199,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp,
OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub,
OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce,
OperatorType::kRecCast};
OperatorType::kRecCast, OperatorType::kRecReshape};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type;
if (type_list.find(type) != type_list.end()) {
......
......@@ -55,7 +55,8 @@ const std::map<std::string, OperatorType> DictOpType{
{"HSigmoid", OperatorType::kRecReLU},
{GELU, OperatorType::kRecReLU},
{TANH, OperatorType::kRecReLU},
{PRELU, OperatorType::kRecReLU},
{PRELU, OperatorType::kRecPReLU},
{TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp},
......
......@@ -83,7 +83,7 @@ double GetWeights(const Graph::NodeType &node) {
auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecUnkownType) {
} else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU) {
// For unknown type
return 0.0;
} else {
......@@ -177,7 +177,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecUnkownType) {
} else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU) {
// For unknown type
StrategyRec default_strategy;
return default_strategy;
......
......@@ -464,6 +464,11 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) {
// Needed by rec_parser
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == TUPLE_GETITEM) {
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
}
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
......@@ -522,6 +527,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) {
// Needed by rec_parser
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == TUPLE_GETITEM) {
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
}
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
......@@ -1153,6 +1163,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
MS_LOG(ERROR) << "Constructing nodes for cost graph failed.";
return FAILED;
}
auto ops = entire_costgraph->GetOperators();
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册