From c0a3506f41edd171ee80a3b14b6f7e306a13d081 Mon Sep 17 00:00:00 2001 From: Yuan Shuai Date: Sun, 9 Feb 2020 22:02:24 -0600 Subject: [PATCH] [LITE][OPENCL] Add 3 kernels of ElementwiseAdd/FusionElemenwiseAddAct op with opencl image format (#2844) * [LITE][OPENCL] Add 3 kernels of ElementwiseAdd/FuseElementwiseAdd op. test=develop --- .../cl_kernel/image/elementwise_add_kernel.cl | 67 ++++- .../kernels/opencl/elementwise_add_compute.cc | 157 ++++++++++- lite/kernels/opencl/elementwise_add_compute.h | 26 ++ .../opencl/elementwise_add_compute_test.cc | 246 ++++++++++++++++-- ...sion_elementwise_add_activation_compute.cc | 54 +++- 5 files changed, 513 insertions(+), 37 deletions(-) diff --git a/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl b/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl index a95c6c6897..0d8867e6a7 100644 --- a/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl @@ -14,15 +14,72 @@ limitations under the License. */ #include -__kernel void elementwise_add(__read_only image2d_t input, __read_only image2d_t bias, __write_only image2d_t outputImage) { +__kernel void elementwise_add(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage) { int x = get_global_id(0); int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 coords; + coords.x = x; + coords.y = y; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords); + CL_DTYPE4 output = activation_type4(in + biase); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage,coords,output); + } + +__kernel void channel_add(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage, + int w) { + int x = get_global_id(0); + int y = get_global_id(1); + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; int2 coords; coords.x = x; coords.y = y; - float4 in = read_imagef(input, sampler, coords); - float4 biase = read_imagef(bias, sampler, coords); - float4 output = in + biase; - write_imagef(outputImage,coords,output); + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output = in + (CL_DTYPE4)(biase.x); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); } + +__kernel void width_add(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage, + int w) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output; + + output.x = in.x + biase.x; + output.y = in.y + biase.x; + output.z = in.z + biase.x; + output.w = in.w + biase.x; + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); +} diff --git a/lite/kernels/opencl/elementwise_add_compute.cc b/lite/kernels/opencl/elementwise_add_compute.cc index ad831010f8..72838b7c49 100644 --- a/lite/kernels/opencl/elementwise_add_compute.cc +++ b/lite/kernels/opencl/elementwise_add_compute.cc @@ -23,6 +23,8 @@ namespace lite { namespace kernels { namespace opencl { +/* Buffer */ +#if 0 void ElementwiseAddCompute::PrepareForRun() { auto& context = ctx_->As(); context.cl_context()->AddKernel( @@ -92,6 +94,124 @@ void ElementwiseAddCompute::UpdateParams() { VLOG(4) << "channels: " << channels_; VLOG(4) << "num: " << num_; } +#endif + +/* Image2D */ +void ElementwiseAddImageCompute::PrepareForRun() { + ele_param_ = param_.get_mutable(); + auto* x = ele_param_->X; + auto* y = ele_param_->Y; + auto axis = ele_param_->axis; + + if (y->dims().size() == 4) { + kernel_func_name_ = "elementwise_add"; // y: ImageDefault + } else if (y->dims().size() == 1) { + if (axis == x->dims().size() - 1) { + kernel_func_name_ = "width_add"; // y: ImageDefault + } else if (axis == x->dims().size() - 3) { + kernel_func_name_ = "channel_add"; // y: ImageFolder + } else { + LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + } else { + LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + VLOG(4) << "kernel_func_name_:" << kernel_func_name_; + + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/elementwise_add_kernel.cl", build_options_); +} + +void ElementwiseAddImageCompute::Run() { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto* x = ele_param_->X; + auto* y = ele_param_->Y; + auto* out = ele_param_->Out; + auto axis = ele_param_->axis; + + VLOG(4) << "x->target():" << TargetToStr(x->target()); + VLOG(4) << "y->target():" << TargetToStr(y->target()); + VLOG(4) << "out->target():" << TargetToStr(out->target()); + VLOG(4) << "x->dims():" << x->dims(); + VLOG(4) << "y->dims():" << y->dims(); + VLOG(4) << "out->dims():" << out->dims(); + VLOG(4) << "axis:" << axis; + + paddle::lite::CLImageConverterDefault default_convertor; + auto x_img_shape = default_convertor.InitImageDimInfoWith(x->dims()); // w, h + auto x_img_width = x_img_shape[0]; + auto x_img_height = x_img_shape[1]; + auto out_img_shape = + default_convertor.InitImageDimInfoWith(out->dims()); // w, h + auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims()); + + auto* x_img = x->data(); + auto* y_img = y->data(); + auto* out_img = + out->mutable_data(out_img_shape[0], out_img_shape[1]); + + VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height; + VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1]; + VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " " + << out_img_shape[1]; + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + auto y_dims = y->dims(); + if (y_dims.size() == 4) { + cl_int status = kernel.setArg(arg_idx, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_img); + CL_CHECK_FATAL(status); + } else if (y_dims.size() == 1) { + if (axis == x->dims().size() - 1 || axis == x->dims().size() - 3) { + int tensor_w = x->dims()[x->dims().size() - 1]; + VLOG(4) << "tensor_w:" << tensor_w; + + cl_int status = kernel.setArg(arg_idx, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(tensor_w)); + CL_CHECK_FATAL(status); + } else { + LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + } else { + LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + + auto global_work_size = cl::NDRange{static_cast(x_img_width), + static_cast(x_img_height)}; + VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height; + auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_img, event_); +} } // namespace opencl } // namespace kernels @@ -99,9 +219,36 @@ void ElementwiseAddCompute::UpdateParams() { } // namespace paddle namespace ocl = paddle::lite::kernels::opencl; -REGISTER_LITE_KERNEL( - elementwise_add, kOpenCL, kFloat, kNCHW, ocl::ElementwiseAddCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + +// REGISTER_LITE_KERNEL( +// elementwise_add, kOpenCL, kFloat, kNCHW, ocl::ElementwiseAddCompute, def) +// .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindInput("Y", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .Finalize(); + +// TODO(ysh329): Not fix. +// "Y" may from constant value like conv bias (kARM, need do cl_image_converter +// on CPU); +// may from anther branch like "X" (kOpenCL, nothing to do). +// Consider 2 situations have different actions when pass running(pick kernel), +// set target of "Y" as kOpenCL temporarily. +REGISTER_LITE_KERNEL(elementwise_add, + kOpenCL, + kFloat, + kImageDefault, + ocl::ElementwiseAddImageCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) .Finalize(); diff --git a/lite/kernels/opencl/elementwise_add_compute.h b/lite/kernels/opencl/elementwise_add_compute.h index bd0398ca3f..efc7f58f44 100644 --- a/lite/kernels/opencl/elementwise_add_compute.h +++ b/lite/kernels/opencl/elementwise_add_compute.h @@ -33,6 +33,10 @@ class ElementwiseAddCompute void Run() override; + std::string doc() const override { + return "ElementwiseAdd using cl::Buffer, kFloat"; + } + protected: void UpdateParams(); @@ -45,6 +49,28 @@ class ElementwiseAddCompute std::shared_ptr event_{new cl::Event}; }; +class ElementwiseAddImageCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void PrepareForRun() override; + + void Run() override; + + std::string doc() const override { + return "ElementwiseAdd using cl::Image2D, kFloat"; + } + + protected: + param_t* ele_param_{nullptr}; + std::string kernel_func_name_{"elementwise_add"}; + std::string build_options_{" -DCL_DTYPE_float"}; + std::shared_ptr event_{new cl::Event}; +}; + } // namespace opencl } // namespace kernels } // namespace lite diff --git a/lite/kernels/opencl/elementwise_add_compute_test.cc b/lite/kernels/opencl/elementwise_add_compute_test.cc index 69df2313bb..06f946bca7 100644 --- a/lite/kernels/opencl/elementwise_add_compute_test.cc +++ b/lite/kernels/opencl/elementwise_add_compute_test.cc @@ -22,6 +22,19 @@ namespace paddle { namespace lite { +template +void fill_data(dtype *x, const int length, int set_value = -1) { + if (set_value == -1) { + for (size_t idx = 0; idx < length; ++idx) { + x[idx] = idx; + } + } else if (set_value != -1) { + for (size_t idx = 0; idx < length; ++idx) { + x[idx] = set_value; + } + } +} + template void elementwise_compute_ref(const dtype *x_data, const dtype *y_data, @@ -46,25 +59,17 @@ void elementwise_compute_ref(const dtype *x_data, for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { num *= x_dims[i]; } + VLOG(4) << "axis:" << axis; + VLOG(4) << "batch:" << batch; + VLOG(4) << "cahnnels:" << channels; + VLOG(4) << "num:" << num; // do elementwise add/sub/max/... - if (elt_type == "add") { - for (int i = 0; i < batch; ++i) { - for (int j = 0; j < channels; ++j) { - int offset = (i * channels + j) * num; - const dtype *din_ptr = x_data + offset; - const dtype diny_data = y_data[j]; - dtype *dout_ptr = out_data + offset; - for (int k = 0; k < num; ++k) { - *dout_ptr = *din_ptr + diny_data; - if (use_relu) { - *dout_ptr = std::max(*dout_ptr, static_cast(0)); - } - dout_ptr++; - din_ptr++; - } - } + if (elt_type == "add" && axis == 1 && y_dims.size() == 1) { + for (int i = 0; i < x_dims.production(); ++i) { + auto w = i % y_dims.production(); + out_data[i] = x_data[i] + y_data[w]; } - } else if (elt_type == "sub") { + } else if (elt_type == "add") { for (int i = 0; i < batch; ++i) { for (int j = 0; j < channels; ++j) { int offset = (i * channels + j) * num; @@ -72,7 +77,7 @@ void elementwise_compute_ref(const dtype *x_data, const dtype diny_data = y_data[j]; dtype *dout_ptr = out_data + offset; for (int k = 0; k < num; ++k) { - *dout_ptr = *din_ptr - diny_data; + *dout_ptr = *din_ptr + diny_data; if (use_relu) { *dout_ptr = std::max(*dout_ptr, static_cast(0)); } @@ -86,7 +91,9 @@ void elementwise_compute_ref(const dtype *x_data, } } -TEST(elementwise_add, compute) { +// buffer +#if 0 +TEST(elementwise_add_buffer, compute) { LOG(INFO) << "to get kernel ..."; auto kernels = KernelRegistry::Global().Create( "elementwise_add", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); @@ -163,7 +170,7 @@ TEST(elementwise_add, compute) { TargetWrapperCL::Unmap(out_data, mapped_out); } -TEST(fusion_elementwise_add_activation, compute) { +TEST(fusion_elementwise_add_activation_buffer, compute) { LOG(INFO) << "to get kernel ..."; auto kernels = KernelRegistry::Global().Create("fusion_elementwise_add_activation", @@ -243,9 +250,204 @@ TEST(fusion_elementwise_add_activation, compute) { } TargetWrapperCL::Unmap(out_data, mapped_out); } +#endif + +// image +TEST(elementwise_add_image2d_fp32, compute) { + LOG(INFO) << "main steps of test: host -> layout(buf2img on cpu) -> " + "elementwise_add(img) -> " + "layout(img2buf on cpu) " + "-> host"; + + // elementwise_add's 3 kernels selection routing strategy: + // -------------------------------------------------------- + // 1. elementwise_add: Need y_dim.size() == 4 + // 2. elementwise_add (used by fuse_elementwise_activation op): + // Need y_dim.size() == 4 && act_type == "relu" + // 3. width_add: Need y_dim.size() == 1 && x_dim.size() == 4 && axis == + // 3 + // 4. channel_add: Need y_dim.size() == 1 && x_dim.size() == 4 && axis == + // 1 + + // dims + const int n = 1; + const int c = 3; + const int h = 2; + const int w = 2; + + const DDim x_dim = DDim(std::vector{n, c, h, w}); + auto out_dim = x_dim; + // y_dim / axis / relu_flag + std::vector y_dim_v{DDim(std::vector{n, c, h, w}), + DDim(std::vector{n, c, h, w}), + DDim(std::vector{w}), + DDim(std::vector{w})}; + std::vector axis_v{-1, -1, 3, 1}; + std::vector relu_flag_v{false, true, false, false}; + CHECK(y_dim_v.size() == axis_v.size() && axis_v.size() == relu_flag_v.size()) + << "y_dim_v.size() == axis_v.size() == relu_flag_v.size() should be " + "same, and be corresponding " + "one by one"; + + // start loop + for (size_t case_idx = 0; case_idx < y_dim_v.size(); ++case_idx) { + auto y_dim = y_dim_v[case_idx]; + auto axis = axis_v[case_idx]; + auto relu_flag = relu_flag_v[case_idx]; + LOG(INFO) << "================== elementwise_add, case_idx:" << case_idx + 1 + << "/" << y_dim_v.size() << " ==================="; + LOG(INFO) << "x_dim:" << x_dim; + LOG(INFO) << "y_dim:" << y_dim; + LOG(INFO) << "out_dim:" << out_dim; + LOG(INFO) << "axis:" << axis; + LOG(INFO) << "relu_flag:" << relu_flag; + + // tensor + VLOG(4) << "set tensors about op param"; + lite::Tensor eleadd_x, eleadd_y, eleadd_out; + eleadd_x.Resize(x_dim); + eleadd_y.Resize(y_dim); + eleadd_out.Resize(out_dim); + + // initialize tensors + VLOG(4) << "initialize tensors"; + paddle::lite::CLImageConverterDefault default_convertor; + // x + std::vector x_v(x_dim.production()); + fill_data(x_v.data(), x_v.size()); // fill with index value + auto x_img_shape = default_convertor.InitImageDimInfoWith(x_dim); // w, h + auto x_img_w = x_img_shape[0]; + auto x_img_h = x_img_shape[1]; + std::vector x_img_v(x_img_w * x_img_h * 4); // 4: RGBA + default_convertor.NCHWToImage(x_v.data(), x_img_v.data(), x_dim); + eleadd_x.mutable_data(x_img_w, x_img_h, x_img_v.data()); + + // y + std::vector y_v(y_dim.production()); + fill_data(y_v.data(), y_v.size()); // fill with index value + auto y_img_shape = default_convertor.InitImageDimInfoWith(y_dim); // w, h + auto y_img_w = y_img_shape[0]; + auto y_img_h = y_img_shape[1]; + std::vector y_img_v(y_img_shape[0] * y_img_shape[1] * 4); // 4: RGBA + default_convertor.NCHWToImage(y_v.data(), y_img_v.data(), y_dim); + eleadd_y.mutable_data(y_img_w, y_img_h, y_img_v.data()); + + // out + auto out_img_shape = + default_convertor.InitImageDimInfoWith(out_dim); // w, h + auto out_img_w = out_img_shape[0]; + auto out_img_h = out_img_shape[1]; + eleadd_out.mutable_data(out_img_w, out_img_h); + + std::vector out_img_v(out_img_w * out_img_h * 4); + fill_data( + out_img_v.data(), out_img_v.size(), 0); // fill with zero value + + std::vector out_v(out_dim.production()); + + // operator param + operators::FusionElementwiseActivationParam + fuseEleaddParam; // enabled if relu_flag is true + fuseEleaddParam.X = &eleadd_x; + fuseEleaddParam.Y = &eleadd_y; + fuseEleaddParam.Out = &eleadd_out; + fuseEleaddParam.axis = axis; + fuseEleaddParam.act_type = relu_flag ? "relu" : ""; + + operators::ElementwiseParam eleaddParam; + eleaddParam.X = &eleadd_x; + eleaddParam.Y = &eleadd_y; + eleaddParam.Out = &eleadd_out; + eleaddParam.axis = axis; + + auto op_param = relu_flag ? fuseEleaddParam : eleaddParam; + + // set kernel + auto eleadd_img_kernels = + KernelRegistry::Global().Create("elementwise_add", + TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(eleadd_img_kernels.empty()); + + auto eleadd_img_kernel = std::move(eleadd_img_kernels.front()); + VLOG(4) << "get eleadd kernel: " << eleadd_img_kernel->doc(); + + // set context and kernel args + VLOG(4) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + eleadd_img_kernel->SetParam(op_param); + std::unique_ptr eleadd_img_context(new KernelContext); + context->As().CopySharedTo( + &(eleadd_img_context->As())); + eleadd_img_kernel->SetContext(std::move(eleadd_img_context)); + + // run kernel + VLOG(4) << "run kernel"; + eleadd_img_kernel->Launch(); + + // download gpu result to cpu + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + TargetWrapperCL::ImgcpySync(out_img_v.data(), + eleadd_out.data(), + out_img_w, + out_img_h, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + default_convertor.ImageToNCHW( + out_img_v.data(), out_v.data(), out_img_shape, out_dim); + + // compute cpu reference + std::unique_ptr out_ref(new float[out_dim.production()]); + elementwise_compute_ref(x_v.data(), + y_v.data(), + out_ref.get(), + x_dim, + y_dim, + op_param.axis, + "add", + relu_flag); + +#if 0 // enable to check value of x and y + for (int eidx = 0; eidx < out_dim.production(); eidx++) { + auto value = out_v[eidx]; + auto ref_value = out_ref.get()[eidx]; + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / " + << out_dim.production() << ", x_v[" << eidx << "]:" + << x_v[eidx] << ", value[" << eidx << "]:" << value + << ", ref_value[" << eidx << "]:" << ref_value; + } + + for (int i = 0; i < y_v.size(); i++) { + LOG(INFO) << "y_v[" << i << "]:" << y_v[i]; + } +#endif + + for (int eidx = 0; eidx < out_dim.production(); eidx++) { + auto value = out_v[eidx]; + auto ref_value = out_ref.get()[eidx]; + EXPECT_NEAR(value, ref_value, 1e-6); + if (abs(value - ref_value) > 1e-6) { + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / " + << out_dim.production() << ", value[" << eidx << "]:" << value + << ", ref_value[" << eidx << "]:" << ref_value; + break; + } + } + } +} } // namespace lite } // namespace paddle -USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(fusion_elementwise_add_activation, kOpenCL, kFloat, kNCHW, def); +// USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW, def); +// USE_LITE_KERNEL(fusion_elementwise_add_activation, kOpenCL, kFloat, kNCHW, +// def); + +USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kImageDefault, def); +USE_LITE_KERNEL( + fusion_elementwise_add_activation, kOpenCL, kFloat, kImageDefault, def); diff --git a/lite/kernels/opencl/fusion_elementwise_add_activation_compute.cc b/lite/kernels/opencl/fusion_elementwise_add_activation_compute.cc index ad17575d69..c6e1510efe 100644 --- a/lite/kernels/opencl/fusion_elementwise_add_activation_compute.cc +++ b/lite/kernels/opencl/fusion_elementwise_add_activation_compute.cc @@ -20,6 +20,9 @@ namespace paddle { namespace lite { namespace kernels { namespace opencl { + +/* Buffer */ +#if 0 class FusionElementwiseAddActivationCompute : public ElementwiseAddCompute { public: using param_t = operators::FusionElementwiseActivationParam; @@ -38,19 +41,60 @@ class FusionElementwiseAddActivationCompute : public ElementwiseAddCompute { } } }; +#endif + +class FusionElementwiseAddActivationImageCompute + : public ElementwiseAddImageCompute { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void PrepareForRun() override { + build_options_ += " -DRELU"; + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/elementwise_add_kernel.cl", build_options_); + ele_param_ = param_.get_mutable(); + auto act_t = static_cast(ele_param_)->act_type; + VLOG(4) << "act: " << act_t; + if (act_t != "relu") { + LOG(FATAL) << "Unsupported Activation type: " << act_t; + } + } +}; + } // namespace opencl } // namespace kernels } // namespace lite } // namespace paddle namespace ocl = paddle::lite::kernels::opencl; +// REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, +// kOpenCL, +// kFloat, +// kNCHW, +// ocl::FusionElementwiseAddActivationCompute, +// def) +// .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindInput("Y", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .Finalize(); + REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, kOpenCL, kFloat, - kNCHW, - ocl::FusionElementwiseAddActivationCompute, + kImageDefault, + ocl::FusionElementwiseAddActivationImageCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) .Finalize(); -- GitLab