提交 454d01aa 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] add elementwise_mul kernel with x(nchw), y(nc). test=develop (#2945)

上级 a2d956e1
......@@ -14,7 +14,8 @@ limitations under the License. */
#include <cl_common.h>
__kernel void elementwise_mul(__global image2d_t input, __global image2d_t bias,
__kernel void elementwise_mul(__global image2d_t input,
__global image2d_t bias,
__write_only image2d_t outputImage) {
int x = get_global_id(0);
int y = get_global_id(1);
......@@ -29,8 +30,11 @@ __kernel void elementwise_mul(__global image2d_t input, __global image2d_t bias,
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d1(__read_only image2d_t input, __read_only image2d_t bias,
__write_only image2d_t outputImage, int w) {
__kernel void channel_mul_d1(__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t outputImage,
int w) {
int x = get_global_id(0);
int y = get_global_id(1);
......@@ -52,8 +56,88 @@ __kernel void channel_mul_d1(__read_only image2d_t input, __read_only image2d_t
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d2(__read_only image2d_t input, __read_only image2d_t bias,
__write_only image2d_t outputImage, int w, int h) {
// #define DEBUG
__kernel void channel_mul_d2_nc(__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t outputImage,
int w) {
int x = get_global_id(0);
int y = get_global_id(1);
#ifdef DEBUG
printf("x:%d y:%d\n", x, y);
#endif
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
int2 coords_bias0 = (int2)(x / w * 4, 0);
int2 coords_bias1 = (int2)(x / w * 4 + 1, 0);
int2 coords_bias2 = (int2)(x / w * 4 + 2, 0);
int2 coords_bias3 = (int2)(x / w * 4 + 3, 0);
CL_DTYPE4 b0 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias0);
CL_DTYPE4 b1 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias1);
CL_DTYPE4 b2 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias2);
CL_DTYPE4 b3 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias3);
CL_DTYPE4 biase = {b0.x, b1.x, b2.x, b3.x};
CL_DTYPE4 output = mad(in, biase, 0);
#ifdef DEBUG
if (x == 0 && y == 0) {
printf("w:%d\n", w);
printf("biase:%.1f %.1f %.1f %.1f\n", biase.x, biase.y, biase.z, biase.w);
printf("output:%.1f %.1f %.1f %.1f\n", output.x, output.y, output.z, output.w);
coords.x = 0;
coords.y = 0;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
printf("in(%d,%d):%.2f %.2f %.2f %.2f\n", coords.x, coords.y, in.x, in.y, in.z, in.w);
coords.x = 0;
coords.y = 1;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
printf("in(%d,%d):%.2f %.2f %.2f %.2f\n", coords.x, coords.y, in.x, in.y, in.z, in.w);
coords.x = 1;
coords.y = 0;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
printf("in(%d,%d):%.2f %.2f %.2f %.2f\n", coords.x, coords.y, in.x, in.y, in.z, in.w);
coords.x = 1;
coords.y = 1;
in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
printf("in(%d,%d):%.2f %.2f %.2f %.2f\n", coords.x, coords.y, in.x, in.y, in.z, in.w);
coords_bias.x = 0;
coords_bias.y = 0;
biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
printf("biase(%d,%d):%.2f %.2f %.2f %.2f\n", coords_bias.x, coords_bias.y, biase.x, biase.y, biase.z, biase.w);
coords_bias.x = 1;
coords_bias.y = 0;
biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
printf("biase(%d,%d):%.2f %.2f %.2f %.2f\n", coords_bias.x, coords_bias.y, biase.x, biase.y, biase.z, biase.w);
coords_bias.x = 2;
coords_bias.y = 0;
biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
printf("biase(%d,%d):%.2f %.2f %.2f %.2f\n", coords_bias.x, coords_bias.y, biase.x, biase.y, biase.z, biase.w);
}
#endif
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d2_hw(__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t outputImage,
int w,
int h) {
int x = get_global_id(0);
int y = get_global_id(1);
......@@ -75,8 +159,11 @@ __kernel void channel_mul_d2(__read_only image2d_t input, __read_only image2d_t
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d4(__read_only image2d_t input, __read_only image2d_t bias,
__write_only image2d_t outputImage, int w) {
__kernel void channel_mul_d4(__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t outputImage,
int w) {
int x = get_global_id(0);
int y = get_global_id(1);
......
......@@ -26,13 +26,19 @@ namespace opencl {
void ElementwiseMulFloatImageCompute::PrepareForRun() {
ele_param_ = param_.get_mutable<param_t>();
auto* y = ele_param_->Y;
auto* x = ele_param_->X;
auto y_dims = y->dims();
if (y_dims == ele_param_->X->dims()) {
auto x_dims = x->dims();
if (y_dims == x_dims) {
kernel_func_name_ = "elementwise_mul";
} else if (y_dims.size() == 1) {
kernel_func_name_ = "channel_mul_d1";
} else if (y_dims.size() == 2) {
kernel_func_name_ = "channel_mul_d2";
if (x_dims[0] == y_dims[0] && x_dims[1] == y_dims[1]) {
kernel_func_name_ = "channel_mul_d2_nc";
} else {
kernel_func_name_ = "channel_mul_d2_hw";
}
} else if (y_dims.size() == 4) {
kernel_func_name_ = "channel_mul_d4";
} else {
......@@ -87,7 +93,8 @@ void ElementwiseMulFloatImageCompute::Run() {
int arg_idx = 0;
auto y_dims = y->dims();
if (y_dims == ele_param_->X->dims()) {
auto x_dims = x->dims();
if (y_dims == x_dims) {
// kernel: elementwise_mul(channel_mul_d4)
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
......@@ -96,7 +103,7 @@ void ElementwiseMulFloatImageCompute::Run() {
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1 || y_dims.size() == 4) {
auto tensor_w = x->dims()[x->dims().size() - 1];
auto tensor_w = x_dims[x_dims.size() - 1];
VLOG(4) << "tensor_w:" << tensor_w;
// kernel: channel_mul_d1 / channel_mul_d4
cl_int status = kernel.setArg(arg_idx, *x_img);
......@@ -108,20 +115,34 @@ void ElementwiseMulFloatImageCompute::Run() {
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w));
CL_CHECK_FATAL(status);
} else if (y_dims.size() == 2) {
auto y_tensor_h = y->dims()[0];
auto y_tensor_w = y->dims()[1];
VLOG(4) << "y_tensor_w:" << y_tensor_w << " y_tensor_h:" << y_tensor_h;
// kernel: channel_mul_d2
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_h));
CL_CHECK_FATAL(status);
if (x_dims[0] == y_dims[0] && x_dims[1] == y_dims[1]) {
auto tensor_w = x_dims[x_dims.size() - 1];
VLOG(4) << "tensor_w:" << tensor_w;
// kernel: channel_mul_d2_nc
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w));
CL_CHECK_FATAL(status);
} else {
auto y_tensor_h = y->dims()[0];
auto y_tensor_w = y->dims()[1];
VLOG(4) << "y_tensor_w:" << y_tensor_w << " y_tensor_h:" << y_tensor_h;
// kernel: channel_mul_d2_hw
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_h));
CL_CHECK_FATAL(status);
}
} else {
LOG(FATAL) << "ElementwiseMul not supported y_dims.size():"
<< y_dims.size();
......
......@@ -60,7 +60,23 @@ void elementwise_compute_ref(const dtype *x_data,
num *= x_dims[i];
}
if (x_dims == y_dims || y_dims.size() == 2 || y_dims.size() == 1) {
if (x_dims.size() == 4 && y_dims.size() == 2 && x_dims[0] == y_dims[0] &&
y_dims[1] == y_dims[1]) {
int n = x_dims[0];
int c = x_dims[1];
int h = x_dims[2];
int w = x_dims[3];
// case for x_dims: n,c,h,w
// y_dims: n,c
for (int i = 0; i < n; ++i) {
for (int j = 0; j < c; ++j) {
for (int k = 0; k < h * w; ++k) {
out_data[i * c * h * w + j * h * w + k] =
x_data[i * c * h * w + j * h * w + k] * y_data[j];
}
}
}
} else if (x_dims == y_dims || y_dims.size() == 2 || y_dims.size() == 1) {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
......@@ -103,7 +119,7 @@ TEST(elemul_image2d_fp32, compute_kernel_elemenwise_mul) {
// dims
const int n = 1;
const int c = 3;
const int c = 7;
const int h = 2;
const int w = 2;
......@@ -112,6 +128,7 @@ TEST(elemul_image2d_fp32, compute_kernel_elemenwise_mul) {
std::vector<DDim> y_dim_v{DDim(std::vector<DDim::value_type>{n, c, 1, 1}),
DDim(std::vector<DDim::value_type>{n, c, h, w}),
DDim(std::vector<DDim::value_type>{h, w}),
DDim(std::vector<DDim::value_type>{n, c}),
DDim(std::vector<DDim::value_type>{w})};
for (auto y_dim : y_dim_v) {
LOG(INFO) << "================== elementwise_mul ===================";
......@@ -217,14 +234,14 @@ TEST(elemul_image2d_fp32, compute_kernel_elemenwise_mul) {
elemulParam.axis,
"mul");
#if 0 // enable to check value of x and y
#ifdef PRINT_RESULT // enable to check value of x and y
for (int eidx = 0; eidx < out_dim.production(); eidx++) {
auto value = out_v[eidx];
auto ref_value = out_ref.get()[eidx];
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / "
<< out_dim.production() << ", x_v[" << eidx << "]:"
<< x_v[eidx] << ", value[" << eidx << "]:" << value
<< ", ref_value[" << eidx << "]:" << ref_value;
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / "
<< out_dim.production() << ", x_v[" << eidx << "]:" << x_v[eidx]
<< ", value[" << eidx << "]:" << value << ", ref_value[" << eidx
<< "]:" << ref_value;
}
for (int i = 0; i < y_v.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册