提交 d5090c89 编写于 作者: Y Yancey1989

polish code test=develop

上级 0f8bd73c
...@@ -34,7 +34,7 @@ namespace details { ...@@ -34,7 +34,7 @@ namespace details {
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling // Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang. // them in multiple threads or processes to avoid hang.
// NOTE: ParallelExecutor would execute this pass on each graph, so // NOTE: ParallelGraph would execute this pass on each graph, so
// don't need to append it here. // don't need to append it here.
return (!strategy.enable_sequential_execution_ && return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) && strategy.num_trainers_ > 1) &&
......
...@@ -389,8 +389,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( ...@@ -389,8 +389,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
OpHandleBase *op_handle = nullptr; OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&]( auto append_allreduce_op = [&](
std::vector<Scope *> &scopes, const std::vector<Scope *> &scopes,
std::vector<platform::Place> &places) -> OpHandleBase * { const std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
...@@ -407,13 +407,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( ...@@ -407,13 +407,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
op_handle = append_allreduce_op(local_scopes_, places_); op_handle = append_allreduce_op(local_scopes_, places_);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto p = places_[i]; if (strategy_.enable_parallel_graph_) {
std::vector<Scope *> ss{local_scopes_[i]}; op_handle = append_allreduce_op({local_scopes_[i]}, {places_[i]});
std::vector<platform::Place> ps{p}; }
if (strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(ss, ps);
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, places_[i]);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
...@@ -421,7 +419,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( ...@@ -421,7 +419,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, p); vars.size(), i, og, places_[i]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
......
...@@ -32,8 +32,9 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph( ...@@ -32,8 +32,9 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
g->Set(kGraphDepVars, new GraphDepVars); g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps); g->Set(kGraphOps, new GraphOps);
} }
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (auto &op : op_handles) {
auto &dev_ctx = op->DeviceContext(); auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first; auto &p = dev_ctx.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册