提交 2fef9af5 编写于 作者: L liuqi

Change the calculation of local work group size.

上级 845377f2
...@@ -160,16 +160,16 @@ cl::Program &OpenCLRuntime::program() { ...@@ -160,16 +160,16 @@ cl::Program &OpenCLRuntime::program() {
return program_; return program_;
} }
int OpenCLRuntime::GetDeviceMaxWorkGroupSize() { uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
unsigned long long size = 0; unsigned long long size = 0;
device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size); device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size);
return static_cast<int>(size); return static_cast<uint32_t>(size);
} }
int OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) { uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) {
unsigned long long size = 0; unsigned long long size = 0;
kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size); kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size);
return static_cast<int>(size); return static_cast<uint32_t>(size);
} }
} // namespace mace } // namespace mace
...@@ -21,8 +21,8 @@ class OpenCLRuntime { ...@@ -21,8 +21,8 @@ class OpenCLRuntime {
public: public:
static OpenCLRuntime *Get(); static OpenCLRuntime *Get();
int GetDeviceMaxWorkGroupSize(); uint32_t GetDeviceMaxWorkGroupSize();
int GetKernelMaxWorkGroupSize(const cl::Kernel& kernel); uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Context &context(); cl::Context &context();
cl::Device &device(); cl::Device &device();
cl::CommandQueue &command_queue(); cl::CommandQueue &command_queue();
......
...@@ -22,13 +22,15 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -22,13 +22,15 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)), const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)),
static_cast<uint32_t>(input->dim(1)), static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(input->dim(2) * input->dim(3))}; static_cast<uint32_t>(input->dim(2) * input->dim(3))};
const uint32_t lws[3] = {1, 2, 128};
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
auto bm_kernel = cl::Kernel(program, "batch_norm"); auto bm_kernel = cl::Kernel(program, "batch_norm");
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
const uint32_t lws[3] = {1, kwg_size/128, 128};
uint32_t idx = 0; uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer()))); bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(scale->buffer()))); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(scale->buffer())));
...@@ -41,9 +43,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -41,9 +43,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
MACE_CHECK(std::accumulate(lws, lws+3, 1, std::multiplies<uint32_t>())
< runtime->GetKernelMaxWorkGroupSize(bm_kernel));
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange, bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册