提交 79be0604 编写于 作者: Y Yu Yang

Support CPU/GPU mixture for ParallelExecutor

上级 7083c2a6
......@@ -116,6 +116,19 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0].type()), func);
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &scope = local_scopes_[i];
auto &p = places_[i];
auto *var = scope->FindVar(var_name);
auto *dev_ctx = dev_ctxes_[p];
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
auto &tensor_cpu = trg;
TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu);
});
}
}
}
}
......
......@@ -107,6 +107,22 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#endif
}
void OpHandleBase::RunAndRecordEvent(platform::Place p,
const std::function<void()> &callback) {
if (platform::is_cpu_place(p) || events_.empty()) {
callback();
} else {
#ifdef PADDLE_WITH_CUDA
auto *ctx = dev_ctxes_.at(p);
auto *cuda_ctx = static_cast<platform::CUDADeviceContext *>(ctx);
cuda_ctx->RecordEvent(events_.at(boost::get<platform::CUDAPlace>(p).device),
callback);
#else
PADDLE_THROW("Not implemented");
#endif
}
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -64,6 +64,9 @@ class OpHandleBase {
protected:
void RunAndRecordEvent(const std::function<void()> &callback);
void RunAndRecordEvent(platform::Place p,
const std::function<void()> &callback);
virtual void RunImpl() = 0;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册