diff --git a/lite/backends/opencl/cl_context.cc b/lite/backends/opencl/cl_context.cc index 153c0620035377afac065e8049a9ebbc0a6f0c15..75a34795418a60108a41690c7293d2ecd46c0545 100644 --- a/lite/backends/opencl/cl_context.cc +++ b/lite/backends/opencl/cl_context.cc @@ -45,7 +45,6 @@ cl::Program &CLContext::GetProgram(const std::string &file_name, } auto program = CLRuntime::Global()->CreateProgram(GetContext(), file_name); - VLOG(3) << " --- begin build program -> " << program_key << " --- "; CLRuntime::Global()->BuildProgram(program.get(), options); VLOG(3) << " --- end build program -> " << program_key << " --- "; @@ -77,6 +76,23 @@ void CLContext::AddKernel(const std::string &kernel_name, kernel_offset_map[kernel_key.str()] = kernels.size() - 1; } +std::shared_ptr CLContext::CreateKernel( + const std::string &kernel_name, + const std::string &file_name, + const std::string &options, + const std::string &time_stamp) { + cl_int status{CL_SUCCESS}; + VLOG(3) << " --- to get program " << file_name << " --- "; + auto program = GetProgram(file_name, options); + VLOG(3) << " --- end get program --- "; + VLOG(3) << " --- to create kernel: " << kernel_name << " --- "; + std::shared_ptr kernel( + new cl::Kernel(program, kernel_name.c_str(), &status)); + CL_CHECK_FATAL(status); + VLOG(3) << " --- end create kernel --- "; + return kernel; +} + cl::Kernel &CLContext::GetKernel(const int index) { auto &kernels = CLRuntime::Global()->kernels(); VLOG(3) << " --- kernel count: " << kernels.size() << " --- "; diff --git a/lite/backends/opencl/cl_context.h b/lite/backends/opencl/cl_context.h index b12473ccf5b4238f4ee95b7848a0842ee5b2ffe0..4d609ab3a4fa3e12e87e3c3f8c1345dc6c1b5326 100644 --- a/lite/backends/opencl/cl_context.h +++ b/lite/backends/opencl/cl_context.h @@ -39,6 +39,11 @@ class CLContext { const std::string &options = "", const std::string &time_stamp = ""); + std::shared_ptr CreateKernel(const std::string &kernel_name, + const std::string &file_name, + const std::string &options = "", + const std::string &time_stamp = ""); + cl::Kernel &GetKernel(const int index); cl::Kernel &GetKernel(const std::string &name);