提交 254d7ff4 编写于 作者: Y Yu Yang

Refactor local_scopes

上级 b2c7a9b8
...@@ -151,11 +151,10 @@ class ParallelExecutorPrivate { ...@@ -151,11 +151,10 @@ class ParallelExecutorPrivate {
explicit ParallelExecutorPrivate(size_t num_threads = 12) explicit ParallelExecutorPrivate(size_t num_threads = 12)
: pool_(num_threads) {} : pool_(num_threads) {}
std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
struct NCCLContext { struct NCCLContext {
std::unique_ptr<platform::CUDADeviceContext> ctx_; std::unique_ptr<platform::CUDADeviceContext> ctx_;
...@@ -260,10 +259,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -260,10 +259,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclGroupStart(); platform::dynload::ncclGroupStart();
for (auto &p : member_->places_) { for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
auto &p = member_->places_[i];
auto *s = member_->local_scopes_[i];
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
Scope *s = member_->local_scopes_[p];
auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>(); auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
void *buffer = const_cast<void *>(lod_tensor.data<void>()); void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) { if (dtype == -1) {
...@@ -302,8 +302,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -302,8 +302,8 @@ ParallelExecutor::ParallelExecutor(
Executor exe(places[0]); Executor exe(places[0]);
exe.Run(startup_program, scope, 0); exe.Run(startup_program, scope, 0);
// Create local scopes // Create local scopes
for (auto &place : places) { for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_[place] = &scope->NewScope(); member_->local_scopes_.push_back(&scope->NewScope());
} }
member_->main_place_ = places[0]; member_->main_place_ = places[0];
...@@ -320,9 +320,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -320,9 +320,7 @@ ParallelExecutor::ParallelExecutor(
ConstructDependencyGraph(params, main_program, loss_var_name); ConstructDependencyGraph(params, main_program, loss_var_name);
// Step 3. Create vars in each scope; // Step 3. Create vars in each scope;
for (auto &pair : member_->local_scopes_) { for (auto *scope : member_->local_scopes_) {
auto *scope = pair.second;
for (auto *var : main_program.Block(0).AllVars()) { for (auto *var : main_program.Block(0).AllVars()) {
if (scope->FindVar(var->Name()) != nullptr) { if (scope->FindVar(var->Name()) != nullptr) {
continue; continue;
...@@ -353,46 +351,44 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -353,46 +351,44 @@ void ParallelExecutor::ConstructDependencyGraph(
} }
} }
for (auto &pair : member_->local_scopes_) { for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->ops_.emplace_back( auto &p = member_->places_[i];
new ComputationOpHandle(*op, pair.second, pair.first)); auto *s = member_->local_scopes_[i];
member_->ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = member_->ops_.back().get(); auto *op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>( op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(pair.first)); platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames(); auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
auto &place = pair.first; VarHandle *var = GetVarHandle(each_var_name, p);
VarHandle *var = GetVarHandle(each_var_name, place);
op_handle->inputs_.emplace_back(var); op_handle->inputs_.emplace_back(var);
var->pending_ops_.emplace_back(op_handle); var->pending_ops_.emplace_back(op_handle);
} }
var_names = op->OutputArgumentNames(); var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
auto &place = pair.first; GenerateVar(op_handle, each_var_name, p);
GenerateVar(op_handle, each_var_name, place);
} }
if (is_forwarding) { if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name) { if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
member_->ops_.emplace_back(new ScaleLossGradOpHandle( member_->ops_.emplace_back(new ScaleLossGradOpHandle(
this->member_->local_scopes_.size(), pair.second, pair.first)); this->member_->local_scopes_.size(), s, p));
op_handle = member_->ops_.back().get(); op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
member_->CommunicationDevCtx(pair.first);
auto &place = pair.first;
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators. // factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place); // VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle); // loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss); // op_handle->inputs_.emplace_back(loss);
GenerateVar(op_handle, loss_var_name + "@GRAD", place); GenerateVar(op_handle, loss_var_name + "@GRAD", p);
change_forward = true; change_forward = true;
LOG(INFO) << "Scale Loss " << op_handle->DebugString(); LOG(INFO) << "Scale Loss " << op_handle->DebugString();
} }
...@@ -411,9 +407,9 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -411,9 +407,9 @@ void ParallelExecutor::ConstructDependencyGraph(
member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_)); member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_));
auto *op_handle = member_->ops_.back().get(); auto *op_handle = member_->ops_.back().get();
for (auto &pair : member_->local_scopes_) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &place = pair.first; auto &p = member_->places_[i];
auto &vars = member_->vars_[place][og]; auto &vars = member_->vars_[p][og];
if (vars.empty()) { // This device has no data. continue. if (vars.empty()) { // This device has no data. continue.
continue; continue;
...@@ -422,16 +418,13 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -422,16 +418,13 @@ void ParallelExecutor::ConstructDependencyGraph(
op_handle->inputs_.emplace_back(prev_grad); op_handle->inputs_.emplace_back(prev_grad);
prev_grad->pending_ops_.emplace_back(op_handle); prev_grad->pending_ops_.emplace_back(op_handle);
auto &var = vars[vars.size()]; auto &var = vars[vars.size()];
var.place_ = place; var.place_ = p;
var.generated_op_ = op_handle; var.generated_op_ = op_handle;
var.name_ = og; var.name_ = og;
var.version_ = vars.size() - 1; var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var); op_handle->outputs_.emplace_back(&var);
for (auto &pair : member_->local_scopes_) { op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
op_handle->dev_ctx_[pair.first] =
member_->CommunicationDevCtx(pair.first);
}
} }
} }
} }
...@@ -529,7 +522,7 @@ VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name, ...@@ -529,7 +522,7 @@ VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const ProgramDesc &startup_program) const { const ProgramDesc &startup_program) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *main_scope = member_->local_scopes_[member_->main_place_]; auto *main_scope = member_->local_scopes_[0];
for (auto *var_desc : startup_program.Block(0).AllVars()) { for (auto *var_desc : startup_program.Block(0).AllVars()) {
if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { if (var_desc->GetType() == proto::VarType::LOD_TENSOR) {
...@@ -547,7 +540,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -547,7 +540,7 @@ void ParallelExecutor::BCastParamsToGPUs(
if (i == 0) { if (i == 0) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[place]; auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims); t->Resize(dims);
buffer = t->mutable_data(place, main_tensor.type()); buffer = t->mutable_data(place, main_tensor.type());
...@@ -560,18 +553,6 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -560,18 +553,6 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
} }
} }
// Debug code, bias should be 1.0f.
for (auto &pair : member_->local_scopes_) {
member_->GetNCCLCtx(pair.first).ctx_->Wait();
auto &b = pair.second->FindVar("fc_0.b_0")->Get<framework::LoDTensor>();
framework::LoDTensor cpu;
framework::TensorCopy(b, platform::CPUPlace(), &cpu);
platform::DeviceContextPool::Instance().Get(b.place())->Wait();
LOG(INFO) << *cpu.data<float>();
}
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
...@@ -579,8 +560,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -579,8 +560,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::BuildNCCLCommunicator() const { void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &place_pair : member_->local_scopes_) { for (auto &place : member_->places_) {
auto place = place_pair.first;
int dev_id = boost::get<platform::CUDAPlace>(place).device; int dev_id = boost::get<platform::CUDAPlace>(place).device;
member_->communication_streams_.emplace( member_->communication_streams_.emplace(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册