提交 a23f1ee8 编写于 作者: Q Qiao Longfei

optimize code

上级 a0bb18be
......@@ -23,6 +23,7 @@ namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
......@@ -43,12 +44,15 @@ inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
// get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) {
for (auto &node : graphs[i]->Nodes()) {
if (node->IsOp()) {
if (node->Op()->Type() == "send") {
VLOG(3) << "node name " << node->Name();
std::vector<ir::Node *> nodes_to_delete;
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
......@@ -61,8 +65,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
epmap, height_section);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Op()->Type() == "recv") {
auto recv_var_name = node->Op()->Input("X")[0];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
......@@ -70,18 +74,23 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
graphs[i]->RemoveNode(node);
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
VLOG(3) << "delete all recv ops";
for (auto *node : nodes_to_delete) {
graphs[i]->RemoveNode(node);
}
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use ";
VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
}
}
......
......@@ -277,7 +277,7 @@ ParallelExecutor::ParallelExecutor(
// ncclOp
std::vector<ir::Graph *> async_graphs(places.size());
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode";
temp_owned_graph =
build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]},
......@@ -298,7 +298,7 @@ ParallelExecutor::ParallelExecutor(
member_->nccl_ctxs_.get());
}
#else
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode";
temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
......@@ -358,7 +358,7 @@ ParallelExecutor::ParallelExecutor(
}
}
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
if (build_strategy.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, async_graphs));
......
......@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h"
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
......@@ -28,6 +31,7 @@ namespace distributed {
static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) {
VLOG(3) << "merge " << vars.size() << " vars " << var_name << " to one";
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
......@@ -67,29 +71,32 @@ std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_;
void Communicator::SendThread() {
VLOG("SendThread start!");
while (running_) {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_queue_) {
auto send_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second;
std::vector<std::shared_ptr<Variable>> vars;
// TODO(qiao): need to be configurable
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
send_functor(ctx, *send_scope_, true);
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
auto &var_name = iter.first;
auto &var_queue = iter.second;
if (var_queue->NotEmpty()) { // will block if queue is empty
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << "merge var " << var_name << " and send";
std::vector<std::shared_ptr<Variable>> vars;
// TODO(qiao): need to be configurable
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
send_functor(ctx, *send_scope_, true);
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
}
}
for (auto &task_f : task_futures) {
task_f.wait();
......@@ -98,6 +105,7 @@ void Communicator::SendThread() {
}
void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) {
// parallel run recv graph
std::vector<std::future<void>> task_futures;
......@@ -115,6 +123,8 @@ void Communicator::RecvThread() {
for (auto &task : task_futures) {
task.wait();
}
// TODO(qiao) need to be configuable
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
}
......
......@@ -68,6 +68,12 @@ class BlockingQueue {
return rc;
}
bool NotEmpty() {
std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); });
return true;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
......
......@@ -60,12 +60,14 @@ class VariableResponse {
bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) {
local_scope_ = scope->NewTmpScope();
local_scope_ = &scope->NewScope();
}
}
virtual ~VariableResponse() {
if (local_scope_) delete local_scope_;
if (local_scope_) {
scope_->DeleteScope(local_scope_);
}
}
int Parse(Source* source, const sendrecv::VariableMessage& meta) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册