提交 d87ba58c 编写于 作者: D dongdaxiang

refine document of python API, make device_worker and trainer's API private

test=develop
上级 5687f234
...@@ -279,11 +279,8 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -279,11 +279,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
VLOG(3) << "push sparse and dense gradient done."; VLOG(3) << "push sparse and dense gradient done.";
int32_t tmp_push_dense_wait_times = -1; int32_t tmp_push_dense_wait_times = -1;
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times = static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times); static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) { if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) { for (auto& t : push_dense_status_) {
t.wait(); t.wait();
...@@ -297,6 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -297,6 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
if (need_to_push_sparse_) { if (need_to_push_sparse_) {
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) { if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) { for (auto& t : push_sparse_status_) {
t.wait(); t.wait();
...@@ -311,6 +311,9 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -311,6 +311,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "going to increase thread version"; VLOG(3) << "going to increase thread version";
VLOG(3) << "push dense table id size: " VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size(); << param_.program_config(0).push_dense_table_id_size();
}
if (need_to_push_dense_) {
for (size_t i = 0; for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) { i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>( uint64_t tid = static_cast<uint64_t>(
...@@ -381,69 +384,78 @@ void DownpourWorker::TrainFiles() { ...@@ -381,69 +384,78 @@ void DownpourWorker::TrainFiles() {
} }
} }
// push gradients here if (need_to_push_sparse_) {
for (size_t i = 0; i < param_.program_config(0).push_sparse_table_id_size(); // push gradients here
++i) { for (size_t i = 0;
uint64_t tid = static_cast<uint64_t>( i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
param_.program_config(0).push_sparse_table_id(i)); uint64_t tid = static_cast<uint64_t>(
TableParameter table; param_.program_config(0).push_sparse_table_id(i));
for (auto i : param_.sparse_table()) { TableParameter table;
if (i.table_id() == tid) { for (auto i : param_.sparse_table()) {
table = i; if (i.table_id() == tid) {
break; table = i;
break;
}
} }
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
} }
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
} }
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size(); if (need_to_push_dense_) {
++i) { for (size_t i = 0;
uint64_t tid = static_cast<uint64_t>( i < param_.program_config(0).push_dense_table_id_size(); ++i) {
param_.program_config(0).push_dense_table_id(i)); uint64_t tid = static_cast<uint64_t>(
fleet_ptr_->PushDenseVarsAsync( param_.program_config(0).push_dense_table_id(i));
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_); fleet_ptr_->PushDenseVarsAsync(
} *thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
VLOG(3) << "push dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
VLOG(3) << "push sparse and dense gradient done."; if (push_dense_status_.size() >= push_dense_wait_times) {
// the following code should be more precise and clean for (auto& t : push_dense_status_) {
// TODO(guru4elephant) t.wait();
int32_t tmp_push_dense_wait_times = -1; }
int32_t tmp_push_sparse_wait_times = -1; push_dense_status_.resize(0);
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
} }
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) { if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0); push_dense_status_.resize(0);
}
} }
if (push_sparse_status_.size() >= push_sparse_wait_times) { if (need_to_push_sparse_) {
for (auto& t : push_sparse_status_) { VLOG(3) << "push sparse gradient done.";
t.wait(); int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
} }
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) { if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0); push_sparse_status_.resize(0);
}
} }
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size(); if (need_to_push_dense_) {
++i) { for (size_t i = 0;
uint64_t tid = static_cast<uint64_t>( i < param_.program_config(0).push_dense_table_id_size(); ++i) {
param_.program_config(0).push_dense_table_id(i)); uint64_t tid = static_cast<uint64_t>(
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
} }
PrintFetchVars(); PrintFetchVars();
......
...@@ -154,10 +154,8 @@ class AsyncExecutor(object): ...@@ -154,10 +154,8 @@ class AsyncExecutor(object):
with open("trainer_desc.proto", "w") as fout: with open("trainer_desc.proto", "w") as fout:
fout.write(trainer._desc()) fout.write(trainer._desc())
# define a trainer and a device_worker here # define a trainer and a device_worker here
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc, trainer._desc(), debug)
trainer._desc(), debug)
'''
def run(self, def run(self,
program, program,
data_feed, data_feed,
...@@ -228,8 +226,8 @@ class AsyncExecutor(object): ...@@ -228,8 +226,8 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, mode, debug, str(id(program_desc))) fetch_var_names, mode, debug,
''' str(id(program_desc)))
def download_data(self, def download_data(self,
afs_path, afs_path,
......
...@@ -19,7 +19,10 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD'] ...@@ -19,7 +19,10 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object): class DeviceWorker(object):
""" """
DeviceWorker is a abstract class, which generates worker desc. DeviceWorker is a abstract class, which generates worker desc.
This class is an inner class that we do computation logics within
the implementation. For example, execution of a program or a graph.
""" """
def __init__(self): def __init__(self):
""" """
Init. Init.
...@@ -27,10 +30,16 @@ class DeviceWorker(object): ...@@ -27,10 +30,16 @@ class DeviceWorker(object):
self.program_ = None self.program_ = None
self.infer_ = None self.infer_ = None
def set_infer(self, infer=False): def _set_infer(self, infer=False):
"""
set inference flag for current device worker
Args:
infer(bool): whether to do inference
"""
self.infer_ = infer self.infer_ = infer
def set_fleet_desc(self, fleet_desc): def _set_fleet_desc(self, fleet_desc):
""" """
Set fleet desc. Set fleet desc.
...@@ -39,7 +48,7 @@ class DeviceWorker(object): ...@@ -39,7 +48,7 @@ class DeviceWorker(object):
""" """
self.fleet_desc_ = fleet_desc self.fleet_desc_ = fleet_desc
def set_program(self, program): def _set_program(self, program):
""" """
Set program. Set program.
...@@ -48,7 +57,7 @@ class DeviceWorker(object): ...@@ -48,7 +57,7 @@ class DeviceWorker(object):
""" """
self.program_ = program self.program_ = program
def gen_worker_desc(self, trainer_desc): def _gen_worker_desc(self, trainer_desc):
""" """
Generator worker desc. Generator worker desc.
...@@ -65,13 +74,14 @@ class Hogwild(DeviceWorker): ...@@ -65,13 +74,14 @@ class Hogwild(DeviceWorker):
Hogwild is a kind of SGD algorithm. Hogwild is a kind of SGD algorithm.
""" """
def __init__(self): def __init__(self):
""" """
Init. Init.
""" """
super(Hogwild, self).__init__() super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc): def _gen_worker_desc(self, trainer_desc):
""" """
Generator worker desc, which device worker is HogwildWorker. Generator worker desc, which device worker is HogwildWorker.
...@@ -85,13 +95,15 @@ class DownpourSGD(DeviceWorker): ...@@ -85,13 +95,15 @@ class DownpourSGD(DeviceWorker):
""" """
DownpourSGD is a kind of distributed SGD algorithm. DownpourSGD is a kind of distributed SGD algorithm.
""" """
def __init__(self): def __init__(self):
""" """
Init. Init.
initialize downpourSGD device worker
""" """
super(DownpourSGD, self).__init__() super(DownpourSGD, self).__init__()
def gen_worker_desc(self, trainer_desc): def _gen_worker_desc(self, trainer_desc):
""" """
Generator worker desc, which device worker is DownpourWorker. Generator worker desc, which device worker is DownpourWorker.
...@@ -162,6 +174,6 @@ class DownpourSGD(DeviceWorker): ...@@ -162,6 +174,6 @@ class DownpourSGD(DeviceWorker):
class DeviceWorkerFactory(object): class DeviceWorkerFactory(object):
def create_device_worker(self, worker_type): def _create_device_worker(self, worker_type):
classname = worker_type.capitalize() classname = worker_type.capitalize()
return globals()[classname]() return globals()[classname]()
...@@ -637,23 +637,23 @@ class Executor(object): ...@@ -637,23 +637,23 @@ class Executor(object):
assert len(fetch_list) == len(fetch_info) assert len(fetch_list) == len(fetch_info)
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
if not compiled: if not compiled:
trainer = TrainerFactory().create_trainer(program._fleet_opt) trainer = TrainerFactory()._create_trainer(program._fleet_opt)
trainer.set_program(program) trainer._set_program(program)
else: else:
trainer = TrainerFactory().create_trainer( trainer = TrainerFactory()._create_trainer(
program.program._fleet_opt) program.program._fleet_opt)
trainer.set_program(program.program) trainer._set_program(program.program)
if thread <= 0: if thread <= 0:
if dataset.thread_num <= 0: if dataset.thread_num <= 0:
raise RuntimeError( raise RuntimeError(
"You should set thread num first, either in Dataset" "You should set thread num first, either in Dataset"
"or in Executor.train_from_dataset") "or in Executor.train_from_dataset")
else: else:
trainer.set_thread(dataset.thread_num) trainer._set_thread(dataset.thread_num)
else: else:
trainer.set_thread(thread) trainer._set_thread(thread)
trainer.set_debug(debug) trainer._set_debug(debug)
trainer.set_fetch_var_and_info(fetch_list, fetch_info, print_period) trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return trainer return trainer
def infer_from_dataset(self, def infer_from_dataset(self,
...@@ -679,7 +679,7 @@ class Executor(object): ...@@ -679,7 +679,7 @@ class Executor(object):
for each run. default is global_scope for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread) of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run train_from_dataset debug(bool): whether a user wants to run infer_from_dataset
fetch_list(Variable List): fetch variable list, each variable fetch_list(Variable List): fetch variable list, each variable
will be printed during training will be printed during training
fetch_info(String List): print information for each variable fetch_info(String List): print information for each variable
...@@ -711,8 +711,8 @@ class Executor(object): ...@@ -711,8 +711,8 @@ class Executor(object):
fetch_list=fetch_list, fetch_list=fetch_list,
fetch_info=fetch_info, fetch_info=fetch_info,
print_period=print_period) print_period=print_period)
trainer.gen_trainer_desc() trainer._gen_trainer_desc()
trainer.set_infer(True) trainer._set_infer(True)
dataset._prepare_to_run() dataset._prepare_to_run()
if debug: if debug:
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
...@@ -784,7 +784,7 @@ class Executor(object): ...@@ -784,7 +784,7 @@ class Executor(object):
fetch_list=fetch_list, fetch_list=fetch_list,
fetch_info=fetch_info, fetch_info=fetch_info,
print_period=print_period) print_period=print_period)
trainer.gen_trainer_desc() trainer._gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
if debug: if debug:
self._dump_debug_info(program=program, trainer=trainer) self._dump_debug_info(program=program, trainer=trainer)
......
...@@ -37,32 +37,32 @@ class TrainerDesc(object): ...@@ -37,32 +37,32 @@ class TrainerDesc(object):
self.program_ = None self.program_ = None
self.infer_ = False self.infer_ = False
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period): def _set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for i, v in enumerate(fetch_vars): for i, v in enumerate(fetch_vars):
self.proto_desc.fetch_config.fetch_var_names.extend([v.name]) self.proto_desc.fetch_config.fetch_var_names.extend([v.name])
self.proto_desc.fetch_config.fetch_var_str_format.extend( self.proto_desc.fetch_config.fetch_var_str_format.extend(
[fetch_info[i]]) [fetch_info[i]])
self.proto_desc.fetch_config.print_period = print_period self.proto_desc.fetch_config.print_period = print_period
def set_debug(self, debug): def _set_debug(self, debug):
self.proto_desc.debug = debug self.proto_desc.debug = debug
def set_thread(self, thread_num): def _set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num self.proto_desc.thread_num = thread_num
def set_device_worker(self, device_worker): def _set_device_worker(self, device_worker):
self.device_worker_ = device_worker self.device_worker_ = device_worker
def set_infer(self, infer): def _set_infer(self, infer):
self.infer_ = infer self.infer_ = infer
def set_fleet_desc(self, fleet_desc): def _set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc self.fleet_desc_ = fleet_desc
def gen_trainer_desc(self): def _gen_trainer_desc(self):
pass pass
def set_program(self, program): def _set_program(self, program):
self.program_ = program self.program_ = program
def _desc(self): def _desc(self):
...@@ -74,11 +74,11 @@ class MultiTrainer(TrainerDesc): ...@@ -74,11 +74,11 @@ class MultiTrainer(TrainerDesc):
super(MultiTrainer, self).__init__() super(MultiTrainer, self).__init__()
pass pass
def set_program(self, program): def _set_program(self, program):
super(MultiTrainer, self).set_program(program) super(MultiTrainer, self).set_program(program)
self.program_ = program self.program_ = program
def gen_trainer_desc(self): def _gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc() super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer" self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.set_infer(self.infer_) self.device_worker_.set_infer(self.infer_)
...@@ -90,11 +90,11 @@ class DistMultiTrainer(TrainerDesc): ...@@ -90,11 +90,11 @@ class DistMultiTrainer(TrainerDesc):
super(DistMultiTrainer, self).__init__() super(DistMultiTrainer, self).__init__()
pass pass
def set_program(self, program): def _set_program(self, program):
super(DistMultiTrainer, self).set_program(program) super(DistMultiTrainer, self).set_program(program)
self.program_ = program self.program_ = program
def gen_trainer_desc(self): def _gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc() super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None: if self.program_ == None:
......
...@@ -22,7 +22,7 @@ class TrainerFactory(object): ...@@ -22,7 +22,7 @@ class TrainerFactory(object):
def __init__(self): def __init__(self):
pass pass
def create_trainer(self, opt_info=None): def _create_trainer(self, opt_info=None):
trainer = None trainer = None
device_worker = None device_worker = None
if opt_info == None: if opt_info == None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册