提交 2d94697a 编写于 作者: C chengduoZH

code refine

上级 5a3c8bf8
......@@ -297,21 +297,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
}
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(
p, platform::DeviceContextPool::Instance().Get(p));
}
#else
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
......@@ -334,24 +339,12 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(
p, platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
VLOG(4) << "NCCL - - - " << p;
op_handle->DeviceContext(p)->Wait();
VLOG(4) << "NCCL - - - " << p << " " << op_handle->DeviceContext(p);
auto var = new VarHandle(vars.size() - 1, i, og, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
......@@ -441,17 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
auto *op_handle = result->ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_[i][og];
auto &p = places_[i];
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(
p, platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
......
......@@ -111,6 +111,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
BuildStrategy strategy_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
};
} // namespace details
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册