提交 ff87698a 编写于 作者: D dongdaxiang

refactor downpour optimization

上级 b66f0074
......@@ -61,8 +61,7 @@ class MultiTrainer : public TrainerBase {
public:
MultiTrainer() {}
virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set);
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
......@@ -80,14 +79,12 @@ class DistMultiTrainer : public MultiTrainer {
public:
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set);
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize();
protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
};
} // namespace framework
......
......@@ -29,7 +29,7 @@ class Hogwild(DeviceWorker):
trainer_desc.device_worker_name = "HogwildWorker"
class Downpour(DeviceWorker):
class DownpourSGD(DeviceWorker):
def __init__(self):
super(Downpour, self).__init__()
......@@ -55,6 +55,7 @@ class Downpour(DeviceWorker):
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click"
dense_table = downpour.dense_table.add()
......@@ -70,6 +71,4 @@ class Downpour(DeviceWorker):
class DeviceWorkerFactory(object):
def create_device_worker(self, worker_type):
classname = worker_type.capitalize()
print("------------")
print(classname)
return globals()[classname]()
......@@ -142,4 +142,18 @@ class DownpourSGD(object):
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
return [ps_param, worker_skipped_ops]
# all fleet operations should be defined in operators in the future
# we want to return an object here containing:
# 1) worker execution strategy
# 2) pserver execution strategy
# 3) fleet configurations
# 4) skipped operators in runtime
# 5) distributed optimization
opt_info = {}
opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD"
opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
return opt_info
......@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import sys
from .. import core
from . import ps_instance
......@@ -33,9 +34,15 @@ class Fleet(object):
self.instance_.barrier_all()
self.instance.finalize()
def init_pserver(self, dist_desc):
self.dist_desc_str_ = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc
def init_pserver(self, opt_info):
if "fleet_desc" in opt_info:
self.dist_desc_str_ = text_format.MessageToString(opt_info[
"fleet_desc"])
self.dist_desc_ = opt_info["fleet_desc"]
else:
print(
"You should run distributed optimization to get opt_info first")
sys.exit(-1)
self.fleet_.init_server(self.dist_desc_str_)
ip = self.fleet_.start_server()
self.instance_.set_ip(ip)
......@@ -44,10 +51,15 @@ class Fleet(object):
self.fleet.gather_servers(ips, self.instance_.get_node_cnt())
self.instance_.barrier_all()
def init_worker(self, dist_desc):
self.dist_desc_str_ = text_format.MessageToString(dist_desc)
self.dist_desc_ = dist_desc
def init_worker(self, opt_info):
if "fleet_desc" in opt_info:
self.dist_desc_str_ = text_format.MessageToString(opt_info[
"fleet_desc"])
self.dist_desc_ = opt_info["fleet_desc"]
else:
print(
"You should run distributed optimization to get opt_info first")
sys.exit(-1)
self.instance_.barrier_all()
ips = self.instance.gather_ips()
self.fleet_.init_worker(self.dist_desc_str_, ips,
......
......@@ -630,6 +630,7 @@ class Executor(object):
trainer.set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
trainer.gen_trainer_desc()
dataset._prepare_to_run()
print("run_from_dataset called")
self._default_executor.run_from_dataset(program.desc, scope,
......
......@@ -32,19 +32,19 @@ class TrainerDesc(object):
import multiprocessing as mp
# set default thread num == cpu count
self.proto_desc.thread_num = mp.cpu_count()
self.fleet_desc_ = None
self.device_worker_ = None
def set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num
def set_filelist(self, filelist):
self.proto_desc.filelist.extend(filelist)
self.proto_desc.thread_num = min(
len(filelist), self.proto_desc.thread_num)
def set_device_worker(self, device_worker):
self.device_worker_ = device_worker
def set_data_feed(self, datafeed):
self.proto_desc.data_desc.CopyFrom(datafeed.proto_desc)
def set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker=None):
def gen_trainer_desc(self):
pass
def _desc(self):
......@@ -52,17 +52,14 @@ class TrainerDesc(object):
class MultiTrainer(TrainerDesc):
def __init__(self, dataset=None, worker="Hogwild"):
def __init__(self):
super(MultiTrainer, self).__init__()
if worker == "Hogwild":
self.proto_desc.device_worker_name = worker + "Worker"
self.proto_desc.class_name = "MultiTrainer"
else:
raise ValueError('ValueError: DeviceWorker %s '
'is not supported in MultiTrainer' % worker)
pass
def gen_trainer_desc(self, dataset=None, fleet_desc=None, worker="Hogwild"):
super(MultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, fleet_desc_)
class DistMultiTrainer(TrainerDesc):
......@@ -70,14 +67,10 @@ class DistMultiTrainer(TrainerDesc):
super(DistMultiTrainer, self).__init__()
pass
def gen_trainer_desc(self, dataset=None, fleet_desc=None,
worker="Downpour"):
super(DistMultiTrainer, self).gen_trainer_desc(fleet_desc, worker)
def gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
self.proto_desc.data_desc.CopyFrom(dataset.proto_desc)
worker_builder = DeviceWorkerFactory()
device_worker = worker_builder.create_device_worker("Downpour")
device_worker.gen_worker_desc(self.proto_desc, fleet_desc)
self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_)
def set_program_config(self, fleet_desc, program_id):
for program_config in fleet_desc.trainer_param.program_config:
......
......@@ -20,13 +20,20 @@ class TrainerFactory(object):
pass
def create_trainer(self, opt_info=None):
trainer = None
device_worker = None
if opt_info == None:
return MultiTrainer()
# default is MultiTrainer + Hogwild
trainer = MultiTrainer()
device_worker = Hogwild()
trainer.set_device_worker(device_worker)
trainer.gen_trainer_desc()
else:
if opt_info["optimizer"] == "DownpourSGD":
trainer = DistMultiTrainer()
trainer.gen_trainer_desc(
fleet_desc=opt_info["fleet"], worker="downpour")
return trainer
else:
print("Currently only support DownpourSGD")
trainer_class = opt_info["trainer"]
device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]()
trainer.set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"])
trainer.gen_trainer_desc(fleet_desc=opt_info["fleet_desc"])
return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册