提交 28a6fc98 编写于 作者: J Juncheng 提交者: Jinhui Yuan

fix order of shared model nodes (#1180)

上级 40c299bc
......@@ -117,38 +117,33 @@ void LogicalGraph::NaiveBuildFwStruct(
void LogicalGraph::FixSharedModelNodes(
const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes) {
HashSet<std::string> all_shared_model_op_names;
const DLNetConf& dlnet_conf = Global<JobDesc>::Get()->dlnet_conf();
for (const OpNameSet& op_name_set : dlnet_conf.shared_model_group()) {
std::vector<std::string> shared_model_op_names(op_name_set.op_name().begin(),
op_name_set.op_name().end());
SortAndRemoveDuplication(&shared_model_op_names);
CHECK_GE(shared_model_op_names.size(), 2);
auto shared_model_nodes = std::make_shared<std::vector<LogicalNode*>>();
for (const std::string& op_name : op_name_set.op_name()) {
shared_model_nodes->reserve(shared_model_op_names.size());
for (const std::string& op_name : shared_model_op_names) {
CHECK(all_shared_model_op_names.insert(op_name).second);
CHECK_EQ(op_name2nodes.at(op_name).size(), 1);
shared_model_nodes->push_back(op_name2nodes.at(op_name).front());
}
SortAndRemoveDuplication(shared_model_nodes.get());
for (LogicalNode* cur_node : *shared_model_nodes) {
cur_node->mut_shared_model_nodes() = shared_model_nodes;
}
const std::string& shared_op_name = shared_model_nodes->front()->SoleOp()->op_name();
const ParallelDesc* shared_parallel_desc = shared_model_nodes->front()->parallel_desc().get();
FOR_RANGE(size_t, i, 1, shared_model_nodes->size()) {
shared_model_nodes->at(i)->SoleOp()->FixLbiWhenShareModel(shared_op_name);
CHECK(shared_model_nodes->at(i)->parallel_desc()->Equal(shared_parallel_desc));
}
}
ForEachNode([&](LogicalNode* cur_node) {
if (cur_node->shared_model_nodes()) {
for (LogicalNode* shared_node : *(cur_node->shared_model_nodes())) {
if (shared_node->parallel_desc() == nullptr) { continue; }
if (cur_node->parallel_desc()) {
CHECK(cur_node->parallel_desc()->Equal(shared_node->parallel_desc().get()));
} else {
cur_node->mut_parallel_desc() = shared_node->parallel_desc();
}
}
} else {
// do nothing
}
CHECK(cur_node->parallel_desc())
<< "Please set the placement of " << cur_node->SoleOp()->op_name();
});
}
void LogicalGraph::SetMainModelParallel() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册