diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index bb2cea3b55dba1d32debfb6d30333f93259b2614..b85b9e48074f0d354aa5c85a12ad2a03e0a4f0a8 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -160,4 +160,16 @@ cl::Program &OpenCLRuntime::program() { return program_; } +uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() { + unsigned long long size = 0; + device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size); + return static_cast(size); +} + +uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) { + unsigned long long size = 0; + kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size); + return static_cast(size); +} + } // namespace mace diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index e18c9f93dae38ff8307adecbf321e65252d47427..e7c7b180e43ba13dfdfedec9a63b5973bdfcb55a 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -21,6 +21,8 @@ class OpenCLRuntime { public: static OpenCLRuntime *Get(); + uint32_t GetDeviceMaxWorkGroupSize(); + uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel); cl::Context &context(); cl::Device &device(); cl::CommandQueue &command_queue(); diff --git a/mace/core/runtime/opencl/opencl_wrapper.cc b/mace/core/runtime/opencl/opencl_wrapper.cc index e1d9c12284301b2ad4533287a12a3ab9ddfe4009..49f9293429f5e7e3e0b7531a4301cfd448495777 100644 --- a/mace/core/runtime/opencl/opencl_wrapper.cc +++ b/mace/core/runtime/opencl/opencl_wrapper.cc @@ -136,6 +136,8 @@ class OpenCLLibraryImpl final { using clRetainDeviceFunc = cl_int (*)(cl_device_id); using clReleaseDeviceFunc = cl_int (*)(cl_device_id); using clRetainEventFunc = cl_int (*)(cl_event); + using clGetKernelWorkGroupInfoFunc = + cl_int (*)(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *); #define DEFINE_FUNC_PTR(func) func##Func func = nullptr @@ -177,6 +179,7 @@ class OpenCLLibraryImpl final { DEFINE_FUNC_PTR(clRetainDevice); DEFINE_FUNC_PTR(clReleaseDevice); DEFINE_FUNC_PTR(clRetainEvent); + DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo); #undef DEFINE_FUNC_PTR @@ -296,6 +299,7 @@ void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) { ASSIGN_FROM_DLSYM(clRetainDevice); ASSIGN_FROM_DLSYM(clReleaseDevice); ASSIGN_FROM_DLSYM(clRetainEvent); + ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo); #undef ASSIGN_FROM_DLSYM @@ -782,3 +786,18 @@ cl_int clRetainEvent(cl_event event) { return CL_OUT_OF_RESOURCES; } } + +cl_int clGetKernelWorkGroupInfo(cl_kernel kernel, + cl_device_id device, + cl_kernel_work_group_info param_name, + size_t param_value_size, + void *param_value, + size_t *param_value_size_ret) { + auto func = mace::OpenCLLibraryImpl::Get().clGetKernelWorkGroupInfo; + if (func != nullptr) { + return func(kernel, device, param_name, param_value_size, + param_value, param_value_size_ret); + } else { + return CL_OUT_OF_RESOURCES; + } +} diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 67b810f5149ea51e2456071415857bca467abbf0..4d0771f15819e71496100e441799801f35b191f2 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -18,27 +18,35 @@ void BatchNormFunctor::operator()( const Tensor *var, const Tensor *epsilon, Tensor *output) { - const index_t n = input->dim(0); - const index_t channel = input->dim(1); - const index_t sample_size = input->dim(2) * input->dim(3); + + const uint32_t gws[3] = {static_cast(input->dim(0)), + static_cast(input->dim(1)), + static_cast(input->dim(2) * input->dim(3))}; + auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); - auto _kernel = cl::Kernel(program, "batch_norm"); - _kernel.setArg(0, *(static_cast(input->buffer()))); - _kernel.setArg(1, *(static_cast(scale->buffer()))); - _kernel.setArg(2, *(static_cast(offset->buffer()))); - _kernel.setArg(3, *(static_cast(mean->buffer()))); - _kernel.setArg(4, *(static_cast(var->buffer()))); - _kernel.setArg(5, *(static_cast(epsilon->buffer()))); - _kernel.setArg(6, static_cast(sample_size)); - _kernel.setArg(7, *(static_cast(output->buffer()))); - _kernel.setArg(8, 32u, nullptr); - _kernel.setArg(9, 32u, nullptr); + auto bm_kernel = cl::Kernel(program, "batch_norm"); + + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); + const uint32_t lws[3] = {1, kwg_size/128, 128}; + + uint32_t idx = 0; + bm_kernel.setArg(idx++, *(static_cast(input->buffer()))); + bm_kernel.setArg(idx++, *(static_cast(scale->buffer()))); + bm_kernel.setArg(idx++, *(static_cast(offset->buffer()))); + bm_kernel.setArg(idx++, *(static_cast(mean->buffer()))); + bm_kernel.setArg(idx++, *(static_cast(var->buffer()))); + bm_kernel.setArg(idx++, *(static_cast(epsilon->buffer()))); + bm_kernel.setArg(idx++, gws[2]); + bm_kernel.setArg(idx++, *(static_cast(output->buffer()))); + bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); + bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); + cl_int error = runtime->command_queue().enqueueNDRangeKernel( - _kernel, cl::NullRange, - cl::NDRange(n, channel, sample_size), - cl::NDRange(1, 1, 128)); + bm_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2])); MACE_CHECK(error == CL_SUCCESS); } diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index d5927071f222b6ff0c0cdb7dd32e2b15979983f8..e86c62f336dd62b2aab0d5226f084756eff8350d 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -4,7 +4,7 @@ void kernel batch_norm(global const float *input, global const float *mean, global const float *var, global const float *epsilon, - private const int pixels, + private const uint pixels, global float *output, __local float *new_scale, __local float *new_offset) { @@ -23,7 +23,6 @@ void kernel batch_norm(global const float *input, barrier(CLK_LOCAL_MEM_FENCE); const int sample_offset = (batch * channels + channel) * pixels + pixel_offset; - const float *input_ptr = input + sample_offset; float *output_ptr = output + sample_offset; *output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel]; diff --git a/mace/ops/batch_norm.h b/mace/ops/batch_norm.h index 1452bc726c4807b079bb6dcb24b4319275059bc1..0c5909546f88a0e149a4b0628990c417b5b22630 100644 --- a/mace/ops/batch_norm.h +++ b/mace/ops/batch_norm.h @@ -17,12 +17,12 @@ class BatchNormOp : public Operator { : Operator(operator_def, ws), functor_() {} bool Run() override { - const Tensor *input = this->Input(0); - const Tensor *scale = this->Input(1); - const Tensor *offset = this->Input(2); - const Tensor *mean = this->Input(3); - const Tensor *var = this->Input(4); - const Tensor *epsilon = this->Input(5); + const Tensor *input = this->Input(INPUT); + const Tensor *scale = this->Input(SCALE); + const Tensor *offset = this->Input(OFFSET); + const Tensor *mean = this->Input(MEAN); + const Tensor *var = this->Input(VAR); + const Tensor *epsilon = this->Input(EPSILON); MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size()); @@ -37,7 +37,7 @@ class BatchNormOp : public Operator { MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ", epsilon->dim_size()); - Tensor *output = this->Output(0); + Tensor *output = this->Output(OUTPUT); output->ResizeLike(input); functor_(input, scale, offset, mean, var, epsilon, output); @@ -46,6 +46,10 @@ class BatchNormOp : public Operator { private: kernels::BatchNormFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR, EPSILON); + OP_OUTPUT_TAGS(OUTPUT); }; } // namespace mace