提交 cf45c543 编写于 作者: D dongdaxiang

add distributed optimizer factory

上级 b7a202aa
...@@ -40,6 +40,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -40,6 +40,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i]->Initialize(trainer_desc); workers_[i]->Initialize(trainer_desc);
} }
VLOG(3) << "going to initialize pull dense worker";
pull_dense_worker_ = PullDenseWorker::GetInstance(); pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc); pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker"; VLOG(3) << "initialize pull dense worker";
......
...@@ -85,8 +85,8 @@ class DownpourSGD(DeviceWorker): ...@@ -85,8 +85,8 @@ class DownpourSGD(DeviceWorker):
opt_info = self.program_._fleet_opt opt_info = self.program_._fleet_opt
program_configs = opt_info["program_configs"] program_configs = opt_info["program_configs"]
for program_id in program_configs: for pid in program_configs:
if program_configs[program_id] == program_id: if pid == program_id:
pc = downpour.program_config.add() pc = downpour.program_config.add()
pc.program_id = program_id pc.program_id = program_id
for i in program_configs[program_id]["push_sparse"]: for i in program_configs[program_id]["push_sparse"]:
...@@ -98,21 +98,6 @@ class DownpourSGD(DeviceWorker): ...@@ -98,21 +98,6 @@ class DownpourSGD(DeviceWorker):
for i in program_configs[program_id]["pull_dense"]: for i in program_configs[program_id]["pull_dense"]:
pc.pull_dense_table_id.extend([i]) pc.pull_dense_table_id.extend([i])
break break
'''
for program_config in self.fleet_desc_.trainer_param.program_config:
if program_config.program_id == program_id:
pc = downpour.program_config.add()
pc.program_id = program_config.program_id
for i in program_config.push_sparse_table_id:
pc.push_sparse_table_id.extend([i])
for i in program_config.push_dense_table_id:
pc.push_dense_table_id.extend([i])
for i in program_config.pull_sparse_table_id:
pc.pull_sparse_table_id.extend([i])
for i in program_config.pull_dense_table_id:
pc.pull_dense_table_id.extend([i])
break
'''
class DeviceWorkerFactory(object): class DeviceWorkerFactory(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册