提交 aa866acb 编写于 作者: Y yejianwu

no roundup for qualcomm opencl2.0

上级 0e078d19
......@@ -142,7 +142,6 @@ OpenCLRuntime::OpenCLRuntime(GPUPerfHint gpu_perf_hint,
}
bool gpu_detected = false;
bool is_adreno_gpu = false;
device_ = std::make_shared<cl::Device>();
for (auto device : all_devices) {
if (device.getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_GPU) {
......@@ -150,10 +149,18 @@ OpenCLRuntime::OpenCLRuntime(GPUPerfHint gpu_perf_hint,
gpu_detected = true;
const std::string device_name = device.getInfo<CL_DEVICE_NAME>();
constexpr const char *kQualcommAdrenoGPUStr = "QUALCOMM Adreno(TM)";
constexpr const char *kMaliGPUStr = "Mali";
if (device_name == kQualcommAdrenoGPUStr) {
is_adreno_gpu = true;
gpu_type_ = GPU_TYPE::QUALCOMM_ADRENO;
} else if (device_name.find(kMaliGPUStr) != std::string::npos) {
gpu_type_ = GPU_TYPE::MALI;
} else {
gpu_type_ = GPU_TYPE::UNKNOWN;
}
const std::string device_version = device.getInfo<CL_DEVICE_VERSION>();
opencl_version_ = device_version.substr(7, 3);
VLOG(1) << "Using device: " << device_name;
break;
}
......@@ -171,7 +178,7 @@ OpenCLRuntime::OpenCLRuntime(GPUPerfHint gpu_perf_hint,
}
cl_int err;
if (is_adreno_gpu) {
if (gpu_type_ == GPU_TYPE::QUALCOMM_ADRENO) {
std::vector<cl_context_properties> context_properties;
context_properties.reserve(5);
GetAdrenoContextProperties(&context_properties, gpu_perf_hint,
......@@ -350,4 +357,12 @@ uint64_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) {
return size;
}
const GPU_TYPE OpenCLRuntime::GetGPUType() const {
return gpu_type_;
}
const std::string &OpenCLRuntime::GetOpenclVersion() {
return opencl_version_;
}
} // namespace mace
......@@ -18,6 +18,12 @@
namespace mace {
enum GPU_TYPE {
QUALCOMM_ADRENO,
MALI,
UNKNOWN,
};
class OpenCLProfilingTimer : public Timer {
public:
explicit OpenCLProfilingTimer(const cl::Event *event)
......@@ -49,6 +55,8 @@ class OpenCLRuntime {
uint64_t GetDeviceMaxWorkGroupSize();
uint64_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel);
uint64_t GetKernelWaveSize(const cl::Kernel &kernel);
const GPU_TYPE GetGPUType() const;
const std::string &GetOpenclVersion();
cl::Kernel BuildKernel(const std::string &program_name,
const std::string &kernel_name,
const std::set<std::string> &build_options);
......@@ -74,6 +82,8 @@ class OpenCLRuntime {
std::map<std::string, cl::Program> built_program_map_;
std::mutex program_build_mutex_;
std::string kernel_path_;
GPU_TYPE gpu_type_;
std::string opencl_version_;
static GPUPerfHint gpu_perf_hint_;
static GPUPriorityHint gpu_priority_hint_;
......
......@@ -26,14 +26,18 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation");
built_options.emplace("-Dactivation=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
switch (activation_) {
case RELU:
tuning_key_prefix_ = "relu_opencl_kernel_";
......
......@@ -26,6 +26,8 @@ void AddNFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
for (int i = 1; i < size; ++i) {
MACE_CHECK_NOTNULL(input_tensors[i]);
MACE_CHECK(batch == input_tensors[i]->dim(0));
......@@ -45,6 +47,10 @@ void AddNFunctor<DeviceType::OPENCL, T>::operator()(
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DINPUT_NUM=", input_tensors.size()));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("addn", kernel_name, built_options);
}
......
......@@ -36,6 +36,8 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -43,6 +45,9 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-Dbatch_norm=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (folded_constant_) {
built_options.emplace("-DFOLDED_CONSTANT");
}
......
......@@ -28,6 +28,9 @@ void BiasAddFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
static_cast<uint32_t>(height * batch)};
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -35,6 +38,9 @@ void BiasAddFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-Dbias_add=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("bias_add", kernel_name, built_options);
}
if (!IsVecEqual(input_shape_, input->shape())) {
......@@ -52,15 +58,22 @@ void BiasAddFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
const std::vector<uint32_t> lws = {8, kwg_size / 64, 8};
cl::Event event;
cl_int error;
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
} else {
std::vector<uint32_t> roundup_gws(lws.size());
for (size_t i = 0; i < lws.size(); ++i) {
roundup_gws[i] = RoundUp(gws[i], lws[i]);
}
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
error = runtime->command_queue().enqueueNDRangeKernel(
kernel_, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
......
......@@ -59,11 +59,19 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(
: "winograd_filter_buffer_to_image";
break;
}
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
std::set<std::string> built_options;
std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (buffer->dtype() == image->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" +
......@@ -74,7 +82,6 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(
built_options.emplace("-DCMD_DATA_TYPE=" +
DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
}
auto runtime = OpenCLRuntime::Global();
auto b2f_kernel = runtime->BuildKernel("buffer_to_image",
obfuscated_kernel_name, built_options);
......@@ -105,17 +112,24 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(
const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(b2f_kernel));
const std::vector<uint32_t> lws = {16, kwg_size / 16};
cl::Event event;
cl_int error;
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
b2f_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1]),
cl::NDRange(lws[0], lws[1]), nullptr, &event);
} else {
std::vector<uint32_t> roundup_gws(lws.size());
for (size_t i = 0; i < lws.size(); ++i) {
roundup_gws[i] = RoundUp(gws[i], lws[i]);
}
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
error = runtime->command_queue().enqueueNDRangeKernel(
b2f_kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1]),
cl::NDRange(lws[0], lws[1]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
......
......@@ -36,6 +36,8 @@ void ChannelShuffleFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("channel_shuffle");
......@@ -43,6 +45,9 @@ void ChannelShuffleFunctor<DeviceType::OPENCL, T>::operator()(
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("channel_shuffle", kernel_name,
built_options);
}
......
......@@ -5,19 +5,28 @@ __kernel void activation(__read_only image2d_t input,
__read_only image2d_t alpha,
#endif
__private const float relux_max_limit,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (ch_blk >= global_size_dim0 || w >= global_size_dim1
|| hb >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
......
......@@ -8,12 +8,20 @@ __kernel void addn(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */
#if INPUT_NUM > 3
__read_only image2d_t input3,
#endif
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
const int w = get_global_id(0);
const int hb = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || hb >= global_size_dim1) return;
#endif
DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb));
......
......@@ -9,19 +9,28 @@ __kernel void batch_norm(__read_only image2d_t input,
__private const float epsilon,
#endif
__write_only image2d_t output,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const float relux_max_limit,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const float relux_max_limit) {
#endif
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (ch_blk >= global_size_dim0 || w >= global_size_dim1
|| hb >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
#ifdef FOLDED_CONSTANT
DATA_TYPE4 bn_scale = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0));
......
......@@ -2,19 +2,27 @@
// Supported data types: half/float
__kernel void bias_add(__read_only image2d_t input,
__read_only image2d_t bias,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (ch_blk >= global_size_dim0 || w >= global_size_dim1
|| hb >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
......
......@@ -5,14 +5,22 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, o
__private const int filter_w,
__private const int out_channel,
__private const int in_channel,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int out_channel_idx = h * 4;
const int rounded_in_channel = ((in_channel + 3) / 4) * 4;
......@@ -51,14 +59,22 @@ __kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic
__private const int filter_w,
__private const int out_channel,
__private const int in_channel,
#ifndef USE_QUALCOMM_OPENCL_2_0
__read_only image2d_t input,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__read_only image2d_t input) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int out_channel_idx = h * 4;
const int rounded_in_channel = ((in_channel + 3) / 4) * 4;
......@@ -96,14 +112,22 @@ __kernel void dw_filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w
__private const int filter_w,
__private const int in_channel,
__private const int multiplier,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) { /* ic%4 * kh * kw * m, ic/4 */
#else
__write_only image2d_t output) {
#endif
const int w = get_global_id(0);
const int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
DATA_TYPE4 values = 0;
if (multiplier == 1) {
......@@ -151,14 +175,22 @@ __kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */
__private const int height,
__private const int width,
__private const int channels,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int batch_idx = h / height;
const int height_idx = h % height;
......@@ -189,14 +221,22 @@ __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */
__private const int height,
__private const int width,
__private const int channels,
#ifndef USE_QUALCOMM_OPENCL_2_0
__read_only image2d_t input,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__read_only image2d_t input) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int batch_idx = h / height;
const int height_idx = h % height;
......@@ -225,14 +265,22 @@ __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */
__kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */
__private const int input_offset,
__private const int count,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int offset = input_offset + w * 4;
const int size = count - w * 4;
......@@ -257,14 +305,23 @@ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */
__kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */
__private const int count,
#ifndef USE_QUALCOMM_OPENCL_2_0
__read_only image2d_t input,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__read_only image2d_t input) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int offset = w * 4;
int2 coord = (int2)(w, h);
......@@ -290,14 +347,22 @@ __kernel void in_out_height_buffer_to_image(__global const DATA_TYPE *input, //n
__private const int height,
__private const int width,
__private const int channels,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int wc = width * channels;
const int height_blks = (height + 3) / 4;
......@@ -329,14 +394,22 @@ __kernel void in_out_height_image_to_buffer(__global DATA_TYPE *output, //nhwc
__private const int height,
__private const int width,
__private const int channels,
#ifndef USE_QUALCOMM_OPENCL_2_0
__read_only image2d_t input,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__read_only image2d_t input) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int height_blks = (height + 3) / 4;
const int batch_idx = h / height_blks;
......@@ -366,14 +439,22 @@ __kernel void in_out_width_buffer_to_image(__global const DATA_TYPE *input, /* n
__private const int height,
__private const int width,
__private const int channels,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int width_blks = (width + 3) / 4;
const int batch_idx = h / height;
......@@ -406,16 +487,26 @@ __kernel void winograd_filter_buffer_to_image(__global const DATA_TYPE *input, /
__private const int in_channels,
__private const int height,
__private const int width,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
const int out_channels = global_size_dim1;
#else
const int out_channels = get_global_size(1);
#endif
const int out_channel_idx = h;
const int in_channel_idx = w << 2;
const int offset = input_offset + (out_channel_idx * in_channels + in_channel_idx) * height * width;
......@@ -492,14 +583,22 @@ __kernel void winograd_filter_image_to_buffer(__global DATA_TYPE *output, //Oc,
__private const int height,
__private const int width,
__private const int channel,
#ifndef USE_QUALCOMM_OPENCL_2_0
__read_only image2d_t input,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__read_only image2d_t input) {
#endif
const int w = get_global_id(0);
const int h = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
const int width_idx = w << 2;
const int size = width - width_idx;
......
......@@ -4,19 +4,29 @@
__kernel void channel_shuffle(__read_only image2d_t input,
__private const int groups,
__private const int channels_per_group,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int group_chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (group_chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
const int group_blks = groups / 4;
const int groups_blks_width = group_blks * width;
const int channels_per_group_blks = channels_per_group / 4;
......
......@@ -25,19 +25,29 @@ DATA_TYPE4 stitch_vector(DATA_TYPE4 left,
__kernel void concat_channel(__read_only image2d_t input0,
__read_only image2d_t input1,
__private const int input0_chan,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
const int input0_chan_blk = (input0_chan + 3) >> 2;
DATA_TYPE4 data = 0;
......@@ -82,19 +92,29 @@ __kernel void concat_channel(__read_only image2d_t input0,
// Required: All input channels are divisible by 4
__kernel void concat_channel_multi(__read_only image2d_t input,
__private const int chan_blk_offset,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
DATA_TYPE4 data = 0;
data = READ_IMAGET(input,
SAMPLER,
......
......@@ -18,20 +18,29 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__private const int padding_top,
__private const int padding_left,
__private const int dilation_h,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int dilation_w,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const int dilation_w) {
#endif
const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1);
const int out_hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1
|| out_hb >= global_size_dim2) {
return;
}
const int out_w_blks = global_size_dim1;
#else
const int out_w_blks = get_global_size(1);
#endif
const int rounded_in_ch = in_ch_blks << 2;
#ifdef BIAS
......
......@@ -12,20 +12,28 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__private const int in_ch_blks,
__private const int height,
__private const int width,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int stride,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const int stride) {
#endif
const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1);
const int out_hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1
|| out_hb >= global_size_dim2) {
return;
}
const int out_w_blks = global_size_dim1;
#else
const int out_w_blks = get_global_size(1);
#endif
#ifdef BIAS
DATA_TYPE4 out0 = READ_IMAGET(bias, SAMPLER, (int2)(out_ch_blk, 0));
......
......@@ -16,20 +16,29 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__private const int padding_top,
__private const int padding_left,
__private const int dilation_h,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int dilation_w,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const int dilation_w) {
#endif
const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1);
const int out_hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1
|| out_hb >= global_size_dim2) {
return;
}
const int out_w_blks = global_size_dim1;
#else
const int out_w_blks = get_global_size(1);
#endif
const int rounded_in_ch = in_ch_blks << 2;
#ifdef BIAS
......
......@@ -18,19 +18,29 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
__private const short padding_top,
__private const short padding_left,
__private const short dilation_h,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const short dilation_w,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const short dilation_w) {
#endif
const short out_ch_blk = get_global_id(0);
const short out_w_blk = get_global_id(1);
const short out_hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1
|| out_hb >= global_size_dim2) {
return;
}
const short out_w_blks = global_size_dim1;
#else
const short out_w_blks = get_global_size(1);
#endif
const short rounded_in_ch = in_ch_blks << 2;
const short in_ch_blk = out_ch_blk; // multiplier = 1
......@@ -149,17 +159,25 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4
__private const short filter_height,
__private const short filter_width,
__private const short padding_top,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const short padding_left,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const short padding_left) {
#endif
const short out_ch_blk = get_global_id(0);
const short out_w_blk = get_global_id(1) << 2;
const short out_hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (out_ch_blk >= global_size_dim0 || get_global_id(1) >= global_size_dim1
|| out_hb >= global_size_dim2) {
return;
}
#endif
const short rounded_in_ch = in_ch_blks << 2;
const short in_ch_blk = out_ch_blk; // multiplier = 1
......
......@@ -6,12 +6,20 @@ __kernel void eltwise(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */
__private const float coeff0,
__private const float coeff1,
#endif
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__write_only image2d_t output) {
#endif
const int w = get_global_id(0);
const int hb = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (w >= global_size_dim0 || hb >= global_size_dim1) return;
#endif
DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb));
......
......@@ -10,14 +10,22 @@ __kernel void fully_connected(__read_only image2d_t input,
__private const int input_height,
__private const int input_width,
__private const int input_channel,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const float relux_max_limit,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__private const float relux_max_limit) {
#endif
const int batch_idx = get_global_id(0);
const int out_blk_idx = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (batch_idx >= global_size_dim0 || out_blk_idx >= global_size_dim1) {
return;
}
#endif
const int input_chan_blk = (input_channel + 3) >> 2;
......@@ -74,19 +82,28 @@ __kernel void fully_connected_width(__read_only image2d_t input,
__private const int input_width,
__private const int in_chan_blks,
__private const int out_blks,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const float relux_max_limit,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const float relux_max_limit) {
#endif
const int inter_out_idx = get_global_id(0);
const int width_blk_idx = get_global_id(1);
const int batch_out_blk_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (inter_out_idx >= global_size_dim0 || width_blk_idx >= global_size_dim1
|| batch_out_blk_idx >= global_size_dim2) {
return;
}
const int width_blk_count = global_size_dim1;
#else
const int width_blk_count = get_global_size(1);
#endif
const int batch_idx = batch_out_blk_idx / out_blks;
const int out_blk_idx = batch_out_blk_idx % out_blks;
......
......@@ -8,12 +8,20 @@ __kernel void matmul(__read_only image2d_t A,
__private const int N,
__private const int K,
__private const int height_blocks,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int k_blocks,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__private const int k_blocks) {
#endif
const int gx = get_global_id(0) << 2;
const int hb = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (get_global_id(0) >= global_size_dim0 || hb >= global_size_dim1) return;
#endif
const int batch = hb / height_blocks;
const int ty = (hb % height_blocks);
......
......@@ -27,19 +27,29 @@ __kernel void pooling(__read_only image2d_t input,
__private const int pad_left,
__private const int stride,
__private const int pooling_size,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int out_chan_idx = get_global_id(0);
const int out_width_idx = get_global_id(1);
const int out_hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (out_chan_idx >= global_size_dim0 || out_width_idx >= global_size_dim1
|| out_hb_idx >= global_size_dim2) {
return;
}
const int out_width = global_size_dim1;
#else
const int out_width = get_global_size(1);
#endif
const int batch_idx = mul24((out_hb_idx / out_height), in_height);
const int in_height_start = mul24((out_hb_idx % out_height), stride) - pad_top;
const int in_width_start = mul24(out_width_idx, stride) - pad_left;
......
......@@ -6,19 +6,30 @@ __kernel void resize_bilinear_nocache(__read_only image2d_t input, /* [c%4 * w *
__private const float width_scale,
__private const int in_height,
__private const int in_width,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int out_height,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const int out_height) {
#endif
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (ch_blk >= global_size_dim0 || w >= global_size_dim1
|| hb >= global_size_dim2) {
return;
}
const int ch_blks = global_size_dim0;
const int out_width = global_size_dim1;
#else
const int ch_blks = get_global_size(0);
const int out_width = get_global_size(1);
#endif
const int b = hb / out_height;
const int h = hb % out_height;
......
......@@ -2,19 +2,28 @@
__kernel void slice(__read_only image2d_t input,
__private const int chan_blk_offset,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
DATA_TYPE4 data = READ_IMAGET(input, SAMPLER,
(int2)(mad24(chan_blk_idx + chan_blk_offset,
......
......@@ -3,20 +3,30 @@
__kernel void softmax(__read_only image2d_t input,
__private const int channels,
__private const int remain_channels,
#ifndef USE_QUALCOMM_OPENCL_2_0
__write_only image2d_t output,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__write_only image2d_t output) {
#endif
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
const int chan_blks = global_size_dim0 - 1;
const int width = global_size_dim1;
#else
const int chan_blks = get_global_size(0) - 1;
const int width = get_global_size(1);
#endif
int pos = width_idx;
DATA_TYPE max_value = -FLT_MAX;
......
......@@ -9,17 +9,25 @@ __kernel void space_to_batch(__read_only image2d_t space_data,
__private const int space_height,
__private const int space_width,
__private const int batch_height,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int batch_width,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const int batch_width) {
#endif
const int chan_idx = get_global_id(0);
const int batch_w_idx = get_global_id(1);
const int batch_hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (chan_idx >= global_size_dim0 || batch_w_idx >= global_size_dim1
|| batch_hb_idx >= global_size_dim2) {
return;
}
#endif
const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height;
......@@ -55,17 +63,25 @@ __kernel void batch_to_space(__read_only image2d_t batch_data,
__private const int space_height,
__private const int space_width,
__private const int batch_height,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int batch_width,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const int batch_width) {
#endif
const int chan_idx = get_global_id(0);
const int batch_w_idx = get_global_id(1);
const int batch_hb_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (chan_idx >= global_size_dim0 || batch_w_idx >= global_size_dim1
|| batch_hb_idx >= global_size_dim2) {
return;
}
#endif
const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height;
......
......@@ -8,16 +8,25 @@ __kernel void winograd_transform_2x2(__read_only image2d_t input,
__private const int round_hw,
__private const int round_w,
__private const int padding_top,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const int padding_left,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__private const int padding_left) {
#endif
int out_width_idx = get_global_id(0);
int chan_blk_idx = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (out_width_idx >= global_size_dim0 || chan_blk_idx >= global_size_dim1) {
return;
}
const int chan_blk_size = global_size_dim1;
#else
const int chan_blk_size = get_global_size(1);
#endif
const int batch_idx = out_width_idx / round_hw;
const int t_idx = out_width_idx % round_hw;
......@@ -121,16 +130,26 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
__private const int out_width,
__private const int round_hw,
__private const int round_w,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const float relux_max_limit,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__private const float relux_max_limit) {
#endif
const int width_idx = get_global_id(0);
const int height_idx = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (width_idx >= global_size_dim0 || height_idx >= global_size_dim1) {
return;
}
const int out_channel = global_size_dim1;
#else
const int out_channel = get_global_size(1);
#endif
int width = width_idx;
int height = height_idx;
......
......@@ -31,10 +31,15 @@ static void Concat2(cl::Kernel *kernel,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("concat_channel");
built_options.emplace("-Dconcat_channel=" + kernel_name);
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (input0->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
......@@ -83,12 +88,18 @@ static void ConcatN(cl::Kernel *kernel,
const index_t channel = output->dim(3);
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("concat_channel_multi");
built_options.emplace("-Dconcat_channel_multi=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
*kernel = runtime->BuildKernel("concat", kernel_name, built_options);
}
......
......@@ -37,6 +37,9 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
MACE_CHECK(input_batch == batch);
......@@ -45,6 +48,9 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
built_options.emplace("-Dconv_2d_1x1=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
......
......@@ -37,12 +37,17 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d_3x3");
built_options.emplace("-Dconv_2d_3x3=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation) {
case NOOP:
......
......@@ -37,12 +37,17 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d");
built_options.emplace("-Dconv_2d=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation) {
case NOOP:
......
......@@ -42,6 +42,8 @@ void DepthwiseConv2d(cl::Kernel *kernel,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depthwise_conv2d");
......@@ -51,6 +53,9 @@ void DepthwiseConv2d(cl::Kernel *kernel,
} else {
built_options.emplace("-Ddepthwise_conv2d=" + kernel_name);
}
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
......
......@@ -29,6 +29,8 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -37,6 +39,9 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DELTWISE_TYPE=", type_));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (!coeff_.empty()) built_options.emplace("-DCOEFF_SUM");
kernel_ = runtime->BuildKernel("eltwise", kernel_name, built_options);
}
......
......@@ -24,8 +24,11 @@ void FCWXKernel(cl::Kernel *kernel,
<< "FC width kernel only support input with 4x channel.";
MACE_CHECK_NOTNULL(gws);
MACE_CHECK_NOTNULL(lws);
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -34,6 +37,9 @@ void FCWXKernel(cl::Kernel *kernel,
built_options.emplace("-Dfully_connected_width=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
......@@ -133,14 +139,21 @@ void FCWTXKernel(cl::Kernel *kernel,
StatsFuture *future) {
MACE_CHECK_NOTNULL(gws);
MACE_CHECK_NOTNULL(lws);
if (kernel->get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected");
built_options.emplace("-Dfully_connected=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
......
......@@ -194,12 +194,25 @@ std::string DtToUpstreamCLCMDDt(const DataType dt) {
}
}
const bool IsQualcommOpenCL200() {
auto runtime = OpenCLRuntime::Global();
if (runtime->GetGPUType() == GPU_TYPE::QUALCOMM_ADRENO &&
runtime->GetOpenclVersion() == "2.0") {
return true;
} else {
return false;
}
}
void TuningOrRun3DKernel(const cl::Kernel &kernel,
const std::string tuning_key,
const uint32_t *gws,
const std::vector<uint32_t> &lws,
StatsFuture *future) {
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel));
......@@ -236,9 +249,11 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel,
<< "Tuning parameters of 3D kernel must be 4D";
cl_int error = CL_SUCCESS;
std::vector<uint32_t> roundup_gws(3);
if(!is_qualcomm_opencl200) {
for (size_t i = 0; i < 3; ++i) {
roundup_gws[i] = RoundUp(gws[i], params[i]);
}
}
if (timer == nullptr) {
uint32_t num_blocks = params[3];
......@@ -247,18 +262,31 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel,
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws2 =
(i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size;
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(gws[0], gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
} else {
uint32_t roundup_gws2 = RoundUp(gws2, params[2]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
}
} else {
timer->ClearTiming();
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
} else {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
tuning_result->assign(params.begin(), params.end());
......@@ -274,11 +302,18 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel,
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws2 =
(i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size;
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(gws[0], gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
} else {
uint32_t roundup_gws2 = RoundUp(gws2, params[2]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
}
......@@ -306,6 +341,8 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
const std::vector<uint32_t> &lws,
StatsFuture *future) {
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel));
......@@ -330,9 +367,11 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
<< "Tuning parameters of 2D kernel must be 3d";
cl_int error = CL_SUCCESS;
std::vector<uint32_t> roundup_gws(2);
if (!is_qualcomm_opencl200) {
for (size_t i = 0; i < 2; ++i) {
roundup_gws[i] = RoundUp(gws[i], params[i]);
}
}
if (timer == nullptr) {
uint32_t num_blocks = params[2];
......@@ -341,17 +380,29 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws1 =
(i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size;
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size), cl::NDRange(gws[0], gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
} else {
uint32_t roundup_gws1 = RoundUp(gws1, params[1]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size), cl::NDRange(roundup_gws[0], roundup_gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
}
} else {
timer->ClearTiming();
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]), nullptr, &event);
} else {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1]),
cl::NDRange(params[0], params[1]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
tuning_result->assign(params.begin(), params.end());
......@@ -367,10 +418,16 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws1 =
(i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size;
if (is_qualcomm_opencl200) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size), cl::NDRange(gws[0], gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
} else {
uint32_t roundup_gws1 = RoundUp(gws1, params[1]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size), cl::NDRange(roundup_gws[0], roundup_gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
}
......
......@@ -102,6 +102,8 @@ std::string Concat(Args... args) {
return ss.str();
}
const bool IsQualcommOpenCL200();
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_OPENCL_HELPER_H_
......@@ -33,6 +33,8 @@ void MatMulFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *A,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -40,6 +42,9 @@ void MatMulFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *A,
built_options.emplace("-Dmatmul=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("matmul", kernel_name, built_options);
}
uint32_t idx = 0;
......
......@@ -20,11 +20,14 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
const DataType dt = DataTypeToEnum<T>::value;
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("pooling");
built_options.emplace("-Dpooling=" + kernel_name);
if (pooling_type_ == MAX && input->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
......@@ -36,6 +39,9 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
if (pooling_type_ == AVG) {
built_options.emplace("-DPOOL_AVG");
}
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("pooling", kernel_name, built_options);
}
......
......@@ -30,6 +30,8 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bilinear_nocache");
......@@ -37,6 +39,9 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ =
runtime->BuildKernel("resize_bilinear", kernel_name, built_options);
}
......
......@@ -31,6 +31,8 @@ void SliceFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice");
......@@ -38,6 +40,9 @@ void SliceFunctor<DeviceType::OPENCL, T>::operator()(
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE="
+ DtToCLCMDDt(DataTypeToEnum<T>::value));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("slice", kernel_name, built_options);
}
const index_t channel_blk = RoundUpDiv4(output_channels);
......
......@@ -28,6 +28,9 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
static_cast<uint32_t>(height * batch)};
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("softmax");
......@@ -35,6 +38,9 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("softmax", kernel_name, built_options);
}
if (!IsVecEqual(input_shape_, logits->shape())) {
......
......@@ -38,6 +38,8 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
std::set<std::string> built_options;
......@@ -47,6 +49,9 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" +
DtToCLCMDDt(DataTypeToEnum<T>::value));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ =
runtime->BuildKernel("space_to_batch", kernel_name, built_options);
}
......
......@@ -17,6 +17,8 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::string obfuscated_kernel_name =
MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2");
......@@ -26,6 +28,9 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(
DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" +
DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name,
built_options);
}
......@@ -90,6 +95,8 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel_.get() == nullptr) {
std::string obfuscated_kernel_name =
MACE_OBFUSCATE_SYMBOL("winograd_inverse_transform_2x2");
......@@ -100,6 +107,9 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(
DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" +
DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation_) {
case NOOP:
......
......@@ -18,8 +18,8 @@ BAZEL_BIN_PATH=${BAZEL_BIN_PATH#//}
BAZEL_BIN_PATH=bazel-bin/$BAZEL_BIN_PATH
BIN_NAME=`echo $BAZEL_TARGET | cut -d: -f2`
ANDROID_ABI=armeabi-v7a
ANDROID_ABI=arm64-v8a
ANDROID_ABI=armeabi-v7a
STRIP="--strip always"
VLOG_LEVEL=0
PROFILING="1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册