提交 93401c98 编写于 作者: Y Yancey1989

overlap rpc op memcpy in distributed training

上级 df87e63b
......@@ -191,15 +191,54 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
};
bool is_forwarding = true;
std::unordered_map<std::string, int> rpc_var_device_mapping;
int rpc_op_device_id = 0;
auto schedule_rpc_op = [&]() -> void {
rpc_op_device_id++;
if (rpc_op_device_id >= static_cast<int>(places_.size())) {
rpc_op_device_id = 0;
}
};
for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op);
if (op->Type() == "send_vars") {
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
if (got == remote_vars_devices_.end()) {
schedule_rpc_op();
} else {
rpc_op_device_id = got->second;
}
CreateRPCOp(&result, *op, rpc_op_device_id);
} else if (op->Type() == "recv") {
schedule_rpc_op();
for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id});
}
CreateRPCOp(&result, *op, rpc_op_device_id);
} else {
CreateRPCOp(&result, *op, 0);
}
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op);
if (op->Type() == "split_byref") {
schedule_rpc_op();
for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id});
}
CreateDistTrainOp(&result, *op, rpc_op_device_id);
}
if (op->Type() == "oncat") {
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
PADDLE_ENFORCE_NE(got != remote_vars_devices_.end(),
"can not find right place to concat received var.");
CreateDistTrainOp(&result, *op, got->second);
} else {
CreateDistTrainOp(&result, *op, 0);
}
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
......@@ -464,17 +503,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
const OpDesc &op,
int place_id) const {
CreateComputationalOp(result, op, place_id);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0];
auto *s = local_scopes_[0];
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
int place_id) const {
auto &p = places_[place_id];
auto *s = local_scopes_[place_id];
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") {
......@@ -493,7 +533,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0);
CreateOpHandleIOs(result, op, place_id);
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
......
......@@ -48,6 +48,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetRemoteVarDevice(const std::string &var_name) const {
auto got = remote_vars_devices_.find(var_name);
if (got != remote_vars_devices_.end()) {
return got->second;
}
return -1;
}
private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t place_id) const;
......@@ -64,8 +72,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op,
int place_id) const;
/**
* Is this operator as the end-point operator before/after send operator.
......@@ -111,6 +120,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
BuildStrategy strategy_;
mutable std::unordered_map<std::string, int> remote_vars_devices_;
};
} // namespace details
} // namespace framework
......
......@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -97,15 +96,17 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(
builder_.reset(new details::MultiDevSSAGraphBuilder(
member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), build_strategy);
member_->nccl_ctxs_.get(), build_strategy));
#else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_,
build_strategy);
builder_.reset(new details::MultiDevSSAGraphBuilder(
member_->places_, loss_var_name, params, member_->local_scope_,
build_strategy));
#endif
auto graph = builder.Build(main_program);
auto graph = builder_.get()->Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph)));
......@@ -146,9 +147,17 @@ void ParallelExecutor::BCastParamsToGPUs(
buffer = t->mutable_data(place, main_tensor.type());
}
auto &nccl_ctx = member_->nccl_ctxs_->at(place);
if (builder_.get() != nullptr &&
builder_->GetRemoteVarDevice(var) != -1) {
int place_id = builder_->GetRemoteVarDevice(var);
platform::dynload::ncclBcast(buffer, numel, data_type, place_id,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
}
} else {
platform::CPUPlace cpu;
for (size_t i = 1; i < member_->places_.size(); ++i) {
......
......@@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
......@@ -68,6 +70,7 @@ class ParallelExecutor {
private:
ParallelExecutorPrivate *member_;
std::unique_ptr<details::MultiDevSSAGraphBuilder> builder_;
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册