提交 d0ac9253 编写于 作者: X Xin Pan

Improve ParallelExecutor performance

上级 dd75fbde
......@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
}
std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; }
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -34,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
std::string Name() const override;
bool IsDelayedOp() override { return true; };
protected:
void RunImpl() override;
};
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -53,6 +55,8 @@ class OpHandleBase {
void AddOutput(VarHandleBase *out);
virtual bool IsDelayedOp() { return false; }
protected:
virtual void RunImpl() = 0;
};
......
......@@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
use_event_(use_event) {}
use_event_(use_event),
running_ops_(0) {}
void ThreadedSSAGraphExecutor::RunDelayedOps(
const std::unordered_set<OpHandleBase *> &delayed_ops) {
for (auto op : delayed_ops) {
op->Run(use_event_);
}
}
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
std::unordered_set<OpHandleBase *> delayed_ops;
std::unordered_set<OpHandleBase *> after_delayed_ops;
std::unordered_set<VarHandleBase *> delayed_vars;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var);
if (var.generated_op_ == nullptr) {
......@@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
RunOp(ready_vars, op);
if (op->IsDelayedOp()) {
delayed_ops.insert(op);
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
ready_vars.Extend(op->outputs_);
continue;
}
running_ops_++;
RunOp(&ready_vars, op);
}
ready_ops.clear();
};
......@@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// 2. Find ready variable
bool timeout;
auto cur_ready_vars = ready_vars.PopAll(1000, &timeout);
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) {
if (exception_) {
......@@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
after_delayed_ops.insert(op);
} else {
ready_ops.insert(op);
}
}
}
}
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
RunDelayedOps(delayed_ops);
delayed_ops.clear();
for (auto *op : after_delayed_ops) {
ready_ops.insert(op);
}
after_delayed_ops.clear();
}
// Keep loop until all vars are ready.
}
++computation_count_;
auto sync_computation = [&] {
......@@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
auto op_run = [&ready_var_q, op, this] {
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] {
try {
VLOG(10) << op->Name() << " : " << op->DebugString();
op->Run(use_event_);
ready_var_q.Extend(op->outputs_);
running_ops_--;
ready_var_q->Extend(op->outputs_);
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
......
......@@ -14,7 +14,12 @@
#pragma once
#include <chrono>
#include <deque>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
......@@ -79,9 +84,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() {}
private:
void RunOp(BlockingQueue<VarHandleBase *> &ready_var_q,
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
details::OpHandleBase *op);
void RunDelayedOps(const std::unordered_set<OpHandleBase *> &delayed_ops);
private:
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
......@@ -89,6 +96,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
platform::DeviceContextPool fetch_ctxs_;
const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
size_t computation_count_{0};
size_t max_async_computation{100};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string>
#include <vector>
......@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
platform::RecordBlock b(0);
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
......
......@@ -16,6 +16,7 @@ import core
import multiprocessing
import framework
import executor
import sys
__all__ = ['ParallelExecutor']
......@@ -35,7 +36,7 @@ class ParallelExecutor(object):
places.append(p)
if num_threads is None:
num_threads = min(len(places) * 2, multiprocessing.cpu_count())
num_threads = len(places)
startup = framework.default_startup_program()
main = framework.default_main_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册