提交 6f0dfd89 编写于 作者: Y Yu Yang

Single GPU ParallelExecutor complete

上级 d84ddcf1
...@@ -146,6 +146,7 @@ include(external/cares) ...@@ -146,6 +146,7 @@ include(external/cares)
include(external/grpc) include(external/grpc)
include(external/snappy) # download snappy include(external/snappy) # download snappy
include(external/snappystream) include(external/snappystream)
include(external/threadpool)
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
include(cupti) include(cupti)
......
INCLUDE(ExternalProject)
SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool)
SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR}/src/extern_threadpool)
INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR})
ExternalProject_Add(
extern_threadpool
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/progschj/ThreadPool.git"
GIT_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040
PREFIX ${THREADPOOL_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/threadpool_dummy.c)
file(WRITE ${dummyfile} "const char *dummy_threadpool = \"${dummyfile}\";")
add_library(simple_threadpool STATIC ${dummyfile})
else()
add_library(simple_threadpool INTERFACE)
endif()
add_dependencies(simple_threadpool extern_threadpool)
LIST(APPEND external_project_dependencies simple_threadpool)
...@@ -87,7 +87,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo ...@@ -87,7 +87,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method executor) framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "ThreadPool.h"
#include "executor.h"
#include "lod_tensor.h" #include "lod_tensor.h"
#include "op_registry.h" #include "op_registry.h"
#include "threadpool.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -49,7 +50,7 @@ struct VarHandle : public VarHandleBase { ...@@ -49,7 +50,7 @@ struct VarHandle : public VarHandleBase {
}; };
struct DependencyVarHandle : public VarHandleBase { struct DependencyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "Deps var"; } std::string DebugString() const override { return "Dependency Variable"; }
}; };
struct OpHandle { struct OpHandle {
...@@ -75,7 +76,7 @@ struct OpHandle { ...@@ -75,7 +76,7 @@ struct OpHandle {
virtual ~OpHandle() {} virtual ~OpHandle() {}
virtual void Run() {} virtual void Run() { PADDLE_THROW("Not implemented"); }
virtual void Wait() {} virtual void Wait() {}
}; };
...@@ -84,14 +85,15 @@ struct ComputationOpHandle : public OpHandle { ...@@ -84,14 +85,15 @@ struct ComputationOpHandle : public OpHandle {
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
explicit ComputationOpHandle(const OpDesc &op_desc, platform::Place place) explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(nullptr), scope_(scope),
place_(place) {} place_(place) {}
void Run() override { void Run() override {
// Wait other op if necessary // Wait other op if necessary
LOG(INFO) << DebugString(); LOG(INFO) << "Run " << this << " " << DebugString();
auto *cur_ctx = dev_ctx_[place_]; auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
...@@ -100,12 +102,49 @@ struct ComputationOpHandle : public OpHandle { ...@@ -100,12 +102,49 @@ struct ComputationOpHandle : public OpHandle {
} }
op_->Run(*scope_, place_); op_->Run(*scope_, place_);
LOG(INFO) << "Done " << this;
} }
}; };
struct ScaleLossGradOpHandle : public OpHandle {}; struct ScaleLossGradOpHandle : public OpHandle {
float coeff_;
Scope *scope_;
platform::Place place_;
explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place)
: coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {}
void Run() override {
LOG(INFO) << "Run Scale Loss Grad";
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
struct NCCLAllReduceOpHandle : public OpHandle {}; float *tmp = scope_->FindVar(var_name)
->GetMutable<framework::LoDTensor>()
->mutable_data<float>(make_ddim({1}), place_);
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
} else {
memory::Copy(
boost::get<platform::CUDAPlace>(place_), tmp, platform::CPUPlace(),
&coeff_, sizeof(float),
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream());
}
}
};
struct NCCLAllReduceOpHandle : public OpHandle {
void Run() override {
if (this->inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1;
}
}
};
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
...@@ -182,7 +221,10 @@ class ParallelExecutorPrivate { ...@@ -182,7 +221,10 @@ class ParallelExecutorPrivate {
std::vector<std::unique_ptr<OpHandle>> ops_; std::vector<std::unique_ptr<OpHandle>> ops_;
// Use a simpler thread pool, might be faster.
ThreadPool pool_; ThreadPool pool_;
std::unique_ptr<platform::EnforceNotMet> exception_;
}; };
// TODO(yy): Move this function somewhere // TODO(yy): Move this function somewhere
...@@ -217,6 +259,19 @@ ParallelExecutor::ParallelExecutor( ...@@ -217,6 +259,19 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
ConstructDependencyGraph(params, main_program, loss_var_name); ConstructDependencyGraph(params, main_program, loss_var_name);
// Step 3. Create vars in each scope;
for (auto &pair : member_->local_scopes_) {
auto *scope = pair.second;
for (auto *var : main_program.Block(0).AllVars()) {
if (scope->FindVar(var->Name()) != nullptr) {
continue;
}
InitializeVariable(scope->Var(var->Name()), var->GetType());
}
}
} }
void ParallelExecutor::ConstructDependencyGraph( void ParallelExecutor::ConstructDependencyGraph(
...@@ -240,7 +295,8 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -240,7 +295,8 @@ void ParallelExecutor::ConstructDependencyGraph(
} }
for (auto &pair : member_->local_scopes_) { for (auto &pair : member_->local_scopes_) {
member_->ops_.emplace_back(new ComputationOpHandle(*op, pair.first)); member_->ops_.emplace_back(
new ComputationOpHandle(*op, pair.second, pair.first));
auto *op_handle = member_->ops_.back().get(); auto *op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>( op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(pair.first)); platform::DeviceContextPool::Instance().Get(pair.first));
...@@ -263,16 +319,20 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -263,16 +319,20 @@ void ParallelExecutor::ConstructDependencyGraph(
if (is_forwarding) { if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name) { if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
member_->ops_.emplace_back(new ScaleLossGradOpHandle()); member_->ops_.emplace_back(new ScaleLossGradOpHandle(
this->member_->local_scopes_.size(), pair.second, pair.first));
op_handle = member_->ops_.back().get(); op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = op_handle->dev_ctx_[pair.first] =
member_->CommunicationDevCtx(pair.first); member_->CommunicationDevCtx(pair.first);
auto &place = pair.first; auto &place = pair.first;
VarHandle *loss = GetVarHandle(loss_var_name, place); // FIXME: Currently ScaleLossGradOp only use device_count as scale
loss->pending_ops_.emplace_back(op_handle); // factor. So it does not depend on any other operators.
op_handle->inputs_.emplace_back(loss); // VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
GenerateVar(op_handle, loss_var_name + "@GRAD", place); GenerateVar(op_handle, loss_var_name + "@GRAD", place);
change_forward = true; change_forward = true;
LOG(INFO) << "Scale Loss " << op_handle->DebugString(); LOG(INFO) << "Scale Loss " << op_handle->DebugString();
...@@ -341,11 +401,25 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -341,11 +401,25 @@ void ParallelExecutor::ConstructDependencyGraph(
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_; auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_; auto &read_ops = it_old->second.pending_ops_;
auto *ex_write_op = it_old->second.generated_op_;
if (ex_write_op == nullptr) { // Nobody write this var.
continue;
}
LOG(INFO) << "Link " << it_new->second.DebugString() << " From "
<< it_old->second.version_ << " To "
<< it_new->second.version_;
for (auto *read_op : read_ops) { for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op; // Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DependencyVarHandle(); auto *dep_var = new DependencyVarHandle();
dep_var->generated_op_ = read_op; dep_var->generated_op_ = read_op;
read_op->outputs_.emplace_back(dep_var); read_op->outputs_.emplace_back(dep_var);
...@@ -448,7 +522,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const { ...@@ -448,7 +522,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
std::vector<LoDTensor> ParallelExecutor::Run( std::vector<LoDTensor> ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset();
std::unordered_map<VarHandleBase *, bool> pending_vars; std::unordered_map<VarHandleBase *, bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandle *, size_t> pending_ops;
...@@ -465,8 +539,18 @@ std::vector<LoDTensor> ParallelExecutor::Run( ...@@ -465,8 +539,18 @@ std::vector<LoDTensor> ParallelExecutor::Run(
pending_vars[var.get()] = var->generated_op_ == nullptr; pending_vars[var.get()] = var->generated_op_ == nullptr;
} }
std::vector<OpHandle *> to_run;
for (auto &op : member_->ops_) { for (auto &op : member_->ops_) {
pending_ops.insert({op.get(), op->inputs_.size()}); if (op->inputs_.empty()) { // Special case, Op has no input.
to_run.emplace_back(op.get());
} else {
pending_ops.insert({op.get(), op->inputs_.size()});
}
}
for (auto *op : to_run) {
RunOp(pending_vars, op);
} }
while (!pending_ops.empty()) { while (!pending_ops.empty()) {
...@@ -478,13 +562,19 @@ std::vector<LoDTensor> ParallelExecutor::Run( ...@@ -478,13 +562,19 @@ std::vector<LoDTensor> ParallelExecutor::Run(
} }
if (ready_var == nullptr) { if (ready_var == nullptr) {
member_->pool_.Wait(); // Wait thread pool; // FIXME use conditional var instead of busy wait.
if (member_->exception_) {
throw * member_->exception_;
}
std::this_thread::yield();
continue; continue;
} }
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
std::vector<OpHandle *> to_run; to_run.clear();
for (auto *op : ready_var->pending_ops_) { for (auto *op : ready_var->pending_ops_) {
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
...@@ -496,24 +586,35 @@ std::vector<LoDTensor> ParallelExecutor::Run( ...@@ -496,24 +586,35 @@ std::vector<LoDTensor> ParallelExecutor::Run(
for (auto *op : to_run) { for (auto *op : to_run) {
pending_ops.erase(op); pending_ops.erase(op);
RunOp(pending_vars, op);
std::vector<bool *> ready_buffer;
for (auto *var : op->outputs_) {
ready_buffer.emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op] {
// TODO(yy) Check Previous Op has same dev ctx.
op->Run();
for (auto *ready : ready_buffer) {
*ready = true;
}
};
member_->pool_.Run(op_run);
} }
} }
return std::vector<LoDTensor>(); return std::vector<LoDTensor>();
} }
void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, bool> &pending_vars,
OpHandle *op) const {
std::vector<bool *> ready_buffer;
for (auto *var : op->outputs_) {
ready_buffer.emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this] {
try {
// TODO(yy) Check Previous Op has same dev ctx.
op->Run();
for (auto *ready : ready_buffer) {
*ready = true;
}
} catch (platform::EnforceNotMet ex) {
member_->exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
member_->pool_.enqueue(op_run);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,7 @@ namespace framework { ...@@ -31,6 +31,7 @@ namespace framework {
class ParallelExecutorPrivate; class ParallelExecutorPrivate;
class VarHandle; class VarHandle;
class OpHandle; class OpHandle;
class VarHandleBase;
class ParallelExecutor { class ParallelExecutor {
public: public:
explicit ParallelExecutor(const std::vector<platform::Place>& places, explicit ParallelExecutor(const std::vector<platform::Place>& places,
...@@ -57,6 +58,9 @@ class ParallelExecutor { ...@@ -57,6 +58,9 @@ class ParallelExecutor {
const std::string& loss_var_name) const; const std::string& loss_var_name) const;
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, bool>& pending_vars,
OpHandle* op) const;
}; };
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -59,7 +60,9 @@ class ReadOp : public framework::OperatorBase { ...@@ -59,7 +60,9 @@ class ReadOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>(); detail::Ref(scope.FindVar(Input("Reader")),
"Cannot find reader variable %s", Input("Reader"))
.GetMutable<framework::ReaderHolder>();
std::vector<std::string> out_arg_names = Outputs("Out"); std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册