提交 ae88fdef 编写于 作者: Y Yu Yang

Use thread pool

上级 692a0f74
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "lod_tensor.h" #include "lod_tensor.h"
#include "op_registry.h" #include "op_registry.h"
#include "threadpool.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -34,7 +35,6 @@ struct VarHandle { ...@@ -34,7 +35,6 @@ struct VarHandle {
struct OpHandle { struct OpHandle {
std::vector<VarHandle *> inputs_; std::vector<VarHandle *> inputs_;
std::vector<VarHandle *> outputs_; std::vector<VarHandle *> outputs_;
platform::DeviceContext *dev_ctx_;
std::string DebugString() { std::string DebugString() {
std::stringstream ss; std::stringstream ss;
...@@ -66,6 +66,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {}; ...@@ -66,6 +66,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {};
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(size_t num_threads = 12)
: pool_(num_threads) {}
std::unordered_map<platform::Place, Scope *, platform::PlaceHash> std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
local_scopes_; local_scopes_;
std::unordered_map<platform::Place, platform::CUDADeviceContext, std::unordered_map<platform::Place, platform::CUDADeviceContext,
...@@ -78,6 +81,8 @@ class ParallelExecutorPrivate { ...@@ -78,6 +81,8 @@ class ParallelExecutorPrivate {
platform::PlaceHash> platform::PlaceHash>
vars_; vars_;
std::vector<std::unique_ptr<OpHandle>> ops_; std::vector<std::unique_ptr<OpHandle>> ops_;
ThreadPool pool_;
}; };
// TODO(yy): Move this function somewhere // TODO(yy): Move this function somewhere
...@@ -285,13 +290,15 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -285,13 +290,15 @@ void ParallelExecutor::BCastParamsToGPUs(
std::vector<LoDTensor> ParallelExecutor::Run( std::vector<LoDTensor> ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
// Version --> VarHandle // Version --> VarHandle
std::unordered_set<VarHandle *> pending_vars;
std::unordered_map<VarHandle *, bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops; std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) { for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) { for (auto &name_pair : place_pair.second) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
pending_vars.insert(&version_pair.second); pending_vars[&version_pair.second] =
version_pair.second.generated_op_ == nullptr;
} }
} }
} }
...@@ -300,56 +307,50 @@ std::vector<LoDTensor> ParallelExecutor::Run( ...@@ -300,56 +307,50 @@ std::vector<LoDTensor> ParallelExecutor::Run(
pending_ops.insert({op.get(), op->inputs_.size()}); pending_ops.insert({op.get(), op->inputs_.size()});
} }
std::unordered_set<OpHandle *> complete_op; while (!pending_ops.empty()) {
VarHandle *ready_var = nullptr;
size_t num_op = pending_ops.size(); for (auto &pair : pending_vars) {
if (pair.second) {
while (complete_op.size() != num_op) { ready_var = pair.first;
std::vector<VarHandle *> to_remove;
for (auto &var : pending_vars) {
if (var->generated_op_ == nullptr ||
complete_op.count(var->generated_op_) != 0) {
to_remove.push_back(var);
} }
} }
for (auto *var : to_remove) {
pending_vars.erase(var); if (ready_var == nullptr) {
member_->pool_.Wait(); // Wait thread pool;
continue;
} }
pending_vars.erase(ready_var);
std::vector<OpHandle *> to_run; std::vector<OpHandle *> to_run;
for (auto *var : to_remove) {
for (auto *op : var->pending_ops_) { for (auto *op : ready_var->pending_ops_) {
if (var->name_ == "mean_0.tmp_0@GRAD") { auto &deps = pending_ops[op];
LOG(INFO) << op->DebugString(); --deps;
} if (deps == 0) {
auto &num = pending_ops[op];
--num;
if (num == 0) {
to_run.emplace_back(op); to_run.emplace_back(op);
} }
} }
}
for (auto *op : to_run) { for (auto *op : to_run) {
pending_ops.erase(op); pending_ops.erase(op);
complete_op.insert(op);
}
if (to_run.empty()) break; std::vector<bool *> ready_buffer;
for (auto *var : op->outputs_) {
ready_buffer.emplace_back(&pending_vars[var]);
}
// TODO(yy): Use thead pool to run OpHandle. Operators in ToRun can be auto op_run = [ready_buffer, op] {
// paralleled. We can also use another schedule method. Just a demo here. // TODO(yy) Check Previous Op has same dev ctx.
LOG(INFO) << "Run " << op->DebugString();
for (auto *ready : ready_buffer) {
*ready = true;
}
};
std::stringstream ss; member_->pool_.Run(op_run);
ss << "\n";
for (auto *op : to_run) {
ss << op->DebugString() << "\n";
} }
ss << std::endl;
LOG(INFO) << ss.str();
} }
PADDLE_ENFORCE_EQ(complete_op.size(), num_op);
return std::vector<LoDTensor>(); return std::vector<LoDTensor>();
} }
} // namespace framework } // namespace framework
......
...@@ -32,6 +32,8 @@ namespace framework { ...@@ -32,6 +32,8 @@ namespace framework {
// number of threads. // number of threads.
class ThreadPool { class ThreadPool {
public: public:
explicit ThreadPool(int num_threads);
using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>; using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
// Returns the singleton of ThreadPool. // Returns the singleton of ThreadPool.
...@@ -103,8 +105,6 @@ class ThreadPool { ...@@ -103,8 +105,6 @@ class ThreadPool {
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
explicit ThreadPool(int num_threads);
// If the task queue is empty and avaialbe is equal to the number of // If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function // threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed. // is not thread-safe. Returns true if all tasks are completed.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册