提交 77ea99f5 编写于 作者: L liuqi

Add dynamic build opencl kernel logic.

上级 156c128d
......@@ -83,6 +83,7 @@ bool BuildProgram(OpenCLRuntime *runtime,
} // namespace
OpenCLRuntime *OpenCLRuntime::Get() {
static std::once_flag init_once;
static OpenCLRuntime *instance = nullptr;
......@@ -140,7 +141,10 @@ OpenCLRuntime *OpenCLRuntime::Get() {
OpenCLRuntime::OpenCLRuntime(cl::Context context,
cl::Device device,
cl::CommandQueue command_queue)
: context_(context), device_(device), command_queue_(command_queue) {}
: context_(context), device_(device), command_queue_(command_queue) {
const char *kernel_path = getenv("MACE_KERNEL_PATH");
kernel_path_ = std::string(kernel_path == nullptr ? "" : kernel_path) + "/";
}
OpenCLRuntime::~OpenCLRuntime() {}
......@@ -162,6 +166,65 @@ cl::Program &OpenCLRuntime::program() {
return program_;
}
const std::unodered_map<std::string, std::string>
OpenCLRuntime::kernel_program_map_ = {
{"BatchNorm", "batch_norm.cl"}
};
bool OpenCLRuntime::BuildProgram(const std::string &kernel_name,
const std::string &build_options,
cl::Program *program) {
MACE_CHECK_NOTNULL(program);
cl::Program::Sources sources;
std::string filename = kernel_path_ + kernel_name;
std::string kernel_source;
MACE_CHECK(ReadSourceFile(filename, &kernel_source));
sources.push_back({kernel_source.c_str(), kernel_source.length()});
*program = cl::Program(this->context(), sources);
build_options += " -Werror -cl-mad-enable -cl-fast-relaxed-math -I" + path;
// TODO(heliangliang) -cl-unsafe-math-optimizations -cl-fast-relaxed-math
cl_int ret = program->build({runtime->device()}, build_options.c_str());
if (ret != CL_SUCCESS) {
if (program->getBuildInfo<CL_PROGRAM_BUILD_STATUS>(runtime->device()) ==
CL_BUILD_ERROR) {
std::string build_log =
program->getBuildInfo<CL_PROGRAM_BUILD_LOG>(runtime->device());
LOG(INFO) << "Program build log: " << build_log;
}
LOG(FATAL) << "Build program failed: " << ret;
}
return true;
}
cl::Kernel OpenCLRuntime::BuildKernel(const std::string &kernel_name,
const std::set<std::string> &build_options) {
auto kernel_program_it = kernel_program_map_.find(kernel_name);
if (kernel_program_it == kernel_program_map_.end()) {
MACE_CHECK(false, kernel_name, " opencl kernel doesn't exist.");
}
std::string program_name = kernel_program_it->second;
std::string build_options_str;
for(auto &option : build_options) {
build_options_str += " " + option;
}
std::string built_program_key = program_name + build_options_str;
auto built_program_it = built_program_map_.find(built_program_key);
cl::Program program;
if (built_program_it != built_program_map_.end()) {
program = built_program_it->second;
} else {
this->BuildProgram(kernel_name, build_options_str, &program);
built_program_map_.emplace(built_program_key, std::move(program));
}
return cl::Kernel(kernel_name, program);
}
uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
unsigned long long size = 0;
device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size);
......
......@@ -7,6 +7,7 @@
#include <map>
#include <mutex>
#include <unordered_map>
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_wrapper.h"
......@@ -17,12 +18,15 @@ class OpenCLRuntime {
public:
static OpenCLRuntime *Get();
uint32_t GetDeviceMaxWorkGroupSize();
uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Context &context();
cl::Device &device();
cl::CommandQueue &command_queue();
cl::Program &program();
uint32_t GetDeviceMaxWorkGroupSize();
uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Kernel BuildKernel(const std::string &kernel_name,
const std::set<std::string> &build_options);
private:
OpenCLRuntime(cl::Context context,
cl::Device device,
......@@ -31,12 +35,21 @@ class OpenCLRuntime {
OpenCLRuntime(const OpenCLRuntime&) = delete;
OpenCLRuntime &operator=(const OpenCLRuntime&) = delete;
bool BuildProgram(const std::string &kernel_name,
const std::string &build_options,
cl::Program *program);
private:
cl::Context context_;
cl::Device device_;
cl::CommandQueue command_queue_;
cl::Program program_;
std::once_flag build_flag_;
std::string kernel_path_;
static const std::unordered_map<std::string,
std::string> kernel_program_map_;
mutable std::unordered_map<std::string,
cl::Program> built_program_map_;
};
} // namespace mace
......
......@@ -76,15 +76,17 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *epsilon,
Tensor *output);
template <>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output);
template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> {
void operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output);
};
} // namepsace kernels
} // namespace mace
......
......@@ -10,8 +10,8 @@
namespace mace {
namespace kernels {
template <>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
template <typename T>
void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
......@@ -27,10 +27,10 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(blocks)};
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto bm_kernel = cl::Kernel(program, "batch_norm");
std::set<std::string> built_options;
built_options.emplace("-DDataType=" + GetDataTypeFromEnum(input->dtype()));
auto bm_kernel = runtime->CreateKernel("batch_norm");
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
const std::vector<uint32_t> lws = {1, 1, kwg_size};
......
......@@ -23,6 +23,7 @@ enum DataType {
DT_INT64 = 8;
DT_UINT16 = 9;
DT_BOOL = 10;
DT_HALF = 19;
}
message TensorProto {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册