提交 1fba0c57 编写于 作者: T typhoonzero

fix multi gpu dist train

上级 6c0356e4
...@@ -77,6 +77,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, ...@@ -77,6 +77,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
} }
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const {
if (send_op == nullptr) {
return false;
}
auto checker = [&](const std::vector<std::string> opvars,
const std::vector<std::string> sendvars) -> bool {
bool is_dist_train_op = false;
for (auto &var : opvars) {
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
is_dist_train_op = true;
break;
}
}
return is_dist_train_op;
};
if (op.Type() == "split") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
}
return false;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
auto graph = new SSAGraph(); auto graph = new SSAGraph();
...@@ -88,17 +115,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -88,17 +115,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
// Find "send" op first for split is in front of send.
OpDesc *send_op = nullptr;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
send_op = op;
break;
}
}
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") { if (op->Type() == "send") {
// append send op if program is distributed trainer main program. // append send op if program is distributed trainer main program.
// always use the first device // always use the first device
CreateSendOp(&result, *op); CreateSendOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
is_forwarding = false; is_forwarding = false;
} else { } else {
CreateComputationalOps(&result, *op); CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) { if (!is_forwarding) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. But there are no // broadcast, and each gradient is only broadcast once. But there are no
...@@ -196,8 +234,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { ...@@ -196,8 +234,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
} }
void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op,
for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) { size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
......
...@@ -62,7 +62,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -62,7 +62,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const; bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const; void CreateScaleLossGradOp(SSAGraph *result) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册