提交 201983c6 编写于 作者: L liuqi

Change the some opencl runtime api return type to int64_t

上级 46c7414d
...@@ -323,21 +323,21 @@ void OpenCLRuntime::GetCallStats(const cl::Event &event, CallStats *stats) { ...@@ -323,21 +323,21 @@ void OpenCLRuntime::GetCallStats(const cl::Event &event, CallStats *stats) {
} }
} }
uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() { uint64_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() {
uint64_t size = 0; uint64_t size = 0;
device_->getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size); device_->getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size);
return static_cast<uint32_t>(size); return size;
} }
uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel &kernel) { uint64_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel &kernel) {
uint64_t size = 0; uint64_t size = 0;
kernel.getWorkGroupInfo(*device_, CL_KERNEL_WORK_GROUP_SIZE, &size); kernel.getWorkGroupInfo(*device_, CL_KERNEL_WORK_GROUP_SIZE, &size);
return static_cast<uint32_t>(size); return size;
} }
// TODO(liuqi): not compatible with mali gpu. // TODO(liuqi): not compatible with mali gpu.
uint32_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) { uint64_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) {
uint32_t size = 0; uint64_t size = 0;
kernel.getWorkGroupInfo(*device_, CL_KERNEL_WAVE_SIZE_QCOM, &size); kernel.getWorkGroupInfo(*device_, CL_KERNEL_WAVE_SIZE_QCOM, &size);
return size; return size;
} }
......
...@@ -46,9 +46,9 @@ class OpenCLRuntime { ...@@ -46,9 +46,9 @@ class OpenCLRuntime {
cl::CommandQueue &command_queue(); cl::CommandQueue &command_queue();
void GetCallStats(const cl::Event &event, CallStats *stats); void GetCallStats(const cl::Event &event, CallStats *stats);
uint32_t GetDeviceMaxWorkGroupSize(); uint64_t GetDeviceMaxWorkGroupSize();
uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel); uint64_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel);
uint32_t GetKernelWaveSize(const cl::Kernel &kernel); uint64_t GetKernelWaveSize(const cl::Kernel &kernel);
cl::Kernel BuildKernel(const std::string &program_name, cl::Kernel BuildKernel(const std::string &program_name,
const std::string &kernel_name, const std::string &kernel_name,
const std::set<std::string> &build_options); const std::set<std::string> &build_options);
......
...@@ -62,11 +62,13 @@ void FCWXKernel(cl::Kernel *kernel, ...@@ -62,11 +62,13 @@ void FCWXKernel(cl::Kernel *kernel,
const index_t batch = output->dim(0); const index_t batch = output->dim(0);
const index_t output_size = output->dim(3); const index_t output_size = output->dim(3);
const index_t output_blocks = RoundUpDiv4(output_size); const index_t output_blocks = RoundUpDiv4(output_size);
const uint32_t wave_size = runtime->GetKernelWaveSize(*kernel); const uint32_t wave_size =
static_cast<uint32_t>(runtime->GetKernelWaveSize(*kernel));
*gws = {4, (wave_size / 4), static_cast<uint32_t>(batch * output_blocks)}; *gws = {4, (wave_size / 4), static_cast<uint32_t>(batch * output_blocks)};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(*kernel); const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(*kernel));
const uint32_t inter_local_blks = kwg_size / ((*gws)[0] * (*gws)[1]); const uint32_t inter_local_blks = kwg_size / ((*gws)[0] * (*gws)[1]);
*lws = {(*gws)[0], (*gws)[1], inter_local_blks}; *lws = {(*gws)[0], (*gws)[1], inter_local_blks};
} }
......
...@@ -201,7 +201,8 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel, ...@@ -201,7 +201,8 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel,
StatsFuture *future) { StatsFuture *future) {
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> { auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel); const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel));
std::vector<uint32_t> local_ws(3, 0); std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size); local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]); local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
...@@ -304,7 +305,8 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel, ...@@ -304,7 +305,8 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
StatsFuture *future) { StatsFuture *future) {
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> { auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel); const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel));
uint32_t local_ws[2]; uint32_t local_ws[2];
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size); local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]); local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册