提交 0ef9edf5 编写于 作者: Y Yu Yang

Stash

上级 5e87cd75
......@@ -229,8 +229,15 @@ class ParallelExecutorPrivate {
// TODO(yy): Move this function somewhere
ncclDataType_t ToNCCLDataType(std::type_index type) {
// FIXME!!
return ncclFloat;
if (type == typeid(float)) { // NOLINT
return ncclFloat;
} else if (type == typeid(double)) { // NOLINT
return ncclDouble;
} else if (type == typeid(int)) { // NOLINT
return ncclInt;
} else {
PADDLE_THROW("Not supported");
}
}
ParallelExecutor::ParallelExecutor(
......@@ -479,30 +486,32 @@ void ParallelExecutor::BCastParamsToGPUs(
ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
auto &dims = main_tensor.dims();
size_t numel = main_tensor.numel();
std::vector<std::pair<void *, ParallelExecutorPrivate::NCCLContext *>>
mems;
mems.emplace_back(const_cast<void *>(main_tensor.data<void>()),
&member_->GetNCCLCtx(member_->main_place_));
for (auto &pair : member_->local_scopes_) {
if (pair.first == member_->main_place_) {
continue;
}
platform::dynload::ncclGroupStart();
for (auto &pair : member_->local_scopes_) {
auto local_scope = pair.second;
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims);
mems.emplace_back(t->mutable_data(pair.first, main_tensor.type()),
&member_->GetNCCLCtx(member_->main_place_));
auto &nccl_ctx = member_->GetNCCLCtx(pair.first);
platform::dynload::ncclBcast(
t->mutable_data(pair.first, main_tensor.type()), numel, data_type,
0, nccl_ctx.comm, nccl_ctx.stream());
}
platform::dynload::ncclGroupEnd();
}
}
// TODO(yy): Invoke ncclBCast here. mems, numel, data_type. The mems[0]
// is the src, rests are dests.
for (auto &pair : member_->local_scopes_) {
member_->GetNCCLCtx(pair.first).ctx_->Wait();
(void)(data_type);
(void)(numel);
}
auto &b = pair.second->FindVar("fc_1.b_0")->Get<framework::LoDTensor>();
framework::LoDTensor cpu;
framework::TensorCopy(b, platform::CPUPlace(), &cpu);
platform::DeviceContextPool::Instance().Get(b.place())->Wait();
LOG(INFO) << *cpu.data<float>();
}
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
......
......@@ -52,7 +52,7 @@ class ParallelExecutor(unittest.TestCase):
adam = fluid.optimizer.Adam()
adam.minimize(loss)
act_places = []
for each in [fluid.CUDAPlace(0)]:
for each in [fluid.CUDAPlace(0), fluid.CUDAPlace(1)]:
p = fluid.core.Place()
p.set_place(each)
act_places.append(p)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册