提交 93a388a3 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] add opencl image2d elementwise_mul. test=develop (#2815)

* [LITE][OPENCL] remove useless code. test=develop

* [LITE][OPENCL] finish 4 kernel and unit tests of image2d opencl elementwise_mul kernel. test=develop

* [LITE][OPENCL] Fix little bug of ASSERT. test=develop

* [LITE][OPENCL] Fix bug of channel_mul_d2 and d4. test=develop
上级 957ad169
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void elementwise_mul(__global image2d_t input, __global image2d_t bias,
__write_only image2d_t outputImage) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords);
CL_DTYPE4 output = in * biase;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d1(__read_only image2d_t input, __read_only image2d_t bias,
__write_only image2d_t outputImage, int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias;
coords_bias.x = x % w;
coords_bias.y = 0;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
CL_DTYPE4 output = in * (CL_DTYPE4)(biase.x);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d2(__read_only image2d_t input, __read_only image2d_t bias,
__write_only image2d_t outputImage, int w, int h) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias;
coords_bias.x = x % w;
coords_bias.y = y % h;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
CL_DTYPE4 output = in * (CL_DTYPE4)(biase.x);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d4(__read_only image2d_t input, __read_only image2d_t bias,
__write_only image2d_t outputImage, int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias;
coords_bias.x = x / w;
coords_bias.y = 0;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
CL_DTYPE4 output = in * biase;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
......@@ -2,14 +2,15 @@ if ((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT LITE_WITH_OPENCL))
return ()
endif()
set(cl_kernel_deps op_params cl_runtime cl_context cl_wrapper cl_target_wrapper)
set(cl_kernel_deps op_params cl_runtime cl_context cl_wrapper cl_target_wrapper cl_image_converter)
add_kernel(fc_opencl OPENCL basic SRCS fc_compute.cc DEPS ${cl_kernel_deps})
add_kernel(mul_opencl OPENCL basic SRCS mul_compute.cc DEPS ${cl_kernel_deps})
add_kernel(elementwise_add_opencl OPENCL basic SRCS elementwise_add_compute.cc DEPS ${cl_kernel_deps})
add_kernel(elementwise_mul_opencl OPENCL basic SRCS elementwise_mul_compute.cc DEPS ${cl_kernel_deps})
add_kernel(fusion_elementwise_add_activation_opencl
OPENCL basic SRCS fusion_elementwise_add_activation_compute.cc
DEPS elementwise_add_opencl ${cl_kernel_deps})
OPENCL basic SRCS fusion_elementwise_add_activation_compute.cc
DEPS elementwise_add_opencl ${cl_kernel_deps})
add_kernel(pool_opencl OPENCL basic SRCS pool_compute.cc DEPS ${cl_kernel_deps})
add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps})
add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps})
......@@ -20,16 +21,20 @@ add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps})
add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps})
lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc
DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_elementwise_mul_opencl SRCS elementwise_mul_compute_test.cc
DEPS elementwise_mul_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_pool_opencl SRCS pool_compute_test.cc
DEPS pool_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS pool_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_fc_opencl SRCS fc_compute_test.cc
DEPS fc_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS fc_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
# TODO(ysh329): comment for buffer-impl mul
#lite_cc_test(test_mul_opencl SRCS mul_compute_test.cc
......@@ -37,34 +42,34 @@ lite_cc_test(test_fc_opencl SRCS fc_compute_test.cc
# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_io_copy_compute_opencl SRCS io_copy_compute_test.cc
DEPS io_copy_compute_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS io_copy_compute_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
#TODO(ysh329): comment buffer-impl relu
lite_cc_test(test_relu_opencl SRCS relu_compute_test.cc
DEPS relu_opencl layout_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS relu_opencl layout_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS depthwise_conv2d_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_depthwise_conv2d_basic_opencl SRCS depthwise_conv2d_basic_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS depthwise_conv2d_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
#lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc
# DEPS conv2d_1x1_opencl cl_image_converter op_registry program context
# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
# DEPS conv2d_1x1_opencl op_registry program context
# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_reshape_opencl SRCS reshape_compute_test.cc
DEPS reshape_opencl cl_image_converter op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS reshape_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc
DEPS conv_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS conv_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc
DEPS layout_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
DEPS layout_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/opencl/elementwise_mul_compute.h"
#include <memory>
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/op_registry.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
void ElementwiseMulFloatImageCompute::PrepareForRun() {
ele_param_ = param_.get_mutable<param_t>();
auto* y = ele_param_->Y;
auto y_dims = y->dims();
if (y_dims == ele_param_->X->dims()) {
kernel_func_name_ = "elementwise_mul";
} else if (y_dims.size() == 1) {
kernel_func_name_ = "channel_mul_d1";
} else if (y_dims.size() == 2) {
kernel_func_name_ = "channel_mul_d2";
} else if (y_dims.size() == 4) {
kernel_func_name_ = "channel_mul_d4";
} else {
LOG(FATAL) << "ElementwiseMul not supported y_dims.size():" << y_dims.size()
<< ", x_dims.size():" << ele_param_->X->dims().size();
}
VLOG(4) << "kernel_func_name_:" << kernel_func_name_;
VLOG(4) << "y_dims:" << y_dims;
VLOG(4) << "y_dims.size():" << y_dims.size();
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/elementwise_mul_kernel.cl", build_options_);
}
void ElementwiseMulFloatImageCompute::Run() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* x = ele_param_->X;
auto* y = ele_param_->Y;
auto* out = ele_param_->Out;
VLOG(4) << "x->target():" << TargetToStr(x->target());
VLOG(4) << "y->target():" << TargetToStr(y->target());
VLOG(4) << "out->target():" << TargetToStr(out->target());
VLOG(4) << "x->dims():" << x->dims();
VLOG(4) << "y->dims():" << y->dims();
VLOG(4) << "out->dims():" << out->dims();
paddle::lite::CLImageConverterDefault default_convertor;
auto x_img_shape = default_convertor.InitImageDimInfoWith(x->dims()); // w, h
auto x_img_width = x_img_shape[0];
auto x_img_height = x_img_shape[1];
auto out_img_shape =
default_convertor.InitImageDimInfoWith(out->dims()); // w, h
auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims());
auto* x_img = x->data<float, cl::Image2D>();
auto* y_img = y->data<float, cl::Image2D>();
auto* out_img =
out->mutable_data<float, cl::Image2D>(out_img_shape[0], out_img_shape[1]);
VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height;
VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1];
VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " "
<< out_img_shape[1];
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
auto y_dims = y->dims();
if (y_dims == ele_param_->X->dims()) {
// kernel: elementwise_mul(channel_mul_d4)
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1 || y_dims.size() == 4) {
auto tensor_w = x->dims()[x->dims().size() - 1];
VLOG(4) << "tensor_w:" << tensor_w;
// kernel: channel_mul_d1 / channel_mul_d4
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w));
CL_CHECK_FATAL(status);
} else if (y_dims.size() == 2) {
auto y_tensor_h = y->dims()[0];
auto y_tensor_w = y->dims()[1];
VLOG(4) << "y_tensor_w:" << y_tensor_w << " y_tensor_h:" << y_tensor_h;
// kernel: channel_mul_d2
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_h));
CL_CHECK_FATAL(status);
} else {
LOG(FATAL) << "ElementwiseMul not supported y_dims.size():"
<< y_dims.size();
}
auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
}
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
namespace ocl = paddle::lite::kernels::opencl;
REGISTER_LITE_KERNEL(elementwise_mul,
kOpenCL,
kFloat,
kImageDefault,
ocl::ElementwiseMulFloatImageCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class ElementwiseMulFloatImageCompute
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ElementwiseParam;
std::string doc() const override {
return "ElementwiseMul using cl::Image2D(ImageDefault/RGBA), kFP32";
}
void PrepareForRun() override;
void Run() override;
protected:
param_t* ele_param_{nullptr};
std::string kernel_func_name_{"elementwise_mul"};
std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <random>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
template <typename dtype>
void fill_data(dtype *x, const int length, int set_value = -1) {
if (set_value == -1) {
for (size_t idx = 0; idx < length; ++idx) {
x[idx] = idx;
}
} else if (set_value != -1) {
for (size_t idx = 0; idx < length; ++idx) {
x[idx] = set_value;
}
}
}
template <typename dtype>
void elementwise_compute_ref(const dtype *x_data,
const dtype *y_data,
dtype *out_data,
const DDim &x_dims,
const DDim &y_dims,
int axis,
const std::string elt_type,
bool use_relu = false) {
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
if (x_dims == y_dims || y_dims.size() == 2 || y_dims.size() == 1) {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype *din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype *dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr * diny_data;
if (use_relu) {
*dout_ptr = std::max(*dout_ptr, static_cast<dtype>(0));
}
dout_ptr++;
din_ptr++;
}
}
}
} else if (y_dims.size() == 4) {
// eg: x_dims: [1, 3, 2, 2]
// y_dims: [1, 3, 1, 1]
ASSERT_EQ(y_dims[2], y_dims[3]);
ASSERT_EQ(y_dims[2], 1);
ASSERT_EQ(y_dims[0], 1);
auto y_offset = y_dims.production();
auto x_offset = x_dims.production() / y_offset;
for (auto x = 0; x < x_dims.production(); ++x) {
auto y = x / x_offset;
out_data[x] = x_data[x] * y_data[y];
}
} else {
LOG(FATAL) << "unsupported Elementwise type: " << elt_type << std::endl;
}
}
// #define PRINT_RESULT
TEST(elemul_image2d_fp32, compute_kernel_elemenwise_mul) {
LOG(INFO)
<< "main steps of test: host -> layout(buf2img on cpu) -> elemul(img) -> "
"layout(img2buf on cpu) "
"-> host";
// dims
const int n = 1;
const int c = 3;
const int h = 2;
const int w = 2;
const DDim x_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
auto out_dim = x_dim;
std::vector<DDim> y_dim_v{DDim(std::vector<DDim::value_type>{n, c, 1, 1}),
DDim(std::vector<DDim::value_type>{n, c, h, w}),
DDim(std::vector<DDim::value_type>{h, w}),
DDim(std::vector<DDim::value_type>{w})};
for (auto y_dim : y_dim_v) {
LOG(INFO) << "================== elementwise_mul ===================";
LOG(INFO) << "x_dim:" << x_dim << "\ty_dim:" << y_dim
<< "\tout_dim:" << out_dim;
// tensor
LOG(INFO) << "set tensors about op param";
lite::Tensor elemul_x, elemul_y, elemul_out;
elemul_x.Resize(x_dim);
elemul_y.Resize(y_dim);
elemul_out.Resize(out_dim);
// initialize tensors
VLOG(4) << "initialize tensors";
paddle::lite::CLImageConverterDefault default_convertor;
// x
std::vector<float> x_v(x_dim.production());
fill_data<float>(x_v.data(), x_v.size()); // fill with index value
auto x_img_shape = default_convertor.InitImageDimInfoWith(x_dim); // w, h
auto x_img_w = x_img_shape[0];
auto x_img_h = x_img_shape[1];
std::vector<float> x_img_v(x_img_w * x_img_h * 4); // 4: RGBA
default_convertor.NCHWToImage(x_v.data(), x_img_v.data(), x_dim);
elemul_x.mutable_data<float, cl::Image2D>(x_img_w, x_img_h, x_img_v.data());
// y
std::vector<float> y_v(y_dim.production());
fill_data<float>(y_v.data(), y_v.size()); // fill with index value
auto y_img_shape = default_convertor.InitImageDimInfoWith(y_dim); // w, h
auto y_img_w = y_img_shape[0];
auto y_img_h = y_img_shape[1];
std::vector<float> y_img_v(y_img_shape[0] * y_img_shape[1] * 4); // 4: RGBA
default_convertor.NCHWToImage(y_v.data(), y_img_v.data(), y_dim);
elemul_y.mutable_data<float, cl::Image2D>(y_img_w, y_img_h, y_img_v.data());
// out
auto out_img_shape =
default_convertor.InitImageDimInfoWith(out_dim); // w, h
auto out_img_w = out_img_shape[0];
auto out_img_h = out_img_shape[1];
elemul_out.mutable_data<float, cl::Image2D>(out_img_w, out_img_h);
std::vector<float> out_img_v(out_img_w * out_img_h * 4);
fill_data<float>(
out_img_v.data(), out_img_v.size(), 0); // fill with zero value
std::vector<float> out_v(out_dim.production());
// operator param
operators::ElementwiseParam elemulParam;
elemulParam.X = &elemul_x;
elemulParam.Y = &elemul_y;
elemulParam.Out = &elemul_out;
elemulParam.axis = -1;
// set kernel
auto elemul_img_kernels =
KernelRegistry::Global().Create("elementwise_mul",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(elemul_img_kernels.empty());
auto elemul_img_kernel = std::move(elemul_img_kernels.front());
VLOG(4) << "get elemul kernel: " << elemul_img_kernel->doc();
// set context and kernel args
VLOG(4) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
elemul_img_kernel->SetParam(elemulParam);
std::unique_ptr<KernelContext> elemul_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(elemul_img_context->As<OpenCLContext>()));
elemul_img_kernel->SetContext(std::move(elemul_img_context));
// run kernel
VLOG(4) << "run kernel";
elemul_img_kernel->Launch();
// download gpu result to cpu
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
TargetWrapperCL::ImgcpySync(out_img_v.data(),
elemul_out.data<float, cl::Image2D>(),
out_img_w,
out_img_h,
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
default_convertor.ImageToNCHW(
out_img_v.data(), out_v.data(), out_img_shape, out_dim);
// compute cpu reference
std::unique_ptr<float[]> out_ref(new float[out_dim.production()]);
elementwise_compute_ref<float>(x_v.data(),
y_v.data(),
out_ref.get(),
x_dim,
y_dim,
elemulParam.axis,
"mul");
#if 0 // enable to check value of x and y
for (int eidx = 0; eidx < out_dim.production(); eidx++) {
auto value = out_v[eidx];
auto ref_value = out_ref.get()[eidx];
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / "
<< out_dim.production() << ", x_v[" << eidx << "]:"
<< x_v[eidx] << ", value[" << eidx << "]:" << value
<< ", ref_value[" << eidx << "]:" << ref_value;
}
for (int i = 0; i < y_v.size(); i++) {
LOG(INFO) << "y_v[" << i << "]:" << y_v[i];
}
#endif
for (int eidx = 0; eidx < out_dim.production(); eidx++) {
auto value = out_v[eidx];
auto ref_value = out_ref.get()[eidx];
EXPECT_NEAR(value, ref_value, 1e-6);
if (abs(value - ref_value) > 1e-6) {
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / "
<< out_dim.production() << ", value[" << eidx << "]:" << value
<< ", ref_value[" << eidx << "]:" << ref_value;
break;
}
}
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(elementwise_mul, kOpenCL, kFloat, kImageDefault, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册