提交 84498198 编写于 作者: B baolei.an

[LITE][BM] multi device ok,test=develop

上级 10ce7d3c
...@@ -24,16 +24,17 @@ std::map<int, void*> TargetWrapperBM::bm_hds_; ...@@ -24,16 +24,17 @@ std::map<int, void*> TargetWrapperBM::bm_hds_;
size_t TargetWrapperBM::num_devices() { size_t TargetWrapperBM::num_devices() {
int count = 0; int count = 0;
bm_dev_getcount(&count); bm_status_t ret = bm_dev_getcount(&count);
CHECK_EQ(ret, BM_SUCCESS) << "Failed with error code: "
<< static_cast<int>(ret);
return count; return count;
} }
int TargetWrapperBM::GetDevice() { return device_id_; }
void TargetWrapperBM::SetDevice(int id) { void TargetWrapperBM::SetDevice(int id) {
/* if (id < 0 || (size_t)id >= num_devices()) {
if (id < 0 || (size_t)id >= num_devices()) { LOG(FATAL) << "Failed with invalid device id " << id;
LOG(FATAL) << "Failed with invalid device id " << id; }
}
*/
device_id_ = id; device_id_ = id;
if (bm_hds_.find(id) == bm_hds_.end()) { if (bm_hds_.find(id) == bm_hds_.end()) {
bm_handle_t bm_handle; bm_handle_t bm_handle;
......
...@@ -31,6 +31,7 @@ class TargetWrapper<TARGET(kBM)> { ...@@ -31,6 +31,7 @@ class TargetWrapper<TARGET(kBM)> {
static size_t maximum_stream() { return 0; } static size_t maximum_stream() { return 0; }
static void SetDevice(int id); static void SetDevice(int id);
static int GetDevice();
static void CreateStream(stream_t* stream) {} static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {} static void DestroyStream(const stream_t& stream) {}
......
...@@ -110,9 +110,8 @@ class Context<TargetType::kBM> { ...@@ -110,9 +110,8 @@ class Context<TargetType::kBM> {
Context() {} Context() {}
explicit Context(const BMContext& ctx); explicit Context(const BMContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { Init(0); } void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); }
void Init(int dev_id) { TargetWrapperBM::SetDevice(dev_id); }
void CopySharedTo(BMContext* ctx) {} void CopySharedTo(BMContext* ctx) {}
void* GetHandle() { return TargetWrapperBM::GetHandle(); } void* GetHandle() { return TargetWrapperBM::GetHandle(); }
...@@ -151,23 +150,14 @@ class Context<TargetType::kXPU> { ...@@ -151,23 +150,14 @@ class Context<TargetType::kXPU> {
if (_tls_raw_ctx == nullptr) { if (_tls_raw_ctx == nullptr) {
_tls_raw_ctx = xdnn::create_context(); _tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx); CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
} }
return _tls_raw_ctx; return _tls_raw_ctx;
} }
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) { static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
_workspace_l3_size_per_thread = l3_size; xdnn::set_workspace_l3_size(GetRawContext(), l3_size);
} }
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) { static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV"); const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) { if (dev_env) {
...@@ -182,7 +172,6 @@ class Context<TargetType::kXPU> { ...@@ -182,7 +172,6 @@ class Context<TargetType::kXPU> {
private: private:
static thread_local xdnn::Context* _tls_raw_ctx; static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
}; };
#endif #endif
...@@ -350,17 +339,27 @@ class Context<TargetType::kX86> { ...@@ -350,17 +339,27 @@ class Context<TargetType::kX86> {
template <> template <>
class Context<TargetType::kOpenCL> { class Context<TargetType::kOpenCL> {
std::shared_ptr<CLContext> cl_context_; std::shared_ptr<CLContext> cl_context_;
using WaitListType =
std::unordered_map<decltype(static_cast<const void*>(nullptr)),
std::shared_ptr<cl::Event>>;
std::shared_ptr<WaitListType> cl_wait_list_;
public: public:
CLContext* cl_context() { return cl_context_.get(); } CLContext* cl_context() { return cl_context_.get(); }
WaitListType* cl_wait_list() { return cl_wait_list_.get(); }
void InitOnce() { void InitOnce() {
// Init cl runtime. // Init cl runtime.
CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed"; CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed";
cl_context_ = std::make_shared<CLContext>(); cl_context_ = std::make_shared<CLContext>();
cl_wait_list_ = std::make_shared<WaitListType>();
} }
void CopySharedTo(OpenCLContext* ctx) { ctx->cl_context_ = cl_context_; } void CopySharedTo(OpenCLContext* ctx) {
ctx->cl_context_ = cl_context_;
ctx->cl_wait_list_ = cl_wait_list_;
}
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册