提交 e657c127 编写于 作者: D dongdaxiang

hide opt_info in distirbuted optimizer

上级 ecfc7df9
......@@ -28,8 +28,8 @@ namespace framework {
class Dataset {
public:
Dataset() {};
virtual ~Dataset() {};
Dataset() {}
virtual ~Dataset() {}
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
virtual void SetThreadNum(int thread_num) = 0;
virtual void SetTrainerNum(int trainer_num) = 0;
......@@ -39,18 +39,19 @@ class Dataset {
virtual int GetTrainerNum() = 0;
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
GetReaders() = 0;
virtual void LoadIntoMemory() = 0;
virtual void LocalShuffle() = 0;
virtual void GlobalShuffle() = 0;
virtual void CreateReaders() = 0;
virtual void DestroyReaders() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
};
template<typename T>
template <typename T>
class DatasetImpl : public Dataset {
public:
DatasetImpl();
......@@ -69,7 +70,7 @@ class DatasetImpl : public Dataset {
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
GetReaders();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
......@@ -82,8 +83,10 @@ class DatasetImpl : public Dataset {
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_vec_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_out_vec_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>>
shuffled_ins_vec_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>>
shuffled_ins_out_vec_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_;
......@@ -96,6 +99,5 @@ class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
virtual ~MultiSlotDataset() {}
};
} // end namespace framework
} // end namespace paddle
......@@ -118,7 +118,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset,
MultiSlotDataset* dataset,
const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
......
......@@ -19,8 +19,6 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include <unordered_map>
#include <memory>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -115,7 +113,8 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, const std::string& trainer_desc_str);
MultiSlotDataset* dataset,
const std::string& trainer_desc_str);
private:
const platform::Place place_;
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object):
def __init__(self):
pass
def gen_worker_desc(self, trainer_desc, fleet_desc):
def gen_worker_desc(self, trainer_desc):
pass
......@@ -25,7 +27,7 @@ class Hogwild(DeviceWorker):
def __init__(self):
super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc):
def gen_worker_desc(self, trainer_desc):
trainer_desc.device_worker_name = "HogwildWorker"
......@@ -33,7 +35,7 @@ class DownpourSGD(DeviceWorker):
def __init__(self):
super(Downpour, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc):
def gen_worker_desc(self, trainer_desc):
trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
......
......@@ -33,6 +33,9 @@ class DownpourSGD(object):
Examples:
.. code-block:: python
opt = fluid.DistributedOptimizer(sgd_opt)
opt.minimize()
downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
downpour_sgd.minimize(cost)
"""
......@@ -87,6 +90,7 @@ class DownpourSGD(object):
prefetch_slots, prefetch_slots_emb)
dense_table_index = 1
program_configs = []
param_grads_list = []
for loss_index in range(len(losses)):
program_config = ps_param.trainer_param.program_config.add()
program_config.program_id = str(
......@@ -97,6 +101,7 @@ class DownpourSGD(object):
append_backward(losses[loss_index], parameter_list,
no_grad_set),
key=lambda x: x[0].name)
param_grads_list.append(params_grads)
params = []
grads = []
data_norm_params = []
......@@ -156,4 +161,8 @@ class DownpourSGD(object):
opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
return opt_info
for loss in losses:
loss.block.program._fleet_opt = opt_info
return None, param_grads_list
......@@ -612,31 +612,40 @@ class Executor(object):
def _run_inference(self, exe, feed):
return exe.run(feed)
def run_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
def infer_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
pass
def train_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
if scope is None:
scope = global_scope()
if fetch_list is None:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory().create_trainer(opt_info)
if thread <= 0:
trainer.set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
trainer = TrainerFactory().create_trainer(program._fleet_opt)
else:
trainer = TrainerFactory().create_trainer(
program.program._fleet_opt)
if thread <= 0:
trainer.set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
trainer.gen_trainer_desc()
dataset._prepare_to_run()
print("run_from_dataset called")
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
else:
# For compiled program, more runtime should be implemented
print("run_from_dataset current does not support compiled program"
", we will support this later", sys.stderr)
......@@ -2704,6 +2704,10 @@ class Program(object):
# whether the program is optimized by memory_optimize_transpiler
self.__is_mem_optimized = False
# if this program has been optimized by distributed optimizer
# fleet_opt will be given a value
self._fleet_opt = None
@property
def _is_mem_optimized(self):
# if the program is optimized, operator input/outputs
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc):
def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_)
self.device_worker_.gen_worker_desc(self.proto_desc)
class DistMultiTrainer(TrainerDesc):
......@@ -70,7 +70,7 @@ class DistMultiTrainer(TrainerDesc):
def gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_)
self.device_worker_.gen_worker_desc(self.proto_desc)
def set_program_config(self, fleet_desc, program_id):
for program_config in fleet_desc.trainer_param.program_config:
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from trainer_desc import *
from device_worker import *
from .trainer_desc import MultiTrainer
from .device_worker import Hogwild
__all__ = ["TrainerFactory"]
......@@ -38,5 +38,5 @@ class TrainerFactory(object):
device_worker = globals()[device_worker_class]()
trainer.set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"])
trainer.gen_trainer_desc(fleet_desc=opt_info["fleet_desc"])
trainer.gen_trainer_desc()
return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册