未验证 提交 a34f06a1 编写于 作者: H HappyAngel 提交者: GitHub

[OpenCL]Add leakyRelu and tanh op (#3048)

* fix leakyrelu and tanh compute error, test=develop

* delete extra file, test=develop

* reset act

* fix conflict and readme, test=develop

* fix ios run error, test=develop
上级 331aaecd
......@@ -16,7 +16,9 @@ limitations under the License. */
__kernel void relu(__read_only image2d_t input,
__write_only image2d_t output) {
__write_only image2d_t output,
__private const float threshold,
__private const float scale) {
const int x = get_global_id(0); // image_width
const int y = get_global_id(1); // image_height
......@@ -33,7 +35,8 @@ __kernel void relu(__read_only image2d_t input,
__kernel void relu6(__read_only image2d_t input,
__write_only image2d_t output,
__private const float threshold){
__private const float threshold,
__private const float scale){
const int x = get_global_id(0);
const int y = get_global_id(1);
......@@ -50,7 +53,9 @@ __kernel void relu6(__read_only image2d_t input,
__kernel void sigmoid(__read_only image2d_t input,
__write_only image2d_t output) {
__write_only image2d_t output,
__private const float threshold,
__private const float scale) {
const int x = get_global_id(0); // image_width
const int y = get_global_id(1); // image_height
......@@ -63,3 +68,48 @@ __kernel void sigmoid(__read_only image2d_t input,
CL_DTYPE4 out = 1 / (1 + exp(-in));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out);
}
__kernel void leaky_relu(__read_only image2d_t input,
__write_only image2d_t output,
__private const float threshold,
__private const float scale) {
const int x = get_global_id(0);
const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
CL_DTYPE4 s_val = CONVERT_TYPE_TO(scale, CL_DTYPE) * in;
if (in.x < 0.0f){
in.x = s_val.x;
}
if (in.y < 0.0f){
in.y = s_val.y;
}
if (in.z < 0.0f){
in.z = s_val.z;
}
if (in.w < 0.0f){
in.w = s_val.w;
}
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
}
__kernel void tanhAct(__read_only image2d_t input,
__write_only image2d_t output,
__private const float threshold,
__private const float scale) {
const int x = get_global_id(0); // image_width
const int y = get_global_id(1); // image_height
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
CL_DTYPE4 out= (exp(in) - exp(-in))/ (exp(in) + exp(-in));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out);
}
......@@ -28,7 +28,7 @@ OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
#CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
......@@ -40,7 +40,7 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYS
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
test_model_cv: fetch_opencv test_model_cv.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_model_cv.o -o test_model_cv $(CXX_LIBS) $(LDFLAGS)
......
......@@ -28,7 +28,7 @@ OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
#CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
......@@ -39,7 +39,7 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYS
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = ${OPENCV_LIBS} $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
CXX_LIBS = ${OPENCV_LIBS} $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
test_model_cv: fetch_opencv test_model_cv.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_model_cv.o -o test_model_cv $(CXX_LIBS) $(LDFLAGS)
......
# 图像预测库的使用
1. 下载源码(https://github.com/PaddlePaddle/Paddle-Lite),打开LITE_WITH_CV=ON,编译full_publish模式
1. 下载源码(https://github.com/PaddlePaddle/Paddle-Lite),打开LITE_WITH_CV=ON,编译full_publish or tiny_publish模式
example:
```shell
set BUILD_WITH_CV=ON or LITE_WITH_CV=ON
......@@ -8,7 +8,7 @@ set BUILD_WITH_CV=ON or LITE_WITH_CV=ON
--arm_abi=armv8
--arm_lang=gcc
--android_stl=c++_static
full_publish
tiny_publish
```
2. 准备模型和优化模型
......@@ -68,7 +68,8 @@ make
adb -s device_id push mobilenet_v1 /data/local/tmp/
adb -s device_id push test_model_cv /data/local/tmp/
adb -s device_id push test.jpg /data/local/tmp/
adb -s device_id push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/
adb -s device_id push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/
#adb -s device_id push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/
adb -s device_id shell chmod +x /data/local/tmp/test_model_cv
adb -s device_id shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH &&
/data/local/tmp/test_model_cv /data/local/tmp/mobilenet_v1 /data/local/tmp/test.jpg 1 3 224 224 "
......@@ -119,7 +120,8 @@ make
adb -s device_id push mobilenet_v1 /data/local/tmp/
adb -s device_id push test_img_propress /data/local/tmp/
adb -s device_id push test.jpg /data/local/tmp/
adb -s device_id push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/
adb -s device_id push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/
#adb -s device_id push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/
adb -s device_id shell chmod +x /data/local/tmp/test_model_cv
adb -s device_id shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH &&
/data/local/tmp/test_img_propress /data/local/tmp/test.jpg /data/local/tmp/ 3 3 1 3 224 224 /data/local/tmp/mobilenet_v1 "
......
......@@ -25,84 +25,43 @@ namespace lite {
namespace kernels {
namespace opencl {
class ReluComputeImageDefault : public KernelLite<TARGET(kOpenCL),
class ActivationComputeImageDefault
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ActivationParam;
std::string doc() const override {
return "Relu using cl::Image2D(ImageDefault/RGBA), kFP16";
return "Activation using cl::Image2D(ImageDefault/RGBA), kFP16";
}
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/activation_kernel.cl", build_options_);
}
void Run() override {
auto& param = *param_.get_mutable<param_t>();
const auto& x_dims = param.X->dims();
auto* x_img = param.X->data<half_t, cl::Image2D>();
auto image_shape = InitImageDimInfoWith(x_dims);
auto* out_img = param.Out->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
const auto& y_dims = param.Out->dims(); // useless: check dim only
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
VLOG(4) << TargetToStr(param.X->target());
VLOG(4) << TargetToStr(param.Out->target());
VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " "
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(image_shape["width"]),
static_cast<cl::size_type>(image_shape["height"])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
act_param_ = param_.get_mutable<param_t>();
int act_type = static_cast<int>(act_param_->active_type);
switch (act_type) {
case 1:
kernel_func_name_ = "relu";
break;
case 2:
kernel_func_name_ = "relu6";
threshold_ = act_param_->Relu_clipped_coef;
break;
case 4:
kernel_func_name_ = "leaky_relu";
scale_ = act_param_->Leaky_relu_alpha;
break;
case 5:
kernel_func_name_ = "sigmoid";
break;
case 6:
kernel_func_name_ = "tanhAct";
break;
default:
printf("This act type: %d doesn't support \n", act_type);
return;
}
private:
std::string kernel_func_name_{"relu"};
std::string build_options_{"-DCL_DTYPE_half -DRELU"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class Relu6ComputeImageDefault : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ActivationParam;
std::string doc() const override {
return "Relu6 using cl::Image2D(ImageDefault/RGBA), kFP16";
}
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/activation_kernel.cl", build_options_);
}
......@@ -115,7 +74,6 @@ class Relu6ComputeImageDefault : public KernelLite<TARGET(kOpenCL),
auto* out_img = param.Out->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
const auto& y_dims = param.Out->dims(); // useless: check dim only
auto threshold = param.Relu_clipped_coef;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
......@@ -128,79 +86,9 @@ class Relu6ComputeImageDefault : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, threshold);
CL_CHECK_FATAL(status);
VLOG(4) << TargetToStr(param.X->target());
VLOG(4) << TargetToStr(param.Out->target());
VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " "
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
VLOG(4) << "threshold:" << threshold;
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(image_shape["width"]),
static_cast<cl::size_type>(image_shape["height"])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
}
private:
std::string kernel_func_name_{"relu6"};
std::string build_options_{"-DCL_DTYPE_half -DRELU6"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class SigmoidComputeImageDefault
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ActivationParam;
std::string doc() const override {
return "Sigmoid using cl::Image2D(ImageDefault/RGBA), kFP16";
}
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/activation_kernel.cl", build_options_);
}
void Run() override {
auto& param = *param_.get_mutable<param_t>();
const auto& x_dims = param.X->dims();
auto* x_img =
param.X->data<half_t,
cl::Image2D>(); // use half_t represents half float
auto image_shape = InitImageDimInfoWith(x_dims);
auto* out_img = param.Out->mutable_data<half_t, cl::Image2D>( // use half_t
// represents half float
image_shape["width"],
image_shape["height"]);
const auto& y_dims = param.Out->dims(); // useless: check dim only
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_img);
status = kernel.setArg(++arg_idx, threshold_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
status = kernel.setArg(++arg_idx, scale_);
CL_CHECK_FATAL(status);
VLOG(4) << TargetToStr(param.X->target());
......@@ -211,6 +99,9 @@ class SigmoidComputeImageDefault
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
VLOG(4) << "threshold:" << threshold_;
VLOG(4) << "scale:" << scale_;
VLOG(4) << "kernel func name:" << kernel_func_name_;
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(image_shape["width"]),
......@@ -227,22 +118,59 @@ class SigmoidComputeImageDefault
}
private:
std::string kernel_func_name_{"sigmoid"};
std::string build_options_{"-DCL_DTYPE_half -DSIGMOID"};
param_t* act_param_{nullptr};
std::string kernel_func_name_{};
float threshold_{6.f};
float scale_{1.f};
std::string build_options_{"-DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
// leakyRelu
REGISTER_LITE_KERNEL(
leaky_relu,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::ActivationComputeImageDefault,
ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// tanh
REGISTER_LITE_KERNEL(
tanhAct,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::ActivationComputeImageDefault,
ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Relu
REGISTER_LITE_KERNEL(relu,
REGISTER_LITE_KERNEL(
relu,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::ReluComputeImageDefault,
paddle::lite::kernels::opencl::ActivationComputeImageDefault,
ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
......@@ -255,11 +183,12 @@ REGISTER_LITE_KERNEL(relu,
.Finalize();
// Relu6
REGISTER_LITE_KERNEL(relu6,
REGISTER_LITE_KERNEL(
relu6,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::Relu6ComputeImageDefault,
paddle::lite::kernels::opencl::ActivationComputeImageDefault,
ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
......@@ -272,11 +201,12 @@ REGISTER_LITE_KERNEL(relu6,
.Finalize();
// Sigmoid
REGISTER_LITE_KERNEL(sigmoid,
REGISTER_LITE_KERNEL(
sigmoid,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::SigmoidComputeImageDefault,
paddle::lite::kernels::opencl::ActivationComputeImageDefault,
ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
......
......@@ -26,107 +26,150 @@ namespace paddle {
namespace lite {
template <typename dtype>
void relu_compute_ref(const dtype *x_data,
void act_compute_ref(const dtype *x_data,
const DDim &x_dim,
dtype *out_data,
float threshold = 0.f) {
if (abs(threshold) < 1e-5) {
// relu
for (int i = 0; i < x_dim.production(); ++i) {
out_data[i] = (x_data[i] > threshold) ? x_data[i] : threshold;
}
} else {
// relu6 or relu with threshold
for (int i = 0; i < x_dim.production(); ++i) {
auto out_tmp = (x_data[i] > 0) ? x_data[i] : 0;
out_data[i] = (out_tmp < threshold) ? out_tmp : threshold;
}
}
}
template <typename dtype>
void sigmoid_compute_ref(const dtype *x_data,
const DDim &x_dim,
dtype *out_data) {
for (int i = 0; i < x_dim.production(); ++i) {
int act_type,
float threshold,
float scale) {
for (int i = 0; i < x_dim.production(); i++) {
switch (act_type) {
case 1: // relu
out_data[i] = x_data[i] > 0 ? x_data[i] : 0;
break;
case 2: // relu6
out_data[i] = x_data[i] > 0 ? x_data[i] : 0;
out_data[i] = (out_data[i] < threshold) ? out_data[i] : threshold;
break;
case 4: // leakyRelu
out_data[i] = x_data[i] > 0 ? x_data[i] : x_data[i] * scale;
break;
case 5: // sigmoid
out_data[i] = 1 / (1 + expf(-x_data[i]));
break;
case 6: // tanh
out_data[i] = (expf(x_data[i]) - expf(-x_data[i])) /
(expf(x_data[i]) + expf(-x_data[i]));
break;
default:
break;
}
}
}
// #define RELU_FP16_LOOP_TEST
// #define RELU_FP16_PRINT_RESULT
TEST(relu_image2d_fp16, compute) {
// #define ACT_FP16_LOOP_TEST
// #define ACT_FP16_PRINT_RESULT
TEST(act_image2d_fp16, compute) {
LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu(img) -> "
"layout(img2buf) "
"-> host";
#ifdef RELU_FP16_LOOP_TEST
for (int n = 1; n <= 2; n += 1) {
for (auto c : {1}) {
#ifdef ACT_FP16_LOOP_TEST
for (int n = 1; n <= 100; n += 33) {
for (auto c : {1, 3, 8, 23, 32}) {
for (int h = 12; h <= 100; h += 13) {
for (int w = 12; w <= 100; w += 25) {
for (auto act_type : {1, 2, 4, 5, 6}) {
for (auto scale : {0.5, 0.8}) {
for (auto threshold : {6.0}) {
#else
const int n = 1;
const int c = 2;
const int h = 3;
const int w = 4;
#endif // RELU_FP16_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " "
<< h << " " << w << " ========";
const int act_type = 4;
const float scale = 0.5f;
const float threshold = 6.f;
#endif // ACT_FP16_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c
<< " " << h << " " << w << " ========";
LOG(INFO) << "====act_type: " << act_type
<< ", scale: " << scale
<< ", threshold: " << threshold;
std::string func_name = "relu";
switch (act_type) {
case 1: // relu
func_name = "relu";
break;
case 2: // relu6
func_name = "relu6";
break;
case 4: // leaky_relu
func_name = "leaky_relu";
break;
case 5: // sigmoid
func_name = "sigmoid";
break;
case 6: // tanh
func_name = "tanhAct";
break;
}
LOG(INFO) << "func_name: " << func_name;
// set layout kernels
auto buf_to_img_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto img_to_buf_kernels = KernelRegistry::Global().Create(
"layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW));
auto relu_img_kernels =
KernelRegistry::Global().Create("relu",
auto img_to_buf_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kNCHW));
auto act_img_kernels =
KernelRegistry::Global().Create(func_name.c_str(),
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(relu_img_kernels.empty());
ASSERT_FALSE(act_img_kernels.empty());
auto buf_to_img_kernel = std::move(buf_to_img_kernels.front());
auto img_to_buf_kernel = std::move(img_to_buf_kernels.front());
auto relu_img_kernel = std::move(relu_img_kernels.front());
auto act_img_kernel = std::move(act_img_kernels.front());
LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc();
LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc();
LOG(INFO) << "get 3rd kernel: " << relu_img_kernel->doc();
LOG(INFO) << "get 3rd kernel: " << act_img_kernel->doc();
// set tensors about op param
LOG(INFO) << "set tensors about op param";
// layout(buf->img): x -> relu_in
// relu(img): relu_in -> relu_out
// layout(img->buf): relu_out -> y
lite::Tensor x, y, relu_in, relu_out, y_ref;
// layout(buf->img): x -> act_in
// relu(img): act_in -> act_out
// layout(img->buf): act_out -> y
lite::Tensor x, y, act_in, act_out, y_ref;
operators::LayoutParam BufferToImageParam;
operators::LayoutParam ImageToBufferParam;
BufferToImageParam.x = &x;
BufferToImageParam.y = &relu_in;
ImageToBufferParam.x = &relu_out;
BufferToImageParam.y = &act_in;
ImageToBufferParam.x = &act_out;
ImageToBufferParam.y = &y;
operators::ActivationParam ReluParam;
ReluParam.X = &relu_in;
ReluParam.Out = &relu_out;
const DDim x_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
operators::ActivationParam actParam;
actParam.X = &act_in;
actParam.Out = &act_out;
actParam.active_type =
(paddle::lite_api::ActivationType)act_type;
actParam.Relu_clipped_coef = threshold;
actParam.Leaky_relu_alpha = scale;
const DDim x_dim =
DDim(std::vector<DDim::value_type>{n, c, h, w});
x.Resize(x_dim);
y.Resize(x_dim);
relu_in.Resize(x_dim);
relu_out.Resize(x_dim);
act_in.Resize(x_dim);
act_out.Resize(x_dim);
y_ref.Resize(x_dim);
auto relu_image2d_shape =
auto act_image2d_shape =
paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
auto *x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *x_data =
x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data =
y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM));
auto *mapped_x = static_cast<float *>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production()));
......@@ -136,10 +179,10 @@ TEST(relu_image2d_fp16, compute) {
mapped_x[i] = static_cast<int>(i) - x_dim.production() / 2;
mapped_y[i] = static_cast<int>(0);
}
auto *relu_in_data = relu_in.mutable_data<half_t, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
auto *relu_out_data = relu_out.mutable_data<half_t, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
auto *act_in_data = act_in.mutable_data<half_t, cl::Image2D>(
act_image2d_shape["width"], act_image2d_shape["height"]);
auto *act_out_data = act_out.mutable_data<half_t, cl::Image2D>(
act_image2d_shape["width"], act_image2d_shape["height"]);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
......@@ -147,28 +190,31 @@ TEST(relu_image2d_fp16, compute) {
context->As<OpenCLContext>().InitOnce();
buf_to_img_kernel->SetParam(BufferToImageParam);
std::unique_ptr<KernelContext> buf_to_img_context(new KernelContext);
std::unique_ptr<KernelContext> buf_to_img_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
buf_to_img_kernel->SetContext(std::move(buf_to_img_context));
img_to_buf_kernel->SetParam(ImageToBufferParam);
std::unique_ptr<KernelContext> img_to_buf_context(new KernelContext);
std::unique_ptr<KernelContext> img_to_buf_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(img_to_buf_context->As<OpenCLContext>()));
img_to_buf_kernel->SetContext(std::move(img_to_buf_context));
relu_img_kernel->SetParam(ReluParam);
std::unique_ptr<KernelContext> relu_img_context(new KernelContext);
act_img_kernel->SetParam(actParam);
std::unique_ptr<KernelContext> act_img_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(relu_img_context->As<OpenCLContext>()));
relu_img_kernel->SetContext(std::move(relu_img_context));
&(act_img_context->As<OpenCLContext>()));
act_img_kernel->SetContext(std::move(act_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_kernel->Launch();
LOG(INFO) << "run kernel: relu_img_kernel";
relu_img_kernel->Launch();
LOG(INFO) << "run kernel: act_img_kernel";
act_img_kernel->Launch();
LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch();
......@@ -188,31 +234,37 @@ TEST(relu_image2d_fp16, compute) {
}
// compute ref cpu
relu_compute_ref<float>(mapped_x, x_dim, y_data_ref);
act_compute_ref<float>(
mapped_x, x_dim, y_data_ref, act_type, threshold, scale);
// result
#ifdef RELU_FP16_PRINT_RESULT
#ifdef ACT_FP16_PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx]
<< ", ref: " << y_data_ref[eidx] << std::endl;
}
#endif // RELU_FP16_PRINT_RESULT
#endif // ACT_FP16_PRINT_RESULT
// check result: compare kernel output and cpu output(y_data_ref)
// check result: compare kernel output and cpu
// output(y_data_ref)
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
auto abs_diff = COMPUTE_ABS_DIFF(y_data_ref[eidx], mapped_y[eidx]);
auto abs_diff =
COMPUTE_ABS_DIFF(y_data_ref[eidx], mapped_y[eidx]);
auto relative_diff =
COMPUTE_RELATIVE_DIFF(y_data_ref[eidx], mapped_y[eidx]);
EXPECT_EQ(
(relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << eidx << ", y_data_ref[" << eidx
// EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) ||
// (abs_diff <= FP16_MAX_DIFF),
// true);
if ((relative_diff > FP16_MAX_DIFF) &&
(abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR)
<< "error idx:" << eidx << ", y_data_ref[" << eidx
<< "]:" << y_data_ref[eidx] << ", mapped_y[" << eidx
<< "]:" << mapped_y[eidx] << " abs_diff:" << abs_diff
<< "]:" << mapped_y[eidx] << " mapped_x[" << eidx
<< "]:" << mapped_x[eidx] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
break;
return;
}
}
......@@ -220,7 +272,10 @@ TEST(relu_image2d_fp16, compute) {
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef RELU_FP16_LOOP_TEST
#ifdef ACT_FP16_LOOP_TEST
} // threshold
} // scale
} // act_type
} // w
} // h
} // c
......@@ -229,360 +284,17 @@ TEST(relu_image2d_fp16, compute) {
// nothing to do.
#endif
}
// #define RELU6_FP16_LOOP_TEST
// #define RELU6_FP16_PRINT_RESULT
TEST(relu6_image2d_fp16, compute) {
LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu6(img) -> "
"layout(img2buf) "
"-> host";
#ifdef RELU6_FP16_LOOP_TEST
for (int n = 1; n <= 100; n += 33) {
for (auto c : {1, 3}) {
for (int h = 12; h <= 100; h += 13) {
for (int w = 12; w <= 100; w += 25) {
#else
const int n = 1;
const int c = 2;
const int h = 3;
const int w = 4;
#endif // RELU6_FP16_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " "
<< h << " " << w << " ========";
// set layout kernels
auto buf_to_img_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto img_to_buf_kernels = KernelRegistry::Global().Create(
"layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW));
auto relu_img_kernels =
KernelRegistry::Global().Create("relu6",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(relu_img_kernels.empty());
auto buf_to_img_kernel = std::move(buf_to_img_kernels.front());
auto img_to_buf_kernel = std::move(img_to_buf_kernels.front());
auto relu_img_kernel = std::move(relu_img_kernels.front());
LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc();
LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc();
LOG(INFO) << "get 3rd kernel: " << relu_img_kernel->doc();
// set tensors about op param
LOG(INFO) << "set tensors about op param";
// layout(buf->img): x -> relu_in
// relu(img): relu_in -> relu_out
// layout(img->buf): relu_out -> y
lite::Tensor x, y, relu_in, relu_out, y_ref;
operators::LayoutParam BufferToImageParam;
operators::LayoutParam ImageToBufferParam;
BufferToImageParam.x = &x;
BufferToImageParam.y = &relu_in;
ImageToBufferParam.x = &relu_out;
ImageToBufferParam.y = &y;
operators::ActivationParam ReluParam;
ReluParam.X = &relu_in;
ReluParam.Out = &relu_out;
ReluParam.Relu_clipped_coef = 6.f;
const DDim x_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
x.Resize(x_dim);
y.Resize(x_dim);
relu_in.Resize(x_dim);
relu_out.Resize(x_dim);
y_ref.Resize(x_dim);
auto relu_image2d_shape =
paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
auto *x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM));
auto *mapped_x = static_cast<float *>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production()));
auto *mapped_y = static_cast<float *>(TargetWrapperCL::Map(
y_data, 0, sizeof(float) * x_dim.production()));
for (int i = 0; i < x_dim.production(); ++i) {
mapped_x[i] = static_cast<int>(i) - x_dim.production() / 2 * 0.1;
mapped_y[i] = static_cast<int>(0);
}
auto *relu_in_data = relu_in.mutable_data<half_t, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
auto *relu_out_data = relu_out.mutable_data<half_t, cl::Image2D>(
relu_image2d_shape["width"], relu_image2d_shape["height"]);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
buf_to_img_kernel->SetParam(BufferToImageParam);
std::unique_ptr<KernelContext> buf_to_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
buf_to_img_kernel->SetContext(std::move(buf_to_img_context));
img_to_buf_kernel->SetParam(ImageToBufferParam);
std::unique_ptr<KernelContext> img_to_buf_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(img_to_buf_context->As<OpenCLContext>()));
img_to_buf_kernel->SetContext(std::move(img_to_buf_context));
relu_img_kernel->SetParam(ReluParam);
std::unique_ptr<KernelContext> relu_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(relu_img_context->As<OpenCLContext>()));
relu_img_kernel->SetContext(std::move(relu_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_kernel->Launch();
LOG(INFO) << "run kernel: relu_img_kernel";
relu_img_kernel->Launch();
LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch();
// wait for opencl
auto *wait_list = context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr = ImageToBufferParam.y->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
// compute ref cpu
relu_compute_ref<float>(mapped_x, x_dim, y_data_ref, 6.f);
// result
#ifdef RELU6_FP16_PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx]
<< ", ref: " << y_data_ref[eidx] << std::endl;
}
#endif // RELU6_FP16_PRINT_RESULT
// check result: compare kernel output and cpu output(y_data_ref)
for (int eidx = 0; eidx < x_dim.production(); eidx++) {
EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], FP16_MAX_DIFF);
if (abs(y_data_ref[eidx] - mapped_y[eidx]) > FP16_MAX_DIFF) {
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx
<< " / " << x_dim.production() << ", y_data_ref["
<< eidx << "]:" << y_data_ref[eidx] << ", mapped_y["
<< eidx << "]:" << mapped_y[eidx];
break;
}
}
// free
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef RELU6_FP16_LOOP_TEST
} // w
} // h
} // c
} // n
#else
// nothing to do.
#endif
}
// #define SIGMOID_FP16_LOOP_TEST
// #define SIGMOID_FP16_PRINT_RESULT
TEST(sigmoid_image2d_fp16, compute) {
LOG(INFO) << "main steps of test: host -> layout(buf2img) -> sigmoid(img) -> "
"layout(img2buf) "
"-> host";
#ifdef SIGMOID_FP16_LOOP_TEST
for (int n = 1; n <= 100; n += 33) {
for (auto c : {1, 3}) {
for (int h = 12; h <= 100; h += 13) {
for (int w = 12; w <= 100; w += 25) {
#else
const int n = 1;
const int c = 2;
const int h = 3;
const int w = 4;
#endif // SIGMOID_FP16_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " "
<< h << " " << w << " ========";
// set layout kernels
auto buf_to_img_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto img_to_buf_kernels = KernelRegistry::Global().Create(
"layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW));
auto sigmoid_img_kernels =
KernelRegistry::Global().Create("sigmoid",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(sigmoid_img_kernels.empty());
auto buf_to_img_kernel = std::move(buf_to_img_kernels.front());
auto img_to_buf_kernel = std::move(img_to_buf_kernels.front());
auto sigmoid_img_kernel = std::move(sigmoid_img_kernels.front());
LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc();
LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc();
LOG(INFO) << "get 3rd kernel: " << sigmoid_img_kernel->doc();
// set tensors about op param
LOG(INFO) << "set tensors about op param";
// layout(buf->img): x -> sigmoid_in
// sigmoid(img): sigmoid_in -> sigmoid_out
// layout(img->buf): sigmoid_out -> y
lite::Tensor x, y, sigmoid_in, sigmoid_out, y_ref;
operators::LayoutParam BufferToImageParam;
operators::LayoutParam ImageToBufferParam;
BufferToImageParam.x = &x;
BufferToImageParam.y = &sigmoid_in;
ImageToBufferParam.x = &sigmoid_out;
ImageToBufferParam.y = &y;
operators::ActivationParam SigmoidParam;
SigmoidParam.X = &sigmoid_in;
SigmoidParam.Out = &sigmoid_out;
const DDim x_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
x.Resize(x_dim);
y.Resize(x_dim);
sigmoid_in.Resize(x_dim);
sigmoid_out.Resize(x_dim);
y_ref.Resize(x_dim);
auto sigmoid_image2d_shape =
paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
auto *x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM));
auto *mapped_x = static_cast<float *>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production()));
auto *mapped_y = static_cast<float *>(TargetWrapperCL::Map(
y_data, 0, sizeof(float) * x_dim.production()));
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-1, 1);
for (int i = 0; i < x_dim.production(); ++i) {
mapped_x[i] = static_cast<float>(dist(engine));
}
auto *sigmoid_in_data = sigmoid_in.mutable_data<half_t, cl::Image2D>(
sigmoid_image2d_shape["width"], sigmoid_image2d_shape["height"]);
auto *sigmoid_out_data =
sigmoid_out.mutable_data<half_t, cl::Image2D>(
sigmoid_image2d_shape["width"],
sigmoid_image2d_shape["height"]);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
buf_to_img_kernel->SetParam(BufferToImageParam);
std::unique_ptr<KernelContext> buf_to_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
buf_to_img_kernel->SetContext(std::move(buf_to_img_context));
img_to_buf_kernel->SetParam(ImageToBufferParam);
std::unique_ptr<KernelContext> img_to_buf_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(img_to_buf_context->As<OpenCLContext>()));
img_to_buf_kernel->SetContext(std::move(img_to_buf_context));
sigmoid_img_kernel->SetParam(SigmoidParam);
std::unique_ptr<KernelContext> sigmoid_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(sigmoid_img_context->As<OpenCLContext>()));
sigmoid_img_kernel->SetContext(std::move(sigmoid_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_kernel->Launch();
LOG(INFO) << "run kernel: sigmoid_img_kernel";
sigmoid_img_kernel->Launch();
LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch();
// wait for opencl
auto *wait_list = context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr = ImageToBufferParam.y->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
// compute ref cpu
sigmoid_compute_ref<float>(mapped_x, x_dim, y_data_ref);
// result
#ifdef SIGMOID_FP16_PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx]
<< ", ref:" << y_data_ref[eidx] << std::endl;
}
#endif // SIGMOID_FP16_PRINT_RESULT
// check result: compare kernel output and cpu output(y_data_ref)
for (int eidx = 0; eidx < x_dim.production(); eidx++) {
EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], FP16_MAX_DIFF);
if (abs(y_data_ref[eidx] - mapped_y[eidx]) > FP16_MAX_DIFF) {
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx
<< " / " << x_dim.production() << ", y_data_ref["
<< eidx << "]: " << y_data_ref[eidx] << ", mapped_y["
<< eidx << "]: " << mapped_y[eidx] << ", mapped_x["
<< eidx << "]: " << mapped_x[eidx];
break;
}
}
// free
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef SIGMOID_FP16_LOOP_TEST
} // w
} // h
} // c
} // n
#else
// nothing to do.
#endif
}
} // namespace lite
} // namespace paddle
// layout
USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault);
USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW);
// leakyRelu
USE_LITE_KERNEL(leaky_relu, kOpenCL, kFP16, kImageDefault, ImageDefault);
// tanh
USE_LITE_KERNEL(tanhAct, kOpenCL, kFP16, kImageDefault, ImageDefault);
// relu image2d fp16
USE_LITE_KERNEL(relu, kOpenCL, kFP16, kImageDefault, ImageDefault);
......
......@@ -103,7 +103,6 @@ bool ActivationGradOp::AttachImpl(const cpp::OpDesc& opdesc,
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(leaky_relu, paddle::lite::operators::ActivationOp);
......
......@@ -830,6 +830,9 @@ void hwc3_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) {
uint8x8_t vg = vdup_n_u8(g);
uint8x8_t vr = vdup_n_u8(r);
#ifdef __aarch64__
uint8x16_t vb1 = vdupq_n_u8(b);
uint8x16_t vg1 = vdupq_n_u8(g);
uint8x16_t vr1 = vdupq_n_u8(r);
#else
uint8_t vb_array[8] = {b, b, b, b, b, b, b, b};
uint8_t vg_array[8] = {g, g, g, g, g, g, g, g};
......@@ -925,7 +928,7 @@ void hwc3_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) {
[outr2] "+r"(outr2),
[outr3] "+r"(outr3),
[cnt] "+r"(cnt)
: [vb] "w"(vb), [vg] "w"(vg), [vr] "w"(vr)
: [vb] "w"(vb1), [vg] "w"(vg1), [vr] "w"(vr1)
: "cc",
"memory",
"v0",
......@@ -1104,6 +1107,9 @@ void hwc4_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) {
uint8x8_t vg = vdup_n_u8(g);
uint8x8_t vr = vdup_n_u8(r);
#ifdef __aarch64__
uint8x16_t vb1 = vdupq_n_u8(b);
uint8x16_t vg1 = vdupq_n_u8(g);
uint8x16_t vr1 = vdupq_n_u8(r);
#else
uint8_t vb_array[8] = {b, b, b, b, b, b, b, b};
uint8_t vg_array[8] = {g, g, g, g, g, g, g, g};
......@@ -1199,7 +1205,7 @@ void hwc4_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) {
[outr2] "+r"(outr2),
[outr3] "+r"(outr3),
[cnt] "+r"(cnt)
: [vb] "w"(vb), [vg] "w"(vg), [vr] "w"(vr)
: [vb] "w"(vb1), [vg] "w"(vg1), [vr] "w"(vr1)
: "cc",
"memory",
"v0",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册