提交 3ff4f62b 编写于 作者: L Liangliang He

Merge branch 'mali-tuning' into 'master'

New opencl kernel time limit strategy to support opencl 1.1/1.2.

See merge request !405
......@@ -110,7 +110,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::string tuning_key =
Concat(tuning_key_prefix_, output->dim(0), output->dim(1), output->dim(2),
output->dim(3));
......
......@@ -105,7 +105,7 @@ void AddNFunctor<DeviceType::OPENCL, T>::operator()(
input_shape_ = input_tensors[0]->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 1};
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 0};
std::stringstream ss;
ss << "addn_opencl_kernel_" << output_shape[0] << "_" << output_shape[1]
<< "_" << output_shape[2] << "_" << output_shape[3];
......
......@@ -116,7 +116,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::string tuning_key =
Concat("batch_norm_opencl_kernel_", activation_, output->dim(0),
output->dim(1), output->dim(2), output->dim(3), folded_constant_);
......
......@@ -90,7 +90,7 @@ void ChannelShuffleFunctor<DeviceType::OPENCL, T>::operator()(
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::stringstream ss;
ss << "channel_shuffle_opencl_kernel_"
<< output->dim(0) << "_"
......
......@@ -95,7 +95,7 @@ static void Concat2(cl::Kernel *kernel,
*prev_input_shape = input0->shape();
}
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 1};
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 0};
std::stringstream ss;
ss << "concat_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
......
......@@ -130,7 +130,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
*prev_input_shape = input->shape();
}
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 1};
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 0};
std::string tuning_key =
Concat("conv2d_1x1_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
......
......@@ -128,7 +128,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
*prev_input_shape = input->shape();
}
const std::vector<uint32_t> lws = {4, *kwg_size / 32, 8, 1};
const std::vector<uint32_t> lws = {4, *kwg_size / 32, 8, 0};
std::string tuning_key =
Concat("conv2d_3x3_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
......
......@@ -130,7 +130,7 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
*prev_input_shape = input->shape();
}
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 1};
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 0};
std::string tuning_key =
Concat("conv2d_general_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
......
......@@ -76,7 +76,7 @@ void CWiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 1};
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 0};
std::stringstream ss;
ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
......
......@@ -134,7 +134,7 @@ void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()(
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
......
......@@ -149,7 +149,7 @@ static void DepthwiseConv2d(cl::Kernel *kernel,
*prev_input_shape = input->shape();
}
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 1};
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 0};
std::string tuning_key = Concat("depthwise_conv2d_ocl_kernel_", activation,
batch, height, width, channels, multiplier);
TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future);
......
......@@ -85,7 +85,7 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
input_shape_ = input0->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 1};
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 0};
std::stringstream ss;
ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
......
......@@ -233,7 +233,7 @@ void FCWTXKernel(cl::Kernel *kernel,
uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(*kernel));
*lws = {16, kwg_size/16, 1};
*lws = {16, kwg_size/16, 0};
}
if (!IsVecEqual(*prev_input_shape, input->shape())) {
const index_t batch = output->dim(0);
......
......@@ -223,23 +223,23 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel,
std::min<uint32_t>(gws[2], kwg_size / (local_ws[0] * local_ws[1]));
return {
// TODO(heliangliang): tuning these magic numbers
{local_ws[0], local_ws[1], local_ws[2], 1},
{kwg_size / 16, 4, 4, 1},
{kwg_size / 32, 4, 8, 1},
{kwg_size / 32, 8, 4, 1},
{kwg_size / 64, 8, 8, 1},
{kwg_size / 64, 16, 4, 1},
{kwg_size / 128, 8, 16, 1},
{kwg_size / 128, 16, 8, 1},
{kwg_size / 128, 32, 4, 1},
{1, kwg_size / 32, 32, 1},
{1, kwg_size / 64, 64, 1},
{1, kwg_size / 128, 128, 1},
{4, kwg_size / 16, 4, 1},
{4, kwg_size / 28, 7, 1},
{4, kwg_size / 32, 8, 1},
{4, kwg_size / 56, 14, 1},
{1, kwg_size, 1, 1},
{local_ws[0], local_ws[1], local_ws[2], 0},
{kwg_size / 16, 4, 4, 0},
{kwg_size / 32, 4, 8, 0},
{kwg_size / 32, 8, 4, 0},
{kwg_size / 64, 8, 8, 0},
{kwg_size / 64, 16, 4, 0},
{kwg_size / 128, 8, 16, 0},
{kwg_size / 128, 16, 8, 0},
{kwg_size / 128, 32, 4, 0},
{1, kwg_size / 32, 32, 0},
{1, kwg_size / 64, 64, 0},
{1, kwg_size / 128, 128, 0},
{4, kwg_size / 16, 4, 0},
{4, kwg_size / 28, 7, 0},
{4, kwg_size / 32, 8, 0},
{4, kwg_size / 56, 14, 0},
{1, kwg_size, 1, 0},
};
};
cl::Event event;
......@@ -248,46 +248,35 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel,
MACE_CHECK(params.size() == 4)
<< "Tuning parameters of 3D kernel must be 4D";
cl_int error = CL_SUCCESS;
std::vector<uint32_t> roundup_gws(3);
std::vector<uint32_t> internal_gws(gws, gws+3);
if (!runtime->IsNonUniformWorkgroupsSupported()) {
for (size_t i = 0; i < 3; ++i) {
roundup_gws[i] = RoundUp(gws[i], params[i]);
internal_gws[i] = RoundUp(gws[i], params[i]);
}
}
if (timer == nullptr) {
uint32_t num_blocks = params[3];
const uint32_t block_size = gws[2] / num_blocks;
if (gws[2] % num_blocks > 0) num_blocks++;
uint32_t block_size = params[3] == 0 ? internal_gws[2] : params[3];
const uint32_t num_blocks = RoundUpDiv<uint32_t>(internal_gws[2],
block_size);
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws2 =
(i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size;
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(gws[0], gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
} else {
uint32_t roundup_gws2 = RoundUp(gws2, params[2]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
uint32_t gws2 = block_size;
if (runtime->IsNonUniformWorkgroupsSupported()
&& (i == num_blocks - 1)) {
gws2 = (internal_gws[2] - (i * block_size));
}
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(internal_gws[0], internal_gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK_CL_SUCCESS(error);
}
} else {
timer->ClearTiming();
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
} else {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange,
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
}
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange,
cl::NDRange(internal_gws[0], internal_gws[1], internal_gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK_CL_SUCCESS(error);
timer->AccumulateTiming();
tuning_result->assign(params.begin(), params.end());
......@@ -297,24 +286,22 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel,
timer->ClearTiming();
uint32_t num_blocks = std::min(
static_cast<uint32_t>(elapse_time / kMaxKernelExeTime) + 1, gws[2]);
(*tuning_result)[3] = num_blocks;
const uint32_t block_size = gws[2] / num_blocks;
if (gws[2] % num_blocks > 0) num_blocks++;
uint32_t block_size = gws[2] / num_blocks;
if (!runtime->IsNonUniformWorkgroupsSupported()) {
block_size = RoundUp(block_size, params[2]);
}
(*tuning_result)[3] = block_size;
num_blocks = RoundUpDiv<uint32_t>(internal_gws[2], block_size);
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws2 =
(i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size;
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(gws[0], gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
} else {
uint32_t roundup_gws2 = RoundUp(gws2, params[2]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
uint32_t gws2 = block_size;
if (runtime->IsNonUniformWorkgroupsSupported()
&& (i == num_blocks - 1)) {
gws2 = (internal_gws[2] - (i * block_size));
}
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, 0, i * block_size),
cl::NDRange(internal_gws[0], internal_gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK_CL_SUCCESS(error);
timer->AccumulateTiming();
}
......@@ -349,16 +336,16 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
uint32_t local_ws[2];
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
return {{local_ws[0], local_ws[1], 1},
{local_ws[1], local_ws[0], 1},
{kwg_size / 4, 4, 1},
{kwg_size / 16, 16, 1},
{kwg_size / 32, 32, 1},
{kwg_size / 64, 64, 1},
{kwg_size / 128, 128, 1},
{kwg_size / 256, 256, 1},
{kwg_size, 1, 1},
{1, kwg_size, 1}};
return {{local_ws[0], local_ws[1], 0},
{local_ws[1], local_ws[0], 0},
{kwg_size / 4, 4, 0},
{kwg_size / 16, 16, 0},
{kwg_size / 32, 32, 0},
{kwg_size / 64, 64, 0},
{kwg_size / 128, 128, 0},
{kwg_size / 256, 256, 0},
{kwg_size, 1, 0},
{1, kwg_size, 0}};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t> &params, Timer *timer,
......@@ -366,44 +353,34 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
MACE_CHECK(params.size() == 3)
<< "Tuning parameters of 2D kernel must be 3d";
cl_int error = CL_SUCCESS;
std::vector<uint32_t> roundup_gws(2);
std::vector<uint32_t> internal_gws(gws, gws+2);
if (!runtime->IsNonUniformWorkgroupsSupported()) {
for (size_t i = 0; i < 2; ++i) {
roundup_gws[i] = RoundUp(gws[i], params[i]);
internal_gws[i] = RoundUp(gws[i], params[i]);
}
}
if (timer == nullptr) {
uint32_t num_blocks = params[2];
const uint32_t block_size = gws[1] / num_blocks;
if (gws[1] % num_blocks > 0) num_blocks++;
uint32_t block_size = params[2] == 0 ? internal_gws[1] : params[2];
const uint32_t num_blocks = RoundUpDiv<uint32_t>(internal_gws[1],
block_size);
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws1 =
(i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size;
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size), cl::NDRange(gws[0], gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
} else {
uint32_t roundup_gws1 = RoundUp(gws1, params[1]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size),
cl::NDRange(roundup_gws[0], roundup_gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
uint32_t gws1 = block_size;
if (runtime->IsNonUniformWorkgroupsSupported()
&& (i == num_blocks - 1)) {
gws1 = (internal_gws[1] - (i * block_size));
}
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size),
cl::NDRange(internal_gws[0], gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
MACE_CHECK_CL_SUCCESS(error);
}
} else {
timer->ClearTiming();
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]), nullptr, &event);
} else {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1]),
cl::NDRange(params[0], params[1]), nullptr, &event);
}
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(internal_gws[0], internal_gws[1]),
cl::NDRange(params[0], params[1]), nullptr, &event);
MACE_CHECK_CL_SUCCESS(error);
timer->AccumulateTiming();
tuning_result->assign(params.begin(), params.end());
......@@ -413,24 +390,22 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel,
timer->ClearTiming();
uint32_t num_blocks = std::min(
static_cast<uint32_t>(elapse_time / kMaxKernelExeTime) + 1, gws[1]);
(*tuning_result)[2] = num_blocks;
const uint32_t block_size = gws[1] / num_blocks;
if (gws[1] % num_blocks > 0) num_blocks++;
uint32_t block_size = gws[1] / num_blocks;
if (!runtime->IsNonUniformWorkgroupsSupported()) {
block_size = RoundUp(block_size, params[1]);
}
(*tuning_result)[2] = block_size;
num_blocks = RoundUpDiv<uint32_t>(internal_gws[1], block_size);
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws1 =
(i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size;
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size),
cl::NDRange(gws[0], gws1), cl::NDRange(params[0], params[1]),
nullptr, &event);
} else {
uint32_t roundup_gws1 = RoundUp(gws1, params[1]);
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size),
cl::NDRange(roundup_gws[0], roundup_gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
uint32_t gws1 = block_size;
if (runtime->IsNonUniformWorkgroupsSupported()
&& (i == num_blocks - 1)) {
gws1 = (internal_gws[1] - (i * block_size));
}
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NDRange(0, i * block_size),
cl::NDRange(internal_gws[0], gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
MACE_CHECK_CL_SUCCESS(error);
timer->AccumulateTiming();
}
......
......@@ -84,7 +84,7 @@ void MatMulFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *A,
kernel_.setArg(idx++, static_cast<int>(height_blocks));
kernel_.setArg(idx++, static_cast<int>(RoundUpDiv4(A->dim(2))));
const std::vector<uint32_t> lws = {kwg_size_ / 64, 64, 1};
const std::vector<uint32_t> lws = {kwg_size_ / 64, 64, 0};
std::stringstream ss;
ss << "matmul_opencl_kernel_" << C->dim(0) << "_" << C->dim(1) << "_"
<< C->dim(2) << "_" << C->dim(3);
......
......@@ -100,7 +100,7 @@ void PadFunctor<DeviceType::OPENCL, T>::operator()(
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::string tuning_key =
Concat("pad", output->dim(0), output->dim(1), output->dim(2),
output->dim(3));
......
......@@ -134,7 +134,7 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
};
}
std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::stringstream ss;
ss << "pooling_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
......
......@@ -99,7 +99,7 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::stringstream ss;
ss << "resize_bilinear_opencl_kernel_" << output->dim(0) << "_"
<< output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3);
......
......@@ -81,7 +81,7 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
input_shape_ = logits->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::stringstream ss;
ss << "softmax_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
......
......@@ -105,7 +105,7 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(
space_shape_ = space_tensor->shape();
}
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 1};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::stringstream ss;
ss << kernel_name << "_" << batch_tensor->dim(0) << "_"
<< batch_tensor->dim(1) << "_" << batch_tensor->dim(2) << "_"
......
......@@ -54,7 +54,7 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {3, 3, input_tensor->dim(3), 1};
std::vector<index_t> filter_shape = {3, 3, 1, input_tensor->dim(3)};
std::vector<int> paddings(2);
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
......@@ -101,7 +101,7 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(
input_shape_ = input_tensor->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 8, 8, 1};
const std::vector<uint32_t> lws = {kwg_size_ / 8, 8, 0};
std::stringstream ss;
ss << "winograd_transform_kernel_" << input_tensor->dim(0) << "_"
<< input_tensor->dim(1) << "_" << input_tensor->dim(2) << "_"
......@@ -215,7 +215,7 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(
input_shape_ = input_tensor->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 8, 8, 1};
const std::vector<uint32_t> lws = {kwg_size_ / 8, 8, 0};
std::stringstream ss;
ss << "winograd_inverse_transform_kernel_" << input_tensor->dim(0) << "_"
......
......@@ -559,6 +559,7 @@ class CaffeConverter(object):
paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None)
filter_shape = np.asarray(op.data[0].shape)
filter_shape = filter_shape[[2, 3, 0, 1]] # OIHW -> HWOI
input_format = 'NHWC'
output_shape = Shapes.conv_pool_shape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册