提交 6f0dfd89 编写于 作者: Y Yu Yang

Single GPU ParallelExecutor complete

上级 d84ddcf1
......@@ -146,6 +146,7 @@ include(external/cares)
include(external/grpc)
include(external/snappy) # download snappy
include(external/snappystream)
include(external/threadpool)
include(cudnn) # set cudnn libraries, must before configure
include(cupti)
......
INCLUDE(ExternalProject)
SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool)
SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR}/src/extern_threadpool)
INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR})
ExternalProject_Add(
extern_threadpool
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/progschj/ThreadPool.git"
GIT_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040
PREFIX ${THREADPOOL_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/threadpool_dummy.c)
file(WRITE ${dummyfile} "const char *dummy_threadpool = \"${dummyfile}\";")
add_library(simple_threadpool STATIC ${dummyfile})
else()
add_library(simple_threadpool INTERFACE)
endif()
add_dependencies(simple_threadpool extern_threadpool)
LIST(APPEND external_project_dependencies simple_threadpool)
......@@ -87,7 +87,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method executor)
framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "ThreadPool.h"
#include "executor.h"
#include "lod_tensor.h"
#include "op_registry.h"
#include "threadpool.h"
namespace paddle {
namespace framework {
......@@ -49,7 +50,7 @@ struct VarHandle : public VarHandleBase {
};
struct DependencyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "Deps var"; }
std::string DebugString() const override { return "Dependency Variable"; }
};
struct OpHandle {
......@@ -75,7 +76,7 @@ struct OpHandle {
virtual ~OpHandle() {}
virtual void Run() {}
virtual void Run() { PADDLE_THROW("Not implemented"); }
virtual void Wait() {}
};
......@@ -84,14 +85,15 @@ struct ComputationOpHandle : public OpHandle {
Scope *scope_;
platform::Place place_;
explicit ComputationOpHandle(const OpDesc &op_desc, platform::Place place)
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(nullptr),
scope_(scope),
place_(place) {}
void Run() override {
// Wait other op if necessary
LOG(INFO) << DebugString();
LOG(INFO) << "Run " << this << " " << DebugString();
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
......@@ -100,12 +102,49 @@ struct ComputationOpHandle : public OpHandle {
}
op_->Run(*scope_, place_);
LOG(INFO) << "Done " << this;
}
};
struct ScaleLossGradOpHandle : public OpHandle {};
struct ScaleLossGradOpHandle : public OpHandle {
float coeff_;
Scope *scope_;
platform::Place place_;
explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place)
: coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {}
void Run() override {
LOG(INFO) << "Run Scale Loss Grad";
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
struct NCCLAllReduceOpHandle : public OpHandle {};
float *tmp = scope_->FindVar(var_name)
->GetMutable<framework::LoDTensor>()
->mutable_data<float>(make_ddim({1}), place_);
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
} else {
memory::Copy(
boost::get<platform::CUDAPlace>(place_), tmp, platform::CPUPlace(),
&coeff_, sizeof(float),
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream());
}
}
};
struct NCCLAllReduceOpHandle : public OpHandle {
void Run() override {
if (this->inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1;
}
}
};
class ParallelExecutorPrivate {
public:
......@@ -182,7 +221,10 @@ class ParallelExecutorPrivate {
std::vector<std::unique_ptr<OpHandle>> ops_;
// Use a simpler thread pool, might be faster.
ThreadPool pool_;
std::unique_ptr<platform::EnforceNotMet> exception_;
};
// TODO(yy): Move this function somewhere
......@@ -217,6 +259,19 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
ConstructDependencyGraph(params, main_program, loss_var_name);
// Step 3. Create vars in each scope;
for (auto &pair : member_->local_scopes_) {
auto *scope = pair.second;
for (auto *var : main_program.Block(0).AllVars()) {
if (scope->FindVar(var->Name()) != nullptr) {
continue;
}
InitializeVariable(scope->Var(var->Name()), var->GetType());
}
}
}
void ParallelExecutor::ConstructDependencyGraph(
......@@ -240,7 +295,8 @@ void ParallelExecutor::ConstructDependencyGraph(
}
for (auto &pair : member_->local_scopes_) {
member_->ops_.emplace_back(new ComputationOpHandle(*op, pair.first));
member_->ops_.emplace_back(
new ComputationOpHandle(*op, pair.second, pair.first));
auto *op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(pair.first));
......@@ -263,16 +319,20 @@ void ParallelExecutor::ConstructDependencyGraph(
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle
member_->ops_.emplace_back(new ScaleLossGradOpHandle());
member_->ops_.emplace_back(new ScaleLossGradOpHandle(
this->member_->local_scopes_.size(), pair.second, pair.first));
op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] =
member_->CommunicationDevCtx(pair.first);
auto &place = pair.first;
VarHandle *loss = GetVarHandle(loss_var_name, place);
loss->pending_ops_.emplace_back(op_handle);
op_handle->inputs_.emplace_back(loss);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
GenerateVar(op_handle, loss_var_name + "@GRAD", place);
change_forward = true;
LOG(INFO) << "Scale Loss " << op_handle->DebugString();
......@@ -341,11 +401,25 @@ void ParallelExecutor::ConstructDependencyGraph(
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *ex_write_op = it_old->second.generated_op_;
if (ex_write_op == nullptr) { // Nobody write this var.
continue;
}
LOG(INFO) << "Link " << it_new->second.DebugString() << " From "
<< it_old->second.version_ << " To "
<< it_new->second.version_;
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DependencyVarHandle();
dep_var->generated_op_ = read_op;
read_op->outputs_.emplace_back(dep_var);
......@@ -448,7 +522,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
std::vector<LoDTensor> ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// Version --> VarHandle
member_->exception_.reset();
std::unordered_map<VarHandleBase *, bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
......@@ -465,9 +539,19 @@ std::vector<LoDTensor> ParallelExecutor::Run(
pending_vars[var.get()] = var->generated_op_ == nullptr;
}
std::vector<OpHandle *> to_run;
for (auto &op : member_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input.
to_run.emplace_back(op.get());
} else {
pending_ops.insert({op.get(), op->inputs_.size()});
}
}
for (auto *op : to_run) {
RunOp(pending_vars, op);
}
while (!pending_ops.empty()) {
VarHandleBase *ready_var = nullptr;
......@@ -478,13 +562,19 @@ std::vector<LoDTensor> ParallelExecutor::Run(
}
if (ready_var == nullptr) {
member_->pool_.Wait(); // Wait thread pool;
// FIXME use conditional var instead of busy wait.
if (member_->exception_) {
throw * member_->exception_;
}
std::this_thread::yield();
continue;
}
pending_vars.erase(ready_var);
std::vector<OpHandle *> to_run;
to_run.clear();
for (auto *op : ready_var->pending_ops_) {
auto &deps = pending_ops[op];
......@@ -496,24 +586,35 @@ std::vector<LoDTensor> ParallelExecutor::Run(
for (auto *op : to_run) {
pending_ops.erase(op);
RunOp(pending_vars, op);
}
}
return std::vector<LoDTensor>();
}
void ParallelExecutor::RunOp(
std::unordered_map<VarHandleBase *, bool> &pending_vars,
OpHandle *op) const {
std::vector<bool *> ready_buffer;
for (auto *var : op->outputs_) {
ready_buffer.emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op] {
auto op_run = [ready_buffer, op, this] {
try {
// TODO(yy) Check Previous Op has same dev ctx.
op->Run();
for (auto *ready : ready_buffer) {
*ready = true;
}
} catch (platform::EnforceNotMet ex) {
member_->exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
member_->pool_.Run(op_run);
}
}
return std::vector<LoDTensor>();
member_->pool_.enqueue(op_run);
}
} // namespace framework
} // namespace paddle
......@@ -31,6 +31,7 @@ namespace framework {
class ParallelExecutorPrivate;
class VarHandle;
class OpHandle;
class VarHandleBase;
class ParallelExecutor {
public:
explicit ParallelExecutor(const std::vector<platform::Place>& places,
......@@ -57,6 +58,9 @@ class ParallelExecutor {
const std::string& loss_var_name) const;
void BuildNCCLCommunicator() const;
void RunOp(std::unordered_map<VarHandleBase*, bool>& pending_vars,
OpHandle* op) const;
};
} // namespace framework
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle {
namespace operators {
......@@ -59,7 +60,9 @@ class ReadOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
detail::Ref(scope.FindVar(Input("Reader")),
"Cannot find reader variable %s", Input("Reader"))
.GetMutable<framework::ReaderHolder>();
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册