提交 addf50a5 编写于 作者: 刘琦

Merge branch 'batch_norm_opencl' into 'master'

Optimize the batch norm opencl kernel.

See merge request !78
...@@ -160,4 +160,16 @@ cl::Program &OpenCLRuntime::program() { ...@@ -160,4 +160,16 @@ cl::Program &OpenCLRuntime::program() {
return program_; return program_;
} }
uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
unsigned long long size = 0;
device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size);
return static_cast<uint32_t>(size);
}
uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) {
unsigned long long size = 0;
kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size);
return static_cast<uint32_t>(size);
}
} // namespace mace } // namespace mace
...@@ -21,6 +21,8 @@ class OpenCLRuntime { ...@@ -21,6 +21,8 @@ class OpenCLRuntime {
public: public:
static OpenCLRuntime *Get(); static OpenCLRuntime *Get();
uint32_t GetDeviceMaxWorkGroupSize();
uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Context &context(); cl::Context &context();
cl::Device &device(); cl::Device &device();
cl::CommandQueue &command_queue(); cl::CommandQueue &command_queue();
......
...@@ -136,6 +136,8 @@ class OpenCLLibraryImpl final { ...@@ -136,6 +136,8 @@ class OpenCLLibraryImpl final {
using clRetainDeviceFunc = cl_int (*)(cl_device_id); using clRetainDeviceFunc = cl_int (*)(cl_device_id);
using clReleaseDeviceFunc = cl_int (*)(cl_device_id); using clReleaseDeviceFunc = cl_int (*)(cl_device_id);
using clRetainEventFunc = cl_int (*)(cl_event); using clRetainEventFunc = cl_int (*)(cl_event);
using clGetKernelWorkGroupInfoFunc =
cl_int (*)(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *);
#define DEFINE_FUNC_PTR(func) func##Func func = nullptr #define DEFINE_FUNC_PTR(func) func##Func func = nullptr
...@@ -177,6 +179,7 @@ class OpenCLLibraryImpl final { ...@@ -177,6 +179,7 @@ class OpenCLLibraryImpl final {
DEFINE_FUNC_PTR(clRetainDevice); DEFINE_FUNC_PTR(clRetainDevice);
DEFINE_FUNC_PTR(clReleaseDevice); DEFINE_FUNC_PTR(clReleaseDevice);
DEFINE_FUNC_PTR(clRetainEvent); DEFINE_FUNC_PTR(clRetainEvent);
DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo);
#undef DEFINE_FUNC_PTR #undef DEFINE_FUNC_PTR
...@@ -296,6 +299,7 @@ void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) { ...@@ -296,6 +299,7 @@ void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) {
ASSIGN_FROM_DLSYM(clRetainDevice); ASSIGN_FROM_DLSYM(clRetainDevice);
ASSIGN_FROM_DLSYM(clReleaseDevice); ASSIGN_FROM_DLSYM(clReleaseDevice);
ASSIGN_FROM_DLSYM(clRetainEvent); ASSIGN_FROM_DLSYM(clRetainEvent);
ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo);
#undef ASSIGN_FROM_DLSYM #undef ASSIGN_FROM_DLSYM
...@@ -782,3 +786,18 @@ cl_int clRetainEvent(cl_event event) { ...@@ -782,3 +786,18 @@ cl_int clRetainEvent(cl_event event) {
return CL_OUT_OF_RESOURCES; return CL_OUT_OF_RESOURCES;
} }
} }
cl_int clGetKernelWorkGroupInfo(cl_kernel kernel,
cl_device_id device,
cl_kernel_work_group_info param_name,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLLibraryImpl::Get().clGetKernelWorkGroupInfo;
if (func != nullptr) {
return func(kernel, device, param_name, param_value_size,
param_value, param_value_size_ret);
} else {
return CL_OUT_OF_RESOURCES;
}
}
...@@ -18,27 +18,35 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -18,27 +18,35 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *var, const Tensor *var,
const Tensor *epsilon, const Tensor *epsilon,
Tensor *output) { Tensor *output) {
const index_t n = input->dim(0);
const index_t channel = input->dim(1); const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)),
const index_t sample_size = input->dim(2) * input->dim(3); static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(input->dim(2) * input->dim(3))};
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
auto _kernel = cl::Kernel(program, "batch_norm"); auto bm_kernel = cl::Kernel(program, "batch_norm");
_kernel.setArg(0, *(static_cast<const cl::Buffer *>(input->buffer())));
_kernel.setArg(1, *(static_cast<cl::Buffer *>(scale->buffer()))); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
_kernel.setArg(2, *(static_cast<cl::Buffer *>(offset->buffer()))); const uint32_t lws[3] = {1, kwg_size/128, 128};
_kernel.setArg(3, *(static_cast<cl::Buffer *>(mean->buffer())));
_kernel.setArg(4, *(static_cast<cl::Buffer *>(var->buffer()))); uint32_t idx = 0;
_kernel.setArg(5, *(static_cast<cl::Buffer *>(epsilon->buffer()))); bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
_kernel.setArg(6, static_cast<int>(sample_size)); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(scale->buffer())));
_kernel.setArg(7, *(static_cast<cl::Buffer *>(output->buffer()))); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(offset->buffer())));
_kernel.setArg(8, 32u, nullptr); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(mean->buffer())));
_kernel.setArg(9, 32u, nullptr); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(var->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(epsilon->buffer())));
bm_kernel.setArg(idx++, gws[2]);
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
_kernel, cl::NullRange, bm_kernel, cl::NullRange,
cl::NDRange(n, channel, sample_size), cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(1, 1, 128)); cl::NDRange(lws[0], lws[1], lws[2]));
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
......
...@@ -4,7 +4,7 @@ void kernel batch_norm(global const float *input, ...@@ -4,7 +4,7 @@ void kernel batch_norm(global const float *input,
global const float *mean, global const float *mean,
global const float *var, global const float *var,
global const float *epsilon, global const float *epsilon,
private const int pixels, private const uint pixels,
global float *output, global float *output,
__local float *new_scale, __local float *new_scale,
__local float *new_offset) { __local float *new_offset) {
...@@ -23,7 +23,6 @@ void kernel batch_norm(global const float *input, ...@@ -23,7 +23,6 @@ void kernel batch_norm(global const float *input,
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
const int sample_offset = (batch * channels + channel) * pixels + pixel_offset; const int sample_offset = (batch * channels + channel) * pixels + pixel_offset;
const float *input_ptr = input + sample_offset; const float *input_ptr = input + sample_offset;
float *output_ptr = output + sample_offset; float *output_ptr = output + sample_offset;
*output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel]; *output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel];
......
...@@ -17,12 +17,12 @@ class BatchNormOp : public Operator<D, T> { ...@@ -17,12 +17,12 @@ class BatchNormOp : public Operator<D, T> {
: Operator<D, T>(operator_def, ws), functor_() {} : Operator<D, T>(operator_def, ws), functor_() {}
bool Run() override { bool Run() override {
const Tensor *input = this->Input(0); const Tensor *input = this->Input(INPUT);
const Tensor *scale = this->Input(1); const Tensor *scale = this->Input(SCALE);
const Tensor *offset = this->Input(2); const Tensor *offset = this->Input(OFFSET);
const Tensor *mean = this->Input(3); const Tensor *mean = this->Input(MEAN);
const Tensor *var = this->Input(4); const Tensor *var = this->Input(VAR);
const Tensor *epsilon = this->Input(5); const Tensor *epsilon = this->Input(EPSILON);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size()); input->dim_size());
...@@ -37,7 +37,7 @@ class BatchNormOp : public Operator<D, T> { ...@@ -37,7 +37,7 @@ class BatchNormOp : public Operator<D, T> {
MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ", MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ",
epsilon->dim_size()); epsilon->dim_size());
Tensor *output = this->Output(0); Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input); output->ResizeLike(input);
functor_(input, scale, offset, mean, var, epsilon, output); functor_(input, scale, offset, mean, var, epsilon, output);
...@@ -46,6 +46,10 @@ class BatchNormOp : public Operator<D, T> { ...@@ -46,6 +46,10 @@ class BatchNormOp : public Operator<D, T> {
private: private:
kernels::BatchNormFunctor<D, T> functor_; kernels::BatchNormFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR, EPSILON);
OP_OUTPUT_TAGS(OUTPUT);
}; };
} // namespace mace } // namespace mace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册