提交 6d752baf 编写于 作者: Y Yancey1989

use get_appropriate_dev to schedule rpc op

上级 4444e79e
......@@ -142,7 +142,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
VLOG(3) << "Building ....";
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars[var->Name()] = var;
......@@ -162,36 +161,32 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
int64_t numel_all = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
PADDLE_ENFORCE_GT(numel, 0);
numel_all += numel;
}
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
balance_grads[dev_id] += numel_all;
return dev_id;
};
bool is_forwarding = true;
int rpc_op_device_id = 0;
auto schedule_rpc_op = [&]() -> void {
rpc_op_device_id++;
if (rpc_op_device_id >= static_cast<int>(places_.size())) {
rpc_op_device_id = 0;
}
};
for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>(
......@@ -200,37 +195,40 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// append rpc op if program is distributed trainer main program.
// always use the first device
if (op->Type() == "send_vars") {
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
if (got == remote_vars_devices_.end()) {
schedule_rpc_op();
} else {
rpc_op_device_id = got->second;
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
if (op_dev_id == -1) {
op_dev_id = get_appropriate_dev(op->InputArgumentNames());
for (auto &varname : op->InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
CreateRPCOp(&result, *op, rpc_op_device_id);
CreateRPCOp(&result, *op, op_dev_id);
} else if (op->Type() == "recv") {
schedule_rpc_op();
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id});
var_name_on_devices_.emplace(varname, op_dev_id);
}
CreateRPCOp(&result, *op, rpc_op_device_id);
CreateRPCOp(&result, *op, op_dev_id);
} else {
// send_barrier and fetch_barrier op would run on device 0
CreateRPCOp(&result, *op, 0);
}
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
if (op->Type() == "split_byref") {
schedule_rpc_op();
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) {
remote_vars_devices_.insert({varname, rpc_op_device_id});
}
CreateDistTrainOp(&result, *op, rpc_op_device_id);
var_name_on_devices_.emplace(varname, op_dev_id);
}
if (op->Type() == "concat") {
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
PADDLE_ENFORCE(got != remote_vars_devices_.end(),
CreateDistTrainOp(&result, *op, op_dev_id);
} else if (op->Type() == "concat") {
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place to concatenate received var.");
CreateDistTrainOp(&result, *op, got->second);
CreateDistTrainOp(&result, *op, op_dev_id);
} else {
CreateDistTrainOp(&result, *op, 0);
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
......@@ -240,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name);
var_name_on_devices_.emplace(var_name, op_dev_id);
}
}
if (!is_forwarding && places_.size() > 1) {
......@@ -269,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name);
cur_device_id = get_appropriate_dev({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name);
var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
......@@ -402,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once;
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
for (auto &varname : op.InputArgumentNames()) {
int dev_id = GetVarDeviceID(varname);
if (dev_id != -1) {
return dev_id;
}
}
return var_dev_id;
return -1;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
......
......@@ -47,14 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetRemoteVarDeviceId(const std::string &var_name) const override {
auto got = remote_vars_devices_.find(var_name);
if (got != remote_vars_devices_.end()) {
return got->second;
}
return -1;
}
int GetVarDeviceID(const std::string &varname) const;
private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
......@@ -105,9 +98,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
int GetOpDeviceID(const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
......@@ -120,7 +111,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
BuildStrategy strategy_;
mutable std::unordered_map<std::string, int> remote_vars_devices_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
......
......@@ -30,9 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetRemoteVarDeviceId(const std::string &var_name) const {
return -1;
}
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
......@@ -161,9 +161,8 @@ void ParallelExecutor::BCastParamsToGPUs(
}
auto &nccl_ctx = member_->nccl_ctxs_->at(place);
if (builder_.get() != nullptr &&
builder_->GetRemoteVarDeviceId(var) != -1) {
int place_id = builder_->GetRemoteVarDeviceId(var);
if (builder_.get() != nullptr && builder_->GetVarDeviceID(var) != -1) {
int place_id = builder_->GetVarDeviceID(var);
platform::dynload::ncclBcast(buffer, numel, data_type, place_id,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册