提交 ff87698a 编写于 作者: D dongdaxiang

refactor downpour optimization

上级 b66f0074
...@@ -61,8 +61,7 @@ class MultiTrainer : public TrainerBase { ...@@ -61,8 +61,7 @@ class MultiTrainer : public TrainerBase {
public: public:
MultiTrainer() {} MultiTrainer() {}
virtual ~MultiTrainer() {} virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program, virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place); const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {} virtual void InitOtherEnv(const ProgramDesc& main_program) {}
...@@ -80,14 +79,12 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -80,14 +79,12 @@ class DistMultiTrainer : public MultiTrainer {
public: public:
DistMultiTrainer() {} DistMultiTrainer() {}
virtual ~DistMultiTrainer() {} virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program); virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize(); virtual void Finalize();
protected: protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_; std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
}; };
} // namespace framework } // namespace framework
......
...@@ -29,7 +29,7 @@ class Hogwild(DeviceWorker): ...@@ -29,7 +29,7 @@ class Hogwild(DeviceWorker):
trainer_desc.device_worker_name = "HogwildWorker" trainer_desc.device_worker_name = "HogwildWorker"
class Downpour(DeviceWorker): class DownpourSGD(DeviceWorker):
def __init__(self): def __init__(self):
super(Downpour, self).__init__() super(Downpour, self).__init__()
...@@ -55,6 +55,7 @@ class Downpour(DeviceWorker): ...@@ -55,6 +55,7 @@ class Downpour(DeviceWorker):
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param[ sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2 0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2 sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click" sparse_table.label_var_name = "click"
dense_table = downpour.dense_table.add() dense_table = downpour.dense_table.add()
...@@ -70,6 +71,4 @@ class Downpour(DeviceWorker): ...@@ -70,6 +71,4 @@ class Downpour(DeviceWorker):
class DeviceWorkerFactory(object): class DeviceWorkerFactory(object):
def create_device_worker(self, worker_type): def create_device_worker(self, worker_type):
classname = worker_type.capitalize() classname = worker_type.capitalize()
print("------------")
print(classname)
return globals()[classname]() return globals()[classname]()
...@@ -142,4 +142,18 @@ class DownpourSGD(object): ...@@ -142,4 +142,18 @@ class DownpourSGD(object):
# currently only support lookup_table # currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops) ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
return [ps_param, worker_skipped_ops]
# all fleet operations should be defined in operators in the future
# we want to return an object here containing:
# 1) worker execution strategy
# 2) pserver execution strategy
# 3) fleet configurations
# 4) skipped operators in runtime
# 5) distributed optimization
opt_info = {}
opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD"
opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
return opt_info
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import sys
from .. import core from .. import core
from . import ps_instance from . import ps_instance
...@@ -33,9 +34,15 @@ class Fleet(object): ...@@ -33,9 +34,15 @@ class Fleet(object):
self.instance_.barrier_all() self.instance_.barrier_all()
self.instance.finalize() self.instance.finalize()
def init_pserver(self, dist_desc): def init_pserver(self, opt_info):
self.dist_desc_str_ = text_format.MessageToString(dist_desc) if "fleet_desc" in opt_info:
self.dist_desc = dist_desc self.dist_desc_str_ = text_format.MessageToString(opt_info[
"fleet_desc"])
self.dist_desc_ = opt_info["fleet_desc"]
else:
print(
"You should run distributed optimization to get opt_info first")
sys.exit(-1)
self.fleet_.init_server(self.dist_desc_str_) self.fleet_.init_server(self.dist_desc_str_)
ip = self.fleet_.start_server() ip = self.fleet_.start_server()
self.instance_.set_ip(ip) self.instance_.set_ip(ip)
...@@ -44,10 +51,15 @@ class Fleet(object): ...@@ -44,10 +51,15 @@ class Fleet(object):
self.fleet.gather_servers(ips, self.instance_.get_node_cnt()) self.fleet.gather_servers(ips, self.instance_.get_node_cnt())
self.instance_.barrier_all() self.instance_.barrier_all()
def init_worker(self, dist_desc): def init_worker(self, opt_info):
self.dist_desc_str_ = text_format.MessageToString(dist_desc) if "fleet_desc" in opt_info:
self.dist_desc_ = dist_desc self.dist_desc_str_ = text_format.MessageToString(opt_info[
"fleet_desc"])
self.dist_desc_ = opt_info["fleet_desc"]
else:
print(
"You should run distributed optimization to get opt_info first")
sys.exit(-1)
self.instance_.barrier_all() self.instance_.barrier_all()
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.fleet_.init_worker(self.dist_desc_str_, ips, self.fleet_.init_worker(self.dist_desc_str_, ips,
......
...@@ -630,6 +630,7 @@ class Executor(object): ...@@ -630,6 +630,7 @@ class Executor(object):
trainer.set_thread(dataset.thread_num) trainer.set_thread(dataset.thread_num)
else: else:
trainer.set_thread(thread) trainer.set_thread(thread)
trainer.gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
print("run_from_dataset called") print("run_from_dataset called")
self._default_executor.run_from_dataset(program.desc, scope, self._default_executor.run_from_dataset(program.desc, scope,
......
...@@ -32,19 +32,19 @@ class TrainerDesc(object): ...@@ -32,19 +32,19 @@ class TrainerDesc(object):
import multiprocessing as mp import multiprocessing as mp
# set default thread num == cpu count # set default thread num == cpu count
self.proto_desc.thread_num = mp.cpu_count() self.proto_desc.thread_num = mp.cpu_count()
self.fleet_desc_ = None
self.device_worker_ = None
def set_thread(self, thread_num): def set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num self.proto_desc.thread_num = thread_num
def set_filelist(self, filelist): def set_device_worker(self, device_worker):
self.proto_desc.filelist.extend(filelist) self.device_worker_ = device_worker
self.proto_desc.thread_num = min(
len(filelist), self.proto_desc.thread_num)
def set_data_feed(self, datafeed): def set_fleet_desc(self, fleet_desc):
self.proto_desc.data_desc.CopyFrom(datafeed.proto_desc) self.fleet_desc_ = fleet_desc
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker=None): def gen_trainer_desc(self):
pass pass
def _desc(self): def _desc(self):
...@@ -52,17 +52,14 @@ class TrainerDesc(object): ...@@ -52,17 +52,14 @@ class TrainerDesc(object):
class MultiTrainer(TrainerDesc): class MultiTrainer(TrainerDesc):
def __init__(self, dataset=None, worker="Hogwild"): def __init__(self):
super(MultiTrainer, self).__init__() super(MultiTrainer, self).__init__()
if worker == "Hogwild": pass
self.proto_desc.device_worker_name = worker + "Worker"
self.proto_desc.class_name = "MultiTrainer"
else:
raise ValueError('ValueError: DeviceWorker %s '
'is not supported in MultiTrainer' % worker)
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker="Hogwild"): def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc(fleet_desc, worker) super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, fleet_desc_)
class DistMultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc):
...@@ -70,14 +67,10 @@ class DistMultiTrainer(TrainerDesc): ...@@ -70,14 +67,10 @@ class DistMultiTrainer(TrainerDesc):
super(DistMultiTrainer, self).__init__() super(DistMultiTrainer, self).__init__()
pass pass
def gen_trainer_desc(self, dataset=None, fleet_desc=None, def gen_trainer_desc(self):
worker="Downpour"): super(DistMultiTrainer, self).gen_trainer_desc()
super(DistMultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
self.proto_desc.data_desc.CopyFrom(dataset.proto_desc) self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_)
worker_builder = DeviceWorkerFactory()
device_worker = worker_builder.create_device_worker("Downpour")
device_worker.gen_worker_desc(self.proto_desc, fleet_desc)
def set_program_config(self, fleet_desc, program_id): def set_program_config(self, fleet_desc, program_id):
for program_config in fleet_desc.trainer_param.program_config: for program_config in fleet_desc.trainer_param.program_config:
......
...@@ -20,13 +20,20 @@ class TrainerFactory(object): ...@@ -20,13 +20,20 @@ class TrainerFactory(object):
pass pass
def create_trainer(self, opt_info=None): def create_trainer(self, opt_info=None):
trainer = None
device_worker = None
if opt_info == None: if opt_info == None:
return MultiTrainer() # default is MultiTrainer + Hogwild
trainer = MultiTrainer()
device_worker = Hogwild()
trainer.set_device_worker(device_worker)
trainer.gen_trainer_desc()
else: else:
if opt_info["optimizer"] == "DownpourSGD": trainer_class = opt_info["trainer"]
trainer = DistMultiTrainer() device_worker_class = opt_info["device_worker"]
trainer.gen_trainer_desc( trainer = globals()[trainer_class]()
fleet_desc=opt_info["fleet"], worker="downpour") device_worker = globals()[device_worker_class]()
return trainer trainer.set_device_worker(device_worker)
else: trainer.set_fleet_desc(opt_info["fleet_desc"])
print("Currently only support DownpourSGD") trainer.gen_trainer_desc(fleet_desc=opt_info["fleet_desc"])
return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册