提交 de97ec84 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5743 [MS][LITE][DDevelop] concat ops support nc4hw4 format

Merge pull request !5743 from pengyongrong/op_format_toNC4HW4
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void Concat(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output,
int4 input_shape0, int4 input_shape1, int4 output_shape, const int axis) {
__kernel void Concat2input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 output_shape,
const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
......@@ -44,9 +45,9 @@ __kernel void Concat(__read_only image2d_t input0, __read_only image2d_t input1,
}
}
__kernel void Concat3input(__read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2,
int4 output_shape, const int axis) {
__kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
......@@ -105,3 +106,144 @@ __kernel void Concat3input(__read_only image2d_t input0, __read_only image2d_t i
}
}
}
__kernel void Concat2input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1,
int4 output_shape, const int axis) {
int X = get_global_id(0); // H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (input_shape0.y == 0 || input_shape1.y == 0 || output_shape.y == 0) {
return;
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 0) {
if (X < (input_shape0.x * input_shape0.y)) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y +
Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
((X - input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else {
if (Z < input_shape0.w) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
(X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
}
}
__kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) {
return;
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 0) {
if (X < (input_shape0.x * input_shape0.y)) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) {
in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y +
Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) / input_shape2.y) *
input_shape2.w * input_shape2.y +
Z * input_shape2.y +
(X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) % input_shape2.y;
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < input_shape0.y + input_shape1.y) {
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
((X - input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y +
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Y < input_shape0.z + input_shape1.z) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else {
if (Z < input_shape0.w) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Z < input_shape0.w + input_shape1.w) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
(X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y +
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
}
}
......@@ -35,8 +35,8 @@ int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size)
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height() * out_tensors_[0]->Batch();
} else {
im_dst_y = out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width() * out_tensors_[0]->Batch();
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
}
size_t img_dtype = CL_FLOAT;
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
......@@ -61,30 +61,37 @@ int ConcatOpenCLKernel::Init() {
MS_LOG(ERROR) << " only support axis >= 0 and axis <= 3 ";
return RET_ERROR;
}
auto in_format = op_format_;
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
MS_LOG(ERROR) << "input format(" << in_format << ") "
<< "format not support!";
return RET_ERROR;
}
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(op_format_);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(op_format_);
if (in_tensors_.size() == 2) {
std::set<std::string> build_options;
std::string source = concat_source;
std::string program_name = "Concat";
std::string kernel_name = "Concat";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
if (in_tensors_.size() == 2) {
kernel_name += "2input";
} else if (in_tensors_.size() == 3) {
kernel_name += "3input";
} else {
MS_LOG(ERROR) << " input must be 2 or 3";
return RET_ERROR;
}
if (in_format == schema::Format_NC4HW4) {
kernel_name += "_NC4HW4";
} else if (in_format == schema::Format_NHWC4) {
kernel_name += "_NHWC4";
}
if (in_tensors_.size() == 3) {
std::set<std::string> build_options;
std::string source = concat_source;
std::string program_name = "Concat3input";
std::string kernel_name = "Concat3input";
std::string program_name = "Concat";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
}
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
return RET_OK;
}
......
......@@ -49,6 +49,7 @@ int DepthwiseConv2dOpenCLKernel::Init() {
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
MS_LOG(ERROR) << "input format(" << in_format << ") "
<< "format not support!";
return RET_ERROR;
}
in_tensors_[0]->SetFormat(in_format);
out_tensors_[0]->SetFormat(in_format);
......@@ -103,6 +104,7 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
PackNCHWToNC4HW4<float, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
return RET_ERROR;
}
} else {
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float));
......@@ -112,6 +114,7 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
return RET_ERROR;
}
}
......
......@@ -168,10 +168,10 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
// get the input from .bin
size_t input1_size, input2_size, input3_size, output_size;
std::string input1Ppath = "./test_data/concat_input1.bin";
std::string input2Ppath = "./test_data/concat_input2.bin";
std::string input3Ppath = "./test_data/concat_input3.bin";
std::string correctOutputPath = "./test_data/concat_output.bin";
std::string input1Ppath = "./test_data/concatfp32_input1.bin";
std::string input2Ppath = "./test_data/concatfp32_input2.bin";
std::string input3Ppath = "./test_data/concatfp32_input3.bin";
std::string correctOutputPath = "./test_data/concatfp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto input_data3 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size));
......@@ -180,8 +180,8 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
MS_LOG(INFO) << " init tensors ";
constexpr int INPUT_NUM = 3;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {
std::vector<int>{1, 16, 256, 80}, std::vector<int>{1, 16, 256, 80}, std::vector<int>{1, 16, 256, 80}};
std::vector<int> output_shape = {1, 48, 256, 80};
std::vector<int>{1, 2, 4, 8}, std::vector<int>{1, 2, 4, 8}, std::vector<int>{1, 2, 4, 8}};
std::vector<int> output_shape = {3, 2, 4, 8};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;
std::vector<lite::tensor::Tensor *> inputs;
......@@ -217,7 +217,7 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
}
return;
}
param->axis_ = 1;
param->axis_ = 0;
auto *concat_kernel =
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (concat_kernel == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册