提交 d87ba58c 编写于 作者: D dongdaxiang

refine document of python API, make device_worker and trainer's API private

test=develop
上级 5687f234
......@@ -279,11 +279,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
total_time += timeline.ElapsedSec();
VLOG(3) << "push sparse and dense gradient done.";
int32_t tmp_push_dense_wait_times = -1;
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
......@@ -297,6 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
if (need_to_push_sparse_) {
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
......@@ -311,6 +311,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "going to increase thread version";
VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size();
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
......@@ -381,69 +384,78 @@ void DownpourWorker::TrainFiles() {
}
}
// push gradients here
for (size_t i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
if (need_to_push_sparse_) {
// push gradients here
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
}
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
}
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
VLOG(3) << "push dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
VLOG(3) << "push sparse and dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
}
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
if (need_to_push_sparse_) {
VLOG(3) << "push sparse gradient done.";
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
}
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
if (need_to_push_dense_) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
PrintFetchVars();
......
......@@ -154,10 +154,8 @@ class AsyncExecutor(object):
with open("trainer_desc.proto", "w") as fout:
fout.write(trainer._desc())
# define a trainer and a device_worker here
self.executor.run_from_files(program_desc,
trainer._desc(), debug)
self.executor.run_from_files(program_desc, trainer._desc(), debug)
'''
def run(self,
program,
data_feed,
......@@ -228,8 +226,8 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num,
fetch_var_names, mode, debug, str(id(program_desc)))
'''
fetch_var_names, mode, debug,
str(id(program_desc)))
def download_data(self,
afs_path,
......
......@@ -19,7 +19,10 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object):
"""
DeviceWorker is a abstract class, which generates worker desc.
This class is an inner class that we do computation logics within
the implementation. For example, execution of a program or a graph.
"""
def __init__(self):
"""
Init.
......@@ -27,10 +30,16 @@ class DeviceWorker(object):
self.program_ = None
self.infer_ = None
def set_infer(self, infer=False):
def _set_infer(self, infer=False):
"""
set inference flag for current device worker
Args:
infer(bool): whether to do inference
"""
self.infer_ = infer
def set_fleet_desc(self, fleet_desc):
def _set_fleet_desc(self, fleet_desc):
"""
Set fleet desc.
......@@ -39,7 +48,7 @@ class DeviceWorker(object):
"""
self.fleet_desc_ = fleet_desc
def set_program(self, program):
def _set_program(self, program):
"""
Set program.
......@@ -48,7 +57,7 @@ class DeviceWorker(object):
"""
self.program_ = program
def gen_worker_desc(self, trainer_desc):
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc.
......@@ -65,13 +74,14 @@ class Hogwild(DeviceWorker):
Hogwild is a kind of SGD algorithm.
"""
def __init__(self):
"""
Init.
"""
super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc):
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is HogwildWorker.
......@@ -85,13 +95,15 @@ class DownpourSGD(DeviceWorker):
"""
DownpourSGD is a kind of distributed SGD algorithm.
"""
def __init__(self):
"""
Init.
initialize downpourSGD device worker
"""
super(DownpourSGD, self).__init__()
def gen_worker_desc(self, trainer_desc):
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is DownpourWorker.
......@@ -162,6 +174,6 @@ class DownpourSGD(DeviceWorker):
class DeviceWorkerFactory(object):
def create_device_worker(self, worker_type):
def _create_device_worker(self, worker_type):
classname = worker_type.capitalize()
return globals()[classname]()
......@@ -637,23 +637,23 @@ class Executor(object):
assert len(fetch_list) == len(fetch_info)
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory().create_trainer(program._fleet_opt)
trainer.set_program(program)
trainer = TrainerFactory()._create_trainer(program._fleet_opt)
trainer._set_program(program)
else:
trainer = TrainerFactory().create_trainer(
trainer = TrainerFactory()._create_trainer(
program.program._fleet_opt)
trainer.set_program(program.program)
trainer._set_program(program.program)
if thread <= 0:
if dataset.thread_num <= 0:
raise RuntimeError(
"You should set thread num first, either in Dataset"
"or in Executor.train_from_dataset")
else:
trainer.set_thread(dataset.thread_num)
trainer._set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
trainer.set_debug(debug)
trainer.set_fetch_var_and_info(fetch_list, fetch_info, print_period)
trainer._set_thread(thread)
trainer._set_debug(debug)
trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return trainer
def infer_from_dataset(self,
......@@ -679,7 +679,7 @@ class Executor(object):
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run train_from_dataset
debug(bool): whether a user wants to run infer_from_dataset
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
fetch_info(String List): print information for each variable
......@@ -711,8 +711,8 @@ class Executor(object):
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer.gen_trainer_desc()
trainer.set_infer(True)
trainer._gen_trainer_desc()
trainer._set_infer(True)
dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer)
......@@ -784,7 +784,7 @@ class Executor(object):
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer.gen_trainer_desc()
trainer._gen_trainer_desc()
dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer)
......
......@@ -37,32 +37,32 @@ class TrainerDesc(object):
self.program_ = None
self.infer_ = False
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
def _set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for i, v in enumerate(fetch_vars):
self.proto_desc.fetch_config.fetch_var_names.extend([v.name])
self.proto_desc.fetch_config.fetch_var_str_format.extend(
[fetch_info[i]])
self.proto_desc.fetch_config.print_period = print_period
def set_debug(self, debug):
def _set_debug(self, debug):
self.proto_desc.debug = debug
def set_thread(self, thread_num):
def _set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num
def set_device_worker(self, device_worker):
def _set_device_worker(self, device_worker):
self.device_worker_ = device_worker
def set_infer(self, infer):
def _set_infer(self, infer):
self.infer_ = infer
def set_fleet_desc(self, fleet_desc):
def _set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc
def gen_trainer_desc(self):
def _gen_trainer_desc(self):
pass
def set_program(self, program):
def _set_program(self, program):
self.program_ = program
def _desc(self):
......@@ -74,11 +74,11 @@ class MultiTrainer(TrainerDesc):
super(MultiTrainer, self).__init__()
pass
def set_program(self, program):
def _set_program(self, program):
super(MultiTrainer, self).set_program(program)
self.program_ = program
def gen_trainer_desc(self):
def _gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.set_infer(self.infer_)
......@@ -90,11 +90,11 @@ class DistMultiTrainer(TrainerDesc):
super(DistMultiTrainer, self).__init__()
pass
def set_program(self, program):
def _set_program(self, program):
super(DistMultiTrainer, self).set_program(program)
self.program_ = program
def gen_trainer_desc(self):
def _gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
......
......@@ -22,7 +22,7 @@ class TrainerFactory(object):
def __init__(self):
pass
def create_trainer(self, opt_info=None):
def _create_trainer(self, opt_info=None):
trainer = None
device_worker = None
if opt_info == None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册