提交 44d4903d 编写于 作者: Y yejianwu

remove compatible 1.1 1.2 for fully_connected

上级 bdd6ff45
......@@ -10,23 +10,9 @@ __kernel void fully_connected(__read_only image2d_t input,
__private const int input_height,
__private const int input_width,
__private const int input_channel,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const float relux_max_limit,
__private const int global_size_dim0,
__private const int global_size_dim1) {
#else
__private const float relux_max_limit) {
#endif
const int batch_idx = get_global_id(0);
const int out_blk_idx = get_global_id(1);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (batch_idx >= global_size_dim0 || out_blk_idx >= global_size_dim1) {
return;
}
#endif
const int input_chan_blk = (input_channel + 3) >> 2;
float4 input_value;
......@@ -82,29 +68,11 @@ __kernel void fully_connected_width(__read_only image2d_t input,
__private const int input_width,
__private const int in_chan_blks,
__private const int out_blks,
#ifndef USE_QUALCOMM_OPENCL_2_0
__private const float relux_max_limit,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2) {
#else
__private const float relux_max_limit) {
#endif
const int inter_out_idx = get_global_id(0);
const int width_blk_idx = get_global_id(1);
const int batch_out_blk_idx = get_global_id(2);
#ifndef USE_QUALCOMM_OPENCL_2_0
if (inter_out_idx >= global_size_dim0 || width_blk_idx >= global_size_dim1
|| batch_out_blk_idx >= global_size_dim2) {
return;
}
const int width_blk_count = global_size_dim1;
#else
const int width_blk_count = get_global_size(1);
#endif
const int batch_out_blk_idx = get_global_id(2);
const int batch_idx = batch_out_blk_idx / out_blks;
const int out_blk_idx = batch_out_blk_idx % out_blks;
......
......@@ -24,11 +24,8 @@ void FCWXKernel(cl::Kernel *kernel,
<< "FC width kernel only support input with 4x channel.";
MACE_CHECK_NOTNULL(gws);
MACE_CHECK_NOTNULL(lws);
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -37,9 +34,6 @@ void FCWXKernel(cl::Kernel *kernel,
built_options.emplace("-Dfully_connected_width=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
......@@ -81,7 +75,6 @@ void FCWXKernel(cl::Kernel *kernel,
if (!IsVecEqual(*prev_input_shape, input->shape())) {
const index_t batch = output->dim(0);
const index_t output_blocks = RoundUpDiv4(output->dim(3));
(*gws)[2] = static_cast<uint32_t>(batch * output_blocks);
uint32_t idx = 0;
kernel->setArg(idx++, *(input->opencl_image()));
......@@ -97,22 +90,14 @@ void FCWXKernel(cl::Kernel *kernel,
kernel->setArg(idx++, static_cast<int>(RoundUpDiv4(input->dim(3))));
kernel->setArg(idx++, static_cast<int>(output_blocks));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, (*gws)[0]);
kernel->setArg(idx++, (*gws)[1]);
kernel->setArg(idx++, (*gws)[2]);
*prev_input_shape = input->shape();
}
(*gws)[2] = static_cast<uint32_t>(batch * output_blocks);
std::vector<uint32_t> roundup_gws(lws->size());
for (size_t i = 0; i < lws->size(); ++i) {
roundup_gws[i] = RoundUp((*gws)[i], (*lws)[i]);
*prev_input_shape = input->shape();
}
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
*kernel, cl::NullRange,
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]),
*kernel, cl::NullRange, cl::NDRange((*gws)[0], (*gws)[1], (*gws)[2]),
cl::NDRange((*lws)[0], (*lws)[1], (*lws)[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
......@@ -140,21 +125,14 @@ void FCWTXKernel(cl::Kernel *kernel,
StatsFuture *future) {
MACE_CHECK_NOTNULL(gws);
MACE_CHECK_NOTNULL(lws);
auto runtime = OpenCLRuntime::Global();
const bool is_qualcomm_opencl200 = IsQualcommOpenCL200();
if (kernel->get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected");
built_options.emplace("-Dfully_connected=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (is_qualcomm_opencl200) {
built_options.emplace("-DUSE_QUALCOMM_OPENCL_2_0");
}
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
......@@ -183,13 +161,6 @@ void FCWTXKernel(cl::Kernel *kernel,
}
if (!IsVecEqual(*prev_input_shape, input->shape())) {
uint32_t idx = 0;
const index_t batch = output->dim(0);
const index_t output_blocks = RoundUpDiv4(output->dim(3));
*gws = {
static_cast<uint32_t>(batch), static_cast<uint32_t>(output_blocks),
};
kernel->setArg(idx++, *(input->opencl_image()));
kernel->setArg(idx++, *(weight->opencl_image()));
if (bias != nullptr) {
......@@ -201,9 +172,13 @@ void FCWTXKernel(cl::Kernel *kernel,
kernel->setArg(idx++, static_cast<int>(input->dim(3)));
// FIXME handle flexable data type: half not supported
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, (*gws)[0]);
kernel->setArg(idx++, (*gws)[1]);
const index_t batch = output->dim(0);
const index_t output_blocks = RoundUpDiv4(output->dim(3));
*gws = {
static_cast<uint32_t>(batch), static_cast<uint32_t>(output_blocks),
};
*prev_input_shape = input->shape();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册