提交 e1620821 编写于 作者: L leaves-zwx 提交者: Jinhui Yuan

Bugfix parallel desc shared ptr (#996)

* fix bug

* reduce one copy
上级 3a85b18a
......@@ -45,24 +45,27 @@ void LogicalGraph::NaiveBuildFwStruct(
HashMap<std::string, std::vector<LogicalNode*>>* op_name2nodes) {
const DLNetConf& dlnet_conf = Global<JobDesc>::Get()->dlnet_conf();
const Placement& placement = Global<JobDesc>::Get()->placement();
HashMap<std::string, ParallelDesc*> name2parallel_desc;
HashMap<std::string, std::shared_ptr<ParallelDesc>> name2parallel_desc;
for (const PlacementGroup& p_group : placement.placement_group()) {
for (const std::string& op_name : p_group.op_set().op_name()) {
auto parallel_desc_raw_ptr = new ParallelDesc(p_group.parallel_conf());
CHECK(name2parallel_desc.emplace(op_name, parallel_desc_raw_ptr).second);
CHECK(name2parallel_desc
.emplace(op_name, std::make_shared<ParallelDesc>(p_group.parallel_conf()))
.second);
}
}
HashMap<LogicalBlobId, LogicalNode*> lbi2producer;
for (OperatorConf cur_op_conf : dlnet_conf.op()) {
ParallelDesc* parallel_desc_raw_ptr = name2parallel_desc.at(cur_op_conf.name());
cur_op_conf.set_device_type(parallel_desc_raw_ptr->device_type());
auto parallel_desc_ptr_it = name2parallel_desc.find(cur_op_conf.name());
CHECK(parallel_desc_ptr_it != name2parallel_desc.end());
const std::shared_ptr<ParallelDesc>& parallel_desc_ptr = parallel_desc_ptr_it->second;
cur_op_conf.set_device_type(parallel_desc_ptr->device_type());
std::shared_ptr<Operator> cur_op = ConstructOp(cur_op_conf);
LogicalNode* cur_node = cur_op->NewProperLogicalNode();
AddAllocatedNode(cur_node);
cur_node->mut_op_vec() = {cur_op};
cur_node->SoleOp()->FixParallelDesc(parallel_desc_raw_ptr);
cur_node->mut_parallel_desc().reset(parallel_desc_raw_ptr);
cur_node->SoleOp()->FixParallelDesc(parallel_desc_ptr.get());
cur_node->mut_parallel_desc() = parallel_desc_ptr;
for (const std::string& obn : cur_node->SoleOp()->output_bns()) {
const LogicalBlobId& lbi = cur_node->SoleOp()->BnInOp2Lbi(obn);
CHECK(lbi2producer.emplace(lbi, cur_node).second);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册