提交 23d3929a 编写于 作者: Q Qiao Longfei

optimize merge vars

上级 d3a14377
...@@ -18,12 +18,15 @@ limitations under the License. */ ...@@ -18,12 +18,15 @@ limitations under the License. */
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h" #include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h" #include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
DEFINE_bool(communicator_independent_recv_thread, true, DEFINE_bool(communicator_independent_recv_thread, true,
"use an independent to recv vars from parameter server"); "use an independent to recv vars from parameter server");
...@@ -40,28 +43,54 @@ namespace paddle { ...@@ -40,28 +43,54 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
static inline void MergeVars(const std::string &var_name, static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars, const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) { Scope *scope) {
VLOG(3) << "merge " << vars.size() << " vars " << var_name << " to 1";
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0]; auto &var0 = vars[0];
auto *out_var = scope->Var(var_name); auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) { if (var0->IsType<framework::LoDTensor>()) {
VLOG(3) << "merge " << var_name << " LoDTensor"
<< var0->Get<framework::LoDTensor>().dims();
// init output tensor
auto *out_t = out_var->GetMutable<framework::LoDTensor>(); auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out_t->mutable_data<float>( auto *out_ptr = out_t->mutable_data<float>(
var0->Get<framework::LoDTensor>().dims(), cpu_place); var0->Get<framework::LoDTensor>().dims(), cpu_place);
auto numel = out_t->numel(); auto numel = out_t->numel();
for (auto i = 0; i < numel; ++i) {
out_ptr[i] = 0; // check the input dims
for (auto &var : vars) { for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>(); auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(var_t.numel(), numel, "should have the same dims"); PADDLE_ENFORCE_EQ(var_t.numel(), numel, "should have the same dims");
out_ptr[i] += var_t.data<float>()[i]; }
}
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
math::SetConstant<paddle::platform::CPUDeviceContext, float>
constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
// sum all vars to out
auto result = EigenVector<T>::Flatten(*out_t);
for (auto &var : vars) {
auto &in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<float>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
} }
} else if (var0->IsType<framework::SelectedRows>()) { } else if (var0->IsType<framework::SelectedRows>()) {
auto &slr0 = var0->Get<framework::SelectedRows>();
auto *out_slr = out_var->GetMutable<framework::SelectedRows>(); auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear(); out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place); out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
...@@ -74,6 +103,8 @@ static inline void MergeVars(const std::string &var_name, ...@@ -74,6 +103,8 @@ static inline void MergeVars(const std::string &var_name,
merge_add; merge_add;
auto dev_ctx = paddle::platform::CPUDeviceContext(); auto dev_ctx = paddle::platform::CPUDeviceContext();
merge_add(dev_ctx, inputs, out_slr, false); merge_add(dev_ctx, inputs, out_slr, false);
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims();
} else { } else {
PADDLE_THROW("unsupported var type!"); PADDLE_THROW("unsupported var type!");
} }
...@@ -123,12 +154,13 @@ void Communicator::SendThread() { ...@@ -123,12 +154,13 @@ void Communicator::SendThread() {
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size()); task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph"; VLOG(3) << "run send graph";
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) { for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first; auto &var_name = iter.first;
auto &var_queue = iter.second; auto &var_queue = iter.second;
if (var_queue->Size() > 0) { if (var_queue->Size() > 0) {
auto send_task = [this, &var_name, &var_queue] { auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << "merge var " << var_name << " and send"; VLOG(3) << var_name << " merge and send";
std::vector<std::shared_ptr<Variable>> vars; std::vector<std::shared_ptr<Variable>> vars;
size_t merged_var_num = 0; size_t merged_var_num = 0;
while (var_queue->Size() > 0 && while (var_queue->Size() > 0 &&
...@@ -136,12 +168,19 @@ void Communicator::SendThread() { ...@@ -136,12 +168,19 @@ void Communicator::SendThread() {
vars.push_back(var_queue->Pop()); vars.push_back(var_queue->Pop());
merged_var_num++; merged_var_num++;
} }
auto before_merge = GetCurrentUS();
MergeVars(var_name, vars, send_scope_.get()); MergeVars(var_name, vars, send_scope_.get());
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << var_name << " use time "
<< after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>(); auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name); auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) { if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, *send_scope_, true); send_functor(ctx, *send_scope_, true);
} }
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
}; };
task_futures.emplace_back( task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task))); send_threadpool_->enqueue(std::move(send_task)));
...@@ -152,7 +191,9 @@ void Communicator::SendThread() { ...@@ -152,7 +191,9 @@ void Communicator::SendThread() {
for (auto &task_f : task_futures) { for (auto &task_f : task_futures) {
task_f.wait(); task_f.wait();
} }
VLOG(3) << "run send graph done"; auto after_run_send_graph = GetCurrentUS();
VLOG(3) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
if (!FLAGS_communicator_independent_recv_thread) { if (!FLAGS_communicator_independent_recv_thread) {
RecvAll(); RecvAll();
} }
...@@ -161,6 +202,7 @@ void Communicator::SendThread() { ...@@ -161,6 +202,7 @@ void Communicator::SendThread() {
void Communicator::RecvAll() { void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph"; VLOG(3) << "parallel run recv graph";
auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size()); task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) { for (auto &iter : recv_varname_to_ctx_) {
...@@ -177,7 +219,8 @@ void Communicator::RecvAll() { ...@@ -177,7 +219,8 @@ void Communicator::RecvAll() {
for (auto &task : task_futures) { for (auto &task : task_futures) {
task.wait(); task.wait();
} }
VLOG(3) << "run recv graph done"; auto after_recv = GetCurrentUS();
VLOG(3) << "run recv graph use time " << after_recv - before_send;
} }
void Communicator::RecvThread() { void Communicator::RecvThread() {
...@@ -191,17 +234,15 @@ void Communicator::RecvThread() { ...@@ -191,17 +234,15 @@ void Communicator::RecvThread() {
void Communicator::Send(const std::string &var_name, void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) { const framework::Scope &scope) {
if (!FLAGS_communicator_fake_rpc) { VLOG(3) << "communicator send " << var_name;
VLOG(3) << "communicator send " << var_name; // push var into send queue by var_name
// push var into send queue by var_name auto *grad_var = scope.FindVar(var_name);
auto *grad_var = scope.FindVar(var_name); PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); auto tmp_grad_var = std::make_shared<Variable>();
auto tmp_grad_var = std::make_shared<Variable>(); framework::CopyVariable(*grad_var, tmp_grad_var.get());
framework::CopyVariable(*grad_var, tmp_grad_var.get()); auto &queue = send_varname_to_queue_.at(var_name);
auto &queue = send_varname_to_queue_.at(var_name); VLOG(3) << "send " << var_name << " queue size " << queue->Size();
VLOG(3) << "send " << var_name << " queue size " << queue->Size(); queue->Push(tmp_grad_var);
queue->Push(tmp_grad_var);
}
} }
Communicator *Communicator::GetInstance() { return communicator_.get(); } Communicator *Communicator::GetInstance() { return communicator_.get(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册