提交 36deb84b 编写于 作者: B baolei.an

[LITE][BM] support multi cards,test=develop

上级 84498198
...@@ -111,7 +111,6 @@ class Context<TargetType::kBM> { ...@@ -111,7 +111,6 @@ class Context<TargetType::kBM> {
explicit Context(const BMContext& ctx); explicit Context(const BMContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); } void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); }
void CopySharedTo(BMContext* ctx) {} void CopySharedTo(BMContext* ctx) {}
void* GetHandle() { return TargetWrapperBM::GetHandle(); } void* GetHandle() { return TargetWrapperBM::GetHandle(); }
...@@ -150,14 +149,23 @@ class Context<TargetType::kXPU> { ...@@ -150,14 +149,23 @@ class Context<TargetType::kXPU> {
if (_tls_raw_ctx == nullptr) { if (_tls_raw_ctx == nullptr) {
_tls_raw_ctx = xdnn::create_context(); _tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx); CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
} }
return _tls_raw_ctx; return _tls_raw_ctx;
} }
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) { static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
xdnn::set_workspace_l3_size(GetRawContext(), l3_size); _workspace_l3_size_per_thread = l3_size;
} }
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) { static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV"); const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) { if (dev_env) {
...@@ -172,6 +180,7 @@ class Context<TargetType::kXPU> { ...@@ -172,6 +180,7 @@ class Context<TargetType::kXPU> {
private: private:
static thread_local xdnn::Context* _tls_raw_ctx; static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
}; };
#endif #endif
...@@ -339,27 +348,17 @@ class Context<TargetType::kX86> { ...@@ -339,27 +348,17 @@ class Context<TargetType::kX86> {
template <> template <>
class Context<TargetType::kOpenCL> { class Context<TargetType::kOpenCL> {
std::shared_ptr<CLContext> cl_context_; std::shared_ptr<CLContext> cl_context_;
using WaitListType =
std::unordered_map<decltype(static_cast<const void*>(nullptr)),
std::shared_ptr<cl::Event>>;
std::shared_ptr<WaitListType> cl_wait_list_;
public: public:
CLContext* cl_context() { return cl_context_.get(); } CLContext* cl_context() { return cl_context_.get(); }
WaitListType* cl_wait_list() { return cl_wait_list_.get(); }
void InitOnce() { void InitOnce() {
// Init cl runtime. // Init cl runtime.
CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed"; CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed";
cl_context_ = std::make_shared<CLContext>(); cl_context_ = std::make_shared<CLContext>();
cl_wait_list_ = std::make_shared<WaitListType>();
} }
void CopySharedTo(OpenCLContext* ctx) { void CopySharedTo(OpenCLContext* ctx) { ctx->cl_context_ = cl_context_; }
ctx->cl_context_ = cl_context_;
ctx->cl_wait_list_ = cl_wait_list_;
}
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册