提交 0ef9edf5 编写于 作者: Y Yu Yang

Stash

上级 5e87cd75
...@@ -229,8 +229,15 @@ class ParallelExecutorPrivate { ...@@ -229,8 +229,15 @@ class ParallelExecutorPrivate {
// TODO(yy): Move this function somewhere // TODO(yy): Move this function somewhere
ncclDataType_t ToNCCLDataType(std::type_index type) { ncclDataType_t ToNCCLDataType(std::type_index type) {
// FIXME!! if (type == typeid(float)) { // NOLINT
return ncclFloat; return ncclFloat;
} else if (type == typeid(double)) { // NOLINT
return ncclDouble;
} else if (type == typeid(int)) { // NOLINT
return ncclInt;
} else {
PADDLE_THROW("Not supported");
}
} }
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
...@@ -479,30 +486,32 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -479,30 +486,32 @@ void ParallelExecutor::BCastParamsToGPUs(
ncclDataType_t data_type = ToNCCLDataType(main_tensor.type()); ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
std::vector<std::pair<void *, ParallelExecutorPrivate::NCCLContext *>>
mems;
mems.emplace_back(const_cast<void *>(main_tensor.data<void>()),
&member_->GetNCCLCtx(member_->main_place_));
for (auto &pair : member_->local_scopes_) { platform::dynload::ncclGroupStart();
if (pair.first == member_->main_place_) {
continue;
}
for (auto &pair : member_->local_scopes_) {
auto local_scope = pair.second; auto local_scope = pair.second;
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims); t->Resize(dims);
mems.emplace_back(t->mutable_data(pair.first, main_tensor.type()), auto &nccl_ctx = member_->GetNCCLCtx(pair.first);
&member_->GetNCCLCtx(member_->main_place_)); platform::dynload::ncclBcast(
t->mutable_data(pair.first, main_tensor.type()), numel, data_type,
0, nccl_ctx.comm, nccl_ctx.stream());
}
platform::dynload::ncclGroupEnd();
}
} }
// TODO(yy): Invoke ncclBCast here. mems, numel, data_type. The mems[0] for (auto &pair : member_->local_scopes_) {
// is the src, rests are dests. member_->GetNCCLCtx(pair.first).ctx_->Wait();
(void)(data_type); auto &b = pair.second->FindVar("fc_1.b_0")->Get<framework::LoDTensor>();
(void)(numel); framework::LoDTensor cpu;
} framework::TensorCopy(b, platform::CPUPlace(), &cpu);
platform::DeviceContextPool::Instance().Get(b.place())->Wait();
LOG(INFO) << *cpu.data<float>();
} }
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
......
...@@ -52,7 +52,7 @@ class ParallelExecutor(unittest.TestCase): ...@@ -52,7 +52,7 @@ class ParallelExecutor(unittest.TestCase):
adam = fluid.optimizer.Adam() adam = fluid.optimizer.Adam()
adam.minimize(loss) adam.minimize(loss)
act_places = [] act_places = []
for each in [fluid.CUDAPlace(0)]: for each in [fluid.CUDAPlace(0), fluid.CUDAPlace(1)]:
p = fluid.core.Place() p = fluid.core.Place()
p.set_place(each) p.set_place(each)
act_places.append(p) act_places.append(p)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册