提交 6262bc32 编写于 作者: 开心的小妮's avatar 开心的小妮

[LITE][OPENCL] Make comptue kernel hold cl::kernel. test=develop

上级 293d2d38
......@@ -77,14 +77,10 @@ class ActivationComputeImageDefault
#endif
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/activation_kernel.cl",
build_options_,
time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/activation_kernel.cl",
build_options_,
time_stamp_);
}
void ReInitWhenNeeded() override {
......@@ -118,15 +114,14 @@ class ActivationComputeImageDefault
auto* out_img = act_param_->Out->mutable_data<half_t, cl::Image2D>(
out_img_shape_[0], out_img_shape_[1]);
auto kernel = kernel_;
cl_int status;
status = kernel.setArg(0, *x_img);
status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img);
status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, threshold_);
status = kernel_->setArg(2, threshold_);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, scale_);
status = kernel_->setArg(3, scale_);
CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG
......@@ -148,7 +143,7 @@ class ActivationComputeImageDefault
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size_,
cl::NullRange,
......@@ -168,7 +163,7 @@ class ActivationComputeImageDefault
std::string kernel_func_name_{};
float threshold_{6.f};
float scale_{1.f};
cl::Kernel kernel_;
std::shared_ptr<cl::Kernel> kernel_;
bool first_epoch_for_reinit_{true};
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
......
......@@ -40,10 +40,10 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
kernel_func_name_ = "concat_mul";
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(kernel_func_name_,
"image/concat_kernel.cl",
build_options_,
time_stamp_);
kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/concat_kernel.cl",
build_options_,
time_stamp_);
auto axis = concat_param_->axis;
auto inputs = concat_param_->x;
......@@ -118,8 +118,6 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto inputs = param.x;
int arg_idx = 0;
......@@ -164,31 +162,30 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
<< (image_shape["height"]);
#endif
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int out_w = x_dims[x_dims.size() - 1];
int out_c = x_dims[1];
if (inputs.size() == 2) {
auto* x_buf0 = inputs[0]->data<half_t, cl::Image2D>();
auto* x_buf1 = inputs[1]->data<half_t, cl::Image2D>();
cl_int status = kernel.setArg(arg_idx, *x_buf0);
cl_int status = kernel_->setArg(arg_idx, *x_buf0);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_buf1);
status = kernel_->setArg(++arg_idx, *x_buf1);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
status = kernel_->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag_);
status = kernel_->setArg(++arg_idx, flag_);
CL_CHECK_FATAL(status);
status =
kernel.setArg(++arg_idx, static_cast<int>(inputs[0]->dims()[axis_]));
status = kernel_->setArg(++arg_idx,
static_cast<int>(inputs[0]->dims()[axis_]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_c);
status = kernel_->setArg(++arg_idx, out_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w);
status = kernel_->setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_);
status = kernel_->setArg(++arg_idx, width_);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size,
cl::NullRange,
......@@ -213,25 +210,25 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(image_shape["width"] /
in_dims[in_dims.size() - 1]),
static_cast<cl::size_type>(image_shape["height"])};
cl_int status = kernel.setArg(arg_idx, *x_buf);
cl_int status = kernel_->setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
status = kernel_->setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag_);
status = kernel_->setArg(++arg_idx, flag_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start);
status = kernel_->setArg(++arg_idx, start);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_c);
status = kernel_->setArg(++arg_idx, out_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w);
status = kernel_->setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_w);
status = kernel_->setArg(++arg_idx, in_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_);
status = kernel_->setArg(++arg_idx, width_);
CL_CHECK_FATAL(status);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size,
cl::NullRange,
......@@ -255,6 +252,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
std::string build_options_{" -DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
};
} // namespace opencl
......
......@@ -71,7 +71,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
int default_w_blk_ = 1;
int default_nh_blk_ = 1;
cl::Kernel kernel_;
std::shared_ptr<cl::Kernel> kernel_;
cl::NDRange local_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
bool use_lws_{true};
......
......@@ -59,14 +59,11 @@ void ElementwiseAddImageCompute::ReInitWhenNeeded() {
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/elementwise_add_kernel.cl",
build_options_,
time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
kernel_ =
context.cl_context()->CreateKernel(kernel_func_name_,
"image/elementwise_add_kernel.cl",
build_options_,
time_stamp_);
// compute image shape
paddle::lite::CLImageConverterDefault default_convertor;
......@@ -118,13 +115,12 @@ void ElementwiseAddImageCompute::Run() {
#endif
cl_int status;
auto kernel = kernel_;
if (y_dims.size() == 4) {
status = kernel.setArg(0, *x_img);
status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1) {
if (axis == x_dims.size() - 1 || axis == x_dims.size() - 3) {
......@@ -132,13 +128,13 @@ void ElementwiseAddImageCompute::Run() {
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "tensor_w:" << tensor_w;
#endif
status = kernel.setArg(0, *x_img);
status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w);
status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else {
LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis
......@@ -154,7 +150,7 @@ void ElementwiseAddImageCompute::Run() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size_,
cl::NullRange,
......
......@@ -60,7 +60,7 @@ class ElementwiseAddImageCompute
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
bool first_epoch_for_reinit_{true};
cl::Kernel kernel_;
std::shared_ptr<cl::Kernel> kernel_;
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::shared_ptr<cl::Event> event_{new cl::Event};
......
......@@ -71,10 +71,11 @@ class ElementwiseMulImageCompute
VLOG(4) << "bias_dims.size():" << bias_dims.size();
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/elementwise_mul_kernel.cl",
build_options_,
time_stamp_);
kernel_ =
context.cl_context()->CreateKernel(kernel_func_name_,
"image/elementwise_mul_kernel.cl",
build_options_,
time_stamp_);
}
void Run() override {
......@@ -115,66 +116,61 @@ class ElementwiseMulImageCompute
<< out_img_shape[1];
#endif
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
auto bias_dims = y->dims();
auto x_dims = x->dims();
if (bias_dims == x_dims) {
// kernel_func_name_ = "elementwise_mul";
cl_int status = kernel.setArg(0, *x_img);
cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status);
} else {
const int bias_dim_size = bias_dims.size();
if (bias_dim_size == 1) {
// kernel_func_name_ = "channel_mul_d1";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w);
status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else if (bias_dim_size == 2) {
// kernel_func_name_ = "channel_mul_d2";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w);
status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else if (bias_dim_size == 3) {
// kernel_func_name_ = "channel_mul_d3";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w);
status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else if (bias_dim_size == 4) {
// kernel_func_name_ = "channel_mul_d4";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
cl_int status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
status = kernel_->setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
status = kernel_->setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w);
status = kernel_->setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else {
LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
......@@ -186,7 +182,7 @@ class ElementwiseMulImageCompute
cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size,
cl::NullRange,
......@@ -205,6 +201,7 @@ class ElementwiseMulImageCompute
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
};
} // namespace opencl
......
......@@ -75,13 +75,10 @@ class FcCompute
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"buffer/fc_kernel.cl",
build_options_,
time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"buffer/fc_kernel.cl",
build_options_,
time_stamp_);
// compute global work size
GetGlobalWorkSize();
......@@ -106,25 +103,25 @@ class FcCompute
auto kernel = kernel_;
cl_int status;
status = kernel.setArg(0, *x_buf);
status = kernel_->setArg(0, *x_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *w_buf);
status = kernel_->setArg(1, *w_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *bias_buf);
status = kernel_->setArg(2, *bias_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, *out_buf);
status = kernel_->setArg(3, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(4, static_cast<const int>(m_));
status = kernel_->setArg(4, static_cast<const int>(m_));
CL_CHECK_FATAL(status);
status = kernel.setArg(5, static_cast<const int>(n_));
status = kernel_->setArg(5, static_cast<const int>(n_));
CL_CHECK_FATAL(status);
status = kernel.setArg(6, static_cast<const int>(k_));
status = kernel_->setArg(6, static_cast<const int>(k_));
CL_CHECK_FATAL(status);
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel.get()),
cl::NullRange,
global_work_size_,
cl::NullRange,
......@@ -143,7 +140,7 @@ class FcCompute
bool first_epoch_for_reinit_{true};
DDim last_x_dims_;
cl::NDRange global_work_size_;
cl::Kernel kernel_;
std::shared_ptr<cl::Kernel> kernel_;
std::shared_ptr<cl::Event> event_{new cl::Event};
};
......
......@@ -31,10 +31,11 @@ class FusionElementwiseAddActivationImageCompute
void PrepareForRun() override {
build_options_ += " -DRELU";
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/elementwise_add_kernel.cl",
build_options_,
time_stamp_);
kernel_ =
context.cl_context()->CreateKernel(kernel_func_name_,
"image/elementwise_add_kernel.cl",
build_options_,
time_stamp_);
ele_param_ = param_.get_mutable<param_t>();
auto act_t = static_cast<param_t*>(ele_param_)->act_type;
VLOG(4) << "act: " << act_t;
......
......@@ -38,10 +38,11 @@ class NearestInterpComputeImageDefault
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/nearest_interp_kernel.cl",
build_options_,
time_stamp_);
kernel_ =
context.cl_context()->CreateKernel(kernel_func_name_,
"image/nearest_interp_kernel.cl",
build_options_,
time_stamp_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
}
......@@ -67,26 +68,23 @@ class NearestInterpComputeImageDefault
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_img);
cl_int status;
status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const float>(scale_h));
status = kernel_->setArg(2, static_cast<const float>(scale_h));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const float>(scale_w));
status = kernel_->setArg(3, static_cast<const float>(scale_w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims_h));
status = kernel_->setArg(4, static_cast<const int>(in_dims_h));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims_h));
status = kernel_->setArg(5, static_cast<const int>(out_dims_h));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims_w));
status = kernel_->setArg(6, static_cast<const int>(in_dims_w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims_w));
status = kernel_->setArg(7, static_cast<const int>(out_dims_w));
CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG
......@@ -110,7 +108,7 @@ class NearestInterpComputeImageDefault
static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size,
cl::NullRange,
......@@ -125,6 +123,7 @@ class NearestInterpComputeImageDefault
std::string build_options_{" -DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
};
} // namespace opencl
......
......@@ -46,7 +46,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_ = context.cl_context()->CreateKernel(
kernel_func_name_, "image/pool_kernel.cl", build_options_, time_stamp_);
}
......@@ -111,10 +111,6 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
out_image_shape["width"], out_image_shape["height"]);
// VLOG(4) << "out_image" << out_img;
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (out_dims[1] + 3) / 4;
int w = out_dims[3];
int nh = out_dims[0] * out_dims[2];
......@@ -124,34 +120,33 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
<< " " << nh << " ";
#endif
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, *x_img);
status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[2]));
status = kernel_->setArg(2, static_cast<const int>(in_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(in_dims[3]));
status = kernel_->setArg(3, static_cast<const int>(in_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[2]));
status = kernel_->setArg(4, static_cast<const int>(out_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_dims[3]));
status = kernel_->setArg(5, static_cast<const int>(out_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[0]));
status = kernel_->setArg(6, static_cast<const int>(ksize[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(ksize[1]));
status = kernel_->setArg(7, static_cast<const int>(ksize[1]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
status = kernel_->setArg(8, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[1]));
status = kernel_->setArg(9, static_cast<const int>(strides[1]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[2]));
status = kernel_->setArg(10, static_cast<const int>(paddings[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
status = kernel_->setArg(11, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size,
cl::NullRange,
......@@ -162,6 +157,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
}
private:
std::shared_ptr<cl::Kernel> kernel_;
std::string kernel_func_name_{"pool_"};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
......
......@@ -36,10 +36,10 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(kernel_func_name_,
"image/reshape_kernel.cl",
build_options_,
time_stamp_);
kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/reshape_kernel.cl",
build_options_,
time_stamp_);
}
void Run() override {
......@@ -111,42 +111,38 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << TargetToStr(x->target());
VLOG(4) << TargetToStr(param.output->target());
#endif
int arg_idx = 0;
cl_int status;
status = kernel.setArg(arg_idx, *x_image);
status = kernel_->setArg(0, *x_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_image);
status = kernel_->setArg(1, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_C);
status = kernel_->setArg(2, out_C);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_H);
status = kernel_->setArg(3, out_H);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_W);
status = kernel_->setArg(4, out_W);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_W);
status = kernel_->setArg(5, in_W);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_H);
status = kernel_->setArg(6, in_H);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride0);
status = kernel_->setArg(7, in_Stride0);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride1);
status = kernel_->setArg(8, in_Stride1);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, in_Stride2);
status = kernel_->setArg(9, in_Stride2);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride0);
status = kernel_->setArg(10, out_Stride0);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride1);
status = kernel_->setArg(11, out_Stride1);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_Stride2);
status = kernel_->setArg(12, out_Stride2);
CL_CHECK_FATAL(status);
auto global_work_size =
......@@ -155,7 +151,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
static_cast<size_t>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size,
cl::NullRange,
......@@ -170,6 +166,7 @@ class ReshapeComputeFloatImage : public KernelLite<TARGET(kOpenCL),
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Kernel> kernel_;
};
} // namespace opencl
......
......@@ -37,15 +37,11 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/scale_kernel.cl",
build_options_,
time_stamp_);
kernel_ = context.cl_context()->CreateKernel(kernel_func_name_,
"image/scale_kernel.cl",
build_options_,
time_stamp_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
}
void ReInitWhenNeeded() override {
......@@ -82,19 +78,18 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto kernel = kernel_;
cl_int status;
status = kernel.setArg(0, *x_img);
status = kernel_->setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img);
status = kernel_->setArg(1, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, scale);
status = kernel_->setArg(2, scale);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, bias);
status = kernel_->setArg(3, bias);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
*(kernel_.get()),
cl::NullRange,
global_work_size_,
cl::NullRange,
......@@ -111,7 +106,7 @@ class ScaleComputeImage2D : public KernelLite<TARGET(kOpenCL),
std::shared_ptr<cl::Event> event_{new cl::Event};
param_t* scale_param_{nullptr};
cl::Kernel kernel_;
std::shared_ptr<cl::Kernel> kernel_;
bool first_epoch_for_reinit_{true};
DDim last_x_dims_;
DDim out_img_shape_ = DDim(std::vector<DDim::value_type>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册