提交 6262bc32 编写于 作者: 开心的小妮's avatar 开心的小妮

[LITE][OPENCL] Make comptue kernel hold cl::kernel. test=develop

上级 293d2d38
...@@ -77,14 +77,10 @@ class ActivationComputeImageDefault ...@@ -77,14 +77,10 @@ class ActivationComputeImageDefault
#endif #endif
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_, kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/activation_kernel.cl", "image/activation_kernel.cl",
build_options_, build_options_,
time_stamp_); time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
} }
void ReInitWhenNeeded() override { void ReInitWhenNeeded() override {
...@@ -118,15 +114,14 @@ class ActivationComputeImageDefault ...@@ -118,15 +114,14 @@ class ActivationComputeImageDefault
auto* out_img = act_param_->Out->mutable_data<half_t, cl::Image2D>( auto* out_img = act_param_->Out->mutable_data<half_t, cl::Image2D>(
out_img_shape_[0], out_img_shape_[1]); out_img_shape_[0], out_img_shape_[1]);
auto kernel = kernel_;
cl_int status; cl_int status;
status = kernel.setArg(0, *x_img); status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img); status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, threshold_); status = kernel_->setArg(2, threshold_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, scale_); status = kernel_->setArg(3, scale_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -148,7 +143,7 @@ class ActivationComputeImageDefault ...@@ -148,7 +143,7 @@ class ActivationComputeImageDefault
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -168,7 +163,7 @@ class ActivationComputeImageDefault ...@@ -168,7 +163,7 @@ class ActivationComputeImageDefault
std::string kernel_func_name_{}; std::string kernel_func_name_{};
float threshold_{6.f}; float threshold_{6.f};
float scale_{1.f}; float scale_{1.f};
cl::Kernel kernel_; std::shared_ptr<cl::Kernel> kernel_;
bool first_epoch_for_reinit_{true}; bool first_epoch_for_reinit_{true};
cl::NDRange global_work_size_ = cl::NDRange{ cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
......
...@@ -40,10 +40,10 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -40,10 +40,10 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
kernel_func_name_ = "concat_mul"; kernel_func_name_ = "concat_mul";
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(kernel_func_name_, kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/concat_kernel.cl", "image/concat_kernel.cl",
build_options_, build_options_,
time_stamp_); time_stamp_);
auto axis = concat_param_->axis; auto axis = concat_param_->axis;
auto inputs = concat_param_->x; auto inputs = concat_param_->x;
...@@ -118,8 +118,6 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -118,8 +118,6 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto inputs = param.x; auto inputs = param.x;
int arg_idx = 0; int arg_idx = 0;
...@@ -164,31 +162,30 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -164,31 +162,30 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
<< (image_shape["height"]); << (image_shape["height"]);
#endif #endif
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int out_w = x_dims[x_dims.size() - 1]; int out_w = x_dims[x_dims.size() - 1];
int out_c = x_dims[1]; int out_c = x_dims[1];
if (inputs.size() == 2) { if (inputs.size() == 2) {
auto* x_buf0 = inputs[0]->data<half_t, cl::Image2D>(); auto* x_buf0 = inputs[0]->data<half_t, cl::Image2D>();
auto* x_buf1 = inputs[1]->data<half_t, cl::Image2D>(); auto* x_buf1 = inputs[1]->data<half_t, cl::Image2D>();
cl_int status = kernel.setArg(arg_idx, *x_buf0); cl_int status = kernel_->setArg(arg_idx, *x_buf0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_buf1); status = kernel_->setArg(++arg_idx, *x_buf1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel_->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag_); status = kernel_->setArg(++arg_idx, flag_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = status = kernel_->setArg(++arg_idx,
kernel.setArg(++arg_idx, static_cast<int>(inputs[0]->dims()[axis_])); static_cast<int>(inputs[0]->dims()[axis_]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_c); status = kernel_->setArg(++arg_idx, out_c);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w); status = kernel_->setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_); status = kernel_->setArg(++arg_idx, width_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -213,25 +210,25 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -213,25 +210,25 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(image_shape["width"] / static_cast<cl::size_type>(image_shape["width"] /
in_dims[in_dims.size() - 1]), in_dims[in_dims.size() - 1]),
static_cast<cl::size_type>(image_shape["height"])}; static_cast<cl::size_type>(image_shape["height"])};
cl_int status = kernel.setArg(arg_idx, *x_buf); cl_int status = kernel_->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel_->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag_); status = kernel_->setArg(++arg_idx, flag_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start); status = kernel_->setArg(++arg_idx, start);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_c); status = kernel_->setArg(++arg_idx, out_c);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w); status = kernel_->setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_w); status = kernel_->setArg(++arg_idx, in_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_); status = kernel_->setArg(++arg_idx, width_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -255,6 +252,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -255,6 +252,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
std::string build_options_{" -DCL_DTYPE_half"}; std::string build_options_{" -DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
}; };
} // namespace opencl } // namespace opencl
......
...@@ -368,25 +368,17 @@ void ConvImageCompute::PrepareForRun() { ...@@ -368,25 +368,17 @@ void ConvImageCompute::PrepareForRun() {
build_options_.push_back(build_options_single); build_options_.push_back(build_options_single);
for (size_t i = 0; i < kernel_func_names_.size(); i++) { kernel_ = context.cl_context()->CreateKernel(kernel_func_names_[0],
context.cl_context()->AddKernel(kernel_func_names_[i], kernel_func_paths_[0],
kernel_func_paths_[i], build_options_[0],
build_options_[i], time_stamp_);
time_stamp_);
}
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << "," VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}"; << global_work_size_[1] << "," << global_work_size_[2] << "}";
std::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0] << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
size_t max_work_group_size = 0; size_t max_work_group_size = 0;
kernel_.getWorkGroupInfo<size_t>(CLRuntime::Global()->device(), kernel_->getWorkGroupInfo<size_t>(CLRuntime::Global()->device(),
CL_KERNEL_WORK_GROUP_SIZE, CL_KERNEL_WORK_GROUP_SIZE,
&max_work_group_size); &max_work_group_size);
VLOG(4) << "max_work_group_size: " << max_work_group_size; VLOG(4) << "max_work_group_size: " << max_work_group_size;
...@@ -501,49 +493,48 @@ void ConvImageCompute::Conv2d1x1opt(bool is_turn) { ...@@ -501,49 +493,48 @@ void ConvImageCompute::Conv2d1x1opt(bool is_turn) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset); status = kernel_->setArg(++arg_idx, offset);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block); status = kernel_->setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c); status = kernel_->setArg(++arg_idx, input_c);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, default_w_blk_); status = kernel_->setArg(++arg_idx, default_w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
local_work_size_, local_work_size_,
...@@ -649,56 +640,55 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) { ...@@ -649,56 +640,55 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) {
if (has_bias) { if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: "; VLOG(4) << "set bias_image: ";
#endif #endif
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset); status = kernel_->setArg(++arg_idx, offset);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block); status = kernel_->setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_channel); status = kernel_->setArg(++arg_idx, output_channel);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_channel); status = kernel_->setArg(++arg_idx, filter_channel);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width); status = kernel_->setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height); status = kernel_->setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, new_groups); status = kernel_->setArg(++arg_idx, new_groups);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -708,7 +698,7 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) { ...@@ -708,7 +698,7 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) {
#endif #endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -784,48 +774,46 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) { ...@@ -784,48 +774,46 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: "; VLOG(4) << "set bias_image: ";
#endif #endif
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, paddings[0]); status = kernel_->setArg(++arg_idx, paddings[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, batch); status = kernel_->setArg(++arg_idx, batch);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_channel); status = kernel_->setArg(++arg_idx, input_channel);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -835,7 +823,7 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) { ...@@ -835,7 +823,7 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) {
#endif #endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
local_work_size_, local_work_size_,
...@@ -917,46 +905,44 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) { ...@@ -917,46 +905,44 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: "; VLOG(4) << "set bias_image: ";
#endif #endif
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset); status = kernel_->setArg(++arg_idx, offset);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block); status = kernel_->setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -966,7 +952,7 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) { ...@@ -966,7 +952,7 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) {
#endif #endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -1042,50 +1028,49 @@ void ConvImageCompute::Conv2d5x5opt(bool is_turn) { ...@@ -1042,50 +1028,49 @@ void ConvImageCompute::Conv2d5x5opt(bool is_turn) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, paddings[0]); status = kernel_->setArg(++arg_idx, paddings[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, batch); status = kernel_->setArg(++arg_idx, batch);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_channel); status = kernel_->setArg(++arg_idx, input_channel);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
// VLOG(4) << "out_image: " << out_image; // VLOG(4) << "out_image: " << out_image;
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
local_work_size_, local_work_size_,
...@@ -1167,46 +1152,44 @@ void ConvImageCompute::Conv2d7x7(bool is_turn) { ...@@ -1167,46 +1152,44 @@ void ConvImageCompute::Conv2d7x7(bool is_turn) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: "; VLOG(4) << "set bias_image: ";
#endif #endif
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset); status = kernel_->setArg(++arg_idx, offset);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block); status = kernel_->setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -1216,7 +1199,7 @@ void ConvImageCompute::Conv2d7x7(bool is_turn) { ...@@ -1216,7 +1199,7 @@ void ConvImageCompute::Conv2d7x7(bool is_turn) {
#endif #endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -1290,49 +1273,47 @@ void ConvImageCompute::Conv2d7x7opt(bool is_turn) { ...@@ -1290,49 +1273,47 @@ void ConvImageCompute::Conv2d7x7opt(bool is_turn) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, paddings[0]); status = kernel_->setArg(++arg_idx, paddings[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, batch); status = kernel_->setArg(++arg_idx, batch);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_channel); status = kernel_->setArg(++arg_idx, input_channel);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
local_work_size_, local_work_size_,
...@@ -1369,19 +1350,17 @@ void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) { ...@@ -1369,19 +1350,17 @@ void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) {
auto* output_img = param.output->mutable_data<half_t, cl::Image2D>( auto* output_img = param.output->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]); image_shape["width"], image_shape["height"]);
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img); status = kernel_->setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img); status = kernel_->setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
const bool has_bias = param.bias != nullptr; const bool has_bias = param.bias != nullptr;
...@@ -1393,30 +1372,30 @@ void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) { ...@@ -1393,30 +1372,30 @@ void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: "; VLOG(4) << "set bias_image: ";
#endif #endif
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *output_img); status = kernel_->setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0])); status = kernel_->setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0])); status = kernel_->setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0])); status = kernel_->setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[1])); status = kernel_->setArg(++arg_idx, static_cast<const int>(x_dims[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3])); status = kernel_->setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2])); status = kernel_->setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3])); status = kernel_->setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2])); status = kernel_->setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
local_work_size_, local_work_size_,
...@@ -1456,8 +1435,6 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) { ...@@ -1456,8 +1435,6 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) {
auto* output_img = param.output->mutable_data<half_t, cl::Image2D>( auto* output_img = param.output->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]); image_shape["width"], image_shape["height"]);
auto kernel = kernel_;
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "setArg"; VLOG(4) << "setArg";
VLOG(4) << "strides = " << strides[0]; VLOG(4) << "strides = " << strides[0];
...@@ -1472,15 +1449,15 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) { ...@@ -1472,15 +1449,15 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) {
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img); status = kernel_->setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img); status = kernel_->setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
const bool has_bias = param.bias != nullptr; const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias = const bool is_element_wise_bias =
...@@ -1491,30 +1468,30 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) { ...@@ -1491,30 +1468,30 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: "; VLOG(4) << "set bias_image: ";
#endif #endif
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *output_img); status = kernel_->setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0])); status = kernel_->setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(offset)); status = kernel_->setArg(++arg_idx, static_cast<const int>(offset));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0])); status = kernel_->setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(input_c_block)); status = kernel_->setArg(++arg_idx, static_cast<const int>(input_c_block));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3])); status = kernel_->setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2])); status = kernel_->setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3])); status = kernel_->setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2])); status = kernel_->setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -1598,50 +1575,48 @@ void ConvImageCompute::DepthwiseConv2d(bool is_turn) { ...@@ -1598,50 +1575,48 @@ void ConvImageCompute::DepthwiseConv2d(bool is_turn) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>(); bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
} }
auto kernel = kernel_;
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, c_blk_); status = kernel_->setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk_); status = kernel_->setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh_blk_); status = kernel_->setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image); status = kernel_->setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image); status = kernel_->setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
if (has_bias) { if (has_bias) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: "; VLOG(4) << "set bias_image: ";
#endif #endif
status = kernel.setArg(++arg_idx, *bias_image); status = kernel_->setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]); status = kernel_->setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset); status = kernel_->setArg(++arg_idx, offset);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block); status = kernel_->setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]); status = kernel_->setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width); status = kernel_->setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height); status = kernel_->setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width); status = kernel_->setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height); status = kernel_->setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width); status = kernel_->setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height); status = kernel_->setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -1651,7 +1626,7 @@ void ConvImageCompute::DepthwiseConv2d(bool is_turn) { ...@@ -1651,7 +1626,7 @@ void ConvImageCompute::DepthwiseConv2d(bool is_turn) {
#endif #endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
......
...@@ -71,7 +71,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -71,7 +71,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
int default_w_blk_ = 1; int default_w_blk_ = 1;
int default_nh_blk_ = 1; int default_nh_blk_ = 1;
cl::Kernel kernel_; std::shared_ptr<cl::Kernel> kernel_;
cl::NDRange local_work_size_ = cl::NDRange{ cl::NDRange local_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
bool use_lws_{true}; bool use_lws_{true};
......
...@@ -59,14 +59,11 @@ void ElementwiseAddImageCompute::ReInitWhenNeeded() { ...@@ -59,14 +59,11 @@ void ElementwiseAddImageCompute::ReInitWhenNeeded() {
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_, kernel_ =
"image/elementwise_add_kernel.cl", context.cl_context()->CreateKernel(kernel_func_name_,
build_options_, "image/elementwise_add_kernel.cl",
time_stamp_); build_options_,
time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
// compute image shape // compute image shape
paddle::lite::CLImageConverterDefault default_convertor; paddle::lite::CLImageConverterDefault default_convertor;
...@@ -118,13 +115,12 @@ void ElementwiseAddImageCompute::Run() { ...@@ -118,13 +115,12 @@ void ElementwiseAddImageCompute::Run() {
#endif #endif
cl_int status; cl_int status;
auto kernel = kernel_;
if (y_dims.size() == 4) { if (y_dims.size() == 4) {
status = kernel.setArg(0, *x_img); status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1) { } else if (y_dims.size() == 1) {
if (axis == x_dims.size() - 1 || axis == x_dims.size() - 3) { if (axis == x_dims.size() - 1 || axis == x_dims.size() - 3) {
...@@ -132,13 +128,13 @@ void ElementwiseAddImageCompute::Run() { ...@@ -132,13 +128,13 @@ void ElementwiseAddImageCompute::Run() {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "tensor_w:" << tensor_w; VLOG(4) << "tensor_w:" << tensor_w;
#endif #endif
status = kernel.setArg(0, *x_img); status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis
...@@ -154,7 +150,7 @@ void ElementwiseAddImageCompute::Run() { ...@@ -154,7 +150,7 @@ void ElementwiseAddImageCompute::Run() {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
......
...@@ -60,7 +60,7 @@ class ElementwiseAddImageCompute ...@@ -60,7 +60,7 @@ class ElementwiseAddImageCompute
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
bool first_epoch_for_reinit_{true}; bool first_epoch_for_reinit_{true};
cl::Kernel kernel_; std::shared_ptr<cl::Kernel> kernel_;
cl::NDRange global_work_size_ = cl::NDRange{ cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
......
...@@ -71,10 +71,11 @@ class ElementwiseMulImageCompute ...@@ -71,10 +71,11 @@ class ElementwiseMulImageCompute
VLOG(4) << "bias_dims.size():" << bias_dims.size(); VLOG(4) << "bias_dims.size():" << bias_dims.size();
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_, kernel_ =
"image/elementwise_mul_kernel.cl", context.cl_context()->CreateKernel(kernel_func_name_,
build_options_, "image/elementwise_mul_kernel.cl",
time_stamp_); build_options_,
time_stamp_);
} }
void Run() override { void Run() override {
...@@ -115,66 +116,61 @@ class ElementwiseMulImageCompute ...@@ -115,66 +116,61 @@ class ElementwiseMulImageCompute
<< out_img_shape[1]; << out_img_shape[1];
#endif #endif
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
auto bias_dims = y->dims(); auto bias_dims = y->dims();
auto x_dims = x->dims(); auto x_dims = x->dims();
if (bias_dims == x_dims) { if (bias_dims == x_dims) {
// kernel_func_name_ = "elementwise_mul"; // kernel_func_name_ = "elementwise_mul";
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
const int bias_dim_size = bias_dims.size(); const int bias_dim_size = bias_dims.size();
if (bias_dim_size == 1) { if (bias_dim_size == 1) {
// kernel_func_name_ = "channel_mul_d1"; // kernel_func_name_ = "channel_mul_d1";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (bias_dim_size == 2) { } else if (bias_dim_size == 2) {
// kernel_func_name_ = "channel_mul_d2"; // kernel_func_name_ = "channel_mul_d2";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (bias_dim_size == 3) { } else if (bias_dim_size == 3) {
// kernel_func_name_ = "channel_mul_d3"; // kernel_func_name_ = "channel_mul_d3";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (bias_dim_size == 4) { } else if (bias_dim_size == 4) {
// kernel_func_name_ = "channel_mul_d4"; // kernel_func_name_ = "channel_mul_d4";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
...@@ -186,7 +182,7 @@ class ElementwiseMulImageCompute ...@@ -186,7 +182,7 @@ class ElementwiseMulImageCompute
cl::NDRange{static_cast<cl::size_type>(x_img_width), cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)}; static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -205,6 +201,7 @@ class ElementwiseMulImageCompute ...@@ -205,6 +201,7 @@ class ElementwiseMulImageCompute
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
}; };
} // namespace opencl } // namespace opencl
......
...@@ -75,13 +75,10 @@ class FcCompute ...@@ -75,13 +75,10 @@ class FcCompute
} }
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_, kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"buffer/fc_kernel.cl", "buffer/fc_kernel.cl",
build_options_, build_options_,
time_stamp_); time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
// compute global work size // compute global work size
GetGlobalWorkSize(); GetGlobalWorkSize();
...@@ -106,25 +103,25 @@ class FcCompute ...@@ -106,25 +103,25 @@ class FcCompute
auto kernel = kernel_; auto kernel = kernel_;
cl_int status; cl_int status;
status = kernel.setArg(0, *x_buf); status = kernel_->setArg(0, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *w_buf); status = kernel_->setArg(1, *w_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *bias_buf); status = kernel_->setArg(2, *bias_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, *out_buf); status = kernel_->setArg(3, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(4, static_cast<const int>(m_)); status = kernel_->setArg(4, static_cast<const int>(m_));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(5, static_cast<const int>(n_)); status = kernel_->setArg(5, static_cast<const int>(n_));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(6, static_cast<const int>(k_)); status = kernel_->setArg(6, static_cast<const int>(k_));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -143,7 +140,7 @@ class FcCompute ...@@ -143,7 +140,7 @@ class FcCompute
bool first_epoch_for_reinit_{true}; bool first_epoch_for_reinit_{true};
DDim last_x_dims_; DDim last_x_dims_;
cl::NDRange global_work_size_; cl::NDRange global_work_size_;
cl::Kernel kernel_; std::shared_ptr<cl::Kernel> kernel_;
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
}; };
......
...@@ -31,10 +31,11 @@ class FusionElementwiseAddActivationImageCompute ...@@ -31,10 +31,11 @@ class FusionElementwiseAddActivationImageCompute
void PrepareForRun() override { void PrepareForRun() override {
build_options_ += " -DRELU"; build_options_ += " -DRELU";
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_, kernel_ =
"image/elementwise_add_kernel.cl", context.cl_context()->CreateKernel(kernel_func_name_,
build_options_, "image/elementwise_add_kernel.cl",
time_stamp_); build_options_,
time_stamp_);
ele_param_ = param_.get_mutable<param_t>(); ele_param_ = param_.get_mutable<param_t>();
auto act_t = static_cast<param_t*>(ele_param_)->act_type; auto act_t = static_cast<param_t*>(ele_param_)->act_type;
VLOG(4) << "act: " << act_t; VLOG(4) << "act: " << act_t;
......
...@@ -38,10 +38,11 @@ class NearestInterpComputeImageDefault ...@@ -38,10 +38,11 @@ class NearestInterpComputeImageDefault
void PrepareForRun() override { void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_, kernel_ =
"image/nearest_interp_kernel.cl", context.cl_context()->CreateKernel(kernel_func_name_,
build_options_, "image/nearest_interp_kernel.cl",
time_stamp_); build_options_,
time_stamp_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
} }
...@@ -67,26 +68,23 @@ class NearestInterpComputeImageDefault ...@@ -67,26 +68,23 @@ class NearestInterpComputeImageDefault
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; cl_int status;
cl_int status = kernel.setArg(arg_idx, *x_img); status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const float>(scale_h)); status = kernel_->setArg(2, static_cast<const float>(scale_h));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const float>(scale_w)); status = kernel_->setArg(3, static_cast<const float>(scale_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims_h)); status = kernel_->setArg(4, static_cast<const int>(in_dims_h));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims_h)); status = kernel_->setArg(5, static_cast<const int>(out_dims_h));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims_w)); status = kernel_->setArg(6, static_cast<const int>(in_dims_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims_w)); status = kernel_->setArg(7, static_cast<const int>(out_dims_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -110,7 +108,7 @@ class NearestInterpComputeImageDefault ...@@ -110,7 +108,7 @@ class NearestInterpComputeImageDefault
static_cast<cl::size_type>(default_work_size.data()[1]), static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])}; static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -125,6 +123,7 @@ class NearestInterpComputeImageDefault ...@@ -125,6 +123,7 @@ class NearestInterpComputeImageDefault
std::string build_options_{" -DCL_DTYPE_half"}; std::string build_options_{" -DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
}; };
} // namespace opencl } // namespace opencl
......
...@@ -46,7 +46,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -46,7 +46,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( kernel_ = context.cl_context()->CreateKernel(
kernel_func_name_, "image/pool_kernel.cl", build_options_, time_stamp_); kernel_func_name_, "image/pool_kernel.cl", build_options_, time_stamp_);
} }
...@@ -111,10 +111,6 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -111,10 +111,6 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
out_image_shape["width"], out_image_shape["height"]); out_image_shape["width"], out_image_shape["height"]);
// VLOG(4) << "out_image" << out_img; // VLOG(4) << "out_image" << out_img;
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (out_dims[1] + 3) / 4; int c_block = (out_dims[1] + 3) / 4;
int w = out_dims[3]; int w = out_dims[3];
int nh = out_dims[0] * out_dims[2]; int nh = out_dims[0] * out_dims[2];
...@@ -124,34 +120,33 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -124,34 +120,33 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
<< " " << nh << " "; << " " << nh << " ";
#endif #endif
cl_int status; cl_int status;
int arg_idx = 0; status = kernel_->setArg(0, *x_img);
status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[2])); status = kernel_->setArg(2, static_cast<const int>(in_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[3])); status = kernel_->setArg(3, static_cast<const int>(in_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[2])); status = kernel_->setArg(4, static_cast<const int>(out_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[3])); status = kernel_->setArg(5, static_cast<const int>(out_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[0])); status = kernel_->setArg(6, static_cast<const int>(ksize[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[1])); status = kernel_->setArg(7, static_cast<const int>(ksize[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0])); status = kernel_->setArg(8, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[1])); status = kernel_->setArg(9, static_cast<const int>(strides[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[2])); status = kernel_->setArg(10, static_cast<const int>(paddings[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0])); status = kernel_->setArg(11, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -162,6 +157,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -162,6 +157,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
} }
private: private:
std::shared_ptr<cl::Kernel> kernel_;
std::string kernel_func_name_{"pool_"}; std::string kernel_func_name_{"pool_"};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
......
...@@ -36,10 +36,10 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL), ...@@ -36,10 +36,10 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override { void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(kernel_func_name_, kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/reshape_kernel.cl", "image/reshape_kernel.cl",
build_options_, build_options_,
time_stamp_); time_stamp_);
} }
void Run() override { void Run() override {
...@@ -111,42 +111,38 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL), ...@@ -111,42 +111,38 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << TargetToStr(x->target()); VLOG(4) << TargetToStr(x->target());
VLOG(4) << TargetToStr(param.output->target()); VLOG(4) << TargetToStr(param.output->target());
#endif #endif
int arg_idx = 0;
cl_int status; cl_int status;
status = kernel.setArg(arg_idx, *x_image); status = kernel_->setArg(0, *x_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_image); status = kernel_->setArg(1, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_C); status = kernel_->setArg(2, out_C);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_H); status = kernel_->setArg(3, out_H);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_W); status = kernel_->setArg(4, out_W);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_W); status = kernel_->setArg(5, in_W);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_H); status = kernel_->setArg(6, in_H);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride0); status = kernel_->setArg(7, in_Stride0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride1); status = kernel_->setArg(8, in_Stride1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride2); status = kernel_->setArg(9, in_Stride2);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride0); status = kernel_->setArg(10, out_Stride0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride1); status = kernel_->setArg(11, out_Stride1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride2); status = kernel_->setArg(12, out_Stride2);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = auto global_work_size =
...@@ -155,7 +151,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL), ...@@ -155,7 +151,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
static_cast<size_t>(default_work_size.data()[2])}; static_cast<size_t>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -170,6 +166,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL), ...@@ -170,6 +166,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
}; };
} // namespace opencl } // namespace opencl
......
...@@ -37,15 +37,11 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -37,15 +37,11 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override { void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_, kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/scale_kernel.cl", "image/scale_kernel.cl",
build_options_, build_options_,
time_stamp_); time_stamp_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
} }
void ReInitWhenNeeded() override { void ReInitWhenNeeded() override {
...@@ -82,19 +78,18 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -82,19 +78,18 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
auto kernel = kernel_;
cl_int status; cl_int status;
status = kernel.setArg(0, *x_img); status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img); status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, scale); status = kernel_->setArg(2, scale);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, bias); status = kernel_->setArg(3, bias);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *(kernel_.get()),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -111,7 +106,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -111,7 +106,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
param_t* scale_param_{nullptr}; param_t* scale_param_{nullptr};
cl::Kernel kernel_; std::shared_ptr<cl::Kernel> kernel_;
bool first_epoch_for_reinit_{true}; bool first_epoch_for_reinit_{true};
DDim last_x_dims_; DDim last_x_dims_;
DDim out_img_shape_ = DDim(std::vector<DDim::value_type>( DDim out_img_shape_ = DDim(std::vector<DDim::value_type>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册