提交 98dda08a 编写于 作者: D dongdaxiang

fix pull sparse slow problem

test=develop
上级 93c3c7f9
......@@ -153,11 +153,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) {
th.join();
}
// TODO(guru4elephant): we don't need this
/*
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
_pull_dense_thread->stop();
}
#endif
*/
VLOG(3) << "start to run from files in async_executor";
VLOG(3) << "Drop current scope kids";
root_scope_->DropKids();
......
......@@ -210,6 +210,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
timeline.Pause();
pull_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
timeline.Start();
CollectLabelInfo(i);
timeline.Pause();
collect_label_time += timeline.ElapsedSec();
......@@ -336,6 +337,16 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "pull sparse time percent: %f\n",
pull_sparse_time / total_time * 100);
fprintf(stderr, "collect label time percent: %f\n",
collect_label_time / total_time * 100);
fprintf(stderr, "fill sparse time percent: %f\n",
fill_sparse_time / total_time * 100);
fprintf(stderr, "push sparse time percent: %f\n",
push_sparse_time / total_time * 100);
fprintf(stderr, "push dense time percent: %f\n",
push_dense_time / total_time * 100);
fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
}
}
......
......@@ -142,6 +142,17 @@ void FleetWrapper::PullSparseVarsSync(
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
/*
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
*/
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
......@@ -153,7 +164,6 @@ void FleetWrapper::PullSparseVarsSync(
auto status = pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_sparse_status.push_back(std::move(status));
}
for (auto& t : pull_sparse_status) {
t.wait();
auto status = t.get();
......@@ -207,7 +217,7 @@ void FleetWrapper::PullDenseVarsSync(
}
void FleetWrapper::PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
auto place = platform::CPUPlace();
......
......@@ -73,7 +73,7 @@ class FleetWrapper {
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status);
void PushDenseParamSync(const ProgramDesc& program, const uint64_t table_id,
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push dense variables to server in async mode
......
......@@ -41,7 +41,7 @@ void print_lod_tensor(const std::string& var_name,
void PrintVar(framework::Scope* scope, const std::string& var_name,
const std::string& print_info) {
framework::Variable* var = scope->FindVar(var_name);
if (tensor == nullptr) {
if (var == nullptr) {
VLOG(1) << "Variable Name " << var_name << " does not exist in your scope";
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册