提交 23d3929a 编写于 作者: Q Qiao Longfei

optimize merge vars

上级 d3a14377
......@@ -18,12 +18,15 @@ limitations under the License. */
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
DEFINE_bool(communicator_independent_recv_thread, true,
"use an independent to recv vars from parameter server");
......@@ -40,28 +43,54 @@ namespace paddle {
namespace operators {
namespace distributed {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) {
VLOG(3) << "merge " << vars.size() << " vars " << var_name << " to 1";
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
VLOG(3) << "merge " << var_name << " LoDTensor"
<< var0->Get<framework::LoDTensor>().dims();
// init output tensor
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out_t->mutable_data<float>(
var0->Get<framework::LoDTensor>().dims(), cpu_place);
auto numel = out_t->numel();
for (auto i = 0; i < numel; ++i) {
out_ptr[i] = 0;
// check the input dims
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(var_t.numel(), numel, "should have the same dims");
out_ptr[i] += var_t.data<float>()[i];
}
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
math::SetConstant<paddle::platform::CPUDeviceContext, float>
constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
// sum all vars to out
auto result = EigenVector<T>::Flatten(*out_t);
for (auto &var : vars) {
auto &in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<float>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto &slr0 = var0->Get<framework::SelectedRows>();
auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
......@@ -74,6 +103,8 @@ static inline void MergeVars(const std::string &var_name,
merge_add;
auto dev_ctx = paddle::platform::CPUDeviceContext();
merge_add(dev_ctx, inputs, out_slr, false);
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims();
} else {
PADDLE_THROW("unsupported var type!");
}
......@@ -123,12 +154,13 @@ void Communicator::SendThread() {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
if (var_queue->Size() > 0) {
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << "merge var " << var_name << " and send";
VLOG(3) << var_name << " merge and send";
std::vector<std::shared_ptr<Variable>> vars;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 &&
......@@ -136,12 +168,19 @@ void Communicator::SendThread() {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
auto before_merge = GetCurrentUS();
MergeVars(var_name, vars, send_scope_.get());
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << var_name << " use time "
<< after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, *send_scope_, true);
}
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
......@@ -152,7 +191,9 @@ void Communicator::SendThread() {
for (auto &task_f : task_futures) {
task_f.wait();
}
VLOG(3) << "run send graph done";
auto after_run_send_graph = GetCurrentUS();
VLOG(3) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
if (!FLAGS_communicator_independent_recv_thread) {
RecvAll();
}
......@@ -161,6 +202,7 @@ void Communicator::SendThread() {
void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
......@@ -177,7 +219,8 @@ void Communicator::RecvAll() {
for (auto &task : task_futures) {
task.wait();
}
VLOG(3) << "run recv graph done";
auto after_recv = GetCurrentUS();
VLOG(3) << "run recv graph use time " << after_recv - before_send;
}
void Communicator::RecvThread() {
......@@ -191,7 +234,6 @@ void Communicator::RecvThread() {
void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) {
if (!FLAGS_communicator_fake_rpc) {
VLOG(3) << "communicator send " << var_name;
// push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name);
......@@ -201,7 +243,6 @@ void Communicator::Send(const std::string &var_name,
auto &queue = send_varname_to_queue_.at(var_name);
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
queue->Push(tmp_grad_var);
}
}
Communicator *Communicator::GetInstance() { return communicator_.get(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册