提交 8149a07a 编写于 作者: M minqiyang

Fix wait stream two times bug

test=develop
上级 6fabbd8f
...@@ -25,7 +25,7 @@ struct ExecutionStrategy { ...@@ -25,7 +25,7 @@ struct ExecutionStrategy {
size_t num_threads_{0}; size_t num_threads_{0};
bool use_cuda_{true}; bool use_cuda_{true};
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{1}; size_t num_iteration_per_drop_scope_{100};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
}; };
......
...@@ -66,17 +66,15 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -66,17 +66,15 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr); platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
++drop_scope_counter_; ++drop_scope_counter_;
bool stream_end = false;
if (!fetch_tensors.empty()) { if (!fetch_tensors.empty()) {
// Wait All computational streams WaitComputationalStreams();
for (auto p : places_) { stream_end = true;
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
} }
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
// Wait All computational streams if (!stream_end) {
for (auto p : places_) { WaitComputationalStreams();
platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {
......
...@@ -47,6 +47,14 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -47,6 +47,14 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override; FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
inline void WaitComputationalStreams() {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
private: private:
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};
......
...@@ -815,7 +815,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -815,7 +815,7 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster, is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations. Default 1. because the temp variable's shape maybe the same between two iterations. Default 100.
NOTES: NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor 1. If you fetch data when calling the 'run', the ParallelExecutor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册