提交 84498198 编写于 作者: B baolei.an

[LITE][BM] multi device ok,test=develop

上级 10ce7d3c
......@@ -24,16 +24,17 @@ std::map<int, void*> TargetWrapperBM::bm_hds_;
size_t TargetWrapperBM::num_devices() {
int count = 0;
bm_dev_getcount(&count);
bm_status_t ret = bm_dev_getcount(&count);
CHECK_EQ(ret, BM_SUCCESS) << "Failed with error code: "
<< static_cast<int>(ret);
return count;
}
int TargetWrapperBM::GetDevice() { return device_id_; }
void TargetWrapperBM::SetDevice(int id) {
/*
if (id < 0 || (size_t)id >= num_devices()) {
LOG(FATAL) << "Failed with invalid device id " << id;
}
*/
device_id_ = id;
if (bm_hds_.find(id) == bm_hds_.end()) {
bm_handle_t bm_handle;
......
......@@ -31,6 +31,7 @@ class TargetWrapper<TARGET(kBM)> {
static size_t maximum_stream() { return 0; }
static void SetDevice(int id);
static int GetDevice();
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
......
......@@ -110,9 +110,8 @@ class Context<TargetType::kBM> {
Context() {}
explicit Context(const BMContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { Init(0); }
void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); }
void Init(int dev_id) { TargetWrapperBM::SetDevice(dev_id); }
void CopySharedTo(BMContext* ctx) {}
void* GetHandle() { return TargetWrapperBM::GetHandle(); }
......@@ -151,23 +150,14 @@ class Context<TargetType::kXPU> {
if (_tls_raw_ctx == nullptr) {
_tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
}
return _tls_raw_ctx;
}
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
_workspace_l3_size_per_thread = l3_size;
xdnn::set_workspace_l3_size(GetRawContext(), l3_size);
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
......@@ -182,7 +172,6 @@ class Context<TargetType::kXPU> {
private:
static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
};
#endif
......@@ -350,17 +339,27 @@ class Context<TargetType::kX86> {
template <>
class Context<TargetType::kOpenCL> {
std::shared_ptr<CLContext> cl_context_;
using WaitListType =
std::unordered_map<decltype(static_cast<const void*>(nullptr)),
std::shared_ptr<cl::Event>>;
std::shared_ptr<WaitListType> cl_wait_list_;
public:
CLContext* cl_context() { return cl_context_.get(); }
WaitListType* cl_wait_list() { return cl_wait_list_.get(); }
void InitOnce() {
// Init cl runtime.
CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed";
cl_context_ = std::make_shared<CLContext>();
cl_wait_list_ = std::make_shared<WaitListType>();
}
void CopySharedTo(OpenCLContext* ctx) { ctx->cl_context_ = cl_context_; }
void CopySharedTo(OpenCLContext* ctx) {
ctx->cl_context_ = cl_context_;
ctx->cl_wait_list_ = cl_wait_list_;
}
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册