提交 0fdea60f 编写于 作者: 刘托

Merge branch 'gpu-lws-bug' into 'master'

Fix opencl default lws calculation bug to support low-end SOCs.

See merge request !596
...@@ -56,7 +56,7 @@ MaceStatus OpenCLAllocator::New(size_t nbytes, void **result) const { ...@@ -56,7 +56,7 @@ MaceStatus OpenCLAllocator::New(size_t nbytes, void **result) const {
nbytes, nullptr, &error); nbytes, nullptr, &error);
if (error != CL_SUCCESS) { if (error != CL_SUCCESS) {
LOG(WARNING) << "Allocate OpenCL Buffer with " LOG(WARNING) << "Allocate OpenCL Buffer with "
<< nbytes << " bytes failed because of" << nbytes << " bytes failed because of "
<< OpenCLErrorToString(error); << OpenCLErrorToString(error);
delete buffer; delete buffer;
*result = nullptr; *result = nullptr;
......
...@@ -371,7 +371,8 @@ OpenCLRuntime::OpenCLRuntime(): ...@@ -371,7 +371,8 @@ OpenCLRuntime::OpenCLRuntime():
} }
cl_int err; cl_int err;
if (gpu_type_ == GPUType::QUALCOMM_ADRENO) { if (gpu_type_ == GPUType::QUALCOMM_ADRENO
&& opencl_version_ == OpenCLVersion::CL_VER_2_0) {
std::vector<cl_context_properties> context_properties; std::vector<cl_context_properties> context_properties;
context_properties.reserve(5); context_properties.reserve(5);
GetAdrenoContextProperties(&context_properties, GetAdrenoContextProperties(&context_properties,
...@@ -698,7 +699,7 @@ uint64_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) { ...@@ -698,7 +699,7 @@ uint64_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) {
bool OpenCLRuntime::IsNonUniformWorkgroupsSupported() const { bool OpenCLRuntime::IsNonUniformWorkgroupsSupported() const {
return (gpu_type_ == GPUType::QUALCOMM_ADRENO && return (gpu_type_ == GPUType::QUALCOMM_ADRENO &&
opencl_version_ == "2.0"); opencl_version_ == OpenCLVersion::CL_VER_2_0);
} }
GPUType OpenCLRuntime::gpu_type() const { GPUType OpenCLRuntime::gpu_type() const {
...@@ -709,13 +710,24 @@ const std::string OpenCLRuntime::platform_info() const { ...@@ -709,13 +710,24 @@ const std::string OpenCLRuntime::platform_info() const {
return platform_info_; return platform_info_;
} }
const std::string OpenCLRuntime::ParseDeviceVersion( OpenCLVersion OpenCLRuntime::ParseDeviceVersion(
const std::string &device_version) { const std::string &device_version) {
// OpenCL Device version string format: // OpenCL Device version string format:
// OpenCL<space><major_version.minor_version><space> // OpenCL<space><major_version.minor_version><space>
// <vendor-specific information> // <vendor-specific information>
auto words = Split(device_version, ' '); auto words = Split(device_version, ' ');
return words[1]; if (words[1] == "2.0") {
return OpenCLVersion::CL_VER_2_0;
} else if (words[1] == "1.2") {
return OpenCLVersion::CL_VER_1_2;
} else if (words[1] == "1.1") {
return OpenCLVersion::CL_VER_1_1;
} else if (words[1] == "1.0") {
return OpenCLVersion::CL_VER_1_0;
} else {
LOG(FATAL) << "Do not support OpenCL version: " << words[1];
return OpenCLVersion::CL_VER_1_0;
}
} }
bool OpenCLRuntime::IsOutOfRangeCheckEnabled() const { bool OpenCLRuntime::IsOutOfRangeCheckEnabled() const {
......
...@@ -38,6 +38,13 @@ enum GPUType { ...@@ -38,6 +38,13 @@ enum GPUType {
UNKNOWN, UNKNOWN,
}; };
enum OpenCLVersion {
CL_VER_1_0,
CL_VER_1_1,
CL_VER_1_2,
CL_VER_2_0,
};
const std::string OpenCLErrorToString(cl_int error); const std::string OpenCLErrorToString(cl_int error);
...@@ -113,7 +120,7 @@ class OpenCLRuntime { ...@@ -113,7 +120,7 @@ class OpenCLRuntime {
const std::string &built_program_key, const std::string &built_program_key,
const std::string &build_options_str, const std::string &build_options_str,
cl::Program *program); cl::Program *program);
const std::string ParseDeviceVersion(const std::string &device_version); OpenCLVersion ParseDeviceVersion(const std::string &device_version);
private: private:
std::unique_ptr<KVStorage> precompiled_binary_storage_; std::unique_ptr<KVStorage> precompiled_binary_storage_;
...@@ -127,7 +134,7 @@ class OpenCLRuntime { ...@@ -127,7 +134,7 @@ class OpenCLRuntime {
std::map<std::string, cl::Program> built_program_map_; std::map<std::string, cl::Program> built_program_map_;
std::mutex program_build_mutex_; std::mutex program_build_mutex_;
std::string platform_info_; std::string platform_info_;
std::string opencl_version_; OpenCLVersion opencl_version_;
std::string precompiled_binary_platform_info_; std::string precompiled_binary_platform_info_;
std::string cached_binary_platform_info_; std::string cached_binary_platform_info_;
bool out_of_range_check_; bool out_of_range_check_;
......
...@@ -118,10 +118,7 @@ MaceStatus BatchNormFunctor<DeviceType::GPU, T>::operator()( ...@@ -118,10 +118,7 @@ MaceStatus BatchNormFunctor<DeviceType::GPU, T>::operator()(
input_shape_ = input->shape(); input_shape_ = input->shape();
} }
std::vector<uint32_t> lws(4, 0); const std::vector<uint32_t> lws = Default3DLocalWS(gws, kwg_size_);
lws[1] = std::min<uint32_t>(gws[1], kwg_size_);
lws[0] = std::min<uint32_t>(4, kwg_size_ / lws[1]);
lws[2] = std::min<uint32_t>(gws[2], kwg_size_ / (lws[1] * lws[0]));
std::string tuning_key = std::string tuning_key =
Concat("batch_norm_opencl_kernel", activation_, output->dim(0), Concat("batch_norm_opencl_kernel", activation_, output->dim(0),
output->dim(1), output->dim(2), output->dim(3), folded_constant_); output->dim(1), output->dim(2), output->dim(3), folded_constant_);
......
...@@ -25,11 +25,11 @@ namespace { ...@@ -25,11 +25,11 @@ namespace {
std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
std::vector<uint32_t> lws(4, 0); std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t base = cache_size / kBaseGPUMemCacheSize; uint32_t base = std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
lws[0] = std::min<uint32_t>(base, kwg_size / lws[1]); lws[0] = std::min<uint32_t>(base, kwg_size / lws[1]);
const uint32_t lws_size = lws[0] * lws[1]; const uint32_t lws_size = lws[0] * lws[1];
lws[2] = std::min<uint32_t>(base, kwg_size / lws_size); lws[2] = std::max<uint32_t>(std::min<uint32_t>(base, kwg_size / lws_size), 1);
return lws; return lws;
} }
......
...@@ -80,8 +80,8 @@ MaceStatus Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input, ...@@ -80,8 +80,8 @@ MaceStatus Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
std::vector<index_t> *input_shape, Tensor *output, StatsFuture *future, std::vector<index_t> *input_shape, Tensor *output, StatsFuture *future,
uint32_t *kwg_size, std::unique_ptr<BufferBase> *kernel_error); uint32_t *kwg_size, std::unique_ptr<BufferBase> *kernel_error);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5] = { static const Conv2dOpenclFunction selector[3] = {
Conv2dOpenclK1x1, nullptr, Conv2dOpenclK3x3, nullptr, nullptr}; Conv2dOpenclK1x1, nullptr, Conv2dOpenclK3x3};
index_t kernel_h = filter->dim(2); index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3); index_t kernel_w = filter->dim(3);
...@@ -113,7 +113,7 @@ MaceStatus Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input, ...@@ -113,7 +113,7 @@ MaceStatus Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
&output_image_shape); &output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape)); MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
if (kernel_h == kernel_w && kernel_h <= 5 && if (kernel_h == kernel_w && kernel_h <= 3 &&
selector[kernel_h - 1] != nullptr) { selector[kernel_h - 1] != nullptr) {
auto conv2d_func = selector[kernel_h - 1]; auto conv2d_func = selector[kernel_h - 1];
return conv2d_func( return conv2d_func(
......
...@@ -29,7 +29,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { ...@@ -29,7 +29,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
std::vector<uint32_t> lws(4, 0); std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t compute_units = OpenCLRuntime::Global()->device_compute_units(); uint32_t compute_units = OpenCLRuntime::Global()->device_compute_units();
uint32_t base = cache_size / kBaseGPUMemCacheSize; const uint32_t base =
std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
if (lws[1] >= base) { if (lws[1] >= base) {
lws[0] = std::min<uint32_t>(gws[0], base); lws[0] = std::min<uint32_t>(gws[0], base);
...@@ -48,7 +49,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { ...@@ -48,7 +49,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
if (lws[2] == 0) { if (lws[2] == 0) {
lws[2] = std::min<uint32_t>(gws[2], base); lws[2] = std::min<uint32_t>(gws[2], base);
} }
lws[2] = std::min<uint32_t>(lws[2], kwg_size / lws_size); lws[2] = std::max<uint32_t>(std::min<uint32_t>(lws[2], kwg_size / lws_size),
1);
return lws; return lws;
} }
......
...@@ -30,7 +30,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { ...@@ -30,7 +30,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
uint32_t compute_units = std::max<uint32_t>( uint32_t compute_units = std::max<uint32_t>(
OpenCLRuntime::Global()->device_compute_units() / 2, 1); OpenCLRuntime::Global()->device_compute_units() / 2, 1);
const uint32_t base = const uint32_t base =
std::min<uint32_t>(cache_size / kBaseGPUMemCacheSize, 4); std::max<uint32_t>(
std::min<uint32_t>(cache_size / kBaseGPUMemCacheSize, 4), 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
lws[0] = lws[0] =
std::min<uint32_t>(std::min<uint32_t>(gws[0], base), kwg_size / lws[1]); std::min<uint32_t>(std::min<uint32_t>(gws[0], base), kwg_size / lws[1]);
...@@ -42,7 +43,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { ...@@ -42,7 +43,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
if (lws[2] == 0) { if (lws[2] == 0) {
lws[2] = std::min<uint32_t>(gws[2], base); lws[2] = std::min<uint32_t>(gws[2], base);
} }
lws[2] = std::min<uint32_t>(lws[2], kwg_size / lws_size); lws[2] = std::max<uint32_t>(std::min<uint32_t>(lws[2], kwg_size / lws_size),
1);
return lws; return lws;
} }
......
...@@ -32,7 +32,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, ...@@ -32,7 +32,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws,
std::vector<uint32_t> lws(4, 0); std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t compute_units = OpenCLRuntime::Global()->device_compute_units(); uint32_t compute_units = OpenCLRuntime::Global()->device_compute_units();
uint32_t base = cache_size / kBaseGPUMemCacheSize; const uint32_t base =
std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
lws[0] = gws[0] / 4; lws[0] = gws[0] / 4;
if (lws[0] == 0) { if (lws[0] == 0) {
...@@ -51,7 +52,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, ...@@ -51,7 +52,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws,
lws[2] = base; lws[2] = base;
} }
} }
lws[2] = std::min<uint32_t>(lws[2], kwg_size / lws_size); lws[2] = std::max<uint32_t>(std::min<uint32_t>(lws[2], kwg_size / lws_size),
1);
return lws; return lws;
} }
......
...@@ -144,7 +144,7 @@ MaceStatus Deconv2dOpencl(cl::Kernel *kernel, ...@@ -144,7 +144,7 @@ MaceStatus Deconv2dOpencl(cl::Kernel *kernel,
*prev_input_shape = input->shape(); *prev_input_shape = input->shape();
} }
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 0}; const std::vector<uint32_t> lws = Default3DLocalWS(gws, *kwg_size);
std::string tuning_key = std::string tuning_key =
Concat("deconv2d_opencl_kernel_", activation, output->dim(0), Concat("deconv2d_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3)); output->dim(1), output->dim(2), output->dim(3));
......
...@@ -27,24 +27,26 @@ const uint32_t kernel_cache_size = (4 + 4 + 1) * 4 * 4; ...@@ -27,24 +27,26 @@ const uint32_t kernel_cache_size = (4 + 4 + 1) * 4 * 4;
std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
std::vector<uint32_t> lws(4, 0); std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t min_lws0 = cache_size / kBaseGPUMemCacheSize; uint32_t base = cache_size / kBaseGPUMemCacheSize;
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
if (lws[1] >= min_lws0) { if (lws[1] >= base) {
lws[0] = std::min<uint32_t>(gws[0], min_lws0); lws[0] = std::min<uint32_t>(gws[0], base);
} else { } else {
lws[0] = std::min<uint32_t>(gws[0] / 8, kwg_size / lws[1]); lws[0] = std::min<uint32_t>(gws[0] / 8, kwg_size / lws[1]);
if (lws[0] < min_lws0) { if (lws[0] < base) {
lws[0] = std::min<uint32_t>(std::max<uint32_t>(gws[0] / 4, min_lws0), lws[0] = std::min<uint32_t>(std::max<uint32_t>(gws[0] / 4, base),
kwg_size / lws[1]); kwg_size / lws[1]);
} }
} }
lws[0] = std::max<uint32_t>(lws[0], 1);
const uint32_t lws_size = lws[0] * lws[1]; const uint32_t lws_size = lws[0] * lws[1];
lws[2] = std::min<uint32_t>((cache_size / kernel_cache_size / lws_size) * 4, lws[2] = std::min<uint32_t>((cache_size / kernel_cache_size / lws_size) * 4,
gws[2]); gws[2]);
if (lws[2] == 0) { if (lws[2] == 0) {
lws[2] = gws[2]; lws[2] = gws[2];
} }
lws[2] = std::min<uint32_t>(lws[2], kwg_size / lws_size); lws[2] = std::max<uint32_t>(std::min<uint32_t>(lws[2], kwg_size / lws_size),
1);
return lws; return lws;
} }
......
...@@ -252,7 +252,8 @@ std::vector<uint32_t> Default3DLocalWS(const uint32_t *gws, ...@@ -252,7 +252,8 @@ std::vector<uint32_t> Default3DLocalWS(const uint32_t *gws,
lws[2] = lws[2] =
std::min<uint32_t>(std::min<uint32_t>(gws[2], base), kwg_size / lws[1]); std::min<uint32_t>(std::min<uint32_t>(gws[2], base), kwg_size / lws[1]);
const uint32_t lws_size = lws[1] * lws[2]; const uint32_t lws_size = lws[1] * lws[2];
lws[0] = std::min<uint32_t>(base, kwg_size / lws_size); lws[0] = std::max<uint32_t>(std::min<uint32_t>(base, kwg_size / lws_size),
1);
return lws; return lws;
} }
......
...@@ -26,7 +26,7 @@ namespace { ...@@ -26,7 +26,7 @@ namespace {
std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
std::vector<uint32_t> lws(4, 0); std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t base = cache_size / kBaseGPUMemCacheSize; uint32_t base = std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
lws[2] = lws[2] =
std::min<uint32_t>(std::min<uint32_t>(gws[2], base), kwg_size / lws[1]); std::min<uint32_t>(std::min<uint32_t>(gws[2], base), kwg_size / lws[1]);
...@@ -35,7 +35,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { ...@@ -35,7 +35,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
if (lws[0] == 0) { if (lws[0] == 0) {
lws[0] = gws[0]; lws[0] = gws[0];
} }
lws[0] = std::min<uint32_t>(lws[0], kwg_size / lws_size); lws[0] = std::max<uint32_t>(std::min<uint32_t>(lws[0], kwg_size / lws_size),
1);
return lws; return lws;
} }
......
...@@ -26,7 +26,7 @@ namespace { ...@@ -26,7 +26,7 @@ namespace {
std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
std::vector<uint32_t> lws(4, 0); std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t base = cache_size / kBaseGPUMemCacheSize; uint32_t base = std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
if (lws[1] >= base) { if (lws[1] >= base) {
lws[0] = std::min<uint32_t>(gws[0], base); lws[0] = std::min<uint32_t>(gws[0], base);
...@@ -42,7 +42,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { ...@@ -42,7 +42,8 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
if (lws[2] == 0) { if (lws[2] == 0) {
lws[2] = gws[2]; lws[2] = gws[2];
} }
lws[2] = std::min<uint32_t>(lws[2], kwg_size / lws_size); lws[2] = std::max<uint32_t>(std::min<uint32_t>(lws[2], kwg_size / lws_size),
1);
return lws; return lws;
} }
......
...@@ -26,7 +26,7 @@ namespace { ...@@ -26,7 +26,7 @@ namespace {
std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size(); uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t base = cache_size / kBaseGPUMemCacheSize; uint32_t base = std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
std::vector<uint32_t> lws(4, 0); std::vector<uint32_t> lws(4, 0);
lws[1] = std::min<uint32_t>(gws[1], kwg_size); lws[1] = std::min<uint32_t>(gws[1], kwg_size);
if (gws[0] < base) { if (gws[0] < base) {
...@@ -35,7 +35,9 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) { ...@@ -35,7 +35,9 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
lws[0] = gws[0] / base; lws[0] = gws[0] / base;
} }
lws[0] = std::min<uint32_t>(lws[0], kwg_size / lws[1]); lws[0] = std::min<uint32_t>(lws[0], kwg_size / lws[1]);
lws[2] = std::min<uint32_t>(gws[2], kwg_size / (lws[0] * lws[1])); lws[2] = std::max<uint32_t>(std::min<uint32_t>(gws[2],
kwg_size / (lws[0] * lws[1])),
1);
return lws; return lws;
} }
......
...@@ -136,7 +136,7 @@ TEST(BufferToImageTest, WeightWidthMedium) { ...@@ -136,7 +136,7 @@ TEST(BufferToImageTest, WeightWidthMedium) {
TEST(BufferToImageTest, WeightWidthLarge) { TEST(BufferToImageTest, WeightWidthLarge) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_WIDTH, TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_WIDTH,
{64, 128, 11, 13}); {64, 64, 11, 13});
} }
TEST(BufferToImageTest, WeightHeightSmall) { TEST(BufferToImageTest, WeightHeightSmall) {
...@@ -151,7 +151,7 @@ TEST(BufferToImageTest, WeightHeightMedium) { ...@@ -151,7 +151,7 @@ TEST(BufferToImageTest, WeightHeightMedium) {
TEST(BufferToImageTest, WeightHeightLarge) { TEST(BufferToImageTest, WeightHeightLarge) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_HEIGHT, TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_HEIGHT,
{64, 32, 11, 13}); {64, 16, 11, 13});
} }
namespace { namespace {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册