未验证 提交 66f0b25b 编写于 作者: X xiaogang 提交者: GitHub

[LITE][OPENCL] add slice kernel (#3126)

* feat: add opencl elementwise_sub op & ut
上级 34c29406
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void slice(__read_only image2d_t input, __write_only image2d_t output,
__private const int start, __private const int end,
__private const int dims_w){
const int c = get_global_id(0);
const int w = get_global_id(1);
const int nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = c * dims_w + w;
output_pos.y = nh;
int2 input_pos;
half4 input_data;
half4 output_data;
if (start % 4 == 0) {
input_pos.x = (4 * c + start) / 4 * dims_w + w;
input_pos.y = nh;
input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler,input_pos);
output_data = input_data;
} else if (start % 4 == 1) {
input_pos.x = (4 * c + start) / 4 * dims_w + w;
input_pos.y = nh;
input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler,input_pos);
output_data.x = input_data.y;
output_data.y = input_data.z;
output_data.z = input_data.w;
input_pos.x = input_pos.x + dims_w;
input_pos.y = nh;
input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler,input_pos);
output_data.w = input_data.x;
} else if (start % 4 == 2) {
input_pos.x = (4 * c + start) / 4 * dims_w + w;
input_pos.y = nh;
input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler,input_pos);
output_data.x = input_data.z;
output_data.y = input_data.w;
input_pos.x = input_pos.x + dims_w;
input_pos.y = nh;
input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler,input_pos);
output_data.z = input_data.x;
output_data.w = input_data.y;
} else if (start % 4 == 3) {
input_pos.x = (4 * c + start) / 4 * dims_w + w;
input_pos.y = nh;
input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler,input_pos);
output_data.x = input_data.w;
input_pos.x = input_pos.x + dims_w;
input_pos.y = nh;
input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler,input_pos);
output_data.y = input_data.x;
output_data.z = input_data.y;
output_data.w = input_data.z;
}
write_imageh(output, output_pos, output_data);
}
......@@ -29,6 +29,7 @@ add_kernel(scale_opencl OPENCL basic SRCS scale_image_compute.cc DEPS ${cl_kerne
add_kernel(grid_sampler_opencl OPENCL basic SRCS grid_sampler_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(lrn_opencl OPENCL basic SRCS lrn_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(bilinear_interp_opencl OPENCL basic SRCS bilinear_interp_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(slice_opencl OPENCL basic SRCS slice_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(instance_norm_opencl OPENCL basic SRCS instance_norm_image_compute.cc DEPS ${cl_kernel_deps})
# extra
......@@ -71,11 +72,9 @@ lite_cc_test(test_layout_image_opencl SRCS layout_image_compute_test.cc
DEPS layout_opencl op_registry program context)
lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_compute_test.cc
DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context)
lite_cc_test(test_elementwise_sub_image_opencl SRCS elementwise_sub_image_compute_test.cc
DEPS elementwise_sub_opencl fusion_elementwise_sub_activation_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS elementwise_sub_opencl fusion_elementwise_sub_activation_opencl op_registry program context)
lite_cc_test(test_grid_sampler_image_opencl SRCS grid_sampler_image_compute_test.cc
DEPS grid_sampler_opencl op_registry program context)
......@@ -85,7 +84,9 @@ lite_cc_test(test_lrn_image_opencl SRCS lrn_image_compute_test.cc
lite_cc_test(test_bilinear_interp_image_opencl SRCS bilinear_interp_image_compute_test.cc
DEPS bilinear_interp_opencl op_registry program context)
lite_cc_test(test_slice_image_opencl SRCS slice_image_compute_test.cc
DEPS slice_opencl op_registry program context)
lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_test.cc
DEPS instance_norm_opencl op_registry program context)
######################
......
......@@ -46,7 +46,7 @@ void ElementwiseSubImageCompute::PrepareForRun() {
<< ", x->dims().size():" << x->dims().size()
<< ", y->dims.size():" << y->dims().size();
}
VLOG(4) << "kernel_func_name_:" << kernel_func_name_;
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class SliceComputeImage2D : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::SliceParam;
std::string doc() const override { return "Slice using cl::Image2D, kFP16"; }
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(
kernel_func_name_, "image/slice_kernel.cl", build_options_);
}
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
const auto& in_dims = param.X->dims();
auto* x_img = param.X->data<half_t, cl::Image2D>();
auto& out_dims = param.Out->dims();
std::vector<int> axes = param.axes;
std::vector<int32_t> starts = param.starts;
std::vector<int32_t> ends = param.ends;
if (axes.size() > 1 || axes[0] != 1) {
LOG(FATAL) << "opencl slice_image only support channel slice ";
}
int axis = axes[0];
int start = starts[0];
int end = ends[0];
int dim_w = in_dims[axis + 2];
auto out_image_shape = InitImageDimInfoWith(out_dims);
auto* out_img = param.Out->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, end);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dim_w);
CL_CHECK_FATAL(status);
const std::vector<size_t>& default_work_size =
DefaultWorkSize(out_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(default_work_size.data()[0]),
static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
}
private:
std::string kernel_func_name_{"slice"};
std::string build_options_{"-DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(slice,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::SliceComputeImage2D,
image2d)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/test_helper.h"
#define FP16_MAX_DIFF (5e-1)
namespace paddle {
namespace lite {
void slice_channel(const float* input_data,
const DDim& in_dim,
float* output_data,
const int start,
const int end) {
int n = in_dim[0];
int in_n_stride = 1;
for (int i = 1; i < in_dim.size(); ++i) {
in_n_stride *= in_dim[i];
}
int in_c_stride = in_n_stride / in_dim[1];
int mini_batch = end - start;
for (int ni = 0; ni < n; ++ni) {
const float* in_n = input_data + ni * in_n_stride + start * in_c_stride;
float* out_n = output_data + ni * mini_batch * in_c_stride;
memcpy(out_n, in_n, sizeof(float) * mini_batch * in_c_stride);
}
}
TEST(slice_image2d_fp16, compute) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"slice", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
LOG(INFO) << "get kernel:" << kernel->doc();
lite::Tensor x, out;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
param.axes = std::vector<int>({1});
param.starts = std::vector<int32_t>({2});
param.ends = std::vector<int32_t>({5});
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> slice_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(slice_context->As<OpenCLContext>()));
kernel->SetContext(std::move(slice_context));
const DDim in_dim = DDim(std::vector<DDim::value_type>{3, 11, 107, 218});
const DDim out_dim = DDim(std::vector<DDim::value_type>{3, 3, 107, 218});
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-5, 5);
std::vector<float> input_v(3 * 11 * 107 * 218);
for (auto& i : input_v) {
i = dist(engine);
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "image_shape = " << image_shape[0] << " " << image_shape[1];
std::vector<half_t> x_image_data(image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
image_shape[0], image_shape[1], x_image_data.data());
LOG(INFO) << "x_image:" << x_image;
auto* out_image =
out.mutable_data<half_t, cl::Image2D>(image_shape[0], image_shape[1]);
LOG(INFO) << "out_image:" << out_image;
kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.Out->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
std::unique_ptr<float[]> out_ref(new float[out_dim.production()]);
slice_channel(input_v.data(), in_dim, out_ref.get(), 2, 5);
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
image_shape[0],
image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
float* out_data = new float[image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, image_shape, out_dim);
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_ref[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_ref[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_ref[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(slice, kOpenCL, kFP16, kImageDefault, image2d);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册