“24509f4af942bb250564756ad636691c7921e1df”上不存在“paddle/legacy/gserver/layers/AddtoLayer.h”
提交 fe7ed285 编写于 作者: Y Yu Yang

Extract NCCLCtxMap

上级 6ebc6bf5
...@@ -88,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo ...@@ -88,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool var_handle) framework_proto backward glog lod_rank_table simple_threadpool var_handle op_handle_base)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
cc_library(var_handle SRCS var_handle.cc DEPS place) cc_library(var_handle SRCS var_handle.cc DEPS place)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
namespace details {
std::string OpHandleBase::DebugString() const {
std::stringstream ss;
ss << "(";
for (auto *var : inputs_) {
ss << var->DebugString() << ", ";
}
ss << ") --> (";
for (auto *var : outputs_) {
ss << var->DebugString() << ", ";
}
ss << ")\n";
return ss.str();
}
OpHandleBase::~OpHandleBase() {}
void OpHandleBase::Run(bool use_event) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id);
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming);
}
}
#else
PADDLE_ENFORCE(!use_event);
#endif
RunImpl();
#ifdef PADDLE_WITH_CUDA
if (use_event) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
cudaEventRecord(events_.at(dev_id), stream);
}
}
#endif
}
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait();
}
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
}
#else
for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait();
}
#endif
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct OpHandleBase {
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctx_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;
#endif
std::string DebugString() const;
virtual ~OpHandleBase();
void Run(bool use_event);
virtual void Wait(platform::DeviceContext *waited_dev);
protected:
virtual void RunImpl() = 0;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -21,10 +21,8 @@ ...@@ -21,10 +21,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct OpHandleBase;
namespace details { namespace details {
struct OpHandleBase;
// VarHandleBase is the var node in the dependency graph. // VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
......
...@@ -14,86 +14,22 @@ limitations under the License. */ ...@@ -14,86 +14,22 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "ThreadPool.h" #include "ThreadPool.h"
#include "executor.h"
#include "lod_tensor.h" #include "lod_tensor.h"
#include "lod_tensor_array.h" #include "lod_tensor_array.h"
#include "op_registry.h" #include "op_registry.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using details::DummyVarHandle; using details::DummyVarHandle;
using details::OpHandleBase;
using details::VarHandle; using details::VarHandle;
using details::VarHandleBase; using details::VarHandleBase;
struct OpHandleBase {
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctx_;
std::unordered_map<int, cudaEvent_t> events_;
std::string DebugString() {
std::stringstream ss;
ss << "(";
for (auto *var : inputs_) {
ss << var->DebugString() << ", ";
}
ss << ") --> (";
for (auto *var : outputs_) {
ss << var->DebugString() << ", ";
}
ss << ")\n";
return ss.str();
}
virtual ~OpHandleBase() {}
void Run(bool use_event) {
if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id);
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming);
}
}
RunImpl();
if (use_event) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
cudaEventRecord(events_.at(dev_id), stream);
}
}
}
virtual void Wait(platform::DeviceContext *waited_dev) {
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait();
}
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
}
}
protected:
virtual void RunImpl() = 0;
};
struct ScaleLossGradOpHandle : public OpHandleBase { struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_; float coeff_;
Scope *scope_; Scope *scope_;
...@@ -193,12 +129,7 @@ class ParallelExecutorPrivate { ...@@ -193,12 +129,7 @@ class ParallelExecutorPrivate {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
Scope *global_scope_; Scope *global_scope_;
std::unordered_map<int, platform::NCCLContext> communication_streams_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
platform::NCCLContext &GetNCCLCtx(platform::Place p) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
return communication_streams_.at(dev_id);
}
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) { platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) { if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
...@@ -206,7 +137,7 @@ class ParallelExecutorPrivate { ...@@ -206,7 +137,7 @@ class ParallelExecutorPrivate {
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
return GetNCCLCtx(place).ctx_.get(); return nccl_ctxs_->DevCtx(place);
#else #else
PADDLE_THROW("Not compiled with CUDA") PADDLE_THROW("Not compiled with CUDA")
#endif #endif
...@@ -293,15 +224,12 @@ class ParallelExecutorPrivate { ...@@ -293,15 +224,12 @@ class ParallelExecutorPrivate {
struct NCCLAllReduceOpHandle : public OpHandleBase { struct NCCLAllReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
const std::unordered_map<int, platform::NCCLContext> &communication_ctxs_; const platform::NCCLContextMap &nccl_ctxs_;
explicit NCCLAllReduceOpHandle( explicit NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::unordered_map<int, platform::NCCLContext> &ctxs) const platform::NCCLContextMap &ctxs)
: local_scopes_(local_scopes), : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {}
places_(places),
communication_ctxs_(ctxs) {}
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
OpHandleBase::Wait(waited_dev); OpHandleBase::Wait(waited_dev);
...@@ -343,7 +271,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -343,7 +271,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
if (numel == 0) { if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel()); numel = static_cast<size_t>(lod_tensor.numel());
} }
auto &nccl_ctx = communication_ctxs_.at(dev_id); auto &nccl_ctx = nccl_ctxs_.at(dev_id);
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm_, nccl_ctx.stream())); nccl_ctx.comm_, nccl_ctx.stream()));
...@@ -491,8 +419,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -491,8 +419,7 @@ void ParallelExecutor::ConstructDependencyGraph(
if (grads.count(og) != 0) { // is param grad if (grads.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op // Insert NCCL AllReduce Op
member_->ops_.emplace_back(new NCCLAllReduceOpHandle( member_->ops_.emplace_back(new NCCLAllReduceOpHandle(
member_->local_scopes_, member_->places_, member_->local_scopes_, member_->places_, *member_->nccl_ctxs_));
member_->communication_streams_));
auto *op_handle = member_->ops_.back().get(); auto *op_handle = member_->ops_.back().get();
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
...@@ -598,15 +525,12 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -598,15 +525,12 @@ void ParallelExecutor::BCastParamsToGPUs(
buffer = t->mutable_data(place, main_tensor.type()); buffer = t->mutable_data(place, main_tensor.type());
} }
auto &nccl_ctx = member_->GetNCCLCtx(place); auto &nccl_ctx = member_->nccl_ctxs_->at(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0, platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream()); nccl_ctx.comm_, nccl_ctx.stream());
} }
} }
member_->nccl_ctxs_->WaitAll();
for (auto &stream : member_->communication_streams_) {
stream.second.ctx_->Wait();
}
} }
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
...@@ -615,15 +539,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -615,15 +539,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::BuildNCCLCommunicator() const { void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &place : member_->places_) { member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
int dev_id = boost::get<platform::CUDAPlace>(place).device;
member_->communication_streams_.emplace(dev_id,
platform::NCCLContext(dev_id));
}
platform::NCCLContext::InitNCCLContext(member_->communication_streams_,
member_->places_);
#endif #endif
} }
...@@ -682,7 +598,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -682,7 +598,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
op->offset_ = i; op->offset_ = i;
op->local_scopes_ = &member_->local_scopes_; op->local_scopes_ = &member_->local_scopes_;
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
op->dev_ctx_[p] = member_->GetNCCLCtx(p).ctx_.get(); op->dev_ctx_[p] = member_->nccl_ctxs_->DevCtx(p);
} }
for (auto *var : vars) { for (auto *var : vars) {
......
...@@ -87,5 +87,51 @@ struct NCCLContext { ...@@ -87,5 +87,51 @@ struct NCCLContext {
} }
}; };
struct NCCLContextMap {
std::unordered_map<int, NCCLContext> contexts_;
std::vector<int> order_;
NCCLContextMap(const std::vector<platform::Place> &places) {
order_.reserve(places.size());
for (auto &p : places) {
int dev_id = boost::get<CUDAPlace>(p).device;
order_.emplace_back(dev_id);
contexts_.emplace(dev_id, NCCLContext(dev_id));
}
PADDLE_ENFORCE_EQ(
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
std::vector<ncclComm_t> comms;
comms.resize(order_.size());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(order_.size()), &order_[0]));
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
}
CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
CUDADeviceContext *DevCtx(platform::Place p) const {
return DevCtx(boost::get<CUDAPlace>(p).device);
}
const NCCLContext &at(platform::Place p) const {
return this->at(boost::get<CUDAPlace>(p).device);
}
const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }
void WaitAll() {
for (auto &p : contexts_) {
p.second.ctx_->Wait();
}
}
};
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册