提交 d0ac9253 编写于 作者: X Xin Pan

Improve ParallelExecutor performance

上级 dd75fbde
...@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
} }
} }
std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; } std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -34,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -34,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
bool IsDelayedOp() override { return true; };
protected: protected:
void RunImpl() override; void RunImpl() override;
}; };
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -53,6 +55,8 @@ class OpHandleBase { ...@@ -53,6 +55,8 @@ class OpHandleBase {
void AddOutput(VarHandleBase *out); void AddOutput(VarHandleBase *out);
virtual bool IsDelayedOp() { return false; }
protected: protected:
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
}; };
......
...@@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
fetch_ctxs_(places), fetch_ctxs_(places),
use_event_(use_event) {} use_event_(use_event),
running_ops_(0) {}
void ThreadedSSAGraphExecutor::RunDelayedOps(
const std::unordered_set<OpHandleBase *> &delayed_ops) {
for (auto op : delayed_ops) {
op->Run(use_event_);
}
}
FeedFetchList ThreadedSSAGraphExecutor::Run( FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars; BlockingQueue<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops; std::unordered_set<OpHandleBase *> ready_ops;
std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> after_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var); pending_vars.insert(&var);
if (var.generated_op_ == nullptr) { if (var.generated_op_ == nullptr) {
...@@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] { auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) { for (auto *op : ready_ops) {
RunOp(ready_vars, op); if (op->IsDelayedOp()) {
delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_);
continue;
}
running_ops_++;
RunOp(&ready_vars, op);
} }
ready_ops.clear(); ready_ops.clear();
}; };
...@@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// 2. Find ready variable // 2. Find ready variable
bool timeout; bool timeout;
auto cur_ready_vars = ready_vars.PopAll(1000, &timeout); auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_) { if (exception_) {
...@@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
--deps; --deps;
if (deps == 0) { if (deps == 0) {
ready_ops.insert(op); if (delayed_vars.find(ready_var) != delayed_vars.end()) {
after_delayed_ops.insert(op);
} else {
ready_ops.insert(op);
}
} }
} }
} }
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
RunDelayedOps(delayed_ops);
delayed_ops.clear();
for (auto *op : after_delayed_ops) {
ready_ops.insert(op);
}
after_delayed_ops.clear();
}
// Keep loop until all vars are ready. // Keep loop until all vars are ready.
} }
++computation_count_; ++computation_count_;
auto sync_computation = [&] { auto sync_computation = [&] {
...@@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) { BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [&ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
try { try {
VLOG(10) << op->Name() << " : " << op->DebugString(); VLOG(10) << op->Name() << " : " << op->DebugString();
op->Run(use_event_); op->Run(use_event_);
ready_var_q.Extend(op->outputs_); running_ops_--;
ready_var_q->Extend(op->outputs_);
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex)); exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) { } catch (...) {
......
...@@ -14,7 +14,12 @@ ...@@ -14,7 +14,12 @@
#pragma once #pragma once
#include <chrono> #include <deque>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <functional> #include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party #include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
...@@ -79,9 +84,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -79,9 +84,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() {} ~ThreadedSSAGraphExecutor() {}
private: private:
void RunOp(BlockingQueue<VarHandleBase *> &ready_var_q, void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);
void RunDelayedOps(const std::unordered_set<OpHandleBase *> &delayed_ops);
private: private:
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
...@@ -89,6 +96,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -89,6 +96,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
const bool use_event_; const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
size_t computation_count_{0}; size_t computation_count_{0};
size_t max_async_computation{100}; size_t max_async_computation{100};
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
platform::RecordBlock b(0);
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data; fetch_data;
......
...@@ -16,6 +16,7 @@ import core ...@@ -16,6 +16,7 @@ import core
import multiprocessing import multiprocessing
import framework import framework
import executor import executor
import sys
__all__ = ['ParallelExecutor'] __all__ = ['ParallelExecutor']
...@@ -35,7 +36,7 @@ class ParallelExecutor(object): ...@@ -35,7 +36,7 @@ class ParallelExecutor(object):
places.append(p) places.append(p)
if num_threads is None: if num_threads is None:
num_threads = min(len(places) * 2, multiprocessing.cpu_count()) num_threads = len(places)
startup = framework.default_startup_program() startup = framework.default_startup_program()
main = framework.default_main_program() main = framework.default_main_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册