提交 a23f1ee8 编写于 作者: Q Qiao Longfei

optimize code

上级 a0bb18be
...@@ -23,6 +23,7 @@ namespace details { ...@@ -23,6 +23,7 @@ namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos, inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) { Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope(); Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() = *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope; &local_scope;
...@@ -43,12 +44,15 @@ inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos, ...@@ -43,12 +44,15 @@ inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
// get RpcContext and remote send and recv op // get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
using RpcCtxMap = operators::distributed::RpcCtxMap; using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx; RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx; RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) { for (auto i = 0; i < graphs.size(); ++i) {
for (auto &node : graphs[i]->Nodes()) { for (auto &node : graphs[i]->Nodes()) {
if (node->IsOp()) { VLOG(3) << "node name " << node->Name();
if (node->Op()->Type() == "send") { std::vector<ir::Node *> nodes_to_delete;
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0]; auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>( auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames")); node->Op()->GetNullableAttr("send_varnames"));
...@@ -61,8 +65,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -61,8 +65,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
epmap, height_section); epmap, height_section);
VLOG(3) << "find and init an send op: " VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name]; << send_varname_to_ctx[send_var_name];
} else if (node->Op()->Type() == "recv") { } else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Input("X")[0]; auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>( auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames")); node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>( auto epmap = boost::get<std::vector<std::string>>(
...@@ -70,18 +74,23 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -70,18 +74,23 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
recv_varname_to_ctx[recv_var_name] = recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames, operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {}); epmap, {});
graphs[i]->RemoveNode(node); nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: " VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name]; << recv_varname_to_ctx[recv_var_name];
} }
VLOG(3) << "delete all recv ops";
for (auto *node : nodes_to_delete) {
graphs[i]->RemoveNode(node);
}
} }
} }
} }
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use "; VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx, operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope); recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
} }
} }
......
...@@ -277,7 +277,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -277,7 +277,7 @@ ParallelExecutor::ParallelExecutor(
// ncclOp // ncclOp
std::vector<ir::Graph *> async_graphs(places.size()); std::vector<ir::Graph *> async_graphs(places.size());
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
temp_owned_graph = temp_owned_graph =
build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]},
...@@ -298,7 +298,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -298,7 +298,7 @@ ParallelExecutor::ParallelExecutor(
member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get());
} }
#else #else
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
temp_owned_graph = build_strategy.Apply( temp_owned_graph = build_strategy.Apply(
std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name,
...@@ -358,7 +358,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -358,7 +358,7 @@ ParallelExecutor::ParallelExecutor(
} }
} }
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor"; VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor( member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, async_graphs)); exec_strategy, member_->local_scopes_, member_->places_, async_graphs));
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h" #include "paddle/fluid/operators/distributed/communicator.h"
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
...@@ -28,6 +31,7 @@ namespace distributed { ...@@ -28,6 +31,7 @@ namespace distributed {
static inline void MergeVars(const std::string &var_name, static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars, const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) { Scope *scope) {
VLOG(3) << "merge " << vars.size() << " vars " << var_name << " to one";
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0]; auto &var0 = vars[0];
...@@ -67,14 +71,16 @@ std::unique_ptr<Communicator> Communicator::communicator_(nullptr); ...@@ -67,14 +71,16 @@ std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_; std::once_flag Communicator::init_flag_;
void Communicator::SendThread() { void Communicator::SendThread() {
VLOG("SendThread start!");
while (running_) { while (running_) {
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size()); task_futures.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_queue_) { for (auto &iter : send_varname_to_queue_) {
auto send_task = [this, &iter] {
auto &var_name = iter.first; auto &var_name = iter.first;
VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second; auto &var_queue = iter.second;
if (var_queue->NotEmpty()) { // will block if queue is empty
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << "merge var " << var_name << " and send";
std::vector<std::shared_ptr<Variable>> vars; std::vector<std::shared_ptr<Variable>> vars;
// TODO(qiao): need to be configurable // TODO(qiao): need to be configurable
const size_t max_merge_var_num = 20; const size_t max_merge_var_num = 20;
...@@ -91,6 +97,7 @@ void Communicator::SendThread() { ...@@ -91,6 +97,7 @@ void Communicator::SendThread() {
task_futures.emplace_back( task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task))); send_threadpool_->enqueue(std::move(send_task)));
} }
}
for (auto &task_f : task_futures) { for (auto &task_f : task_futures) {
task_f.wait(); task_f.wait();
} }
...@@ -98,6 +105,7 @@ void Communicator::SendThread() { ...@@ -98,6 +105,7 @@ void Communicator::SendThread() {
} }
void Communicator::RecvThread() { void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) { while (running_) {
// parallel run recv graph // parallel run recv graph
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
...@@ -115,6 +123,8 @@ void Communicator::RecvThread() { ...@@ -115,6 +123,8 @@ void Communicator::RecvThread() {
for (auto &task : task_futures) { for (auto &task : task_futures) {
task.wait(); task.wait();
} }
// TODO(qiao) need to be configuable
std::this_thread::sleep_for(std::chrono::milliseconds(200));
} }
} }
......
...@@ -68,6 +68,12 @@ class BlockingQueue { ...@@ -68,6 +68,12 @@ class BlockingQueue {
return rc; return rc;
} }
bool NotEmpty() {
std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); });
return true;
}
size_t Cap() const { size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return capacity_; return capacity_;
......
...@@ -60,12 +60,14 @@ class VariableResponse { ...@@ -60,12 +60,14 @@ class VariableResponse {
bool create_scope = false) bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) { if (create_scope) {
local_scope_ = scope->NewTmpScope(); local_scope_ = &scope->NewScope();
} }
} }
virtual ~VariableResponse() { virtual ~VariableResponse() {
if (local_scope_) delete local_scope_; if (local_scope_) {
scope_->DeleteScope(local_scope_);
}
} }
int Parse(Source* source, const sendrecv::VariableMessage& meta) { int Parse(Source* source, const sendrecv::VariableMessage& meta) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册