From cf45c5434038f0b6ec320fa30b4bc407e4493fd2 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Wed, 13 Mar 2019 22:04:17 +0800 Subject: [PATCH] add distributed optimizer factory --- paddle/fluid/framework/dist_multi_trainer.cc | 1 + python/paddle/fluid/device_worker.py | 19 ++----------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 4f177574b63..4f8d15adc38 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -40,6 +40,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, workers_[i]->Initialize(trainer_desc); } + VLOG(3) << "going to initialize pull dense worker"; pull_dense_worker_ = PullDenseWorker::GetInstance(); pull_dense_worker_->Initialize(trainer_desc); VLOG(3) << "initialize pull dense worker"; diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 02435f0fd30..547db086379 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -85,8 +85,8 @@ class DownpourSGD(DeviceWorker): opt_info = self.program_._fleet_opt program_configs = opt_info["program_configs"] - for program_id in program_configs: - if program_configs[program_id] == program_id: + for pid in program_configs: + if pid == program_id: pc = downpour.program_config.add() pc.program_id = program_id for i in program_configs[program_id]["push_sparse"]: @@ -98,21 +98,6 @@ class DownpourSGD(DeviceWorker): for i in program_configs[program_id]["pull_dense"]: pc.pull_dense_table_id.extend([i]) break - ''' - for program_config in self.fleet_desc_.trainer_param.program_config: - if program_config.program_id == program_id: - pc = downpour.program_config.add() - pc.program_id = program_config.program_id - for i in program_config.push_sparse_table_id: - pc.push_sparse_table_id.extend([i]) - for i in program_config.push_dense_table_id: - pc.push_dense_table_id.extend([i]) - for i in program_config.pull_sparse_table_id: - pc.pull_sparse_table_id.extend([i]) - for i in program_config.pull_dense_table_id: - pc.pull_dense_table_id.extend([i]) - break - ''' class DeviceWorkerFactory(object): -- GitLab