提交 9562d42a 编写于 作者: xiebaiyuan's avatar xiebaiyuan

[LITE][OPENCL]use shared_ptr with cl::kernel , init cl::event when use ,test=develop

上级 c8918d89
...@@ -68,16 +68,16 @@ void CLContext::AddKernel(const std::string &kernel_name, ...@@ -68,16 +68,16 @@ void CLContext::AddKernel(const std::string &kernel_name,
kernel_offset_[kernel_key.str()] = kernels_.size() - 1; kernel_offset_[kernel_key.str()] = kernels_.size() - 1;
} }
cl::Kernel &CLContext::GetKernel(const int index) { std::shared_ptr<cl::Kernel> &CLContext::GetKernel(const int index) {
VLOG(3) << " --- kernel count: " << kernels_.size() << " --- "; VLOG(3) << " --- kernel count: " << kernels_.size() << " --- ";
CHECK(static_cast<size_t>(index) < kernels_.size()) CHECK(static_cast<size_t>(index) < kernels_.size())
<< "The index must be less than the size of kernels."; << "The index must be less than the size of kernels.";
CHECK(kernels_[index] != nullptr) CHECK(kernels_[index] != nullptr)
<< "The target kernel pointer cannot be null."; << "The target kernel pointer cannot be null.";
return *(kernels_[index]); return kernels_[index];
} }
cl::Kernel &CLContext::GetKernel(const std::string &name) { std::shared_ptr<cl::Kernel> &CLContext::GetKernel(const std::string &name) {
auto it = kernel_offset_.find(name); auto it = kernel_offset_.find(name);
CHECK(it != kernel_offset_.end()) << "Cannot find the kernel function: " CHECK(it != kernel_offset_.end()) << "Cannot find the kernel function: "
<< name; << name;
......
...@@ -54,9 +54,9 @@ class CLContext { ...@@ -54,9 +54,9 @@ class CLContext {
const std::string &options = "", const std::string &options = "",
const std::string &time_stamp = ""); const std::string &time_stamp = "");
cl::Kernel &GetKernel(const int index); std::shared_ptr<cl::Kernel> &GetKernel(const int index);
cl::Kernel &GetKernel(const std::string &name); std::shared_ptr<cl::Kernel> &GetKernel(const std::string &name);
cl::NDRange DefaultWorkSize(const CLImage &image); cl::NDRange DefaultWorkSize(const CLImage &image);
......
...@@ -54,16 +54,16 @@ class ReluCompute ...@@ -54,16 +54,16 @@ class ReluCompute
VLOG(4) << TargetToStr(param.Out->target()); VLOG(4) << TargetToStr(param.Out->target());
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_buf); cl_int status = kernel->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, (const int)count); status = kernel->setArg(++arg_idx, (const int)count);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{count}; auto global_work_size = cl::NDRange{count};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -112,16 +112,16 @@ class SigmoidCompute ...@@ -112,16 +112,16 @@ class SigmoidCompute
VLOG(4) << TargetToStr(param.Out->target()); VLOG(4) << TargetToStr(param.Out->target());
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_buf); cl_int status = kernel->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, (const int)count); status = kernel->setArg(++arg_idx, (const int)count);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{count}; auto global_work_size = cl::NDRange{count};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -84,7 +84,7 @@ class ActivationComputeImageDefault ...@@ -84,7 +84,7 @@ class ActivationComputeImageDefault
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
} }
void ReInitWhenNeeded() override { void ReInitWhenNeeded() override {
...@@ -117,16 +117,20 @@ class ActivationComputeImageDefault ...@@ -117,16 +117,20 @@ class ActivationComputeImageDefault
auto* x_img = act_param_->X->data<half_t, cl::Image2D>(); auto* x_img = act_param_->X->data<half_t, cl::Image2D>();
auto* out_img = act_param_->Out->mutable_data<half_t, cl::Image2D>( auto* out_img = act_param_->Out->mutable_data<half_t, cl::Image2D>(
out_img_shape_[0], out_img_shape_[1]); out_img_shape_[0], out_img_shape_[1]);
auto& context = ctx_->As<OpenCLContext>();
auto kernel = kernel_; CHECK(context.cl_context() != nullptr);
std::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
;
cl_int status; cl_int status;
status = kernel.setArg(0, *x_img); status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img); status = kernel->setArg(1, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, threshold_); status = kernel->setArg(2, threshold_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, scale_); status = kernel->setArg(3, scale_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -145,10 +149,8 @@ class ActivationComputeImageDefault ...@@ -145,10 +149,8 @@ class ActivationComputeImageDefault
VLOG(4) << "kernel func name:" << kernel_func_name_; VLOG(4) << "kernel func name:" << kernel_func_name_;
#endif #endif
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -168,7 +170,7 @@ class ActivationComputeImageDefault ...@@ -168,7 +170,7 @@ class ActivationComputeImageDefault
std::string kernel_func_name_{}; std::string kernel_func_name_{};
float threshold_{6.f}; float threshold_{6.f};
float scale_{1.f}; float scale_{1.f};
cl::Kernel kernel_; cl::Kernel kernel;
bool first_epoch_for_reinit_{true}; bool first_epoch_for_reinit_{true};
cl::NDRange global_work_size_ = cl::NDRange{ cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
......
...@@ -118,23 +118,23 @@ class BilinearInterpImageCompute ...@@ -118,23 +118,23 @@ class BilinearInterpImageCompute
VLOG(4) << "default_work_size: " << default_work_size[0] << ", " VLOG(4) << "default_work_size: " << default_work_size[0] << ", "
<< default_work_size[1] << ", " << default_work_size[2]; << default_work_size[1] << ", " << default_work_size[2];
#endif #endif
cl_int status = kernel.setArg(arg_idx++, *x_img); cl_int status = kernel->setArg(arg_idx++, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *out_img); status = kernel->setArg(arg_idx++, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, scale_h); status = kernel->setArg(arg_idx++, scale_h);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, scale_w); status = kernel->setArg(arg_idx++, scale_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, align_delta); status = kernel->setArg(arg_idx++, align_delta);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_h); status = kernel->setArg(arg_idx++, in_h);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_w); status = kernel->setArg(arg_idx++, in_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_h); status = kernel->setArg(arg_idx++, out_h);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_w); status = kernel->setArg(arg_idx++, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = auto global_work_size =
...@@ -143,7 +143,7 @@ class BilinearInterpImageCompute ...@@ -143,7 +143,7 @@ class BilinearInterpImageCompute
static_cast<cl::size_type>(default_work_size[2])}; static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -104,24 +104,24 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -104,24 +104,24 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
<< default_work_size[1] << ", " << default_work_size[2]; << default_work_size[1] << ", " << default_work_size[2];
#endif #endif
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx++, *prior_box_image); cl_int status = kernel->setArg(arg_idx++, *prior_box_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *prior_box_var_image); status = kernel->setArg(arg_idx++, *prior_box_var_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *target_box_image); status = kernel->setArg(arg_idx++, *target_box_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *out_buf); status = kernel->setArg(arg_idx++, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_C); status = kernel->setArg(arg_idx++, out_C);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_H); status = kernel->setArg(arg_idx++, out_H);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(default_work_size[0]), cl::NDRange{static_cast<cl::size_type>(default_work_size[0]),
static_cast<cl::size_type>(default_work_size[2])}; static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -103,28 +103,28 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL), ...@@ -103,28 +103,28 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
auto axis0 = inputs[0]->dims()[axis_]; auto axis0 = inputs[0]->dims()[axis_];
int total0 = axis0 * post_size_; int total0 = axis0 * post_size_;
int total1 = (axis_size_ - axis0) * post_size_; int total1 = (axis_size_ - axis0) * post_size_;
cl_int status = kernel.setArg(arg_idx, *x_buf0); cl_int status = kernel->setArg(arg_idx, *x_buf0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_buf1); status = kernel->setArg(++arg_idx, *x_buf1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<int>(axis0)); status = kernel->setArg(++arg_idx, static_cast<int>(axis0));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, axis_size_); status = kernel->setArg(++arg_idx, axis_size_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, pre_size_); status = kernel->setArg(++arg_idx, pre_size_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, post_size_); status = kernel->setArg(++arg_idx, post_size_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total); status = kernel->setArg(++arg_idx, total);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total0); status = kernel->setArg(++arg_idx, total0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total1); status = kernel->setArg(++arg_idx, total1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -140,24 +140,24 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL), ...@@ -140,24 +140,24 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
auto* x_buf = inputs[i]->data<float, cl::Buffer>(); auto* x_buf = inputs[i]->data<float, cl::Buffer>();
global_work_size = cl::NDRange{static_cast<size_t>(size)}; global_work_size = cl::NDRange{static_cast<size_t>(size)};
int total0 = size * post_size_; int total0 = size * post_size_;
cl_int status = kernel.setArg(arg_idx, *x_buf); cl_int status = kernel->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<int>(size)); status = kernel->setArg(++arg_idx, static_cast<int>(size));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, pre_size_); status = kernel->setArg(++arg_idx, pre_size_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, post_size_); status = kernel->setArg(++arg_idx, post_size_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start); status = kernel->setArg(++arg_idx, start);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total); status = kernel->setArg(++arg_idx, total);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total0); status = kernel->setArg(++arg_idx, total0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -170,25 +170,25 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -170,25 +170,25 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
if (inputs.size() == 2) { if (inputs.size() == 2) {
auto* x_buf0 = inputs[0]->data<half_t, cl::Image2D>(); auto* x_buf0 = inputs[0]->data<half_t, cl::Image2D>();
auto* x_buf1 = inputs[1]->data<half_t, cl::Image2D>(); auto* x_buf1 = inputs[1]->data<half_t, cl::Image2D>();
cl_int status = kernel.setArg(arg_idx, *x_buf0); cl_int status = kernel->setArg(arg_idx, *x_buf0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_buf1); status = kernel->setArg(++arg_idx, *x_buf1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag_); status = kernel->setArg(++arg_idx, flag_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = status =
kernel.setArg(++arg_idx, static_cast<int>(inputs[0]->dims()[axis_])); kernel->setArg(++arg_idx, static_cast<int>(inputs[0]->dims()[axis_]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_c); status = kernel->setArg(++arg_idx, out_c);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w); status = kernel->setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_); status = kernel->setArg(++arg_idx, width_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -213,25 +213,25 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -213,25 +213,25 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(image_shape["width"] / static_cast<cl::size_type>(image_shape["width"] /
in_dims[in_dims.size() - 1]), in_dims[in_dims.size() - 1]),
static_cast<cl::size_type>(image_shape["height"])}; static_cast<cl::size_type>(image_shape["height"])};
cl_int status = kernel.setArg(arg_idx, *x_buf); cl_int status = kernel->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag_); status = kernel->setArg(++arg_idx, flag_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start); status = kernel->setArg(++arg_idx, start);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_c); status = kernel->setArg(++arg_idx, out_c);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w); status = kernel->setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_w); status = kernel->setArg(++arg_idx, in_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_); status = kernel->setArg(++arg_idx, width_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -283,25 +283,25 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel, ...@@ -283,25 +283,25 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel,
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, *filter_d); status = kernel->setArg(arg_idx, *filter_d);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_d); status = kernel->setArg(++arg_idx, *x_d);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *bias_d); status = kernel->setArg(++arg_idx, *bias_d);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_d); status = kernel->setArg(++arg_idx, *output_d);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, m); status = kernel->setArg(++arg_idx, m);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, n); status = kernel->setArg(++arg_idx, n);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, k); status = kernel->setArg(++arg_idx, k);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, batch_size); status = kernel->setArg(++arg_idx, batch_size);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
local_work_size, local_work_size,
......
...@@ -71,7 +71,6 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -71,7 +71,6 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
int default_w_blk_ = 1; int default_w_blk_ = 1;
int default_nh_blk_ = 1; int default_nh_blk_ = 1;
cl::Kernel kernel_;
cl::NDRange local_work_size_ = cl::NDRange{ cl::NDRange local_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
bool use_lws_{true}; bool use_lws_{true};
......
...@@ -75,41 +75,41 @@ class DepthwiseConv2dCompute ...@@ -75,41 +75,41 @@ class DepthwiseConv2dCompute
cl_int status; cl_int status;
auto numel = output_dims.production(); auto numel = output_dims.production();
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(numel)); status = kernel->setArg(arg_idx, static_cast<const int>(numel));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_buf); status = kernel->setArg(++arg_idx, *input_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3])); status = kernel->setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(output_dims[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3])); status = kernel->setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(filter_dims[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(filter_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(filter_dims[3])); status = kernel->setArg(++arg_idx, static_cast<const int>(filter_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(strides[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(paddings[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_buf); status = kernel->setArg(++arg_idx, *output_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_buf); status = kernel->setArg(++arg_idx, *filter_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *bias_buf); status = kernel->setArg(++arg_idx, *bias_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange(static_cast<size_t>(numel)); auto global_work_size = cl::NDRange(static_cast<size_t>(numel));
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -70,13 +70,13 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -70,13 +70,13 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL),
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, *x_img); status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w); status = kernel->setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dropout_prob); status = kernel->setArg(++arg_idx, dropout_prob);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
const std::vector<size_t>& default_work_size = const std::vector<size_t>& default_work_size =
...@@ -90,7 +90,7 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -90,7 +90,7 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size.data()[2])}; static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -49,22 +49,22 @@ void ElementwiseAddCompute::Run() { ...@@ -49,22 +49,22 @@ void ElementwiseAddCompute::Run() {
VLOG(4) << TargetToStr(ele_param_->Out->target()); VLOG(4) << TargetToStr(ele_param_->Out->target());
#endif #endif
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_buf); cl_int status = kernel->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_buf); status = kernel->setArg(++arg_idx, *y_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, (const int)batch_); status = kernel->setArg(++arg_idx, (const int)batch_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, (const int)channels_); status = kernel->setArg(++arg_idx, (const int)channels_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, (const int)num_); status = kernel->setArg(++arg_idx, (const int)num_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{channels_, batch_}; auto global_work_size = cl::NDRange{channels_, batch_};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -66,7 +66,7 @@ void ElementwiseAddImageCompute::ReInitWhenNeeded() { ...@@ -66,7 +66,7 @@ void ElementwiseAddImageCompute::ReInitWhenNeeded() {
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
// compute image shape // compute image shape
paddle::lite::CLImageConverterDefault default_convertor; paddle::lite::CLImageConverterDefault default_convertor;
...@@ -90,6 +90,8 @@ void ElementwiseAddImageCompute::GetGlobalWorkSize() { ...@@ -90,6 +90,8 @@ void ElementwiseAddImageCompute::GetGlobalWorkSize() {
} }
void ElementwiseAddImageCompute::Run() { void ElementwiseAddImageCompute::Run() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* x = ele_param_->X; auto* x = ele_param_->X;
auto* y = ele_param_->Y; auto* y = ele_param_->Y;
auto* out = ele_param_->Out; auto* out = ele_param_->Out;
...@@ -118,13 +120,16 @@ void ElementwiseAddImageCompute::Run() { ...@@ -118,13 +120,16 @@ void ElementwiseAddImageCompute::Run() {
#endif #endif
cl_int status; cl_int status;
auto kernel = kernel_; std::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
if (y_dims.size() == 4) { if (y_dims.size() == 4) {
status = kernel.setArg(0, *x_img); status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1) { } else if (y_dims.size() == 1) {
if (axis == x_dims.size() - 1 || axis == x_dims.size() - 3) { if (axis == x_dims.size() - 1 || axis == x_dims.size() - 3) {
...@@ -132,13 +137,13 @@ void ElementwiseAddImageCompute::Run() { ...@@ -132,13 +137,13 @@ void ElementwiseAddImageCompute::Run() {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "tensor_w:" << tensor_w; VLOG(4) << "tensor_w:" << tensor_w;
#endif #endif
status = kernel.setArg(0, *x_img); status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis
...@@ -151,10 +156,8 @@ void ElementwiseAddImageCompute::Run() { ...@@ -151,10 +156,8 @@ void ElementwiseAddImageCompute::Run() {
<< ", y->dims.size():" << y_dims.size(); << ", y->dims.size():" << y_dims.size();
} }
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
......
...@@ -96,51 +96,51 @@ void ElementwiseMulFloatImageCompute::Run() { ...@@ -96,51 +96,51 @@ void ElementwiseMulFloatImageCompute::Run() {
auto x_dims = x->dims(); auto x_dims = x->dims();
if (y_dims == x_dims) { if (y_dims == x_dims) {
// kernel: elementwise_mul(channel_mul_d4) // kernel: elementwise_mul(channel_mul_d4)
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img); status = kernel->setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1 || y_dims.size() == 4) { } else if (y_dims.size() == 1 || y_dims.size() == 4) {
auto tensor_w = x_dims[x_dims.size() - 1]; auto tensor_w = x_dims[x_dims.size() - 1];
VLOG(4) << "tensor_w:" << tensor_w; VLOG(4) << "tensor_w:" << tensor_w;
// kernel: channel_mul_d1 / channel_mul_d4 // kernel: channel_mul_d1 / channel_mul_d4
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img); status = kernel->setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w)); status = kernel->setArg(++arg_idx, static_cast<const int>(tensor_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (y_dims.size() == 2) { } else if (y_dims.size() == 2) {
if (x_dims[0] == y_dims[0] && x_dims[1] == y_dims[1]) { if (x_dims[0] == y_dims[0] && x_dims[1] == y_dims[1]) {
auto tensor_w = x_dims[x_dims.size() - 1]; auto tensor_w = x_dims[x_dims.size() - 1];
VLOG(4) << "tensor_w:" << tensor_w; VLOG(4) << "tensor_w:" << tensor_w;
// kernel: channel_mul_d2_nc // kernel: channel_mul_d2_nc
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img); status = kernel->setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w)); status = kernel->setArg(++arg_idx, static_cast<const int>(tensor_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
auto y_tensor_h = y->dims()[0]; auto y_tensor_h = y->dims()[0];
auto y_tensor_w = y->dims()[1]; auto y_tensor_w = y->dims()[1];
VLOG(4) << "y_tensor_w:" << y_tensor_w << " y_tensor_h:" << y_tensor_h; VLOG(4) << "y_tensor_w:" << y_tensor_w << " y_tensor_h:" << y_tensor_h;
// kernel: channel_mul_d2_hw // kernel: channel_mul_d2_hw
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img); status = kernel->setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_w)); status = kernel->setArg(++arg_idx, static_cast<const int>(y_tensor_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_h)); status = kernel->setArg(++arg_idx, static_cast<const int>(y_tensor_h));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} }
} else { } else {
...@@ -151,7 +151,7 @@ void ElementwiseMulFloatImageCompute::Run() { ...@@ -151,7 +151,7 @@ void ElementwiseMulFloatImageCompute::Run() {
auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width), auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)}; static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -124,57 +124,57 @@ class ElementwiseMulImageCompute ...@@ -124,57 +124,57 @@ class ElementwiseMulImageCompute
if (bias_dims == x_dims) { if (bias_dims == x_dims) {
// kernel_func_name_ = "elementwise_mul"; // kernel_func_name_ = "elementwise_mul";
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
const int bias_dim_size = bias_dims.size(); const int bias_dim_size = bias_dims.size();
if (bias_dim_size == 1) { if (bias_dim_size == 1) {
// kernel_func_name_ = "channel_mul_d1"; // kernel_func_name_ = "channel_mul_d1";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (bias_dim_size == 2) { } else if (bias_dim_size == 2) {
// kernel_func_name_ = "channel_mul_d2"; // kernel_func_name_ = "channel_mul_d2";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (bias_dim_size == 3) { } else if (bias_dim_size == 3) {
// kernel_func_name_ = "channel_mul_d3"; // kernel_func_name_ = "channel_mul_d3";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (bias_dim_size == 4) { } else if (bias_dim_size == 4) {
// kernel_func_name_ = "channel_mul_d4"; // kernel_func_name_ = "channel_mul_d4";
const int tensor_w = x_dims[x_dims.size() - 1]; const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img); cl_int status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img); status = kernel->setArg(1, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w); status = kernel->setArg(3, tensor_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
...@@ -186,7 +186,7 @@ class ElementwiseMulImageCompute ...@@ -186,7 +186,7 @@ class ElementwiseMulImageCompute
cl::NDRange{static_cast<cl::size_type>(x_img_width), cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)}; static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -101,11 +101,11 @@ void ElementwiseSubImageCompute::Run() { ...@@ -101,11 +101,11 @@ void ElementwiseSubImageCompute::Run() {
int arg_idx = 0; int arg_idx = 0;
auto y_dims = y->dims(); auto y_dims = y->dims();
if (y_dims.size() == 4) { if (y_dims.size() == 4) {
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img); status = kernel->setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1) { } else if (y_dims.size() == 1) {
if (axis == x->dims().size() - 1 || axis == x->dims().size() - 3) { if (axis == x->dims().size() - 1 || axis == x->dims().size() - 3) {
...@@ -113,13 +113,13 @@ void ElementwiseSubImageCompute::Run() { ...@@ -113,13 +113,13 @@ void ElementwiseSubImageCompute::Run() {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "tensor_w:" << tensor_w; VLOG(4) << "tensor_w:" << tensor_w;
#endif #endif
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img); status = kernel->setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w)); status = kernel->setArg(++arg_idx, static_cast<const int>(tensor_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
} else { } else {
LOG(FATAL) << "ElementwiseSubImage doesn't support axis:" << axis LOG(FATAL) << "ElementwiseSubImage doesn't support axis:" << axis
...@@ -139,7 +139,7 @@ void ElementwiseSubImageCompute::Run() { ...@@ -139,7 +139,7 @@ void ElementwiseSubImageCompute::Run() {
#endif #endif
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -81,7 +81,7 @@ class FcCompute ...@@ -81,7 +81,7 @@ class FcCompute
time_stamp_); time_stamp_);
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
// compute global work size // compute global work size
GetGlobalWorkSize(); GetGlobalWorkSize();
...@@ -103,28 +103,30 @@ class FcCompute ...@@ -103,28 +103,30 @@ class FcCompute
auto* bias_buf = fc_param_->bias->data<float, cl::Buffer>(); auto* bias_buf = fc_param_->bias->data<float, cl::Buffer>();
auto* out_buf = auto* out_buf =
fc_param_->output->mutable_data<float, cl::Buffer>(TARGET(kOpenCL)); fc_param_->output->mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto& context = ctx_->As<OpenCLContext>();
auto kernel = kernel_; CHECK(context.cl_context() != nullptr);
std::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
;
cl_int status; cl_int status;
status = kernel.setArg(0, *x_buf); status = kernel->setArg(0, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *w_buf); status = kernel->setArg(1, *w_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *bias_buf); status = kernel->setArg(2, *bias_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, *out_buf); status = kernel->setArg(3, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(4, static_cast<const int>(m_)); status = kernel->setArg(4, static_cast<const int>(m_));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(5, static_cast<const int>(n_)); status = kernel->setArg(5, static_cast<const int>(n_));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(6, static_cast<const int>(k_)); status = kernel->setArg(6, static_cast<const int>(k_));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -143,7 +145,7 @@ class FcCompute ...@@ -143,7 +145,7 @@ class FcCompute
bool first_epoch_for_reinit_{true}; bool first_epoch_for_reinit_{true};
DDim last_x_dims_; DDim last_x_dims_;
cl::NDRange global_work_size_; cl::NDRange global_work_size_;
cl::Kernel kernel_; cl::Kernel kernel;
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
}; };
......
...@@ -48,7 +48,7 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -48,7 +48,7 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str(); VLOG(4) << "kernel_key: " << kernel_key.str();
} }
...@@ -116,22 +116,24 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -116,22 +116,24 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
#endif #endif
cl_int status; cl_int status;
auto kernel = kernel_; auto& context = ctx_->As<OpenCLContext>();
status = kernel.setArg(0, *x_img); CHECK(context.cl_context() != nullptr);
std::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *grid_img); status = kernel->setArg(1, *grid_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img); status = kernel->setArg(2, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, out_height); status = kernel->setArg(3, out_height);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(4, out_width); status = kernel->setArg(4, out_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -148,7 +150,6 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -148,7 +150,6 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
DDim out_img_shape_ = DDim(std::vector<DDim::value_type>( DDim out_img_shape_ = DDim(std::vector<DDim::value_type>(
{static_cast<DDim::value_type>(1), static_cast<DDim::value_type>(1)})); {static_cast<DDim::value_type>(1), static_cast<DDim::value_type>(1)}));
std::string kernel_func_name_{"grid_sampler"}; std::string kernel_func_name_{"grid_sampler"};
cl::Kernel kernel_;
cl::NDRange global_work_size_ = cl::NDRange{ cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
......
...@@ -120,25 +120,25 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -120,25 +120,25 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
kernel_key << kernel_func_name_ << build_options_ << time_stamp_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
cl_int status = kernel.setArg(0, out_w); cl_int status = kernel->setArg(0, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, out_h); status = kernel->setArg(1, out_h);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, out_c_group); status = kernel->setArg(2, out_c_group);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, lws1); status = kernel->setArg(3, lws1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(4, lws2); status = kernel->setArg(4, lws2);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(5, epsilon); status = kernel->setArg(5, epsilon);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(6, *x_img); status = kernel->setArg(6, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(7, *out_img); status = kernel->setArg(7, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
local_work_size, local_work_size,
...@@ -244,23 +244,23 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -244,23 +244,23 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
auto* bias_img = bias_image_.data<half_t, cl::Image2D>(); auto* bias_img = bias_image_.data<half_t, cl::Image2D>();
float epsilon = instance_norm_param_->epsilon; float epsilon = instance_norm_param_->epsilon;
cl_int status = kernel.setArg(arg_idx++, *x_img); cl_int status = kernel->setArg(arg_idx++, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *out_img); status = kernel->setArg(arg_idx++, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *scale_img); status = kernel->setArg(arg_idx++, *scale_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *bias_img); status = kernel->setArg(arg_idx++, *bias_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, epsilon); status = kernel->setArg(arg_idx++, epsilon);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_h); status = kernel->setArg(arg_idx++, in_h);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_w); status = kernel->setArg(arg_idx++, in_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
local_work_size, local_work_size,
......
...@@ -99,21 +99,21 @@ class LayoutComputeBufferChwToImageDefault ...@@ -99,21 +99,21 @@ class LayoutComputeBufferChwToImageDefault
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_data); cl_int status = kernel->setArg(arg_idx, *x_data);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_data); status = kernel->setArg(++arg_idx, *y_data);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_H)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_H));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_W)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_W));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_C)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_C));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride0)); status = kernel->setArg(++arg_idx, static_cast<const int>(Stride0));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride1)); status = kernel->setArg(++arg_idx, static_cast<const int>(Stride1));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride2)); status = kernel->setArg(++arg_idx, static_cast<const int>(Stride2));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3]
...@@ -123,7 +123,7 @@ class LayoutComputeBufferChwToImageDefault ...@@ -123,7 +123,7 @@ class LayoutComputeBufferChwToImageDefault
static_cast<cl::size_type>(new_dims[3]), static_cast<cl::size_type>(new_dims[3]),
static_cast<cl::size_type>(new_dims[0] * new_dims[2])}; static_cast<cl::size_type>(new_dims[0] * new_dims[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -205,21 +205,21 @@ class LayoutComputeImageDefaultToBufferChw ...@@ -205,21 +205,21 @@ class LayoutComputeImageDefaultToBufferChw
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_data); cl_int status = kernel->setArg(arg_idx, *x_data);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_width)); status = kernel->setArg(++arg_idx, static_cast<const int>(in_width));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_height)); status = kernel->setArg(++arg_idx, static_cast<const int>(in_height));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_data); status = kernel->setArg(++arg_idx, *y_data);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(size_ch)); status = kernel->setArg(++arg_idx, static_cast<const int>(size_ch));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(size_block)); status = kernel->setArg(++arg_idx, static_cast<const int>(size_block));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(size_batch)); status = kernel->setArg(++arg_idx, static_cast<const int>(size_batch));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(C)); status = kernel->setArg(++arg_idx, static_cast<const int>(C));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3]
...@@ -230,7 +230,7 @@ class LayoutComputeImageDefaultToBufferChw ...@@ -230,7 +230,7 @@ class LayoutComputeImageDefaultToBufferChw
static_cast<cl::size_type>(new_dims[3]), static_cast<cl::size_type>(new_dims[3]),
static_cast<cl::size_type>(new_dims[0] * new_dims[2])}; static_cast<cl::size_type>(new_dims[0] * new_dims[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
...@@ -300,21 +300,21 @@ class LayoutComputeBufferChwToImage2DNw ...@@ -300,21 +300,21 @@ class LayoutComputeBufferChwToImage2DNw
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_data); cl_int status = kernel->setArg(arg_idx, *x_data);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_data); status = kernel->setArg(++arg_idx, *y_data);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_H)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_H));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_W)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_W));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_N)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_N));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride0)); status = kernel->setArg(++arg_idx, static_cast<const int>(Stride0));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride1)); status = kernel->setArg(++arg_idx, static_cast<const int>(Stride1));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride2)); status = kernel->setArg(++arg_idx, static_cast<const int>(Stride2));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
VLOG(2) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " " VLOG(2) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " "
...@@ -324,7 +324,7 @@ class LayoutComputeBufferChwToImage2DNw ...@@ -324,7 +324,7 @@ class LayoutComputeBufferChwToImage2DNw
static_cast<cl::size_type>(out_W), // w static_cast<cl::size_type>(out_W), // w
static_cast<cl::size_type>(out_C * out_H)}; // ch static_cast<cl::size_type>(out_C * out_H)}; // ch
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -106,21 +106,21 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -106,21 +106,21 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL),
VLOG(4) << "default_work_size: " << default_work_size[0] << ", " VLOG(4) << "default_work_size: " << default_work_size[0] << ", "
<< default_work_size[1] << ", " << default_work_size[3]; << default_work_size[1] << ", " << default_work_size[3];
#endif #endif
cl_int status = kernel.setArg(arg_idx++, *x_img); cl_int status = kernel->setArg(arg_idx++, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *out_img); status = kernel->setArg(arg_idx++, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_channel); status = kernel->setArg(arg_idx++, out_channel);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_width); status = kernel->setArg(arg_idx++, out_width);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, n_); status = kernel->setArg(arg_idx++, n_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, k_); status = kernel->setArg(arg_idx++, k_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, alpha_); status = kernel->setArg(arg_idx++, alpha_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, beta_); status = kernel->setArg(arg_idx++, beta_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = auto global_work_size =
...@@ -129,7 +129,7 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -129,7 +129,7 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size[2])}; static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -76,23 +76,23 @@ class MulCompute ...@@ -76,23 +76,23 @@ class MulCompute
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, *x_buf); status = kernel->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_buf); status = kernel->setArg(++arg_idx, *y_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf); status = kernel->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, m_); status = kernel->setArg(++arg_idx, m_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, n_); status = kernel->setArg(++arg_idx, n_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, k_); status = kernel->setArg(++arg_idx, k_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{static_cast<size_t>((m_ + 3) / 4), auto global_work_size = cl::NDRange{static_cast<size_t>((m_ + 3) / 4),
static_cast<size_t>((n_ + 3) / 4)}; static_cast<size_t>((n_ + 3) / 4)};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -72,21 +72,21 @@ class NearestInterpComputeImageDefault ...@@ -72,21 +72,21 @@ class NearestInterpComputeImageDefault
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_img); cl_int status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const float>(scale_h)); status = kernel->setArg(++arg_idx, static_cast<const float>(scale_h));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const float>(scale_w)); status = kernel->setArg(++arg_idx, static_cast<const float>(scale_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims_h)); status = kernel->setArg(++arg_idx, static_cast<const int>(in_dims_h));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims_h)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_dims_h));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims_w)); status = kernel->setArg(++arg_idx, static_cast<const int>(in_dims_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims_w)); status = kernel->setArg(++arg_idx, static_cast<const int>(out_dims_w));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
...@@ -110,7 +110,7 @@ class NearestInterpComputeImageDefault ...@@ -110,7 +110,7 @@ class NearestInterpComputeImageDefault
static_cast<cl::size_type>(default_work_size.data()[1]), static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])}; static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -114,27 +114,27 @@ class Pad2dCompute : public KernelLite<TARGET(kOpenCL), ...@@ -114,27 +114,27 @@ class Pad2dCompute : public KernelLite<TARGET(kOpenCL),
int pad_w1 = pad2d_param_->paddings[3]; int pad_w1 = pad2d_param_->paddings[3];
float pad_value = pad2d_param_->pad_value; float pad_value = pad2d_param_->pad_value;
cl_int status = kernel.setArg(arg_idx++, *x_img); cl_int status = kernel->setArg(arg_idx++, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *out_img); status = kernel->setArg(arg_idx++, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_h); status = kernel->setArg(arg_idx++, in_h);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_w); status = kernel->setArg(arg_idx++, in_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_h); status = kernel->setArg(arg_idx++, out_h);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_w); status = kernel->setArg(arg_idx++, out_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_h0); status = kernel->setArg(arg_idx++, pad_h0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_h1); status = kernel->setArg(arg_idx++, pad_h1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_w0); status = kernel->setArg(arg_idx++, pad_w0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_w1); status = kernel->setArg(arg_idx++, pad_w1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_value); status = kernel->setArg(arg_idx++, pad_value);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = auto global_work_size =
...@@ -143,7 +143,7 @@ class Pad2dCompute : public KernelLite<TARGET(kOpenCL), ...@@ -143,7 +143,7 @@ class Pad2dCompute : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size[2])}; static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -76,37 +76,37 @@ class PoolCompute ...@@ -76,37 +76,37 @@ class PoolCompute
cl_int status; cl_int status;
auto numel = out_dims.production(); auto numel = out_dims.production();
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(numel)); status = kernel->setArg(arg_idx, static_cast<const int>(numel));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_buf); status = kernel->setArg(++arg_idx, *input_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(in_dims[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(in_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[3])); status = kernel->setArg(++arg_idx, static_cast<const int>(in_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(out_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[3])); status = kernel->setArg(++arg_idx, static_cast<const int>(out_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(ksize[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(ksize[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(strides[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(paddings[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_buf); status = kernel->setArg(++arg_idx, *output_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange(static_cast<size_t>(numel)); auto global_work_size = cl::NDRange(static_cast<size_t>(numel));
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -125,33 +125,33 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -125,33 +125,33 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
#endif #endif
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, *x_img); status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(in_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[3])); status = kernel->setArg(++arg_idx, static_cast<const int>(in_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(out_dims[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[3])); status = kernel->setArg(++arg_idx, static_cast<const int>(out_dims[3]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(ksize[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(ksize[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[1])); status = kernel->setArg(++arg_idx, static_cast<const int>(strides[1]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[2])); status = kernel->setArg(++arg_idx, static_cast<const int>(paddings[2]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0])); status = kernel->setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -122,31 +122,31 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL), ...@@ -122,31 +122,31 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
int arg_idx = 0; int arg_idx = 0;
cl_int status; cl_int status;
status = kernel.setArg(arg_idx, *x_image); status = kernel->setArg(arg_idx, *x_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_image); status = kernel->setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_C); status = kernel->setArg(++arg_idx, out_C);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_H); status = kernel->setArg(++arg_idx, out_H);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_W); status = kernel->setArg(++arg_idx, out_W);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_W); status = kernel->setArg(++arg_idx, in_W);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_H); status = kernel->setArg(++arg_idx, in_H);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride0); status = kernel->setArg(++arg_idx, in_Stride0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride1); status = kernel->setArg(++arg_idx, in_Stride1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride2); status = kernel->setArg(++arg_idx, in_Stride2);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride0); status = kernel->setArg(++arg_idx, out_Stride0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride1); status = kernel->setArg(++arg_idx, out_Stride1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride2); status = kernel->setArg(++arg_idx, out_Stride2);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = auto global_work_size =
...@@ -155,7 +155,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL), ...@@ -155,7 +155,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
static_cast<size_t>(default_work_size.data()[2])}; static_cast<size_t>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
...@@ -45,7 +45,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -45,7 +45,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
} }
void ReInitWhenNeeded() override { void ReInitWhenNeeded() override {
...@@ -82,19 +82,22 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -82,19 +82,22 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
auto kernel = kernel_; std::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
;
cl_int status; cl_int status;
status = kernel.setArg(0, *x_img); status = kernel->setArg(0, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img); status = kernel->setArg(1, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(2, scale); status = kernel->setArg(2, scale);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(3, bias); status = kernel->setArg(3, bias);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
...@@ -111,7 +114,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -111,7 +114,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
param_t* scale_param_{nullptr}; param_t* scale_param_{nullptr};
cl::Kernel kernel_; cl::Kernel kernel;
bool first_epoch_for_reinit_{true}; bool first_epoch_for_reinit_{true};
DDim last_x_dims_; DDim last_x_dims_;
DDim out_img_shape_ = DDim(std::vector<DDim::value_type>( DDim out_img_shape_ = DDim(std::vector<DDim::value_type>(
......
...@@ -75,15 +75,15 @@ class SliceComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -75,15 +75,15 @@ class SliceComputeImage2D : public KernelLite<TARGET(kOpenCL),
cl_int status; cl_int status;
int arg_idx = 0; int arg_idx = 0;
status = kernel.setArg(arg_idx, *x_img); status = kernel->setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img); status = kernel->setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start); status = kernel->setArg(++arg_idx, start);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, end); status = kernel->setArg(++arg_idx, end);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dim_w); status = kernel->setArg(++arg_idx, dim_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
const std::vector<size_t>& default_work_size = const std::vector<size_t>& default_work_size =
...@@ -97,7 +97,7 @@ class SliceComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -97,7 +97,7 @@ class SliceComputeImage2D : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size.data()[2])}; static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, *kernel.get(),
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册