未验证 提交 29467060 编写于 作者: W WangXi 提交者: GitHub

[cherry-pick 2.0.1] [kunlun] fix xpu bind threaded executor (#31116)

* [Kunlun] Add condition_variable and notify() in BindThreadedSSAGraphExecutor (#30586)

* [Kunlun] fix dead lock for exec_op_count_ (#30718)

* Fix the problem that the number of ops executed by xpu is wrong (#30961)
Co-authored-by: Nliuyuhui <liuyuhui@baidu.com>
上级 b582be2d
...@@ -30,9 +30,6 @@ namespace paddle { ...@@ -30,9 +30,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static std::atomic<unsigned int> exec_op_count_;
static std::atomic<int> error_state;
BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor( BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes, const std::vector<Scope *> &local_exec_scopes,
...@@ -126,18 +123,21 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ...@@ -126,18 +123,21 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
ready_ops->Push(cur_op); ready_ops->Push(cur_op);
} }
exec_op_count_ = 0; {
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_ = 0;
}
platform::XPUPlace cur_place; platform::XPUPlace cur_place;
std::size_t cur_count = 0; std::size_t cur_count = 0;
while (cur_count < op_deps_.size()) { while (cur_count < op_deps->size()) {
cur_count++; cur_count++;
auto cur_op = ready_ops->Pop(); auto cur_op = ready_ops->Pop();
// when execption, get cur_op == nullptr
if (cur_op == nullptr) { if (cur_op == nullptr) {
// sleep a while to make sure worker thread quit std::lock_guard<std::mutex> lock(mutex_);
sleep(10); exec_op_count_ = op_deps->size();
exec_op_count_ = op_deps_.size();
break; break;
} }
auto dev_ctxes_ = cur_op->DeviceContext(); auto dev_ctxes_ = cur_op->DeviceContext();
...@@ -151,14 +151,17 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ...@@ -151,14 +151,17 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index); RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index);
} }
} }
while (exec_op_count_ < op_deps_.size()) { {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return exec_op_count_ >= op_deps->size(); });
} }
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
if (exception_.IsCaught()) { if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops); ExecutionFinal(&fetch_ops);
} }
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches; return fetches;
} }
...@@ -222,7 +225,8 @@ void BindThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -222,7 +225,8 @@ void BindThreadedSSAGraphExecutor::InsertFetchOps(
} }
} }
} }
// RunMultiDeviceOpAsync function is used for Communicated OPs
// like all_reduce\broadcast among multicards.
void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
OpHandleBase *op, OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps, std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
...@@ -256,10 +260,14 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( ...@@ -256,10 +260,14 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
ready_ops->Push(nullptr); ready_ops->Push(nullptr);
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
} }
exec_op_count_++; {
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_++;
cv_.notify_all();
}
}); });
} }
// RunOpAsyncMainStream function is used for computed OPs
void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
OpHandleBase *op, OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps, std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
...@@ -285,7 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( ...@@ -285,7 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
ready_ops->Push(nullptr); ready_ops->Push(nullptr);
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
} }
exec_op_count_++; {
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_++;
cv_.notify_all();
}
}); });
} }
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <ThreadPool.h> #include <ThreadPool.h>
#include <condition_variable> // NOLINT
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -76,6 +78,11 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -76,6 +78,11 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
::ThreadPool multi_device_op_pool_; ::ThreadPool multi_device_op_pool_;
std::mutex mutex_;
std::condition_variable cv_;
uint32_t exec_op_count_;
std::atomic<int> error_state;
void RunOpAsyncMainStream( void RunOpAsyncMainStream(
OpHandleBase *op, OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps, std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册