提交 93401c98 编写于 作者: Y Yancey1989

overlap rpc op memcpy in distributed training

上级 df87e63b
...@@ -191,15 +191,54 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -191,15 +191,54 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}; };
bool is_forwarding = true; bool is_forwarding = true;
std::unordered_map<std::string, int> rpc_var_device_mapping;
int rpc_op_device_id = 0;
auto schedule_rpc_op = [&]() -> void {
rpc_op_device_id++;
if (rpc_op_device_id >= static_cast<int>(places_.size())) {
rpc_op_device_id = 0;
}
};
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>( if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program. // append rpc op if program is distributed trainer main program.
// always use the first device // always use the first device
CreateRPCOp(&result, *op); if (op->Type() == "send_vars") {
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
if (got == remote_vars_devices_.end()) {
schedule_rpc_op();
} else {
rpc_op_device_id = got->second;
}
CreateRPCOp(&result, *op, rpc_op_device_id);
} else if (op->Type() == "recv") {
schedule_rpc_op();
for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id});
}
CreateRPCOp(&result, *op, rpc_op_device_id);
} else {
CreateRPCOp(&result, *op, 0);
}
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op); if (op->Type() == "split_byref") {
schedule_rpc_op();
for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id});
}
CreateDistTrainOp(&result, *op, rpc_op_device_id);
}
if (op->Type() == "oncat") {
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
PADDLE_ENFORCE_NE(got != remote_vars_devices_.end(),
"can not find right place to concat received var.");
CreateDistTrainOp(&result, *op, got->second);
} else {
CreateDistTrainOp(&result, *op, 0);
}
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
...@@ -464,17 +503,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, ...@@ -464,17 +503,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op,
CreateComputationalOp(result, op, 0); int place_id) const {
CreateComputationalOp(result, op, place_id);
if (op.Type() == "concat") { if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
} }
} }
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
const OpDesc &op) const { int place_id) const {
auto &p = places_[0]; auto &p = places_[place_id];
auto *s = local_scopes_[0]; auto *s = local_scopes_[place_id];
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
...@@ -493,7 +533,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -493,7 +533,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
// TODO(Yancey1989): schedule rpc op on different place may // TODO(Yancey1989): schedule rpc op on different place may
// increate throughput // increate throughput
CreateOpHandleIOs(result, op, 0); CreateOpHandleIOs(result, op, place_id);
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
......
...@@ -48,6 +48,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -48,6 +48,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetRemoteVarDevice(const std::string &var_name) const {
auto got = remote_vars_devices_.find(var_name);
if (got != remote_vars_devices_.end()) {
return got->second;
}
return -1;
}
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t place_id) const; size_t place_id) const;
...@@ -64,8 +72,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -64,8 +72,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; void CreateDistTrainOp(SSAGraph *result, const OpDesc &op,
int place_id) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
...@@ -111,6 +120,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -111,6 +120,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
mutable std::unordered_map<std::string, int> remote_vars_devices_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -97,15 +96,17 @@ ParallelExecutor::ParallelExecutor( ...@@ -97,15 +96,17 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( builder_.reset(new details::MultiDevSSAGraphBuilder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), build_strategy); member_->nccl_ctxs_.get(), build_strategy));
#else #else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, builder_.reset(new details::MultiDevSSAGraphBuilder(
params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scope_,
build_strategy); build_strategy));
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder_.get()->Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
...@@ -146,8 +147,16 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -146,8 +147,16 @@ void ParallelExecutor::BCastParamsToGPUs(
buffer = t->mutable_data(place, main_tensor.type()); buffer = t->mutable_data(place, main_tensor.type());
} }
auto &nccl_ctx = member_->nccl_ctxs_->at(place); auto &nccl_ctx = member_->nccl_ctxs_->at(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream()); if (builder_.get() != nullptr &&
builder_->GetRemoteVarDevice(var) != -1) {
int place_id = builder_->GetRemoteVarDevice(var);
platform::dynload::ncclBcast(buffer, numel, data_type, place_id,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
} }
} else { } else {
platform::CPUPlace cpu; platform::CPUPlace cpu;
......
...@@ -19,12 +19,14 @@ limitations under the License. */ ...@@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -68,6 +70,7 @@ class ParallelExecutor { ...@@ -68,6 +70,7 @@ class ParallelExecutor {
private: private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unique_ptr<details::MultiDevSSAGraphBuilder> builder_;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册