提交 6ebc6bf5 编写于 作者: Y Yu Yang

ReorganizeCode

上级 a478a11e
add_subdirectory(details)
# ddim lib
proto_library(framework_proto SRCS framework.proto)
......@@ -87,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool concat)
framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool var_handle)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
cc_library(var_handle SRCS var_handle.cc DEPS place)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/var_handle.h"
namespace paddle {
namespace framework {
namespace details {
VarHandleBase::~VarHandleBase() {}
std::string VarHandle::DebugString() const {
std::stringstream ss;
ss << name_ << ":" << place_;
return ss.str();
}
std::string DummyVarHandle::DebugString() const { return "dummy"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sstream>
#include <string>
#include <unordered_set>
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
struct OpHandleBase;
namespace details {
// VarHandleBase is the var node in the dependency graph.
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
struct VarHandleBase {
virtual ~VarHandleBase();
virtual std::string DebugString() const = 0;
// The operator who generate this variable. nullptr if the variable
// is a root node.
OpHandleBase *generated_op_;
// Operators which depend on this variable ready.
std::unordered_set<OpHandleBase *> pending_ops_;
};
// VarHandle is actually a single version of Runtime Variable.
// Variable in Runtime mapped to many VarHandles in Graph.
// Each assignment will generate a new var handle with newer version.
//
// NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase {
std::string DebugString() const override;
// version field currently is not used, however, just store the version to
// debug easily.
size_t version_;
std::string name_;
platform::Place place_;
};
// Dummy Variable. It is used to represent dependencies between operators
struct DummyVarHandle : public VarHandleBase {
std::string DebugString() const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "lod_tensor.h"
#include "lod_tensor_array.h"
#include "op_registry.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/nccl_helper.h"
......@@ -25,35 +26,11 @@ limitations under the License. */
namespace paddle {
namespace framework {
struct OpHandle;
using details::DummyVarHandle;
using details::VarHandle;
using details::VarHandleBase;
struct VarHandleBase {
virtual ~VarHandleBase() {}
virtual std::string DebugString() const = 0;
OpHandle *generated_op_;
std::unordered_set<OpHandle *> pending_ops_;
};
struct VarHandle : public VarHandleBase {
std::string DebugString() const override {
std::stringstream ss;
ss << name_ << ":" << place_;
return ss.str();
}
// version field currently is not used, however, just store the version to
// debug easily.
size_t version_;
std::string name_;
platform::Place place_;
};
struct DummyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "dummy"; }
};
struct OpHandle {
struct OpHandleBase {
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
......@@ -76,7 +53,7 @@ struct OpHandle {
return ss.str();
}
virtual ~OpHandle() {}
virtual ~OpHandleBase() {}
void Run(bool use_event) {
if (events_.empty() && use_event) {
......@@ -117,7 +94,7 @@ struct OpHandle {
virtual void RunImpl() = 0;
};
struct ScaleLossGradOpHandle : public OpHandle {
struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_;
Scope *scope_;
platform::Place place_;
......@@ -150,7 +127,7 @@ struct ScaleLossGradOpHandle : public OpHandle {
}
};
struct FetchOpHandle : public OpHandle {
struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
......@@ -216,51 +193,13 @@ class ParallelExecutorPrivate {
std::vector<Scope *> local_scopes_;
Scope *global_scope_;
#ifdef PADDLE_WITH_CUDA
struct NCCLContext {
std::unique_ptr<platform::CUDADeviceContext> ctx_;
ncclComm_t comm;
explicit NCCLContext(int dev_id) {
ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id)));
}
cudaStream_t stream() const { return ctx_->stream(); }
int device_id() const {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
}
static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
const std::vector<platform::Place> &places) {
std::vector<ncclComm_t> comms;
std::vector<int> devs;
comms.resize(contexts.size());
devs.reserve(contexts.size());
for (auto &p : places) {
devs.push_back(boost::get<platform::CUDAPlace>(p).device);
}
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(contexts.size()), &devs[0]));
int i = 0;
for (auto &dev_id : devs) {
contexts.at(dev_id).comm = comms[i++];
}
}
};
std::unordered_map<int, NCCLContext> communication_streams_;
std::unordered_map<int, platform::NCCLContext> communication_streams_;
NCCLContext &GetNCCLCtx(platform::Place p) {
platform::NCCLContext &GetNCCLCtx(platform::Place p) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
return communication_streams_.at(dev_id);
}
#endif
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
return const_cast<platform::DeviceContext *>(
......@@ -282,27 +221,95 @@ class ParallelExecutorPrivate {
vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandle>> ops_;
std::vector<std::unique_ptr<OpHandleBase>> ops_;
// Use a simpler thread pool, might be faster.
std::unique_ptr<ThreadPool> pool_;
std::unique_ptr<platform::EnforceNotMet> exception_;
};
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
VarHandle *GetVarHandle(const std::string &each_var_name,
const platform::Place &place) {
auto &var_holders = vars_[place];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {}
void RunOp(
bool use_event,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandleBase *op) {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this, use_event] {
try {
VLOG(10) << op->DebugString();
op->Run(use_event);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
if (pool_) {
pool_->enqueue(op_run);
} else {
op_run();
}
}
void GenerateVar(OpHandleBase *op_handle, const std::string &each_var_name,
const platform::Place &place) {
auto &vars = vars_[place][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.generated_op_ = op_handle;
var.name_ = each_var_name;
var.place_ = place;
op_handle->outputs_.emplace_back(&var);
}
}; // namespace framework
struct NCCLAllReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const std::unordered_map<int, platform::NCCLContext> &communication_ctxs_;
explicit NCCLAllReduceOpHandle(
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const std::unordered_map<int, platform::NCCLContext> &ctxs)
: local_scopes_(local_scopes),
places_(places),
communication_ctxs_(ctxs) {}
void Wait(platform::DeviceContext *waited_dev) override {
OpHandle::Wait(waited_dev);
OpHandleBase::Wait(waited_dev);
}
protected:
void RunImpl() override {
if (this->inputs_.size() == 1) {
if (inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
......@@ -317,9 +324,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
auto &p = member_->places_[i];
auto *s = member_->local_scopes_[i];
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
......@@ -336,16 +343,16 @@ struct NCCLAllReduceOpHandle : public OpHandle {
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
auto &nccl_ctx = communication_ctxs_.at(dev_id);
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()));
nccl_ctx.comm_, nccl_ctx.stream()));
}
}
}
};
struct ComputationOpHandle : public OpHandle {
struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
......@@ -443,14 +450,14 @@ void ParallelExecutor::ConstructDependencyGraph(
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = GetVarHandle(each_var_name, p);
VarHandle *var = member_->GetVarHandle(each_var_name, p);
op_handle->inputs_.emplace_back(var);
var->pending_ops_.emplace(op_handle);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
GenerateVar(op_handle, each_var_name, p);
member_->GenerateVar(op_handle, each_var_name, p);
}
if (is_forwarding) {
......@@ -468,7 +475,7 @@ void ParallelExecutor::ConstructDependencyGraph(
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
GenerateVar(op_handle, loss_var_name + "@GRAD", p);
member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p);
change_forward = true;
}
}
......@@ -483,7 +490,9 @@ void ParallelExecutor::ConstructDependencyGraph(
for (auto &og : var_names) {
if (grads.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_));
member_->ops_.emplace_back(new NCCLAllReduceOpHandle(
member_->local_scopes_, member_->places_,
member_->communication_streams_));
auto *op_handle = member_->ops_.back().get();
for (size_t i = 0; i < member_->places_.size(); ++i) {
......@@ -562,37 +571,6 @@ void ParallelExecutor::PolishGraphToSupportDataHazards() const {
}
}
void ParallelExecutor::GenerateVar(OpHandle *op_handle,
const std::string &each_var_name,
const platform::Place &place) const {
auto &vars = member_->vars_[place][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.generated_op_ = op_handle;
var.name_ = each_var_name;
var.place_ = place;
op_handle->outputs_.emplace_back(&var);
}
VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
const platform::Place &place) const {
auto &var_holders = member_->vars_[place];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
void ParallelExecutor::BCastParamsToGPUs(
const ProgramDesc &startup_program) const {
#ifdef PADDLE_WITH_CUDA
......@@ -621,8 +599,8 @@ void ParallelExecutor::BCastParamsToGPUs(
}
auto &nccl_ctx = member_->GetNCCLCtx(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm,
nccl_ctx.stream());
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
}
......@@ -640,12 +618,12 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
for (auto &place : member_->places_) {
int dev_id = boost::get<platform::CUDAPlace>(place).device;
member_->communication_streams_.emplace(
dev_id, ParallelExecutorPrivate::NCCLContext(dev_id));
member_->communication_streams_.emplace(dev_id,
platform::NCCLContext(dev_id));
}
ParallelExecutorPrivate::NCCLContext::InitNCCLContext(
member_->communication_streams_, member_->places_);
platform::NCCLContext::InitNCCLContext(member_->communication_streams_,
member_->places_);
#endif
}
......@@ -656,7 +634,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
// Version --> VarHandle
member_->exception_.reset();
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::vector<DummyVarHandle> dummy_vars;
for (auto &place_pair : member_->vars_) {
......@@ -672,7 +650,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
pending_vars[var.get()] = var->generated_op_ == nullptr;
}
std::vector<OpHandle *> to_run;
std::vector<OpHandleBase *> to_run;
for (auto &op : member_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input.
......@@ -722,7 +700,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
for (auto *op : to_run) {
RunOp(use_event, pending_vars, op);
member_->RunOp(use_event, pending_vars, op);
}
while (!pending_vars.empty()) {
......@@ -750,7 +728,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
for (auto *op : to_run) {
pending_ops.erase(op);
RunOp(use_event, pending_vars, op);
member_->RunOp(use_event, pending_vars, op);
}
}
......@@ -762,35 +740,5 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
fetched_data;
}
void ParallelExecutor::RunOp(
bool use_event,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandle *op) const {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this, use_event] {
try {
VLOG(10) << op->DebugString();
op->Run(use_event);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
member_->exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
if (member_->pool_) {
member_->pool_->enqueue(op_run);
} else {
op_run();
}
}
} // namespace framework
} // namespace paddle
......@@ -29,9 +29,6 @@ namespace paddle {
namespace framework {
class ParallelExecutorPrivate;
class VarHandle;
class OpHandle;
class VarHandleBase;
class ParallelExecutor {
public:
......@@ -50,23 +47,12 @@ class ParallelExecutor {
void BCastParamsToGPUs(const ProgramDesc& startup_program) const;
VarHandle* GetVarHandle(const std::string& each_var_name,
const platform::Place& place) const;
void GenerateVar(OpHandle* op_handle, const std::string& each_var_name,
const platform::Place& place) const;
void ConstructDependencyGraph(const std::unordered_set<std::string>& params,
const ProgramDesc& main_program,
const std::string& loss_var_name) const;
void BuildNCCLCommunicator() const;
void RunOp(
bool use_event,
std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const;
void PolishGraphToSupportDataHazards() const;
};
......
......@@ -47,11 +47,45 @@ class NCCLGroupGuard {
}
private:
static std::mutex& mutex() {
static std::mutex &mutex() {
static std::mutex mtx;
return mtx;
}
};
struct NCCLContext {
std::unique_ptr<CUDADeviceContext> ctx_;
ncclComm_t comm_;
explicit NCCLContext(int dev_id)
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {}
cudaStream_t stream() const { return ctx_->stream(); }
int device_id() const {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
}
static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
const std::vector<platform::Place> &places) {
std::vector<ncclComm_t> comms;
std::vector<int> devs;
comms.resize(contexts.size());
devs.reserve(contexts.size());
for (auto &p : places) {
devs.push_back(boost::get<platform::CUDAPlace>(p).device);
}
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(contexts.size()), &devs[0]));
int i = 0;
for (auto &dev_id : devs) {
contexts.at(dev_id).comm_ = comms[i++];
}
}
};
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册