提交 ae88fdef 编写于 作者: Y Yu Yang

Use thread pool

上级 692a0f74
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "lod_tensor.h"
#include "op_registry.h"
#include "threadpool.h"
namespace paddle {
namespace framework {
......@@ -34,7 +35,6 @@ struct VarHandle {
struct OpHandle {
std::vector<VarHandle *> inputs_;
std::vector<VarHandle *> outputs_;
platform::DeviceContext *dev_ctx_;
std::string DebugString() {
std::stringstream ss;
......@@ -66,6 +66,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {};
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(size_t num_threads = 12)
: pool_(num_threads) {}
std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
local_scopes_;
std::unordered_map<platform::Place, platform::CUDADeviceContext,
......@@ -78,6 +81,8 @@ class ParallelExecutorPrivate {
platform::PlaceHash>
vars_;
std::vector<std::unique_ptr<OpHandle>> ops_;
ThreadPool pool_;
};
// TODO(yy): Move this function somewhere
......@@ -285,13 +290,15 @@ void ParallelExecutor::BCastParamsToGPUs(
std::vector<LoDTensor> ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// Version --> VarHandle
std::unordered_set<VarHandle *> pending_vars;
std::unordered_map<VarHandle *, bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) {
for (auto &version_pair : name_pair.second) {
pending_vars.insert(&version_pair.second);
pending_vars[&version_pair.second] =
version_pair.second.generated_op_ == nullptr;
}
}
}
......@@ -300,56 +307,50 @@ std::vector<LoDTensor> ParallelExecutor::Run(
pending_ops.insert({op.get(), op->inputs_.size()});
}
std::unordered_set<OpHandle *> complete_op;
size_t num_op = pending_ops.size();
while (complete_op.size() != num_op) {
std::vector<VarHandle *> to_remove;
for (auto &var : pending_vars) {
if (var->generated_op_ == nullptr ||
complete_op.count(var->generated_op_) != 0) {
to_remove.push_back(var);
while (!pending_ops.empty()) {
VarHandle *ready_var = nullptr;
for (auto &pair : pending_vars) {
if (pair.second) {
ready_var = pair.first;
}
}
for (auto *var : to_remove) {
pending_vars.erase(var);
if (ready_var == nullptr) {
member_->pool_.Wait(); // Wait thread pool;
continue;
}
pending_vars.erase(ready_var);
std::vector<OpHandle *> to_run;
for (auto *var : to_remove) {
for (auto *op : var->pending_ops_) {
if (var->name_ == "mean_0.tmp_0@GRAD") {
LOG(INFO) << op->DebugString();
}
auto &num = pending_ops[op];
--num;
if (num == 0) {
to_run.emplace_back(op);
}
for (auto *op : ready_var->pending_ops_) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
to_run.emplace_back(op);
}
}
for (auto *op : to_run) {
pending_ops.erase(op);
complete_op.insert(op);
}
if (to_run.empty()) break;
std::vector<bool *> ready_buffer;
for (auto *var : op->outputs_) {
ready_buffer.emplace_back(&pending_vars[var]);
}
// TODO(yy): Use thead pool to run OpHandle. Operators in ToRun can be
// paralleled. We can also use another schedule method. Just a demo here.
auto op_run = [ready_buffer, op] {
// TODO(yy) Check Previous Op has same dev ctx.
LOG(INFO) << "Run " << op->DebugString();
for (auto *ready : ready_buffer) {
*ready = true;
}
};
std::stringstream ss;
ss << "\n";
for (auto *op : to_run) {
ss << op->DebugString() << "\n";
member_->pool_.Run(op_run);
}
ss << std::endl;
LOG(INFO) << ss.str();
}
PADDLE_ENFORCE_EQ(complete_op.size(), num_op);
return std::vector<LoDTensor>();
}
} // namespace framework
......
......@@ -32,6 +32,8 @@ namespace framework {
// number of threads.
class ThreadPool {
public:
explicit ThreadPool(int num_threads);
using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
// Returns the singleton of ThreadPool.
......@@ -103,8 +105,6 @@ class ThreadPool {
DISABLE_COPY_AND_ASSIGN(ThreadPool);
explicit ThreadPool(int num_threads);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册