提交 137a3640 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] add relu6 opencl kernel. test=develop (#2802)

上级 4d35e647
......@@ -72,7 +72,7 @@ inline CL_DTYPE activation(CL_DTYPE in
CL_DTYPE prelu_alpha
#endif
) {
CL_DTYPE output;
CL_DTYPE output = in;
#ifdef PRELU
output = select(prelu_alpha * in, in, in >= (CL_DTYPE)0);
#endif
......@@ -80,6 +80,10 @@ inline CL_DTYPE activation(CL_DTYPE in
#ifdef RELU
output = fmax(in, (CL_DTYPE)0);
#endif
#ifdef RELU6
output = clamp(in, (CL_DTYPE)0, (CL_DTYPE)6);
#endif
return output;
}
......@@ -89,7 +93,7 @@ inline CL_DTYPE4 activation_type4(CL_DTYPE4 in
CL_DTYPE4 prelu_alpha
#endif
) {
CL_DTYPE4 output;
CL_DTYPE4 output = in;
#ifdef PRELU
output = select(prelu_alpha * in, in, in >= (CL_DTYPE4)0.0);
#endif
......@@ -97,5 +101,9 @@ inline CL_DTYPE4 activation_type4(CL_DTYPE4 in
#ifdef RELU
output = fmax(in, (CL_DTYPE4)0);
#endif
#ifdef RELU6
output = clamp(in, (CL_DTYPE4)0, (CL_DTYPE4)6);
#endif
return output;
}
......@@ -95,9 +95,7 @@ __kernel void depth_conv2d(__private const int global_size_dim0,
READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation_type4(output);
#endif
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output);
}
\ No newline at end of file
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void relu6(__read_only image2d_t input,
__write_only image2d_t output,
__private const float threshold){
const int x = get_global_id(0);
const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
in = max((CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), in);
in = min((CL_DTYPE4)(threshold, threshold, threshold, threshold), in);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
}
......@@ -70,9 +70,12 @@ void ConvCompute::PrepareForRun() {
kernel_func_names_.push_back("gemm_batch");
kernel_func_paths_.push_back("buffer/fc_kernel.cl");
if (relu_fused) {
build_options_.push_back("-DCL_DTYPE=float -DRELU");
build_options_.push_back("-DCL_DTYPE_float -DRELU");
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_.push_back("-DCL_DTYPE_float -DRELU6");
} else {
build_options_.push_back("-DCL_DTYPE=float");
build_options_.push_back("-DCL_DTYPE_float");
}
impl_ = &ConvCompute::Conv2d1x1;
} else if (pad_equal) {
......@@ -80,11 +83,14 @@ void ConvCompute::PrepareForRun() {
kernel_func_names_.push_back("gemm_batch");
kernel_func_paths_.push_back("buffer/im2col_kernel.cl");
kernel_func_paths_.push_back("buffer/fc_kernel.cl");
build_options_.push_back("-DCL_DTYPE=float");
build_options_.push_back("-DCL_DTYPE_float");
if (relu_fused) {
build_options_.push_back("-DCL_DTYPE=float -DRELU");
build_options_.push_back("-DCL_DTYPE_float -DRELU");
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_.push_back("-DCL_DTYPE_float -DRELU6");
} else {
build_options_.push_back("-DCL_DTYPE=float");
build_options_.push_back("-DCL_DTYPE_float");
}
impl_ = &ConvCompute::GemmlikeConv2d;
col_buffer_.reset(new lite::Tensor);
......
......@@ -46,7 +46,7 @@ static void conv_basic(const Dtype1* din,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu) {
std::string flag_relu) {
Dtype2 beta = 0;
auto src_data = din;
auto dst_data_ref = dout;
......@@ -96,10 +96,15 @@ static void conv_basic(const Dtype1* din,
}
}
}
if (flag_relu) {
if (flag_relu == "relu") {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
} else if (flag_relu == "relu6") {
auto dst_tmp = (dst_data_ref[out_idx] > (Dtype2)0)
? dst_data_ref[out_idx]
: (Dtype2)0;
dst_data_ref[out_idx] = (dst_tmp < 6.f) ? dst_tmp : 6.f;
}
}
}
......@@ -186,7 +191,7 @@ TEST(conv2d, compute_conv2d_1x1) {
/*int iw = ih;*/ for (int iw = 1; iw < 10; iw += 1) { // iw
for (int ic = 1; ic < 10; ic += 1) { // k
for (bool bias_flag : {true /*, false*/}) {
for (bool relu_flag : {true /*, false*/}) {
for (std::string relu_flag : {"relu" /*, "relu6", "None"*/}) {
#else
// groups:1 stride_h:1 stride_w:1 pad_h:0 pad_w:0 kernel_h:1 kernel_h:1
// x_dims:1 32 112 112
......@@ -229,7 +234,16 @@ TEST(conv2d, compute_conv2d_1x1) {
std::vector<int> paddings = {pad, pad, pad, pad};
param.groups = group;
std::vector<int> dilations = {dilation, dilation};
param.fuse_relu = relu_flag;
if (relu_flag == "relu") {
param.fuse_relu = true;
} else if (relu_flag == "None") {
param.fuse_relu = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
}
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.dilations = std::make_shared<std::vector<int>>(dilations);
......@@ -390,7 +404,7 @@ TEST(conv2d, compute_conv2d_1x1) {
#undef PRINT_RESULT
// #define PRINT_RESULT
#define LOOP_TEST
// #define LOOP_TEST
TEST(conv2d, compute_conv2d_gemm) {
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
......@@ -411,7 +425,7 @@ TEST(conv2d, compute_conv2d_gemm) {
for (int iw = 1; iw < 10; iw += 1) { // iw
for (int ic = 1; ic < 10; ic += 1) { // k
for (bool bias_flag : {true, false}) {
for (bool relu_flag : {true, false}) {
for (std::string relu_flag : {"relu", "relu6", "None"}) {
#else
const int batch_size = 8;
......@@ -420,7 +434,8 @@ TEST(conv2d, compute_conv2d_gemm) {
const int iw = 224;
const int ic = 3;
const bool bias_flag = true;
const bool relu_flag = true;
const std::string relu_flag =
"relu6"; // "relu", "relu6", "None"
#endif
const int oh = (ih + 2 * pad - ksize) / stride + 1;
......@@ -458,7 +473,16 @@ TEST(conv2d, compute_conv2d_gemm) {
std::vector<int> paddings = {pad, pad, pad, pad};
param.groups = group;
std::vector<int> dilations = {dilation, dilation};
param.fuse_relu = relu_flag;
if (relu_flag == "relu") {
param.fuse_relu = true;
} else if (relu_flag == "None") {
param.fuse_relu = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
}
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.dilations = std::make_shared<std::vector<int>>(dilations);
......
......@@ -39,6 +39,9 @@ class DepthwiseConv2dCompute
const auto& param = *param_.get_mutable<param_t>();
if (param.fuse_relu) {
build_options_ += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_ += " -DRELU6";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
......@@ -116,7 +119,7 @@ class DepthwiseConv2dCompute
private:
std::string kernel_func_name_{"depthwise_conv2d"};
std::string build_options_{"-DCL_DTYPE=float"};
std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
......@@ -135,6 +138,9 @@ class DepthwiseConv2dComputeFP16Image
const auto& param = *param_.get_mutable<param_t>();
if (param.fuse_relu) {
build_options_ += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_ += " -DRELU6";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
......@@ -252,6 +258,9 @@ class DepthwiseConv2d3x3s1ComputeFP16Image
const auto& param = *param_.get_mutable<param_t>();
if (param.fuse_relu) {
build_options_ += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_ += " -DRELU6";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
......@@ -360,6 +369,9 @@ class DepthwiseConv2dBasicComputeFP32Image
has_bias && param.output->dims() == param.bias->dims();
if (param.fuse_relu) {
build_options_ += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_ += " -DRELU6";
}
if (has_bias) {
build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
......
......@@ -220,12 +220,158 @@ class ReluComputeFP16ImageDefault
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class Relu6ComputeFloatImageDefault
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ActivationParam;
std::string doc() const override {
return "Relu6 using cl::Image2D(ImageDefault/RGBA), kFloat";
}
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/relu6_kernel.cl", build_options_);
}
void Run() override {
auto& param = *param_.get_mutable<param_t>();
const auto& x_dims = param.X->dims();
auto* x_buf = param.X->data<float, cl::Image2D>();
auto image_shape = InitImageDimInfoWith(x_dims);
auto* out_buf = param.Out->mutable_data<float, cl::Image2D>(
image_shape["width"], image_shape["height"]);
const auto& y_dims = param.Out->dims(); // useless: check dim only
auto threshold = param.Relu_clipped_coef;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, threshold);
CL_CHECK_FATAL(status);
VLOG(4) << TargetToStr(param.X->target());
VLOG(4) << TargetToStr(param.Out->target());
VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " "
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
VLOG(4) << "threshold:" << threshold;
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(image_shape["width"]),
static_cast<cl::size_type>(image_shape["height"])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
// TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list`
// context.cl_wait_list()->emplace(out_buf, event_);
context.cl_context()->GetCommandQueue().finish();
}
private:
std::string kernel_func_name_{"relu6"};
std::string build_options_{"-DCL_DTYPE_float -DRELU6"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class Relu6ComputeFP16ImageDefault
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ActivationParam;
std::string doc() const override {
return "Relu6 using cl::Image2D(ImageDefault/RGBA), kFP16";
}
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/relu6_kernel.cl", build_options_);
}
void Run() override {
auto& param = *param_.get_mutable<param_t>();
const auto& x_dims = param.X->dims();
auto* x_buf = param.X->data<int16_t, cl::Image2D>();
auto image_shape = InitImageDimInfoWith(x_dims);
auto* out_buf = param.Out->mutable_data<int16_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
const auto& y_dims = param.Out->dims(); // useless: check dim only
auto threshold = param.Relu_clipped_coef;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, threshold);
CL_CHECK_FATAL(status);
VLOG(4) << TargetToStr(param.X->target());
VLOG(4) << TargetToStr(param.Out->target());
VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " "
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
VLOG(4) << "threshold:" << threshold;
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(image_shape["width"]),
static_cast<cl::size_type>(image_shape["height"])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
// TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list`
// context.cl_wait_list()->emplace(out_buf, event_);
context.cl_context()->GetCommandQueue().finish();
}
private:
std::string kernel_func_name_{"relu6"};
std::string build_options_{"-DCL_DTYPE_half -DRELU6"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
// REGISTER_LITE_KERNEL(relu,
// REGISTER_LITE_KERNEL(relu,`
// kOpenCL,
// kFloat,
// kNCHW,
......@@ -267,3 +413,38 @@ REGISTER_LITE_KERNEL(relu,
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Relu6
REGISTER_LITE_KERNEL(
relu6,
kOpenCL,
kFloat,
kImageDefault,
paddle::lite::kernels::opencl::Relu6ComputeFloatImageDefault,
ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize();
REGISTER_LITE_KERNEL(
relu6,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::Relu6ComputeFP16ImageDefault,
ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
......@@ -23,9 +23,21 @@ namespace paddle {
namespace lite {
template <typename dtype>
void relu_compute_ref(const dtype *x_data, const DDim &x_dim, dtype *out_data) {
for (int i = 0; i < x_dim.production(); ++i) {
out_data[i] = x_data[i] > 0.f ? x_data[i] : 0.f;
void relu_compute_ref(const dtype *x_data,
const DDim &x_dim,
dtype *out_data,
float threshold = 0.f) {
if (abs(threshold) < 1e-5) {
// relu
for (int i = 0; i < x_dim.production(); ++i) {
out_data[i] = (x_data[i] > threshold) ? x_data[i] : threshold;
}
} else {
// relu6 or relu with threshold
for (int i = 0; i < x_dim.production(); ++i) {
auto out_tmp = (x_data[i] > 0) ? x_data[i] : 0;
out_data[i] = (out_tmp < threshold) ? out_tmp : threshold;
}
}
}
......@@ -252,7 +264,7 @@ TEST(relu_image2d_fp16, compute) {
"layout(img2buf) "
"-> host";
#ifdef LOOP_TEST
#ifdef RELU_FP16_LOOP_TEST
for (int n = 1; n <= 100; n += 33) {
for (auto c : {1, 3}) {
for (int h = 12; h <= 100; h += 13) {
......@@ -262,7 +274,7 @@ TEST(relu_image2d_fp16, compute) {
const int c = 2;
const int h = 3;
const int w = 4;
#endif // LOOP_TEST
#endif // RELU_FP16_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " "
<< h << " " << w << " ========";
......@@ -367,13 +379,13 @@ TEST(relu_image2d_fp16, compute) {
// compute ref cpu
relu_compute_ref<float>(mapped_x, x_dim, y_data_ref);
// result
#ifdef PRINT_RESULT
#ifdef RELU_FP16_PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx]
<< std::endl;
}
#endif // PRINT_RESULT
#endif // RELU_FP16_PRINT_RESULT
// check result: compare kernel output and cpu output(y_data_ref)
for (int eidx = 0; eidx < x_dim.production(); eidx++) {
......@@ -391,7 +403,321 @@ TEST(relu_image2d_fp16, compute) {
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef LOOP_TEST
#ifdef RELU_FP16_LOOP_TEST
} // w
} // h
} // c
} // n
#else
// nothing to do.
#endif
}
// #define RELU6_FP32_LOOP_TEST
// #define RELU6_FP32_PRINT_RESULT
TEST(relu6_image2d_fp32, compute) {
LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu6(img) -> "
"layout(img2buf) "
"-> host";
#ifdef RELU6_FP32_LOOP_TEST
for (int n = 1; n <= 100; n += 33) {
for (auto c : {1, 3}) {
for (int h = 12; h <= 100; h += 13) {
for (int w = 12; w <= 100; w += 25) {
#else
const int n = 1;
const int c = 2;
const int h = 3;
const int w = 4;
#endif // RELU6_FP32_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " "
<< h << " " << w << " ========";
// set layout kernels
auto buf_to_img_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto img_to_buf_kernels = KernelRegistry::Global().Create(
"layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW));
auto relu_img_kernels =
KernelRegistry::Global().Create("relu6",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(relu_img_kernels.empty());
auto buf_to_img_kernel = std::move(buf_to_img_kernels.front());
auto img_to_buf_kernel = std::move(img_to_buf_kernels.front());
auto relu_img_kernel = std::move(relu_img_kernels.front());
LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc();
LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc();
LOG(INFO) << "get 3rd kernel: " << relu_img_kernel->doc();
// set tensors about op param
LOG(INFO) << "set tensors about op param";
// layout(buf->img): x -> relu_in
// relu(img): relu_in -> relu_out
// layout(img->buf): relu_out -> y
lite::Tensor x, y, relu_in, relu_out, y_ref;
operators::LayoutParam BufferToImageParam;
operators::LayoutParam ImageToBufferParam;
BufferToImageParam.x = &x;
BufferToImageParam.y = &relu_in;
ImageToBufferParam.x = &relu_out;
ImageToBufferParam.y = &y;
operators::ActivationParam ReluParam;
ReluParam.X = &relu_in;
ReluParam.Out = &relu_out;
ReluParam.Relu_clipped_coef = 6.f;
const DDim x_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
x.Resize(x_dim);
y.Resize(x_dim);
relu_in.Resize(x_dim);
relu_out.Resize(x_dim);
y_ref.Resize(x_dim);
auto relu_image2d_shape =
paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
auto *x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM));
auto *mapped_x = static_cast<float *>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production()));
auto *mapped_y = static_cast<float *>(TargetWrapperCL::Map(
y_data, 0, sizeof(float) * x_dim.production()));
for (int i = 0; i < x_dim.production(); ++i) {
mapped_x[i] = static_cast<int>(i) - x_dim.production() / 2;
mapped_y[i] = static_cast<int>(0);
}
auto *relu_in_data = relu_in.mutable_data<float, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
auto *relu_out_data = relu_out.mutable_data<float, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
buf_to_img_kernel->SetParam(BufferToImageParam);
std::unique_ptr<KernelContext> buf_to_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
buf_to_img_kernel->SetContext(std::move(buf_to_img_context));
img_to_buf_kernel->SetParam(ImageToBufferParam);
std::unique_ptr<KernelContext> img_to_buf_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(img_to_buf_context->As<OpenCLContext>()));
img_to_buf_kernel->SetContext(std::move(img_to_buf_context));
relu_img_kernel->SetParam(ReluParam);
std::unique_ptr<KernelContext> relu_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(relu_img_context->As<OpenCLContext>()));
relu_img_kernel->SetContext(std::move(relu_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_kernel->Launch();
LOG(INFO) << "run kernel: relu_img_kernel";
relu_img_kernel->Launch();
LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch();
// compute ref cpu
relu_compute_ref<float>(mapped_x, x_dim, y_data_ref, 6.f);
// result
#ifdef RELU6_FP32_PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx]
<< std::endl;
}
#endif // RELU6_FP32_PRINT_RESULT
// check result: compare kernel output and cpu output(y_data_ref)
for (int eidx = 0; eidx < x_dim.production(); eidx++) {
EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-6);
if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-6) {
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx
<< " / " << x_dim.production() << ", y_data_ref["
<< eidx << "]:" << y_data_ref[eidx] << ", mapped_y["
<< eidx << "]:" << mapped_y[eidx];
break;
}
}
// free
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef RELU6_FP32_LOOP_TEST
} // w
} // h
} // c
} // n
#else
// nothing to do.
#endif
}
// #define RELU6_FP16_LOOP_TEST
// #define RELU6_FP16_PRINT_RESULT
TEST(relu6_image2d_fp16, compute) {
LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu6(img) -> "
"layout(img2buf) "
"-> host";
#ifdef RELU6_FP16_LOOP_TEST
for (int n = 1; n <= 100; n += 33) {
for (auto c : {1, 3}) {
for (int h = 12; h <= 100; h += 13) {
for (int w = 12; w <= 100; w += 25) {
#else
const int n = 1;
const int c = 2;
const int h = 3;
const int w = 4;
#endif // RELU6_FP16_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " "
<< h << " " << w << " ========";
// set layout kernels
auto buf_to_img_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto img_to_buf_kernels = KernelRegistry::Global().Create(
"layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW));
auto relu_img_kernels =
KernelRegistry::Global().Create("relu6",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(relu_img_kernels.empty());
auto buf_to_img_kernel = std::move(buf_to_img_kernels.front());
auto img_to_buf_kernel = std::move(img_to_buf_kernels.front());
auto relu_img_kernel = std::move(relu_img_kernels.front());
LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc();
LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc();
LOG(INFO) << "get 3rd kernel: " << relu_img_kernel->doc();
// set tensors about op param
LOG(INFO) << "set tensors about op param";
// layout(buf->img): x -> relu_in
// relu(img): relu_in -> relu_out
// layout(img->buf): relu_out -> y
lite::Tensor x, y, relu_in, relu_out, y_ref;
operators::LayoutParam BufferToImageParam;
operators::LayoutParam ImageToBufferParam;
BufferToImageParam.x = &x;
BufferToImageParam.y = &relu_in;
ImageToBufferParam.x = &relu_out;
ImageToBufferParam.y = &y;
operators::ActivationParam ReluParam;
ReluParam.X = &relu_in;
ReluParam.Out = &relu_out;
ReluParam.Relu_clipped_coef = 6.f;
const DDim x_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
x.Resize(x_dim);
y.Resize(x_dim);
relu_in.Resize(x_dim);
relu_out.Resize(x_dim);
y_ref.Resize(x_dim);
auto relu_image2d_shape =
paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
auto *x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM));
auto *mapped_x = static_cast<float *>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production()));
auto *mapped_y = static_cast<float *>(TargetWrapperCL::Map(
y_data, 0, sizeof(float) * x_dim.production()));
for (int i = 0; i < x_dim.production(); ++i) {
mapped_x[i] = static_cast<int>(i) - x_dim.production() / 2;
mapped_y[i] = static_cast<int>(0);
}
auto *relu_in_data = relu_in.mutable_data<int16_t, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
auto *relu_out_data = relu_out.mutable_data<int16_t, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
buf_to_img_kernel->SetParam(BufferToImageParam);
std::unique_ptr<KernelContext> buf_to_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
buf_to_img_kernel->SetContext(std::move(buf_to_img_context));
img_to_buf_kernel->SetParam(ImageToBufferParam);
std::unique_ptr<KernelContext> img_to_buf_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(img_to_buf_context->As<OpenCLContext>()));
img_to_buf_kernel->SetContext(std::move(img_to_buf_context));
relu_img_kernel->SetParam(ReluParam);
std::unique_ptr<KernelContext> relu_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(relu_img_context->As<OpenCLContext>()));
relu_img_kernel->SetContext(std::move(relu_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_kernel->Launch();
LOG(INFO) << "run kernel: relu_img_kernel";
relu_img_kernel->Launch();
LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch();
// compute ref cpu
relu_compute_ref<float>(mapped_x, x_dim, y_data_ref, 6.f);
// result
#ifdef RELU6_FP16_PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx]
<< std::endl;
}
#endif // RELU6_FP16_PRINT_RESULT
// check result: compare kernel output and cpu output(y_data_ref)
for (int eidx = 0; eidx < x_dim.production(); eidx++) {
EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-6);
if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-6) {
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx
<< " / " << x_dim.production() << ", y_data_ref["
<< eidx << "]:" << y_data_ref[eidx] << ", mapped_y["
<< eidx << "]:" << mapped_y[eidx];
break;
}
}
// free
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef RELU6_FP16_LOOP_TEST
} // w
} // h
} // c
......@@ -414,3 +740,7 @@ USE_LITE_KERNEL(relu, kOpenCL, kFloat, kImageDefault, ImageDefault);
// relu image2d fp16
USE_LITE_KERNEL(relu, kOpenCL, kFP16, kImageDefault, ImageDefault);
// relu6 image2d fp32
USE_LITE_KERNEL(relu6, kOpenCL, kFloat, kImageDefault, ImageDefault);
USE_LITE_KERNEL(relu6, kOpenCL, kFP16, kImageDefault, ImageDefault);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册