提交 d5090c89 编写于 作者: Y Yancey1989

polish code test=develop

上级 0f8bd73c
......@@ -34,7 +34,7 @@ namespace details {
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
// NOTE: ParallelExecutor would execute this pass on each graph, so
// NOTE: ParallelGraph would execute this pass on each graph, so
// don't need to append it here.
return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) &&
......
......@@ -389,8 +389,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&](
std::vector<Scope *> &scopes,
std::vector<platform::Place> &places) -> OpHandleBase * {
const std::vector<Scope *> &scopes,
const std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
......@@ -407,13 +407,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
op_handle = append_allreduce_op(local_scopes_, places_);
for (size_t i = 0; i < places_.size(); ++i) {
auto p = places_[i];
std::vector<Scope *> ss{local_scopes_[i]};
std::vector<platform::Place> ps{p};
if (strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(ss, ps);
if (strategy_.enable_parallel_graph_) {
op_handle = append_allreduce_op({local_scopes_[i]}, {places_[i]});
}
SetCommunicationContext(op_handle, p);
SetCommunicationContext(op_handle, places_[i]);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
......@@ -421,7 +419,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, p);
vars.size(), i, og, places_[i]);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
......
......@@ -32,8 +32,9 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps);
}
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
for (auto &op : op_handles) {
auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(p).device;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册