未验证 提交 fbf9564f 编写于 作者: 1 123malin 提交者: GitHub

【paddle.distributed.fleet】Optimize ParameterServer's Async Mode (#28442)

* test=develop, optimize global_step
上级 98adc8f0
......@@ -65,6 +65,7 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
} else {
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
if (iter.first == STEP_COUNTER && !need_global_step_) continue;
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
send_queue_size_);
......@@ -108,21 +109,87 @@ void AsyncCommunicator::SendGlobalStep(int batches) {
send_functor(ctx, *send_scope_, true, 1);
}
void AsyncCommunicator::SendByCommunicator(int batches) {
void AsyncCommunicator::SendByCommunicator() {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
auto send_task = [this, batches, &var_name, &var_queue] {
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << var_name << " merge and send; ";
std::vector<std::shared_ptr<Variable>> vars;
int merged_var_num = 0;
int wait_times = 0;
while (merged_var_num < max_merge_var_num_) {
if (var_queue->Size() == 0) {
VLOG(4) << "wait_times -> " << wait_times;
if (wait_times >= send_wait_times_) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
} else {
wait_times = 0;
vars.push_back(var_queue->Pop());
merged_var_num++;
}
}
auto before_merge = GetCurrentUS();
if (var_name == STEP_COUNTER) {
SendGlobalStep(merged_var_num);
auto after_merge = GetCurrentUS();
VLOG(3) << "merge and send " << merged_var_num << " " << var_name
<< " use time " << after_merge - before_merge;
return;
}
VLOG(3) << var_name << " merge and send";
auto &ctx = send_varname_to_ctx_.at(var_name);
MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << merged_var_num << " " << var_name << " use time "
<< after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
send_functor(ctx, *send_scope_, true, 1);
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
};
task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
}
for (auto &task_f : task_futures) {
task_f.wait();
}
auto after_run_send_graph = GetCurrentUS();
VLOG(3) << "run send graph use time "
<< (after_run_send_graph - before_run_send_graph);
}
void HalfAsyncCommunicator::SendByCommunicator() {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
int batches = BatchesCounter();
if (batches <= 0) return;
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
auto send_task = [this, batches, &var_name, &var_queue] {
VLOG(3) << var_name << " merge and send; ";
auto before_task = GetCurrentUS();
std::vector<std::shared_ptr<Variable>> vars;
vars.reserve(batches);
......@@ -130,6 +197,14 @@ void AsyncCommunicator::SendByCommunicator(int batches) {
vars.push_back(var_queue->Pop());
}
if (var_name == STEP_COUNTER) {
SendGlobalStep(batches);
auto end_task = GetCurrentUS();
VLOG(3) << "merge " << batches << " " << var_name << " use time "
<< end_task - before_task;
return;
}
auto &ctx = send_varname_to_ctx_.at(var_name);
auto before_merge = GetCurrentUS();
......@@ -142,7 +217,20 @@ void AsyncCommunicator::SendByCommunicator(int batches) {
send_functor(ctx, *send_scope_, true, 1);
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
<< after_send - before_task;
if (var_name.rfind("@GRAD") != var_name.size() - 5) return;
auto recv_param = var_name.substr(0, var_name.size() - 5);
if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end())
return;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_);
auto after_recv = GetCurrentUS();
VLOG(3) << "recv " << recv_param << " use time "
<< after_recv - after_send;
return;
};
task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
}
......@@ -152,7 +240,7 @@ void AsyncCommunicator::SendByCommunicator(int batches) {
auto after_run_send_graph = GetCurrentUS();
VLOG(3) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
<< (after_run_send_graph - before_run_send_graph);
}
void AsyncCommunicator::MainThread() {
......@@ -164,20 +252,28 @@ void AsyncCommunicator::MainThread() {
}
while (running_) {
int batches = BatchesCounter();
if (batches > 0) {
SendGlobalStep(batches);
SendByCommunicator(batches);
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
} else {
VLOG(1) << "get nothing from sending queue, will skip send/recv";
}
SendByCommunicator();
BarrierSend();
}
VLOG(1) << "communicator stopped, send thread exit";
VLOG(3) << "communicator stopped, send thread exit";
}
void HalfAsyncCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
}
VLOG(3) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::RecvByCommunicator() {
......@@ -193,10 +289,13 @@ void AsyncCommunicator::RecvNoBarrier() {
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto before_task = GetCurrentUS();
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_);
auto end_task = GetCurrentUS();
VLOG(1) << "recv var " << var_name << " use time "
<< (end_task - before_task);
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
......@@ -206,37 +305,12 @@ void AsyncCommunicator::RecvNoBarrier() {
}
}
int AsyncCommunicator::BatchesCounter() {
auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER);
size_t merged_var_num = 0;
size_t wait_times = 0;
while (merged_var_num < static_cast<size_t>(max_merge_var_num_)) {
if (step_queue->Size() == 0) {
VLOG(3) << "wait_times -> " << wait_times;
if (wait_times >= static_cast<size_t>(send_wait_times_)) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
} else {
step_queue->Pop();
wait_times = 0;
merged_var_num++;
}
}
return merged_var_num;
}
void AsyncCommunicator::Start() {
VLOG(1) << "Communicator start";
VLOG(3) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
VLOG(1) << "start send thread and recv thread";
VLOG(3) << "start send thread and recv thread";
waiting_ = true;
running_ = true;
BarrierTriggerReset(max_merge_var_num_);
......@@ -247,18 +321,18 @@ void AsyncCommunicator::Start() {
}
void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop";
VLOG(3) << "Communicator stop";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
if (main_thread_) {
VLOG(1) << "stop send thread";
VLOG(3) << "stop send thread";
main_thread_->join();
main_thread_.reset(nullptr);
}
}
VLOG(1) << "Communicator stop done";
VLOG(3) << "Communicator stop done";
}
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
......@@ -271,6 +345,10 @@ void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (table_name == STEP_COUNTER && !need_global_step_) return;
auto before_send_op = GetCurrentUS();
auto &queue = send_varname_to_queue_.at(table_name);
if (table_name == STEP_COUNTER) {
......@@ -279,7 +357,6 @@ void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
tensor->Resize(framework::make_ddim({1}));
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
out_d[0] = 1;
VLOG(3) << "send to " << table_name << " with queue size " << queue->Size();
queue->Push(tmp_var);
} else {
PADDLE_ENFORCE_GE(var_names.size(), 1,
......@@ -295,21 +372,20 @@ void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
auto tmp_var = std::make_shared<Variable>();
if (var->IsType<framework::SelectedRows>()) {
framework::CopyVariable(*var, tmp_var.get());
VLOG(3) << "send to " << table_name << " with queue size "
<< queue->Size();
queue->Push(tmp_var);
} else if (var->IsType<framework::LoDTensor>()) {
// push var into send queue by var_name
auto var_name = var_names[0];
framework::CopyVariable(*var, tmp_var.get());
VLOG(3) << "send to " << table_name << " with queue size "
<< queue->Size();
queue->Push(tmp_var);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown var type to copy, only support LoDTensor/SelectedRows"));
}
}
auto after_send_op = GetCurrentUS();
VLOG(3) << "send to " << table_name << " with queue size " << queue->Size()
<< ", use time " << (after_send_op - before_send_op);
}
void HalfAsyncCommunicator::Clean() {
......
......@@ -302,16 +302,13 @@ class AsyncCommunicator : public Communicator {
const std::vector<std::string> &var_tables,
const framework::Scope &scope) override;
virtual void SendByCommunicator(int batches);
virtual void SendByCommunicator();
virtual void SendGlobalStep(int batches);
virtual void RecvByCommunicator();
virtual void RecvNoBarrier();
virtual int BatchesCounter();
virtual void BarrierSend() {}
virtual void BarrierRecv() {}
......@@ -359,6 +356,10 @@ class HalfAsyncCommunicator : public AsyncCommunicator {
VLOG(0) << "HalfAsyncCommunicator Initialized";
}
void MainThread() override;
void SendByCommunicator() override;
void Clean() override;
void Barrier() override;
......@@ -438,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator {
const std::vector<std::string> &var_tables,
const framework::Scope &scope) override;
void SendByCommunicator(int batches) { return; }
void SendByCommunicator() { return; }
std::vector<int64_t> MergeSparseIds(const std::string &send_varname);
......@@ -475,6 +476,7 @@ class GeoCommunicator : public AsyncCommunicator {
std::shared_ptr<Scope> pserver_scope_;
int send_var_nums_ = 0;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::unordered_map<
......
......@@ -207,6 +207,7 @@ class ParameterServerRuntime(RuntimeBase):
SyncStrategy, GeoStrategy
trainer_config = self.async_strategy.get_trainer_runtime_config()
print(trainer_config)
dist_strategy = self.context["valid_strategy"]
launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册