提交 98dda08a 编写于 作者: D dongdaxiang

fix pull sparse slow problem

test=develop
上级 93c3c7f9
...@@ -153,11 +153,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -153,11 +153,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
// TODO(guru4elephant): we don't need this
/*
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") { if (mode == "mpi") {
_pull_dense_thread->stop(); _pull_dense_thread->stop();
} }
#endif #endif
*/
VLOG(3) << "start to run from files in async_executor"; VLOG(3) << "start to run from files in async_executor";
VLOG(3) << "Drop current scope kids"; VLOG(3) << "Drop current scope kids";
root_scope_->DropKids(); root_scope_->DropKids();
......
...@@ -210,6 +210,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -210,6 +210,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
timeline.Pause(); timeline.Pause();
pull_sparse_time += timeline.ElapsedSec(); pull_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
timeline.Start();
CollectLabelInfo(i); CollectLabelInfo(i);
timeline.Pause(); timeline.Pause();
collect_label_time += timeline.ElapsedSec(); collect_label_time += timeline.ElapsedSec();
...@@ -336,6 +337,16 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -336,6 +337,16 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt); fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100); fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "pull sparse time percent: %f\n",
pull_sparse_time / total_time * 100);
fprintf(stderr, "collect label time percent: %f\n",
collect_label_time / total_time * 100);
fprintf(stderr, "fill sparse time percent: %f\n",
fill_sparse_time / total_time * 100);
fprintf(stderr, "push sparse time percent: %f\n",
push_sparse_time / total_time * 100);
fprintf(stderr, "push dense time percent: %f\n",
push_dense_time / total_time * 100);
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time); fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
} }
} }
......
...@@ -142,6 +142,7 @@ void FleetWrapper::PullSparseVarsSync( ...@@ -142,6 +142,7 @@ void FleetWrapper::PullSparseVarsSync(
} }
fea_keys->push_back(static_cast<uint64_t>(ids[i])); fea_keys->push_back(static_cast<uint64_t>(ids[i]));
} }
/*
fea_values->resize(fea_keys->size() + 1); fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) { for (auto& t : *fea_values) {
t.resize(fea_value_dim); t.resize(fea_value_dim);
...@@ -150,10 +151,19 @@ void FleetWrapper::PullSparseVarsSync( ...@@ -150,10 +151,19 @@ void FleetWrapper::PullSparseVarsSync(
for (auto& t : *fea_values) { for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data()); pull_result_ptr.push_back(t.data());
} }
auto status = pslib_ptr_->_worker_ptr->pull_sparse( */
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_sparse_status.push_back(std::move(status));
} }
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
auto status = pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_sparse_status.push_back(std::move(status));
for (auto& t : pull_sparse_status) { for (auto& t : pull_sparse_status) {
t.wait(); t.wait();
auto status = t.get(); auto status = t.get();
...@@ -207,7 +217,7 @@ void FleetWrapper::PullDenseVarsSync( ...@@ -207,7 +217,7 @@ void FleetWrapper::PullDenseVarsSync(
} }
void FleetWrapper::PushDenseParamSync( void FleetWrapper::PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id, const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names) { const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
......
...@@ -73,7 +73,7 @@ class FleetWrapper { ...@@ -73,7 +73,7 @@ class FleetWrapper {
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status); std::vector<::std::future<int32_t>>* pull_dense_status);
void PushDenseParamSync(const ProgramDesc& program, const uint64_t table_id, void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names); const std::vector<std::string>& var_names);
// Push dense variables to server in async mode // Push dense variables to server in async mode
......
...@@ -41,7 +41,7 @@ void print_lod_tensor(const std::string& var_name, ...@@ -41,7 +41,7 @@ void print_lod_tensor(const std::string& var_name,
void PrintVar(framework::Scope* scope, const std::string& var_name, void PrintVar(framework::Scope* scope, const std::string& var_name,
const std::string& print_info) { const std::string& print_info) {
framework::Variable* var = scope->FindVar(var_name); framework::Variable* var = scope->FindVar(var_name);
if (tensor == nullptr) { if (var == nullptr) {
VLOG(1) << "Variable Name " << var_name << " does not exist in your scope"; VLOG(1) << "Variable Name " << var_name << " does not exist in your scope";
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册