提交 6ebc6bf5 编写于 作者: Y Yu Yang

ReorganizeCode

上级 a478a11e
add_subdirectory(details)
# ddim lib # ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
...@@ -87,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo ...@@ -87,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool concat) framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool var_handle)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
cc_library(var_handle SRCS var_handle.cc DEPS place)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/var_handle.h"
namespace paddle {
namespace framework {
namespace details {
VarHandleBase::~VarHandleBase() {}
std::string VarHandle::DebugString() const {
std::stringstream ss;
ss << name_ << ":" << place_;
return ss.str();
}
std::string DummyVarHandle::DebugString() const { return "dummy"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sstream>
#include <string>
#include <unordered_set>
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
struct OpHandleBase;
namespace details {
// VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
struct VarHandleBase {
virtual ~VarHandleBase();
virtual std::string DebugString() const = 0;
// The operator who generate this variable. nullptr if the variable
// is a root node.
OpHandleBase *generated_op_;
// Operators which depend on this variable ready.
std::unordered_set<OpHandleBase *> pending_ops_;
};
// VarHandle is actually a single version of Runtime Variable.
// Variable in Runtime mapped to many VarHandles in Graph.
// Each assignment will generate a new var handle with newer version.
//
// NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase {
std::string DebugString() const override;
// version field currently is not used, however, just store the version to
// debug easily.
size_t version_;
std::string name_;
platform::Place place_;
};
// Dummy Variable. It is used to represent dependencies between operators
struct DummyVarHandle : public VarHandleBase {
std::string DebugString() const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "lod_tensor.h" #include "lod_tensor.h"
#include "lod_tensor_array.h" #include "lod_tensor_array.h"
#include "op_registry.h" #include "op_registry.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -25,35 +26,11 @@ limitations under the License. */ ...@@ -25,35 +26,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct OpHandle; using details::DummyVarHandle;
using details::VarHandle;
using details::VarHandleBase;
struct VarHandleBase { struct OpHandleBase {
virtual ~VarHandleBase() {}
virtual std::string DebugString() const = 0;
OpHandle *generated_op_;
std::unordered_set<OpHandle *> pending_ops_;
};
struct VarHandle : public VarHandleBase {
std::string DebugString() const override {
std::stringstream ss;
ss << name_ << ":" << place_;
return ss.str();
}
// version field currently is not used, however, just store the version to
// debug easily.
size_t version_;
std::string name_;
platform::Place place_;
};
struct DummyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "dummy"; }
};
struct OpHandle {
std::vector<VarHandleBase *> inputs_; std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_; std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *, std::unordered_map<platform::Place, platform::DeviceContext *,
...@@ -76,7 +53,7 @@ struct OpHandle { ...@@ -76,7 +53,7 @@ struct OpHandle {
return ss.str(); return ss.str();
} }
virtual ~OpHandle() {} virtual ~OpHandleBase() {}
void Run(bool use_event) { void Run(bool use_event) {
if (events_.empty() && use_event) { if (events_.empty() && use_event) {
...@@ -117,7 +94,7 @@ struct OpHandle { ...@@ -117,7 +94,7 @@ struct OpHandle {
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
}; };
struct ScaleLossGradOpHandle : public OpHandle { struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_; float coeff_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
...@@ -150,7 +127,7 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -150,7 +127,7 @@ struct ScaleLossGradOpHandle : public OpHandle {
} }
}; };
struct FetchOpHandle : public OpHandle { struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_; FeedFetchList *data_;
size_t offset_; size_t offset_;
std::vector<Scope *> *local_scopes_; std::vector<Scope *> *local_scopes_;
...@@ -216,51 +193,13 @@ class ParallelExecutorPrivate { ...@@ -216,51 +193,13 @@ class ParallelExecutorPrivate {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
Scope *global_scope_; Scope *global_scope_;
#ifdef PADDLE_WITH_CUDA std::unordered_map<int, platform::NCCLContext> communication_streams_;
struct NCCLContext {
std::unique_ptr<platform::CUDADeviceContext> ctx_;
ncclComm_t comm;
explicit NCCLContext(int dev_id) {
ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id)));
}
cudaStream_t stream() const { return ctx_->stream(); }
int device_id() const { platform::NCCLContext &GetNCCLCtx(platform::Place p) {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
}
static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
const std::vector<platform::Place> &places) {
std::vector<ncclComm_t> comms;
std::vector<int> devs;
comms.resize(contexts.size());
devs.reserve(contexts.size());
for (auto &p : places) {
devs.push_back(boost::get<platform::CUDAPlace>(p).device);
}
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(contexts.size()), &devs[0]));
int i = 0;
for (auto &dev_id : devs) {
contexts.at(dev_id).comm = comms[i++];
}
}
};
std::unordered_map<int, NCCLContext> communication_streams_;
NCCLContext &GetNCCLCtx(platform::Place p) {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
return communication_streams_.at(dev_id); return communication_streams_.at(dev_id);
} }
#endif
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) { platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) { if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
return const_cast<platform::DeviceContext *>( return const_cast<platform::DeviceContext *>(
...@@ -282,27 +221,95 @@ class ParallelExecutorPrivate { ...@@ -282,27 +221,95 @@ class ParallelExecutorPrivate {
vars_; vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_; std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandle>> ops_; std::vector<std::unique_ptr<OpHandleBase>> ops_;
// Use a simpler thread pool, might be faster. // Use a simpler thread pool, might be faster.
std::unique_ptr<ThreadPool> pool_; std::unique_ptr<ThreadPool> pool_;
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
};
struct NCCLAllReduceOpHandle : public OpHandle { VarHandle *GetVarHandle(const std::string &each_var_name,
ParallelExecutorPrivate *member_; const platform::Place &place) {
auto &var_holders = vars_[place];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) void RunOp(
: member_(member) {} bool use_event,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandleBase *op) {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this, use_event] {
try {
VLOG(10) << op->DebugString();
op->Run(use_event);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
if (pool_) {
pool_->enqueue(op_run);
} else {
op_run();
}
}
void GenerateVar(OpHandleBase *op_handle, const std::string &each_var_name,
const platform::Place &place) {
auto &vars = vars_[place][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.generated_op_ = op_handle;
var.name_ = each_var_name;
var.place_ = place;
op_handle->outputs_.emplace_back(&var);
}
}; // namespace framework
struct NCCLAllReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const std::unordered_map<int, platform::NCCLContext> &communication_ctxs_;
explicit NCCLAllReduceOpHandle(
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const std::unordered_map<int, platform::NCCLContext> &ctxs)
: local_scopes_(local_scopes),
places_(places),
communication_ctxs_(ctxs) {}
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
OpHandle::Wait(waited_dev); OpHandleBase::Wait(waited_dev);
} }
protected: protected:
void RunImpl() override { void RunImpl() override {
if (this->inputs_.size() == 1) { if (inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
// Wait input done // Wait input done
...@@ -317,9 +324,9 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -317,9 +324,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) { for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = member_->places_[i]; auto &p = places_[i];
auto *s = member_->local_scopes_[i]; auto *s = local_scopes_[i];
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>(); auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
...@@ -336,16 +343,16 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -336,16 +343,16 @@ struct NCCLAllReduceOpHandle : public OpHandle {
if (numel == 0) { if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel()); numel = static_cast<size_t>(lod_tensor.numel());
} }
auto &nccl_ctx = member_->communication_streams_.at(dev_id); auto &nccl_ctx = communication_ctxs_.at(dev_id);
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream())); nccl_ctx.comm_, nccl_ctx.stream()));
} }
} }
} }
}; };
struct ComputationOpHandle : public OpHandle { struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
...@@ -443,14 +450,14 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -443,14 +450,14 @@ void ParallelExecutor::ConstructDependencyGraph(
auto var_names = op->InputArgumentNames(); auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
VarHandle *var = GetVarHandle(each_var_name, p); VarHandle *var = member_->GetVarHandle(each_var_name, p);
op_handle->inputs_.emplace_back(var); op_handle->inputs_.emplace_back(var);
var->pending_ops_.emplace(op_handle); var->pending_ops_.emplace(op_handle);
} }
var_names = op->OutputArgumentNames(); var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
GenerateVar(op_handle, each_var_name, p); member_->GenerateVar(op_handle, each_var_name, p);
} }
if (is_forwarding) { if (is_forwarding) {
...@@ -468,7 +475,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -468,7 +475,7 @@ void ParallelExecutor::ConstructDependencyGraph(
// loss->pending_ops_.emplace_back(op_handle); // loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss); // op_handle->inputs_.emplace_back(loss);
GenerateVar(op_handle, loss_var_name + "@GRAD", p); member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p);
change_forward = true; change_forward = true;
} }
} }
...@@ -483,7 +490,9 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -483,7 +490,9 @@ void ParallelExecutor::ConstructDependencyGraph(
for (auto &og : var_names) { for (auto &og : var_names) {
if (grads.count(og) != 0) { // is param grad if (grads.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op // Insert NCCL AllReduce Op
member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_)); member_->ops_.emplace_back(new NCCLAllReduceOpHandle(
member_->local_scopes_, member_->places_,
member_->communication_streams_));
auto *op_handle = member_->ops_.back().get(); auto *op_handle = member_->ops_.back().get();
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
...@@ -562,37 +571,6 @@ void ParallelExecutor::PolishGraphToSupportDataHazards() const { ...@@ -562,37 +571,6 @@ void ParallelExecutor::PolishGraphToSupportDataHazards() const {
} }
} }
void ParallelExecutor::GenerateVar(OpHandle *op_handle,
const std::string &each_var_name,
const platform::Place &place) const {
auto &vars = member_->vars_[place][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.generated_op_ = op_handle;
var.name_ = each_var_name;
var.place_ = place;
op_handle->outputs_.emplace_back(&var);
}
VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
const platform::Place &place) const {
auto &var_holders = member_->vars_[place];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const ProgramDesc &startup_program) const { const ProgramDesc &startup_program) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -621,8 +599,8 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -621,8 +599,8 @@ void ParallelExecutor::BCastParamsToGPUs(
} }
auto &nccl_ctx = member_->GetNCCLCtx(place); auto &nccl_ctx = member_->GetNCCLCtx(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm, platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.stream()); nccl_ctx.comm_, nccl_ctx.stream());
} }
} }
...@@ -640,12 +618,12 @@ void ParallelExecutor::BuildNCCLCommunicator() const { ...@@ -640,12 +618,12 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
for (auto &place : member_->places_) { for (auto &place : member_->places_) {
int dev_id = boost::get<platform::CUDAPlace>(place).device; int dev_id = boost::get<platform::CUDAPlace>(place).device;
member_->communication_streams_.emplace( member_->communication_streams_.emplace(dev_id,
dev_id, ParallelExecutorPrivate::NCCLContext(dev_id)); platform::NCCLContext(dev_id));
} }
ParallelExecutorPrivate::NCCLContext::InitNCCLContext( platform::NCCLContext::InitNCCLContext(member_->communication_streams_,
member_->communication_streams_, member_->places_); member_->places_);
#endif #endif
} }
...@@ -656,7 +634,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -656,7 +634,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars; std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::vector<DummyVarHandle> dummy_vars; std::vector<DummyVarHandle> dummy_vars;
for (auto &place_pair : member_->vars_) { for (auto &place_pair : member_->vars_) {
...@@ -672,7 +650,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -672,7 +650,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
pending_vars[var.get()] = var->generated_op_ == nullptr; pending_vars[var.get()] = var->generated_op_ == nullptr;
} }
std::vector<OpHandle *> to_run; std::vector<OpHandleBase *> to_run;
for (auto &op : member_->ops_) { for (auto &op : member_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input. if (op->inputs_.empty()) { // Special case, Op has no input.
...@@ -722,7 +700,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -722,7 +700,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
for (auto *op : to_run) { for (auto *op : to_run) {
RunOp(use_event, pending_vars, op); member_->RunOp(use_event, pending_vars, op);
} }
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
...@@ -750,7 +728,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -750,7 +728,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
for (auto *op : to_run) { for (auto *op : to_run) {
pending_ops.erase(op); pending_ops.erase(op);
RunOp(use_event, pending_vars, op); member_->RunOp(use_event, pending_vars, op);
} }
} }
...@@ -762,35 +740,5 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -762,35 +740,5 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
fetched_data; fetched_data;
} }
void ParallelExecutor::RunOp(
bool use_event,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandle *op) const {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this, use_event] {
try {
VLOG(10) << op->DebugString();
op->Run(use_event);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
member_->exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
if (member_->pool_) {
member_->pool_->enqueue(op_run);
} else {
op_run();
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,9 +29,6 @@ namespace paddle { ...@@ -29,9 +29,6 @@ namespace paddle {
namespace framework { namespace framework {
class ParallelExecutorPrivate; class ParallelExecutorPrivate;
class VarHandle;
class OpHandle;
class VarHandleBase;
class ParallelExecutor { class ParallelExecutor {
public: public:
...@@ -50,23 +47,12 @@ class ParallelExecutor { ...@@ -50,23 +47,12 @@ class ParallelExecutor {
void BCastParamsToGPUs(const ProgramDesc& startup_program) const; void BCastParamsToGPUs(const ProgramDesc& startup_program) const;
VarHandle* GetVarHandle(const std::string& each_var_name,
const platform::Place& place) const;
void GenerateVar(OpHandle* op_handle, const std::string& each_var_name,
const platform::Place& place) const;
void ConstructDependencyGraph(const std::unordered_set<std::string>& params, void ConstructDependencyGraph(const std::unordered_set<std::string>& params,
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& loss_var_name) const; const std::string& loss_var_name) const;
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp(
bool use_event,
std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHazards() const; void PolishGraphToSupportDataHazards() const;
}; };
......
...@@ -47,11 +47,45 @@ class NCCLGroupGuard { ...@@ -47,11 +47,45 @@ class NCCLGroupGuard {
} }
private: private:
static std::mutex& mutex() { static std::mutex &mutex() {
static std::mutex mtx; static std::mutex mtx;
return mtx; return mtx;
} }
}; };
struct NCCLContext {
std::unique_ptr<CUDADeviceContext> ctx_;
ncclComm_t comm_;
explicit NCCLContext(int dev_id)
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {}
cudaStream_t stream() const { return ctx_->stream(); }
int device_id() const {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
}
static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
const std::vector<platform::Place> &places) {
std::vector<ncclComm_t> comms;
std::vector<int> devs;
comms.resize(contexts.size());
devs.reserve(contexts.size());
for (auto &p : places) {
devs.push_back(boost::get<platform::CUDAPlace>(p).device);
}
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(contexts.size()), &devs[0]));
int i = 0;
for (auto &dev_id : devs) {
contexts.at(dev_id).comm_ = comms[i++];
}
}
};
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册