提交 6af697ad 编写于 作者: D dongdaxiang

add trainfileswithprofiler for downpour worker

上级 2644b886
...@@ -53,6 +53,18 @@ void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { ...@@ -53,6 +53,18 @@ void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
VLOG(3) << "init other env done."; VLOG(3) << "init other env done.";
} }
void DistMultiTrainer::Run() {
for (int thidx = 0; thidx < thread_num_; ++thidx) {
if (!debug_) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
} else {
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
workers_[thidx].get()));
}
}
}
void DistMultiTrainer::Finalize() { void DistMultiTrainer::Finalize() {
for (auto& th : threads_) { for (auto& th : threads_) {
th.join(); th.join();
......
...@@ -82,14 +82,10 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) { ...@@ -82,14 +82,10 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
auto& feature = features_[table_id]; auto& feature = features_[table_id];
auto& feature_label = feature_labels_[table_id]; auto& feature_label = feature_labels_[table_id];
feature_label.resize(feature.size()); feature_label.resize(feature.size());
VLOG(3) << "going to get label_var_name " << label_var_name_[table_id];
Variable* var = thread_scope_->FindVar(label_var_name_[table_id]); Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
VLOG(3) << "going to get tensor";
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
VLOG(3) << "going to get ptr";
int64_t* label_ptr = tensor->data<int64_t>(); int64_t* label_ptr = tensor->data<int64_t>();
VLOG(3) << "lele";
int global_index = 0; int global_index = 0;
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) { for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
VLOG(3) << "sparse_key_names_[" << i VLOG(3) << "sparse_key_names_[" << i
...@@ -98,7 +94,6 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) { ...@@ -98,7 +94,6 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>(); LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0; int fea_idx = 0;
VLOG(3) << "Haha";
// tensor->lod()[0].size() == batch_size + 1 // tensor->lod()[0].size() == batch_size + 1
for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) { for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) { for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
...@@ -110,7 +105,6 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) { ...@@ -110,7 +105,6 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
static_cast<float>(label_ptr[lod_idx - 1]); static_cast<float>(label_ptr[lod_idx - 1]);
} }
} }
VLOG(3) << "EE";
} }
CHECK(global_index == feature.size()) CHECK(global_index == feature.size())
<< "expect fea info size:" << feature.size() << " real:" << global_index; << "expect fea info size:" << feature.size() << " real:" << global_index;
...@@ -163,6 +157,174 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -163,6 +157,174 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
void DownpourWorker::TrainFilesWithProfiler() { void DownpourWorker::TrainFilesWithProfiler() {
VLOG(3) << "Begin to train files with profiler"; VLOG(3) << "Begin to train files with profiler";
platform::SetNumThreads(1); platform::SetNumThreads(1);
device_reader_->Start();
std::vector<double> op_total_time;
std::vector<std::string> op_name;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
op_name.push_back(op->Type());
}
}
VLOG(3) << "op name size: " << op_name.size();
op_total_time.resize(op_name.size());
for (size_t i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
platform::Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
double pull_sparse_time = 0.0;
double collect_label_time = 0.0;
double fill_sparse_time = 0.0;
double push_sparse_time = 0.0;
double push_dense_time = 0.0;
int cur_batch;
int batch_cnt = 0;
timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
VLOG(3) << "program config size: " << param_.program_config_size();
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
timeline.Start();
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim());
timeline.Pause();
pull_sparse_time += timeline.ElapsedSec();
CollectLabelInfo(i);
timeline.Pause();
collect_label_time += timeline.ElapsedSec();
timeline.Start();
FillSparseValue(i);
timeline.Pause();
fill_sparse_time += timeline.ElapsedSec();
}
VLOG(3) << "Fill sparse value for all sparse table done.";
int run_op_idx = 0;
for (auto& op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
if (op->Type().find(skip_ops_[t]) != std::string::npos) {
need_skip = true;
break;
}
}
if (!need_skip) {
timeline.Start();
op->Run(*thread_scope_, place_);
timeline.Pause();
op_total_time[run_op_idx++] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
for (size_t i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
for (auto i : param_.sparse_table()) {
if (i.table_id() == tid) {
table = i;
break;
}
}
timeline.Start();
fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_);
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
}
timeline.Start();
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
}
timeline.Pause();
push_dense_time += timeline.ElapsedSec();
VLOG(3) << "push sparse and dense gradient done.";
int32_t tmp_push_dense_wait_times = -1;
int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times =
static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (push_dense_status_.size() >= push_dense_wait_times) {
for (auto& t : push_dense_status_) {
t.wait();
}
push_dense_status_.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
}
push_sparse_status_.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
VLOG(3) << "going to increase thread version";
VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size();
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
thread_scope_->DropKids();
++batch_cnt;
if (thread_id_ == 0) {
// should be configured here
if (batch_cnt > 0 && batch_cnt % 100 == 0) {
for (size_t i = 0; i < op_total_time.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt);
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
}
}
}
} }
void DownpourWorker::TrainFiles() { void DownpourWorker::TrainFiles() {
......
...@@ -90,7 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -90,7 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int batch_cnt = 0; int batch_cnt = 0;
timeline.Start(); timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
LOG(WARNING) << "read a batch in thread " << thread_id_; VLOG(3) << "read a batch in thread " << thread_id_;
timeline.Pause(); timeline.Pause();
read_time += timeline.ElapsedSec(); read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
......
...@@ -83,6 +83,7 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -83,6 +83,7 @@ class DistMultiTrainer : public MultiTrainer {
virtual ~DistMultiTrainer() {} virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program); virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Run();
virtual void Finalize(); virtual void Finalize();
protected: protected:
......
...@@ -627,7 +627,7 @@ class Executor(object): ...@@ -627,7 +627,7 @@ class Executor(object):
fetch_list=None, fetch_list=None,
scope=None, scope=None,
thread=0, thread=0,
opt_info=None): debug=False):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
if fetch_list is None: if fetch_list is None:
...@@ -636,6 +636,8 @@ class Executor(object): ...@@ -636,6 +636,8 @@ class Executor(object):
if not compiled: if not compiled:
trainer = TrainerFactory().create_trainer(program._fleet_opt) trainer = TrainerFactory().create_trainer(program._fleet_opt)
trainer.set_program(program) trainer.set_program(program)
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
else: else:
trainer = TrainerFactory().create_trainer( trainer = TrainerFactory().create_trainer(
program.program._fleet_opt) program.program._fleet_opt)
...@@ -644,8 +646,11 @@ class Executor(object): ...@@ -644,8 +646,11 @@ class Executor(object):
trainer.set_thread(dataset.thread_num) trainer.set_thread(dataset.thread_num)
else: else:
trainer.set_thread(thread) trainer.set_thread(thread)
trainer.set_debug(debug)
trainer.gen_trainer_desc() trainer.gen_trainer_desc()
dataset._prepare_to_run() dataset._prepare_to_run()
with open("trainer_desc.prototxt", "w") as fout:
fout.write(trainer._desc())
self._default_executor.run_from_dataset(program.desc, scope, self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset, dataset.dataset,
trainer._desc()) trainer._desc())
...@@ -101,10 +101,10 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -101,10 +101,10 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return self.get_size() return self.get_size()
def worker_index(self): def worker_index(self):
return self.rank / self.proc_per_node_ return self.rank_ / self.proc_per_node_
def server_index(self): def server_index(self):
return self.rank / self.proc_per_node_ return self.rank_ / self.proc_per_node_
def barrier_worker(self): def barrier_worker(self):
if self.is_worker(): if self.is_worker():
......
...@@ -36,6 +36,9 @@ class TrainerDesc(object): ...@@ -36,6 +36,9 @@ class TrainerDesc(object):
self.device_worker_ = None self.device_worker_ = None
self.program_ = None self.program_ = None
def set_debug(self, debug):
self.proto_desc.debug = debug
def set_thread(self, thread_num): def set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num self.proto_desc.thread_num = thread_num
...@@ -60,6 +63,10 @@ class MultiTrainer(TrainerDesc): ...@@ -60,6 +63,10 @@ class MultiTrainer(TrainerDesc):
super(MultiTrainer, self).__init__() super(MultiTrainer, self).__init__()
pass pass
def set_program(self, program):
super(MultiTrainer, self).set_program(program)
self.program_ = program
def gen_trainer_desc(self): def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc() super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer" self.proto_desc.class_name = "MultiTrainer"
...@@ -71,8 +78,14 @@ class DistMultiTrainer(TrainerDesc): ...@@ -71,8 +78,14 @@ class DistMultiTrainer(TrainerDesc):
super(DistMultiTrainer, self).__init__() super(DistMultiTrainer, self).__init__()
pass pass
def set_program(self, program):
super(DistMultiTrainer, self).set_program(program)
self.program_ = program
def gen_trainer_desc(self): def gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc() super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
print("None program")
self.device_worker_.set_program(self.program_) self.device_worker_.set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc) self.device_worker_.gen_worker_desc(self.proto_desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册