提交 a0713b7b 编写于 作者: xiebaiyuan's avatar xiebaiyuan

add init check interface

上级 76ca8f8b
...@@ -27,9 +27,9 @@ bool CLEngine::Init() { ...@@ -27,9 +27,9 @@ bool CLEngine::Init() {
return true; return true;
} }
cl_int status; cl_int status;
SetPlatform(); bool is_setplatform_success = SetPlatform();
SetClDeviceId(); bool is_setcldeviceid_success = SetClDeviceId();
is_init_success_ = is_setplatform_success && is_setcldeviceid_success;
initialized_ = true; initialized_ = true;
return initialized_; return initialized_;
// setClCommandQueue(); // setClCommandQueue();
...@@ -44,11 +44,14 @@ CLEngine *CLEngine::Instance() { ...@@ -44,11 +44,14 @@ CLEngine *CLEngine::Instance() {
return &cl_engine_; return &cl_engine_;
} }
bool CLEngine::isInitSuccess() { return is_init_success_; }
bool CLEngine::SetPlatform() { bool CLEngine::SetPlatform() {
platform_ = NULL; // the chosen platform platform_ = NULL; // the chosen platform
cl_uint numPlatforms; // the NO. of platforms cl_uint numPlatforms; // the NO. of platforms
cl_int status = clGetPlatformIDs(0, NULL, &numPlatforms); cl_int status = clGetPlatformIDs(0, NULL, &numPlatforms);
if (status != CL_SUCCESS) {
return false;
}
/**For clarity, choose the first available platform. */ /**For clarity, choose the first available platform. */
if (numPlatforms > 0) { if (numPlatforms > 0) {
cl_platform_id *platforms = reinterpret_cast<cl_platform_id *>( cl_platform_id *platforms = reinterpret_cast<cl_platform_id *>(
...@@ -56,10 +59,10 @@ bool CLEngine::SetPlatform() { ...@@ -56,10 +59,10 @@ bool CLEngine::SetPlatform() {
status = clGetPlatformIDs(numPlatforms, platforms, NULL); status = clGetPlatformIDs(numPlatforms, platforms, NULL);
platform_ = platforms[0]; platform_ = platforms[0];
free(platforms); free(platforms);
return true; return status == CL_SUCCESS;
} else {
return false;
} }
return false;
} }
bool CLEngine::SetClDeviceId() { bool CLEngine::SetClDeviceId() {
...@@ -67,13 +70,15 @@ bool CLEngine::SetClDeviceId() { ...@@ -67,13 +70,15 @@ bool CLEngine::SetClDeviceId() {
devices_ = NULL; devices_ = NULL;
cl_int status = cl_int status =
clGetDeviceIDs(platform_, CL_DEVICE_TYPE_GPU, 0, NULL, &numDevices); clGetDeviceIDs(platform_, CL_DEVICE_TYPE_GPU, 0, NULL, &numDevices);
if (status != CL_SUCCESS) {
return false;
}
if (numDevices > 0) { if (numDevices > 0) {
devices_ = reinterpret_cast<cl_device_id *>( devices_ = reinterpret_cast<cl_device_id *>(
malloc(numDevices * sizeof(cl_device_id))); malloc(numDevices * sizeof(cl_device_id)));
status = clGetDeviceIDs(platform_, CL_DEVICE_TYPE_GPU, numDevices, devices_, status = clGetDeviceIDs(platform_, CL_DEVICE_TYPE_GPU, numDevices, devices_,
NULL); NULL);
return true; return status == CL_SUCCESS;
} }
return false; return false;
} }
......
...@@ -31,7 +31,7 @@ class CLEngine { ...@@ -31,7 +31,7 @@ class CLEngine {
static CLEngine *Instance(); static CLEngine *Instance();
bool Init(); bool Init();
bool isInitSuccess();
std::unique_ptr<_cl_context, CLContextDeleter> CreateContext() { std::unique_ptr<_cl_context, CLContextDeleter> CreateContext() {
cl_int status; cl_int status;
cl_context c = clCreateContext(NULL, 1, devices_, NULL, NULL, &status); cl_context c = clCreateContext(NULL, 1, devices_, NULL, NULL, &status);
...@@ -51,6 +51,20 @@ class CLEngine { ...@@ -51,6 +51,20 @@ class CLEngine {
return std::move(command_queue_ptr); return std::move(command_queue_ptr);
} }
cl_context getContext() {
if (context_ == nullptr) {
context_ = CreateContext();
}
return context_.get();
}
cl_command_queue getClCommandQueue() {
if (command_queue_ == nullptr) {
command_queue_ = CreateClCommandQueue(getContext());
}
return command_queue_.get();
}
std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith( std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith(
cl_context context, std::string file_name) { cl_context context, std::string file_name) {
FILE *file = fopen(file_name.c_str(), "rb"); FILE *file = fopen(file_name.c_str(), "rb");
...@@ -137,6 +151,11 @@ class CLEngine { ...@@ -137,6 +151,11 @@ class CLEngine {
std::string cl_path_; std::string cl_path_;
std::unique_ptr<_cl_program, CLProgramDeleter> program_; std::unique_ptr<_cl_program, CLProgramDeleter> program_;
std::unique_ptr<_cl_context, CLContextDeleter> context_ = nullptr;
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_ =
nullptr;
// bool SetClContext(); // bool SetClContext();
// bool SetClCommandQueue(); // bool SetClCommandQueue();
...@@ -144,6 +163,7 @@ class CLEngine { ...@@ -144,6 +163,7 @@ class CLEngine {
// bool LoadKernelFromFile(const char *kernel_file); // bool LoadKernelFromFile(const char *kernel_file);
// bool BuildProgram(); // bool BuildProgram();
bool is_init_success_ = false;
}; };
} // namespace framework } // namespace framework
......
...@@ -29,12 +29,12 @@ namespace framework { ...@@ -29,12 +29,12 @@ namespace framework {
class CLScope { class CLScope {
public: public:
CLScope() { CLScope() {
CLEngine *engin = CLEngine::Instance(); CLEngine *engine = CLEngine::Instance();
context_ = engin->CreateContext(); context_ = engine->getContext();
command_queue_ = engin->CreateClCommandQueue(context_.get()); command_queue_ = engine->getClCommandQueue();
} }
cl_command_queue CommandQueue() { return command_queue_.get(); } cl_command_queue CommandQueue() { return command_queue_; }
std::unique_ptr<_cl_kernel, CLKernelDeleter> GetKernel( std::unique_ptr<_cl_kernel, CLKernelDeleter> GetKernel(
const std::string &kernel_name, const std::string &file_name) { const std::string &kernel_name, const std::string &file_name) {
...@@ -49,7 +49,7 @@ class CLScope { ...@@ -49,7 +49,7 @@ class CLScope {
return std::move(kernel); return std::move(kernel);
} }
cl_context Context() { return context_.get(); } cl_context Context() { return context_; }
cl_program Program(const std::string &file_name) { cl_program Program(const std::string &file_name) {
auto it = programs_.find(file_name); auto it = programs_.find(file_name);
...@@ -58,7 +58,7 @@ class CLScope { ...@@ -58,7 +58,7 @@ class CLScope {
} }
auto program = CLEngine::Instance()->CreateProgramWith( auto program = CLEngine::Instance()->CreateProgramWith(
context_.get(), context_,
CLEngine::Instance()->GetCLPath() + "/cl_kernel/" + file_name); CLEngine::Instance()->GetCLPath() + "/cl_kernel/" + file_name);
DLOG << " --- begin build program -> " << file_name << " --- "; DLOG << " --- begin build program -> " << file_name << " --- ";
...@@ -72,8 +72,8 @@ class CLScope { ...@@ -72,8 +72,8 @@ class CLScope {
private: private:
cl_int status_; cl_int status_;
std::unique_ptr<_cl_context, CLContextDeleter> context_; cl_context context_;
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_; cl_command_queue command_queue_;
std::unordered_map<std::string, std::unordered_map<std::string,
std::unique_ptr<_cl_program, CLProgramDeleter>> std::unique_ptr<_cl_program, CLProgramDeleter>>
programs_; programs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册