提交 8df0d2fd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5759 [MS][LITE][DDevelop] concat ops support nc4hw4 format

Merge pull request !5759 from pengyongrong/op_format_toNC4HW4
......@@ -2,10 +2,10 @@
#define INT4 int4
#define INT2 int2
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void batch_normalization(__read_only image2d_t input, __read_only image2d_t scale,
__read_only image2d_t offset, __read_only image2d_t mean,
__read_only image2d_t variance, __write_only image2d_t output, const INT4 input_shape,
float epsilon) {
__kernel void Batch_normalization_NHWC4(__read_only image2d_t input, __read_only image2d_t scale,
__read_only image2d_t offset, __read_only image2d_t mean,
__read_only image2d_t variance, __write_only image2d_t output,
const INT4 input_shape, float epsilon) {
int X = get_global_id(0); // H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // C/4
......@@ -25,3 +25,27 @@ __kernel void batch_normalization(__read_only image2d_t input, __read_only image
result.w = result_scale.w * ((result.w - result_mean.w) / sqrt(result_var.w + epsilon)) + result_offset.w;
WRITE_IMAGE(output, (int2)((Y)*input_shape.w + Z, (X)), result);
}
__kernel void Batch_normalization_NC4HW4(__read_only image2d_t input, __read_only image2d_t scale,
__read_only image2d_t offset, __read_only image2d_t mean,
__read_only image2d_t variance, __write_only image2d_t output,
const INT4 input_shape, float epsilon) {
int X = get_global_id(0); // H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // C/4
if (X >= input_shape.y || Y >= input_shape.z || Z >= input_shape.w) {
return;
}
FLT4 result = READ_IMAGE(input, smp_none, (int2)((Y), (Z * input_shape.y + X)));
FLT4 result_mean = READ_IMAGE(mean, smp_none, (int2)((0), (Z)));
FLT4 result_var = READ_IMAGE(variance, smp_none, (int2)((0), (Z)));
FLT4 result_scale = READ_IMAGE(scale, smp_none, (int2)((0), (Z)));
FLT4 result_offset = READ_IMAGE(offset, smp_none, (int2)((0), (Z)));
result.x = result_scale.x * ((result.x - result_mean.x) / sqrt(result_var.x + epsilon)) + result_offset.x;
result.y = result_scale.y * ((result.y - result_mean.y) / sqrt(result_var.y + epsilon)) + result_offset.y;
result.z = result_scale.z * ((result.z - result_mean.z) / sqrt(result_var.z + epsilon)) + result_offset.z;
result.w = result_scale.w * ((result.w - result_mean.w) / sqrt(result_var.w + epsilon)) + result_offset.w;
WRITE_IMAGE(output, (int2)((Y), (Z * input_shape.y + X)), result);
}
......@@ -14,7 +14,6 @@
* limitations under the License.
*/
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
......@@ -35,7 +34,7 @@ int BatchNormOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_siz
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height();
} else {
im_dst_y = out_tensors_[0]->Height() * CO4;
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
}
size_t img_dtype = CL_FLOAT;
......@@ -50,17 +49,29 @@ int BatchNormOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_siz
return RET_OK;
}
int BatchNormOpenCLKernel::Init() {
auto in_format = op_format_;
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
MS_LOG(ERROR) << "input format(" << in_format << ") "
<< "format not support!";
return RET_ERROR;
}
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(op_format_);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(op_format_);
std::string kernel_name = "Batch_normalization";
if (in_format == schema::Format_NC4HW4) {
kernel_name += "_NC4HW4";
} else if (in_format == schema::Format_NHWC4) {
kernel_name += "_NHWC4";
}
std::set<std::string> build_options;
std::string source = batchnorm_source;
std::string program_name = "batch_normalization";
std::string kernel_name = "batch_normalization";
std::string program_name = "Batch_normalization";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
return RET_OK;
}
......
......@@ -197,7 +197,7 @@ TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) {
return;
}
auto *output_tensor =
new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type);
new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(INFO) << " init tensor failed ";
delete tensor_data;
......
......@@ -180,8 +180,8 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
MS_LOG(INFO) << " init tensors ";
constexpr int INPUT_NUM = 3;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {
std::vector<int>{1, 2, 4, 8}, std::vector<int>{1, 2, 4, 8}, std::vector<int>{1, 2, 4, 8}};
std::vector<int> output_shape = {3, 2, 4, 8};
std::vector<int>{1, 16, 256, 80}, std::vector<int>{1, 16, 256, 80}, std::vector<int>{1, 16, 256, 80}};
std::vector<int> output_shape = {1, 48, 256, 80};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;
std::vector<lite::tensor::Tensor *> inputs;
......@@ -217,7 +217,7 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
}
return;
}
param->axis_ = 0;
param->axis_ = 1;
auto *concat_kernel =
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (concat_kernel == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册