提交 8149a07a 编写于 作者: M minqiyang

Fix wait stream two times bug

test=develop
上级 6fabbd8f
......@@ -25,7 +25,7 @@ struct ExecutionStrategy {
size_t num_threads_{0};
bool use_cuda_{true};
bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{1};
size_t num_iteration_per_drop_scope_{100};
ExecutorType type_{kDefault};
bool dry_run_{false};
};
......
......@@ -66,17 +66,15 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
++drop_scope_counter_;
bool stream_end = false;
if (!fetch_tensors.empty()) {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
WaitComputationalStreams();
stream_end = true;
}
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
if (!stream_end) {
WaitComputationalStreams();
}
for (auto &scope : local_scopes_) {
......
......@@ -47,6 +47,14 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
inline void WaitComputationalStreams() {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
private:
size_t drop_scope_counter_{0};
......
......@@ -815,7 +815,7 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations. Default 1.
because the temp variable's shape maybe the same between two iterations. Default 100.
NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册