提交 312239cc 编写于 作者: Y Yuan Shuai 提交者: GitHub

fix kernel of conv1x1, fc OOM in opencl buffer kernel. test=develop (#3062)

上级 7788fe08
...@@ -91,11 +91,7 @@ void gemm_batch_naive(__global const CL_DTYPE* a, ...@@ -91,11 +91,7 @@ void gemm_batch_naive(__global const CL_DTYPE* a,
c0 += a0 * b0; c0 += a0 * b0;
} }
#ifdef RELU
cur_c[row * N + col] = activation(c0); cur_c[row * N + col] = activation(c0);
#else
cur_c[row * N + col] = c0;
#endif
} }
...@@ -103,7 +99,7 @@ void gemm_batch_naive(__global const CL_DTYPE* a, ...@@ -103,7 +99,7 @@ void gemm_batch_naive(__global const CL_DTYPE* a,
// a: filter_d // a: filter_d
// b: x_d // b: x_d
// c: output_d // c: output_d
#if 0 // TODO(ysh239): cause CL_OUT_OF_HOST_MEMORY on some devices(such as snapdragon 855)
//#define PRINT_KERNEL //#define PRINT_KERNEL
__kernel __kernel
void gemm_batch(__global const CL_DTYPE* Aptr, void gemm_batch(__global const CL_DTYPE* Aptr,
...@@ -213,7 +209,7 @@ void gemm_batch(__global const CL_DTYPE* Aptr, ...@@ -213,7 +209,7 @@ void gemm_batch(__global const CL_DTYPE* Aptr,
} }
} }
} }
#endif
// fc_gemv_naive: keep for check // fc_gemv_naive: keep for check
// used for fc with M = 1 // used for fc with M = 1
......
...@@ -71,7 +71,11 @@ void ConvCompute::PrepareForRun() { ...@@ -71,7 +71,11 @@ void ConvCompute::PrepareForRun() {
if (kernel_h == 1 && kernel_w == 1 && stride_h == 1 && stride_w == 1 && if (kernel_h == 1 && kernel_w == 1 && stride_h == 1 && stride_w == 1 &&
zero_pad && no_dilation && pad_equal) { zero_pad && no_dilation && pad_equal) {
// conv2d_1x1 // conv2d_1x1
/* TODO(ysh329): CL_OUT_OF_MEMORY when use gemm_batched OpenCL kernel,
use gemm_batched_naive instead.
kernel_func_names_.push_back("gemm_batch"); kernel_func_names_.push_back("gemm_batch");
*/
kernel_func_names_.push_back("gemm_batch_naive");
kernel_func_paths_.push_back("buffer/fc_kernel.cl"); kernel_func_paths_.push_back("buffer/fc_kernel.cl");
if (relu_fused) { if (relu_fused) {
build_options_.push_back("-DCL_DTYPE_float -DRELU"); build_options_.push_back("-DCL_DTYPE_float -DRELU");
...@@ -84,7 +88,11 @@ void ConvCompute::PrepareForRun() { ...@@ -84,7 +88,11 @@ void ConvCompute::PrepareForRun() {
impl_ = &ConvCompute::Conv2d1x1; impl_ = &ConvCompute::Conv2d1x1;
} else if (pad_equal) { } else if (pad_equal) {
kernel_func_names_.push_back("im2col"); kernel_func_names_.push_back("im2col");
/* TODO(ysh329): CL_OUT_OF_MEMORY when use gemm_batched OpenCL kernel,
use gemm_batched_naive instead.
kernel_func_names_.push_back("gemm_batch"); kernel_func_names_.push_back("gemm_batch");
*/
kernel_func_names_.push_back("gemm_batch_naive");
kernel_func_paths_.push_back("buffer/im2col_kernel.cl"); kernel_func_paths_.push_back("buffer/im2col_kernel.cl");
kernel_func_paths_.push_back("buffer/fc_kernel.cl"); kernel_func_paths_.push_back("buffer/fc_kernel.cl");
build_options_.push_back("-DCL_DTYPE_float"); build_options_.push_back("-DCL_DTYPE_float");
...@@ -258,8 +266,14 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel, ...@@ -258,8 +266,14 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel,
const int m, const int m,
const int n, const int n,
const int k) { const int k) {
auto global_work_size = cl::NDRange{static_cast<size_t>((m + 7) / 8), /* TODO(ysh329): CL_OUT_OF_MEMORY when use gemm_batch OpenCL kernel,
static_cast<size_t>((n + 3) / 4), use gemm_batch_naive instead.
auto global_work_size = cl::NDRange{static_cast<size_t>((m + 7) / 8),
static_cast<size_t>((n + 3) / 4),
static_cast<size_t>(batch_size)};
*/
auto global_work_size = cl::NDRange{static_cast<size_t>(m),
static_cast<size_t>(n),
static_cast<size_t>(batch_size)}; static_cast<size_t>(batch_size)};
auto local_work_size = cl::NDRange{16, 16}; // cl::NullRange; auto local_work_size = cl::NDRange{16, 16}; // cl::NullRange;
......
...@@ -168,7 +168,7 @@ void PrintData(std::string name, ...@@ -168,7 +168,7 @@ void PrintData(std::string name,
// buffer // buffer
// #define PRINT_RESULT // #define PRINT_RESULT
#define LOOP_TEST // #define LOOP_TEST
TEST(conv2d, compute_conv2d_1x1) { TEST(conv2d, compute_conv2d_1x1) {
// conv2d 1x1 note // conv2d 1x1 note
// kernel/filter size = 1x1, group = 1, pad = 0, stride = 1, dilation = 1 // kernel/filter size = 1x1, group = 1, pad = 0, stride = 1, dilation = 1
...@@ -199,7 +199,7 @@ TEST(conv2d, compute_conv2d_1x1) { ...@@ -199,7 +199,7 @@ TEST(conv2d, compute_conv2d_1x1) {
// output_dims:1 64 112 112 // output_dims:1 64 112 112
// filter_dims:64 32 1 1 // filter_dims:64 32 1 1
const bool bias_flag = true; const bool bias_flag = true;
const bool relu_flag = true; const std::string relu_flag = "relu";
const int batch_size = 8; const int batch_size = 8;
const int oc = 64; const int oc = 64;
const int ih = 112; const int ih = 112;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册