未验证 提交 14d039e4 编写于 作者: W WangXi 提交者: GitHub

Fix the problem that the number of ops executed by xpu is wrong (#30961)

上级 8e72e031
...@@ -131,13 +131,13 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ...@@ -131,13 +131,13 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
platform::XPUPlace cur_place; platform::XPUPlace cur_place;
std::size_t cur_count = 0; std::size_t cur_count = 0;
while (cur_count < op_deps_.size()) { while (cur_count < op_deps->size()) {
cur_count++; cur_count++;
auto cur_op = ready_ops->Pop(); auto cur_op = ready_ops->Pop();
// when execption, get cur_op == nullptr // when execption, get cur_op == nullptr
if (cur_op == nullptr) { if (cur_op == nullptr) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_ = op_deps_.size(); exec_op_count_ = op_deps->size();
break; break;
} }
auto dev_ctxes_ = cur_op->DeviceContext(); auto dev_ctxes_ = cur_op->DeviceContext();
...@@ -153,7 +153,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ...@@ -153,7 +153,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
} }
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); }); cv_.wait(lock, [&] { return exec_op_count_ >= op_deps->size(); });
} }
if (exception_.IsCaught()) { if (exception_.IsCaught()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册