提交 e72637dd 编写于 作者: Q Qiao Longfei

ThreadedSSAGraphExecutor support num_iteration_per_run test=develop

上级 b1fe8d45
......@@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graphs_[0]->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
......@@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
try {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
executors_[i]->Run(fetch_tensors);
}
return executors_[i]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
......
......@@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
places_(places),
fetch_ctxs_(places),
running_ops_(0),
strategy_(strategy) {}
strategy_(strategy) {
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graph_->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
}
FeedFetchList ThreadedSSAGraphExecutor::Run(
inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr));
......@@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data;
}
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
RunImpl({});
}
return RunImpl(fetch_tensors);
}
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,
......
......@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() final = default;
private:
inline FeedFetchList RunImpl(const std::vector<std::string> &fetch_tensors);
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册