提交 e72637dd 编写于 作者: Q Qiao Longfei

ThreadedSSAGraphExecutor support num_iteration_per_run test=develop

上级 b1fe8d45
...@@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
VLOG(3) << "build AsyncSSAGraphExecutor"; VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graphs_[0]->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size() strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL ? 1UL
...@@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run( ...@@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList { auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
try { try {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
executors_[i]->Run(fetch_tensors);
}
return executors_[i]->Run(fetch_tensors); return executors_[i]->Run(fetch_tensors);
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
......
...@@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
places_(places), places_(places),
fetch_ctxs_(places), fetch_ctxs_(places),
running_ops_(0), running_ops_(0),
strategy_(strategy) {} strategy_(strategy) {
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graph_->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
}
FeedFetchList ThreadedSSAGraphExecutor::Run( inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<platform::RecordEvent> event( std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr)); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr));
...@@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data; return fetch_data;
} }
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
RunImpl({});
}
return RunImpl(fetch_tensors);
}
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
......
...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() final = default; ~ThreadedSSAGraphExecutor() final = default;
private: private:
inline FeedFetchList RunImpl(const std::vector<std::string> &fetch_tensors);
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册