From 84498198ae1ed29f1f2da899cabf72fd666f8a80 Mon Sep 17 00:00:00 2001 From: "baolei.an" Date: Tue, 28 Apr 2020 18:23:05 +0800 Subject: [PATCH] [LITE][BM] multi device ok,test=develop --- lite/backends/bm/target_wrapper.cc | 13 +++++++------ lite/backends/bm/target_wrapper.h | 1 + lite/core/context.h | 27 +++++++++++++-------------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/lite/backends/bm/target_wrapper.cc b/lite/backends/bm/target_wrapper.cc index c75c714522..6dab2a574d 100644 --- a/lite/backends/bm/target_wrapper.cc +++ b/lite/backends/bm/target_wrapper.cc @@ -24,16 +24,17 @@ std::map TargetWrapperBM::bm_hds_; size_t TargetWrapperBM::num_devices() { int count = 0; - bm_dev_getcount(&count); + bm_status_t ret = bm_dev_getcount(&count); + CHECK_EQ(ret, BM_SUCCESS) << "Failed with error code: " + << static_cast(ret); return count; } +int TargetWrapperBM::GetDevice() { return device_id_; } void TargetWrapperBM::SetDevice(int id) { - /* - if (id < 0 || (size_t)id >= num_devices()) { - LOG(FATAL) << "Failed with invalid device id " << id; - } - */ + if (id < 0 || (size_t)id >= num_devices()) { + LOG(FATAL) << "Failed with invalid device id " << id; + } device_id_ = id; if (bm_hds_.find(id) == bm_hds_.end()) { bm_handle_t bm_handle; diff --git a/lite/backends/bm/target_wrapper.h b/lite/backends/bm/target_wrapper.h index 2674ffe161..db65b598b5 100644 --- a/lite/backends/bm/target_wrapper.h +++ b/lite/backends/bm/target_wrapper.h @@ -31,6 +31,7 @@ class TargetWrapper { static size_t maximum_stream() { return 0; } static void SetDevice(int id); + static int GetDevice(); static void CreateStream(stream_t* stream) {} static void DestroyStream(const stream_t& stream) {} diff --git a/lite/core/context.h b/lite/core/context.h index d0c1bd93cc..51605678d5 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -110,9 +110,8 @@ class Context { Context() {} explicit Context(const BMContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler - void InitOnce() { Init(0); } + void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); } - void Init(int dev_id) { TargetWrapperBM::SetDevice(dev_id); } void CopySharedTo(BMContext* ctx) {} void* GetHandle() { return TargetWrapperBM::GetHandle(); } @@ -151,23 +150,14 @@ class Context { if (_tls_raw_ctx == nullptr) { _tls_raw_ctx = xdnn::create_context(); CHECK(_tls_raw_ctx); - int r = xdnn::set_workspace_l3_size(_tls_raw_ctx, - _workspace_l3_size_per_thread); - if (r != 0) { - LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r - << ", _workspace_l3_size_per_thread = " - << _workspace_l3_size_per_thread; - } } return _tls_raw_ctx; } static void SetWorkspaceL3Size(int l3_size = 0xfffc00) { - _workspace_l3_size_per_thread = l3_size; + xdnn::set_workspace_l3_size(GetRawContext(), l3_size); } - // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker - // thread static void SetDev(int dev_no = 0) { const char* dev_env = getenv("LITE_XPU_DEV"); if (dev_env) { @@ -182,7 +172,6 @@ class Context { private: static thread_local xdnn::Context* _tls_raw_ctx; - static int _workspace_l3_size_per_thread; }; #endif @@ -350,17 +339,27 @@ class Context { template <> class Context { std::shared_ptr cl_context_; + using WaitListType = + std::unordered_map(nullptr)), + std::shared_ptr>; + std::shared_ptr cl_wait_list_; public: CLContext* cl_context() { return cl_context_.get(); } + WaitListType* cl_wait_list() { return cl_wait_list_.get(); } void InitOnce() { // Init cl runtime. CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed"; + cl_context_ = std::make_shared(); + cl_wait_list_ = std::make_shared(); } - void CopySharedTo(OpenCLContext* ctx) { ctx->cl_context_ = cl_context_; } + void CopySharedTo(OpenCLContext* ctx) { + ctx->cl_context_ = cl_context_; + ctx->cl_wait_list_ = cl_wait_list_; + } }; #endif -- GitLab