提交 6d752baf 编写于 作者: Y Yancey1989

use get_appropriate_dev to schedule rpc op

上级 4444e79e
...@@ -142,7 +142,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -142,7 +142,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
VLOG(3) << "Building ....";
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars[var->Name()] = var; all_vars[var->Name()] = var;
...@@ -162,36 +161,32 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -162,36 +161,32 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(program);
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0; size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0); std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t { auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
auto var_desc = all_vars.at(g_name); int64_t numel_all = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape()); auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim); int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0); PADDLE_ENFORCE_GT(numel, 0);
numel_all += numel;
}
auto smallest = auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads)); std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id = size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest)); static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel; balance_grads[dev_id] += numel_all;
return dev_id; return dev_id;
}; };
bool is_forwarding = true; bool is_forwarding = true;
int rpc_op_device_id = 0;
auto schedule_rpc_op = [&]() -> void {
rpc_op_device_id++;
if (rpc_op_device_id >= static_cast<int>(places_.size())) {
rpc_op_device_id = 0;
}
};
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>( if (boost::get<int>(
...@@ -200,37 +195,40 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -200,37 +195,40 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// append rpc op if program is distributed trainer main program. // append rpc op if program is distributed trainer main program.
// always use the first device // always use the first device
if (op->Type() == "send_vars") { if (op->Type() == "send_vars") {
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]); int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
if (got == remote_vars_devices_.end()) { if (op_dev_id == -1) {
schedule_rpc_op(); op_dev_id = get_appropriate_dev(op->InputArgumentNames());
} else { for (auto &varname : op->InputArgumentNames()) {
rpc_op_device_id = got->second; var_name_on_devices_.emplace(varname, op_dev_id);
}
} }
CreateRPCOp(&result, *op, rpc_op_device_id); CreateRPCOp(&result, *op, op_dev_id);
} else if (op->Type() == "recv") { } else if (op->Type() == "recv") {
schedule_rpc_op(); int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) { for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id}); var_name_on_devices_.emplace(varname, op_dev_id);
} }
CreateRPCOp(&result, *op, rpc_op_device_id); CreateRPCOp(&result, *op, op_dev_id);
} else { } else {
// send_barrier and fetch_barrier op would run on device 0
CreateRPCOp(&result, *op, 0); CreateRPCOp(&result, *op, 0);
} }
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
if (op->Type() == "split_byref") { if (op->Type() == "split_byref") {
schedule_rpc_op(); int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) { for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id}); var_name_on_devices_.emplace(varname, op_dev_id);
}
CreateDistTrainOp(&result, *op, rpc_op_device_id);
} }
if (op->Type() == "concat") { CreateDistTrainOp(&result, *op, op_dev_id);
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]); } else if (op->Type() == "concat") {
PADDLE_ENFORCE(got != remote_vars_devices_.end(), int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place to concatenate received var."); "can not find right place to concatenate received var.");
CreateDistTrainOp(&result, *op, got->second); CreateDistTrainOp(&result, *op, op_dev_id);
} else { } else {
CreateDistTrainOp(&result, *op, 0); PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
} }
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
...@@ -240,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -240,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
} else { } else {
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name); var_name_on_devices_.emplace(var_name, op_dev_id);
} }
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
...@@ -269,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -269,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name); cur_device_id = get_appropriate_dev({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name); var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
...@@ -402,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -402,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
int var_dev_id = -1; for (auto &varname : op.InputArgumentNames()) {
for (auto &var_name : op.InputArgumentNames()) { int dev_id = GetVarDeviceID(varname);
if (var_dev_id != -1) break; if (dev_id != -1) {
for (size_t i = 0; i < var_name_on_devices.size(); ++i) { return dev_id;
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
} }
} }
return var_dev_id; return -1;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
......
...@@ -47,14 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -47,14 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const;
int GetRemoteVarDeviceId(const std::string &var_name) const override {
auto got = remote_vars_devices_.find(var_name);
if (got != remote_vars_devices_.end()) {
return got->second;
}
return -1;
}
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
...@@ -105,9 +98,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -105,9 +98,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID( int GetOpDeviceID(const OpDesc &op) const;
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
...@@ -120,7 +111,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -120,7 +111,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
mutable std::unordered_map<std::string, int> remote_vars_devices_; mutable std::unordered_map<std::string, int> var_name_on_devices_;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
......
...@@ -30,9 +30,7 @@ class SSAGraphBuilder { ...@@ -30,9 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetRemoteVarDeviceId(const std::string &var_name) const { virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
return -1;
}
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -161,9 +161,8 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -161,9 +161,8 @@ void ParallelExecutor::BCastParamsToGPUs(
} }
auto &nccl_ctx = member_->nccl_ctxs_->at(place); auto &nccl_ctx = member_->nccl_ctxs_->at(place);
if (builder_.get() != nullptr && if (builder_.get() != nullptr && builder_->GetVarDeviceID(var) != -1) {
builder_->GetRemoteVarDeviceId(var) != -1) { int place_id = builder_->GetVarDeviceID(var);
int place_id = builder_->GetRemoteVarDeviceId(var);
platform::dynload::ncclBcast(buffer, numel, data_type, place_id, platform::dynload::ncclBcast(buffer, numel, data_type, place_id,
nccl_ctx.comm_, nccl_ctx.stream()); nccl_ctx.comm_, nccl_ctx.stream());
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册