提交 77ea99f5 编写于 作者: L liuqi

Add dynamic build opencl kernel logic.

上级 156c128d
...@@ -83,6 +83,7 @@ bool BuildProgram(OpenCLRuntime *runtime, ...@@ -83,6 +83,7 @@ bool BuildProgram(OpenCLRuntime *runtime,
} // namespace } // namespace
OpenCLRuntime *OpenCLRuntime::Get() { OpenCLRuntime *OpenCLRuntime::Get() {
static std::once_flag init_once; static std::once_flag init_once;
static OpenCLRuntime *instance = nullptr; static OpenCLRuntime *instance = nullptr;
...@@ -140,7 +141,10 @@ OpenCLRuntime *OpenCLRuntime::Get() { ...@@ -140,7 +141,10 @@ OpenCLRuntime *OpenCLRuntime::Get() {
OpenCLRuntime::OpenCLRuntime(cl::Context context, OpenCLRuntime::OpenCLRuntime(cl::Context context,
cl::Device device, cl::Device device,
cl::CommandQueue command_queue) cl::CommandQueue command_queue)
: context_(context), device_(device), command_queue_(command_queue) {} : context_(context), device_(device), command_queue_(command_queue) {
const char *kernel_path = getenv("MACE_KERNEL_PATH");
kernel_path_ = std::string(kernel_path == nullptr ? "" : kernel_path) + "/";
}
OpenCLRuntime::~OpenCLRuntime() {} OpenCLRuntime::~OpenCLRuntime() {}
...@@ -162,6 +166,65 @@ cl::Program &OpenCLRuntime::program() { ...@@ -162,6 +166,65 @@ cl::Program &OpenCLRuntime::program() {
return program_; return program_;
} }
const std::unodered_map<std::string, std::string>
OpenCLRuntime::kernel_program_map_ = {
{"BatchNorm", "batch_norm.cl"}
};
bool OpenCLRuntime::BuildProgram(const std::string &kernel_name,
const std::string &build_options,
cl::Program *program) {
MACE_CHECK_NOTNULL(program);
cl::Program::Sources sources;
std::string filename = kernel_path_ + kernel_name;
std::string kernel_source;
MACE_CHECK(ReadSourceFile(filename, &kernel_source));
sources.push_back({kernel_source.c_str(), kernel_source.length()});
*program = cl::Program(this->context(), sources);
build_options += " -Werror -cl-mad-enable -cl-fast-relaxed-math -I" + path;
// TODO(heliangliang) -cl-unsafe-math-optimizations -cl-fast-relaxed-math
cl_int ret = program->build({runtime->device()}, build_options.c_str());
if (ret != CL_SUCCESS) {
if (program->getBuildInfo<CL_PROGRAM_BUILD_STATUS>(runtime->device()) ==
CL_BUILD_ERROR) {
std::string build_log =
program->getBuildInfo<CL_PROGRAM_BUILD_LOG>(runtime->device());
LOG(INFO) << "Program build log: " << build_log;
}
LOG(FATAL) << "Build program failed: " << ret;
}
return true;
}
cl::Kernel OpenCLRuntime::BuildKernel(const std::string &kernel_name,
const std::set<std::string> &build_options) {
auto kernel_program_it = kernel_program_map_.find(kernel_name);
if (kernel_program_it == kernel_program_map_.end()) {
MACE_CHECK(false, kernel_name, " opencl kernel doesn't exist.");
}
std::string program_name = kernel_program_it->second;
std::string build_options_str;
for(auto &option : build_options) {
build_options_str += " " + option;
}
std::string built_program_key = program_name + build_options_str;
auto built_program_it = built_program_map_.find(built_program_key);
cl::Program program;
if (built_program_it != built_program_map_.end()) {
program = built_program_it->second;
} else {
this->BuildProgram(kernel_name, build_options_str, &program);
built_program_map_.emplace(built_program_key, std::move(program));
}
return cl::Kernel(kernel_name, program);
}
uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() { uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
unsigned long long size = 0; unsigned long long size = 0;
device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size); device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size);
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <unordered_map>
#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_wrapper.h" #include "mace/core/runtime/opencl/opencl_wrapper.h"
...@@ -17,12 +18,15 @@ class OpenCLRuntime { ...@@ -17,12 +18,15 @@ class OpenCLRuntime {
public: public:
static OpenCLRuntime *Get(); static OpenCLRuntime *Get();
uint32_t GetDeviceMaxWorkGroupSize();
uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Context &context(); cl::Context &context();
cl::Device &device(); cl::Device &device();
cl::CommandQueue &command_queue(); cl::CommandQueue &command_queue();
cl::Program &program(); cl::Program &program();
uint32_t GetDeviceMaxWorkGroupSize();
uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Kernel BuildKernel(const std::string &kernel_name,
const std::set<std::string> &build_options);
private: private:
OpenCLRuntime(cl::Context context, OpenCLRuntime(cl::Context context,
cl::Device device, cl::Device device,
...@@ -31,12 +35,21 @@ class OpenCLRuntime { ...@@ -31,12 +35,21 @@ class OpenCLRuntime {
OpenCLRuntime(const OpenCLRuntime&) = delete; OpenCLRuntime(const OpenCLRuntime&) = delete;
OpenCLRuntime &operator=(const OpenCLRuntime&) = delete; OpenCLRuntime &operator=(const OpenCLRuntime&) = delete;
bool BuildProgram(const std::string &kernel_name,
const std::string &build_options,
cl::Program *program);
private: private:
cl::Context context_; cl::Context context_;
cl::Device device_; cl::Device device_;
cl::CommandQueue command_queue_; cl::CommandQueue command_queue_;
cl::Program program_; cl::Program program_;
std::once_flag build_flag_; std::once_flag build_flag_;
std::string kernel_path_;
static const std::unordered_map<std::string,
std::string> kernel_program_map_;
mutable std::unordered_map<std::string,
cl::Program> built_program_map_;
}; };
} // namespace mace } // namespace mace
......
...@@ -76,8 +76,9 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()( ...@@ -76,8 +76,9 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *epsilon, const Tensor *epsilon,
Tensor *output); Tensor *output);
template <> template <typename T>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( struct BatchNormFunctor<DeviceType::OPENCL, T> {
void operator()(
const Tensor *input, const Tensor *input,
const Tensor *scale, const Tensor *scale,
const Tensor *offset, const Tensor *offset,
...@@ -85,6 +86,7 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -85,6 +86,7 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *var, const Tensor *var,
const Tensor *epsilon, const Tensor *epsilon,
Tensor *output); Tensor *output);
};
} // namepsace kernels } // namepsace kernels
} // namespace mace } // namespace mace
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <> template <typename T>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input, const Tensor *input,
const Tensor *scale, const Tensor *scale,
const Tensor *offset, const Tensor *offset,
...@@ -27,10 +27,10 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -27,10 +27,10 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
static_cast<uint32_t>(input->dim(1)), static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(blocks)}; static_cast<uint32_t>(blocks)};
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); std::set<std::string> built_options;
auto bm_kernel = cl::Kernel(program, "batch_norm"); built_options.emplace("-DDataType=" + GetDataTypeFromEnum(input->dtype()));
auto bm_kernel = runtime->CreateKernel("batch_norm");
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
const std::vector<uint32_t> lws = {1, 1, kwg_size}; const std::vector<uint32_t> lws = {1, 1, kwg_size};
......
...@@ -23,6 +23,7 @@ enum DataType { ...@@ -23,6 +23,7 @@ enum DataType {
DT_INT64 = 8; DT_INT64 = 8;
DT_UINT16 = 9; DT_UINT16 = 9;
DT_BOOL = 10; DT_BOOL = 10;
DT_HALF = 19;
} }
message TensorProto { message TensorProto {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册