提交 f52d78d1 编写于 作者: Y Yancey1989

update by comment

上级 6d752baf
......@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
balance_vars_.resize(places_.size(), 0);
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
......@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker(op.InputArgumentNames(), recv_vars);
}
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars[var->Name()] = var;
all_vars_.emplace(var->Name(), var);
}
auto graph = new SSAGraph();
......@@ -165,71 +185,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
int64_t numel_all = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_all += numel;
}
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel_all;
return dev_id;
};
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
if (op->Type() == "send_vars") {
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
if (op_dev_id == -1) {
op_dev_id = get_appropriate_dev(op->InputArgumentNames());
for (auto &varname : op->InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
CreateRPCOp(&result, *op, op_dev_id);
} else if (op->Type() == "recv") {
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
CreateRPCOp(&result, *op, op_dev_id);
} else {
// send_barrier and fetch_barrier op would run on device 0
CreateRPCOp(&result, *op, 0);
}
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
if (op->Type() == "split_byref") {
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
for (auto &varname : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
CreateDistTrainOp(&result, *op, op_dev_id);
} else if (op->Type() == "concat") {
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place to concatenate received var.");
CreateDistTrainOp(&result, *op, op_dev_id);
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}
CreateDistTrainOp(&result, *op);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
......@@ -267,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev({g_name});
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(all_vars, g_name)) {
if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
......@@ -310,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph);
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const {
PADDLE_ENFORCE(all_vars.count(og) != 0);
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
......@@ -498,18 +460,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op,
int place_id) const {
CreateComputationalOp(result, op, place_id);
const OpDesc &op) const {
int op_dev_id = -1;
if (op.Type() == "split_byref") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}
PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type());
CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
int device_id) const {
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[device_id],
op.Type(), places_[device_id]));
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
int op_dev_id = -1;
if (op.Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
} else if (op.Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
// send_barrier and fetch_barrier op can be scheduled on device 0
op_dev_id = 0;
}
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send");
......@@ -525,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
"send, send_barrier. recv, fetch_barrier]");
}
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, device_id);
CreateOpHandleIOs(result, op, op_dev_id);
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
......
......@@ -65,9 +65,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op,
int place_id) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
......@@ -105,13 +104,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;
bool IsSparseGradient(
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const;
bool IsSparseGradient(const std::string &og) const;
size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;
private:
BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册