提交 845377f2 编写于 作者: L liuqi

Optimize the batch norm opencl kernel.

上级 dff4b94c
......@@ -160,4 +160,16 @@ cl::Program &OpenCLRuntime::program() {
return program_;
}
int OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
unsigned long long size = 0;
device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size);
return static_cast<int>(size);
}
int OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) {
unsigned long long size = 0;
kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size);
return static_cast<int>(size);
}
} // namespace mace
......@@ -21,6 +21,8 @@ class OpenCLRuntime {
public:
static OpenCLRuntime *Get();
int GetDeviceMaxWorkGroupSize();
int GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Context &context();
cl::Device &device();
cl::CommandQueue &command_queue();
......
......@@ -136,6 +136,8 @@ class OpenCLLibraryImpl final {
using clRetainDeviceFunc = cl_int (*)(cl_device_id);
using clReleaseDeviceFunc = cl_int (*)(cl_device_id);
using clRetainEventFunc = cl_int (*)(cl_event);
using clGetKernelWorkGroupInfoFunc =
cl_int (*)(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *);
#define DEFINE_FUNC_PTR(func) func##Func func = nullptr
......@@ -177,6 +179,7 @@ class OpenCLLibraryImpl final {
DEFINE_FUNC_PTR(clRetainDevice);
DEFINE_FUNC_PTR(clReleaseDevice);
DEFINE_FUNC_PTR(clRetainEvent);
DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo);
#undef DEFINE_FUNC_PTR
......@@ -296,6 +299,7 @@ void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) {
ASSIGN_FROM_DLSYM(clRetainDevice);
ASSIGN_FROM_DLSYM(clReleaseDevice);
ASSIGN_FROM_DLSYM(clRetainEvent);
ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo);
#undef ASSIGN_FROM_DLSYM
......@@ -782,3 +786,18 @@ cl_int clRetainEvent(cl_event event) {
return CL_OUT_OF_RESOURCES;
}
}
cl_int clGetKernelWorkGroupInfo(cl_kernel kernel,
cl_device_id device,
cl_kernel_work_group_info param_name,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLLibraryImpl::Get().clGetKernelWorkGroupInfo;
if (func != nullptr) {
return func(kernel, device, param_name, param_value_size,
param_value, param_value_size_ret);
} else {
return CL_OUT_OF_RESOURCES;
}
}
......@@ -18,27 +18,36 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *var,
const Tensor *epsilon,
Tensor *output) {
const index_t n = input->dim(0);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)),
static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(input->dim(2) * input->dim(3))};
const uint32_t lws[3] = {1, 2, 128};
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto _kernel = cl::Kernel(program, "batch_norm");
_kernel.setArg(0, *(static_cast<const cl::Buffer *>(input->buffer())));
_kernel.setArg(1, *(static_cast<cl::Buffer *>(scale->buffer())));
_kernel.setArg(2, *(static_cast<cl::Buffer *>(offset->buffer())));
_kernel.setArg(3, *(static_cast<cl::Buffer *>(mean->buffer())));
_kernel.setArg(4, *(static_cast<cl::Buffer *>(var->buffer())));
_kernel.setArg(5, *(static_cast<cl::Buffer *>(epsilon->buffer())));
_kernel.setArg(6, static_cast<int>(sample_size));
_kernel.setArg(7, *(static_cast<cl::Buffer *>(output->buffer())));
_kernel.setArg(8, 32u, nullptr);
_kernel.setArg(9, 32u, nullptr);
auto bm_kernel = cl::Kernel(program, "batch_norm");
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(scale->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(offset->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(var->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(epsilon->buffer())));
bm_kernel.setArg(idx++, gws[2]);
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
MACE_CHECK(std::accumulate(lws, lws+3, 1, std::multiplies<uint32_t>())
< runtime->GetKernelMaxWorkGroupSize(bm_kernel));
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
_kernel, cl::NullRange,
cl::NDRange(n, channel, sample_size),
cl::NDRange(1, 1, 128));
bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]));
MACE_CHECK(error == CL_SUCCESS);
}
......
......@@ -4,7 +4,7 @@ void kernel batch_norm(global const float *input,
global const float *mean,
global const float *var,
global const float *epsilon,
private const int pixels,
private const uint pixels,
global float *output,
__local float *new_scale,
__local float *new_offset) {
......@@ -23,7 +23,6 @@ void kernel batch_norm(global const float *input,
barrier(CLK_LOCAL_MEM_FENCE);
const int sample_offset = (batch * channels + channel) * pixels + pixel_offset;
const float *input_ptr = input + sample_offset;
float *output_ptr = output + sample_offset;
*output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel];
......
......@@ -17,12 +17,12 @@ class BatchNormOp : public Operator<D, T> {
: Operator<D, T>(operator_def, ws), functor_() {}
bool Run() override {
const Tensor *input = this->Input(0);
const Tensor *scale = this->Input(1);
const Tensor *offset = this->Input(2);
const Tensor *mean = this->Input(3);
const Tensor *var = this->Input(4);
const Tensor *epsilon = this->Input(5);
const Tensor *input = this->Input(INPUT);
const Tensor *scale = this->Input(SCALE);
const Tensor *offset = this->Input(OFFSET);
const Tensor *mean = this->Input(MEAN);
const Tensor *var = this->Input(VAR);
const Tensor *epsilon = this->Input(EPSILON);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
......@@ -37,7 +37,7 @@ class BatchNormOp : public Operator<D, T> {
MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ",
epsilon->dim_size());
Tensor *output = this->Output(0);
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, scale, offset, mean, var, epsilon, output);
......@@ -46,6 +46,10 @@ class BatchNormOp : public Operator<D, T> {
private:
kernels::BatchNormFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR, EPSILON);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册