提交 2fef9af5 编写于 作者: L liuqi

Change the calculation of local work group size.

上级 845377f2
......@@ -160,16 +160,16 @@ cl::Program &OpenCLRuntime::program() {
return program_;
}
int OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
unsigned long long size = 0;
device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size);
return static_cast<int>(size);
return static_cast<uint32_t>(size);
}
int OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) {
uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) {
unsigned long long size = 0;
kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size);
return static_cast<int>(size);
return static_cast<uint32_t>(size);
}
} // namespace mace
......@@ -21,8 +21,8 @@ class OpenCLRuntime {
public:
static OpenCLRuntime *Get();
int GetDeviceMaxWorkGroupSize();
int GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
uint32_t GetDeviceMaxWorkGroupSize();
uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel);
cl::Context &context();
cl::Device &device();
cl::CommandQueue &command_queue();
......
......@@ -22,13 +22,15 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)),
static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(input->dim(2) * input->dim(3))};
const uint32_t lws[3] = {1, 2, 128};
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto bm_kernel = cl::Kernel(program, "batch_norm");
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
const uint32_t lws[3] = {1, kwg_size/128, 128};
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(scale->buffer())));
......@@ -41,9 +43,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
MACE_CHECK(std::accumulate(lws, lws+3, 1, std::multiplies<uint32_t>())
< runtime->GetKernelMaxWorkGroupSize(bm_kernel));
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册