提交 8f0590e7 编写于 作者: Y Yu Yang

Add ncclAllReduce

上级 c15d2c9e
...@@ -138,14 +138,6 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -138,14 +138,6 @@ struct ScaleLossGradOpHandle : public OpHandle {
} }
}; };
struct NCCLAllReduceOpHandle : public OpHandle {
void Run() override {
if (this->inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1;
}
}
};
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(size_t num_threads = 12) explicit ParallelExecutorPrivate(size_t num_threads = 12)
...@@ -243,6 +235,46 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -243,6 +235,46 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
} }
} }
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {}
void Run() override {
if (this->inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1;
size_t numel = 0;
for (auto &p : member_->places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
Scope *s = member_->local_scopes_[p];
auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) {
dtype = ToNCCLDataType(lod_tensor.type());
}
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
ncclAllReduce(buffer, buffer, numel, static_cast<ncclDataType_t>(dtype),
ncclSum, nccl_ctx.comm, nccl_ctx.stream());
}
ncclGroupEnd();
}
}
};
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
...@@ -361,7 +393,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -361,7 +393,7 @@ void ParallelExecutor::ConstructDependencyGraph(
for (auto &og : var_names) { for (auto &og : var_names) {
if (grads.count(og) != 0) { // is param grad if (grads.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op // Insert NCCL AllReduce Op
member_->ops_.emplace_back(new NCCLAllReduceOpHandle()); member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_));
auto *op_handle = member_->ops_.back().get(); auto *op_handle = member_->ops_.back().get();
for (auto &pair : member_->local_scopes_) { for (auto &pair : member_->local_scopes_) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册