未验证 提交 98f37b08 编写于 作者: Y Yuan Shuai 提交者: GitHub

fix kernel of conv1x1, fc OOM in opencl buffer kernel. test=develop (#3062)

上级 8f4db4e0
......@@ -91,11 +91,7 @@ void gemm_batch_naive(__global const CL_DTYPE* a,
c0 += a0 * b0;
}
#ifdef RELU
cur_c[row * N + col] = activation(c0);
#else
cur_c[row * N + col] = c0;
#endif
}
......@@ -103,7 +99,7 @@ void gemm_batch_naive(__global const CL_DTYPE* a,
// a: filter_d
// b: x_d
// c: output_d
#if 0 // TODO(ysh239): cause CL_OUT_OF_HOST_MEMORY on some devices(such as snapdragon 855)
//#define PRINT_KERNEL
__kernel
void gemm_batch(__global const CL_DTYPE* Aptr,
......@@ -213,7 +209,7 @@ void gemm_batch(__global const CL_DTYPE* Aptr,
}
}
}
#endif
// fc_gemv_naive: keep for check
// used for fc with M = 1
......
......@@ -71,7 +71,11 @@ void ConvCompute::PrepareForRun() {
if (kernel_h == 1 && kernel_w == 1 && stride_h == 1 && stride_w == 1 &&
zero_pad && no_dilation && pad_equal) {
// conv2d_1x1
/* TODO(ysh329): CL_OUT_OF_MEMORY when use gemm_batched OpenCL kernel,
use gemm_batched_naive instead.
kernel_func_names_.push_back("gemm_batch");
*/
kernel_func_names_.push_back("gemm_batch_naive");
kernel_func_paths_.push_back("buffer/fc_kernel.cl");
if (relu_fused) {
build_options_.push_back("-DCL_DTYPE_float -DRELU");
......@@ -84,7 +88,11 @@ void ConvCompute::PrepareForRun() {
impl_ = &ConvCompute::Conv2d1x1;
} else if (pad_equal) {
kernel_func_names_.push_back("im2col");
/* TODO(ysh329): CL_OUT_OF_MEMORY when use gemm_batched OpenCL kernel,
use gemm_batched_naive instead.
kernel_func_names_.push_back("gemm_batch");
*/
kernel_func_names_.push_back("gemm_batch_naive");
kernel_func_paths_.push_back("buffer/im2col_kernel.cl");
kernel_func_paths_.push_back("buffer/fc_kernel.cl");
build_options_.push_back("-DCL_DTYPE_float");
......@@ -258,8 +266,14 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel,
const int m,
const int n,
const int k) {
auto global_work_size = cl::NDRange{static_cast<size_t>((m + 7) / 8),
static_cast<size_t>((n + 3) / 4),
/* TODO(ysh329): CL_OUT_OF_MEMORY when use gemm_batch OpenCL kernel,
use gemm_batch_naive instead.
auto global_work_size = cl::NDRange{static_cast<size_t>((m + 7) / 8),
static_cast<size_t>((n + 3) / 4),
static_cast<size_t>(batch_size)};
*/
auto global_work_size = cl::NDRange{static_cast<size_t>(m),
static_cast<size_t>(n),
static_cast<size_t>(batch_size)};
auto local_work_size = cl::NDRange{16, 16}; // cl::NullRange;
......
......@@ -168,7 +168,7 @@ void PrintData(std::string name,
// buffer
// #define PRINT_RESULT
#define LOOP_TEST
// #define LOOP_TEST
TEST(conv2d, compute_conv2d_1x1) {
// conv2d 1x1 note
// kernel/filter size = 1x1, group = 1, pad = 0, stride = 1, dilation = 1
......@@ -199,7 +199,7 @@ TEST(conv2d, compute_conv2d_1x1) {
// output_dims:1 64 112 112
// filter_dims:64 32 1 1
const bool bias_flag = true;
const bool relu_flag = true;
const std::string relu_flag = "relu";
const int batch_size = 8;
const int oc = 64;
const int ih = 112;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册