diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 46fb15f5800bca0884540d238af0adb2e851151b..dd726f1fab0c033dcbcd4a7a2deb9c62202dbfa9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" #include "lod_tensor.h" #include "op_registry.h" +#include "threadpool.h" namespace paddle { namespace framework { @@ -34,7 +35,6 @@ struct VarHandle { struct OpHandle { std::vector inputs_; std::vector outputs_; - platform::DeviceContext *dev_ctx_; std::string DebugString() { std::stringstream ss; @@ -66,6 +66,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {}; class ParallelExecutorPrivate { public: + explicit ParallelExecutorPrivate(size_t num_threads = 12) + : pool_(num_threads) {} + std::unordered_map local_scopes_; std::unordered_map vars_; std::vector> ops_; + + ThreadPool pool_; }; // TODO(yy): Move this function somewhere @@ -285,13 +290,15 @@ void ParallelExecutor::BCastParamsToGPUs( std::vector ParallelExecutor::Run( const std::vector &fetch_tensors) { // Version --> VarHandle - std::unordered_set pending_vars; + + std::unordered_map pending_vars; std::unordered_map pending_ops; for (auto &place_pair : member_->vars_) { for (auto &name_pair : place_pair.second) { for (auto &version_pair : name_pair.second) { - pending_vars.insert(&version_pair.second); + pending_vars[&version_pair.second] = + version_pair.second.generated_op_ == nullptr; } } } @@ -300,56 +307,50 @@ std::vector ParallelExecutor::Run( pending_ops.insert({op.get(), op->inputs_.size()}); } - std::unordered_set complete_op; - - size_t num_op = pending_ops.size(); - - while (complete_op.size() != num_op) { - std::vector to_remove; - for (auto &var : pending_vars) { - if (var->generated_op_ == nullptr || - complete_op.count(var->generated_op_) != 0) { - to_remove.push_back(var); + while (!pending_ops.empty()) { + VarHandle *ready_var = nullptr; + for (auto &pair : pending_vars) { + if (pair.second) { + ready_var = pair.first; } } - for (auto *var : to_remove) { - pending_vars.erase(var); + + if (ready_var == nullptr) { + member_->pool_.Wait(); // Wait thread pool; + continue; } + pending_vars.erase(ready_var); + std::vector to_run; - for (auto *var : to_remove) { - for (auto *op : var->pending_ops_) { - if (var->name_ == "mean_0.tmp_0@GRAD") { - LOG(INFO) << op->DebugString(); - } - auto &num = pending_ops[op]; - --num; - if (num == 0) { - to_run.emplace_back(op); - } + + for (auto *op : ready_var->pending_ops_) { + auto &deps = pending_ops[op]; + --deps; + if (deps == 0) { + to_run.emplace_back(op); } } for (auto *op : to_run) { pending_ops.erase(op); - complete_op.insert(op); - } - if (to_run.empty()) break; + std::vector ready_buffer; + for (auto *var : op->outputs_) { + ready_buffer.emplace_back(&pending_vars[var]); + } - // TODO(yy): Use thead pool to run OpHandle. Operators in ToRun can be - // paralleled. We can also use another schedule method. Just a demo here. + auto op_run = [ready_buffer, op] { + // TODO(yy) Check Previous Op has same dev ctx. + LOG(INFO) << "Run " << op->DebugString(); + for (auto *ready : ready_buffer) { + *ready = true; + } + }; - std::stringstream ss; - ss << "\n"; - for (auto *op : to_run) { - ss << op->DebugString() << "\n"; + member_->pool_.Run(op_run); } - ss << std::endl; - LOG(INFO) << ss.str(); } - - PADDLE_ENFORCE_EQ(complete_op.size(), num_op); return std::vector(); } } // namespace framework diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index df51fb24a588c84788d7d0b671f932ff4c40f9c2..f9dce7105e32ff0ba03d03f8faaac3a4ed1a3595 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -32,6 +32,8 @@ namespace framework { // number of threads. class ThreadPool { public: + explicit ThreadPool(int num_threads); + using Task = std::packaged_task()>; // Returns the singleton of ThreadPool. @@ -103,8 +105,6 @@ class ThreadPool { DISABLE_COPY_AND_ASSIGN(ThreadPool); - explicit ThreadPool(int num_threads); - // If the task queue is empty and avaialbe is equal to the number of // threads, means that all tasks are completed. Note: this function // is not thread-safe. Returns true if all tasks are completed.