提交 e657c127 编写于 作者: D dongdaxiang

hide opt_info in distirbuted optimizer

上级 ecfc7df9
...@@ -28,8 +28,8 @@ namespace framework { ...@@ -28,8 +28,8 @@ namespace framework {
class Dataset { class Dataset {
public: public:
Dataset() {}; Dataset() {}
virtual ~Dataset() {}; virtual ~Dataset() {}
virtual void SetFileList(const std::vector<std::string>& filelist) = 0; virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
virtual void SetThreadNum(int thread_num) = 0; virtual void SetThreadNum(int thread_num) = 0;
virtual void SetTrainerNum(int trainer_num) = 0; virtual void SetTrainerNum(int trainer_num) = 0;
...@@ -45,12 +45,13 @@ class Dataset { ...@@ -45,12 +45,13 @@ class Dataset {
virtual void GlobalShuffle() = 0; virtual void GlobalShuffle() = 0;
virtual void CreateReaders() = 0; virtual void CreateReaders() = 0;
virtual void DestroyReaders() = 0; virtual void DestroyReaders() = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0; const std::string& msg) = 0;
}; };
template<typename T> template <typename T>
class DatasetImpl : public Dataset { class DatasetImpl : public Dataset {
public: public:
DatasetImpl(); DatasetImpl();
...@@ -82,8 +83,10 @@ class DatasetImpl : public Dataset { ...@@ -82,8 +83,10 @@ class DatasetImpl : public Dataset {
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_; std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_; std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_; std::mutex mutex_for_update_memory_data_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_vec_; std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>>
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_out_vec_; shuffled_ins_vec_;
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>>
shuffled_ins_out_vec_;
int thread_num_; int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_; paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_; std::vector<std::string> filelist_;
...@@ -96,6 +99,5 @@ class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> { ...@@ -96,6 +99,5 @@ class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
virtual ~MultiSlotDataset() {} virtual ~MultiSlotDataset() {}
}; };
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -118,7 +118,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -118,7 +118,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope, void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, MultiSlotDataset* dataset,
const std::string& trainer_desc_str) { const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor"; VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc; TrainerDesc trainer_desc;
......
...@@ -19,8 +19,6 @@ limitations under the License. */ ...@@ -19,8 +19,6 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <unordered_map>
#include <memory>
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -115,7 +113,8 @@ class Executor { ...@@ -115,7 +113,8 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope, void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, const std::string& trainer_desc_str); MultiSlotDataset* dataset,
const std::string& trainer_desc_str);
private: private:
const platform::Place place_; const platform::Place place_;
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object): class DeviceWorker(object):
def __init__(self): def __init__(self):
pass pass
def gen_worker_desc(self, trainer_desc, fleet_desc): def gen_worker_desc(self, trainer_desc):
pass pass
...@@ -25,7 +27,7 @@ class Hogwild(DeviceWorker): ...@@ -25,7 +27,7 @@ class Hogwild(DeviceWorker):
def __init__(self): def __init__(self):
super(Hogwild, self).__init__() super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc): def gen_worker_desc(self, trainer_desc):
trainer_desc.device_worker_name = "HogwildWorker" trainer_desc.device_worker_name = "HogwildWorker"
...@@ -33,7 +35,7 @@ class DownpourSGD(DeviceWorker): ...@@ -33,7 +35,7 @@ class DownpourSGD(DeviceWorker):
def __init__(self): def __init__(self):
super(Downpour, self).__init__() super(Downpour, self).__init__()
def gen_worker_desc(self, trainer_desc, fleet_desc): def gen_worker_desc(self, trainer_desc):
trainer_desc.device_worker_name = "DownpourWorker" trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num pull_thread.device_num = trainer_desc.thread_num
......
...@@ -33,6 +33,9 @@ class DownpourSGD(object): ...@@ -33,6 +33,9 @@ class DownpourSGD(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
opt = fluid.DistributedOptimizer(sgd_opt)
opt.minimize()
downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2) downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
downpour_sgd.minimize(cost) downpour_sgd.minimize(cost)
""" """
...@@ -87,6 +90,7 @@ class DownpourSGD(object): ...@@ -87,6 +90,7 @@ class DownpourSGD(object):
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
dense_table_index = 1 dense_table_index = 1
program_configs = [] program_configs = []
param_grads_list = []
for loss_index in range(len(losses)): for loss_index in range(len(losses)):
program_config = ps_param.trainer_param.program_config.add() program_config = ps_param.trainer_param.program_config.add()
program_config.program_id = str( program_config.program_id = str(
...@@ -97,6 +101,7 @@ class DownpourSGD(object): ...@@ -97,6 +101,7 @@ class DownpourSGD(object):
append_backward(losses[loss_index], parameter_list, append_backward(losses[loss_index], parameter_list,
no_grad_set), no_grad_set),
key=lambda x: x[0].name) key=lambda x: x[0].name)
param_grads_list.append(params_grads)
params = [] params = []
grads = [] grads = []
data_norm_params = [] data_norm_params = []
...@@ -156,4 +161,8 @@ class DownpourSGD(object): ...@@ -156,4 +161,8 @@ class DownpourSGD(object):
opt_info["optimizer"] = "DownpourSGD" opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops opt_info["worker_skipped_ops"] = worker_skipped_ops
return opt_info
for loss in losses:
loss.block.program._fleet_opt = opt_info
return None, param_grads_list
...@@ -612,7 +612,16 @@ class Executor(object): ...@@ -612,7 +612,16 @@ class Executor(object):
def _run_inference(self, exe, feed): def _run_inference(self, exe, feed):
return exe.run(feed) return exe.run(feed)
def run_from_dataset(self, def infer_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
pass
def train_from_dataset(self,
program=None, program=None,
dataset=None, dataset=None,
fetch_list=None, fetch_list=None,
...@@ -623,20 +632,20 @@ class Executor(object): ...@@ -623,20 +632,20 @@ class Executor(object):
scope = global_scope() scope = global_scope()
if fetch_list is None: if fetch_list is None:
fetch_list = [] fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
if not compiled: if not compiled:
trainer = TrainerFactory().create_trainer(opt_info) trainer = TrainerFactory().create_trainer(program._fleet_opt)
else:
trainer = TrainerFactory().create_trainer(
program.program._fleet_opt)
if thread <= 0: if thread <= 0:
trainer.set_thread(dataset.thread_num) trainer.set_thread(dataset.thread_num)
else: else:
trainer.set_thread(thread) trainer.set_thread(thread)
trainer.gen_trainer_desc() trainer.gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
print("run_from_dataset called")
self._default_executor.run_from_dataset(program.desc, scope, self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset, dataset.dataset,
trainer._desc()) trainer._desc())
else:
# For compiled program, more runtime should be implemented
print("run_from_dataset current does not support compiled program"
", we will support this later", sys.stderr)
...@@ -2704,6 +2704,10 @@ class Program(object): ...@@ -2704,6 +2704,10 @@ class Program(object):
# whether the program is optimized by memory_optimize_transpiler # whether the program is optimized by memory_optimize_transpiler
self.__is_mem_optimized = False self.__is_mem_optimized = False
# if this program has been optimized by distributed optimizer
# fleet_opt will be given a value
self._fleet_opt = None
@property @property
def _is_mem_optimized(self): def _is_mem_optimized(self):
# if the program is optimized, operator input/outputs # if the program is optimized, operator input/outputs
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc): ...@@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc):
def gen_trainer_desc(self): def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc() super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer" self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_) self.device_worker_.gen_worker_desc(self.proto_desc)
class DistMultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc):
...@@ -70,7 +70,7 @@ class DistMultiTrainer(TrainerDesc): ...@@ -70,7 +70,7 @@ class DistMultiTrainer(TrainerDesc):
def gen_trainer_desc(self): def gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc() super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_) self.device_worker_.gen_worker_desc(self.proto_desc)
def set_program_config(self, fleet_desc, program_id): def set_program_config(self, fleet_desc, program_id):
for program_config in fleet_desc.trainer_param.program_config: for program_config in fleet_desc.trainer_param.program_config:
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from trainer_desc import * from .trainer_desc import MultiTrainer
from device_worker import * from .device_worker import Hogwild
__all__ = ["TrainerFactory"] __all__ = ["TrainerFactory"]
...@@ -38,5 +38,5 @@ class TrainerFactory(object): ...@@ -38,5 +38,5 @@ class TrainerFactory(object):
device_worker = globals()[device_worker_class]() device_worker = globals()[device_worker_class]()
trainer.set_device_worker(device_worker) trainer.set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"]) trainer.set_fleet_desc(opt_info["fleet_desc"])
trainer.gen_trainer_desc(fleet_desc=opt_info["fleet_desc"]) trainer.gen_trainer_desc()
return trainer return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册