提交 e68761ac 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5522 prelu biasadd support fp16 in opencl

Merge pull request !5522 from liuzhongkai/fp16
#pragma OPENCL EXTENSION cl_arm_printf : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define C4NUM 4
#define SLICES 4
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) #define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define FLT4 float4
#define READ_FLT4 read_imagef
#define WRITE_FLT4 write_imagef
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, __kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
__global float *alpha, const int dim) { __read_only image2d_t alpha, const int dim) {
int C = input_shape.w; // channel size int C = input_shape.w; // channel size
int Y = get_global_id(0); // height id int Y = get_global_id(0); // height id
int X = get_global_id(1); // weight id int X = get_global_id(1); // weight id
for (int num = 0; num < UP_DIV(C, SLICES); ++num) { for (int num = 0; num < UP_DIV(C, C4NUM); ++num) {
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, C4NUM) + num, Y)); // NHWC4: H WC
FLT4 tmp; FLT4 tmp = in_c4;
int index = 0; int index = 0;
if (dim == 2) { if (dim == 2) {
index = X * 4; index = X;
} else { } else {
index = num * 4; index = num;
} }
tmp.x = in_c4.x + alpha[index]; tmp += READ_IMAGE(alpha, smp_zero, (int2)(index, 0));
tmp.y = in_c4.y + alpha[index + 1]; WRITE_IMAGE(output, (int2)(X * UP_DIV(C, C4NUM) + num, Y), tmp); // NHWC4: H WC
tmp.z = in_c4.z + alpha[index + 2];
tmp.w = in_c4.w + alpha[index + 3];
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
} }
} }
#pragma OPENCL EXTENSION cl_arm_printf : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define SLICES 4 #define SLICES 4
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) #define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
__global float *alpha, const int dim) { __read_only image2d_t alpha, const int dim) {
int C = input_shape.w; // channel size int C = input_shape.w; // channel size
int Y = get_global_id(0); // height id int Y = get_global_id(0); // height id
...@@ -14,16 +13,17 @@ __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, ...@@ -14,16 +13,17 @@ __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output,
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
FLT4 tmp; FLT4 tmp;
if (dim == 1) { if (dim == 1) {
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * (*alpha); FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(0, 0));
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * (*alpha); tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x;
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * (*alpha); tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.x;
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * (*alpha); tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.x;
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.x;
} else { } else {
int index = num * 4; FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(num, 0));
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha[index]; tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x;
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha[index + 1]; tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.y;
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha[index + 2]; tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.z;
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha[index + 3]; tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.w;
} }
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
} }
......
...@@ -39,19 +39,24 @@ void BiasAddOpenCLKernel::InitBuffer() { ...@@ -39,19 +39,24 @@ void BiasAddOpenCLKernel::InitBuffer() {
int C = in_tensors_[1]->shape()[0]; int C = in_tensors_[1]->shape()[0];
int div_ci = UP_DIV(C, C4NUM); int div_ci = UP_DIV(C, C4NUM);
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
BiasAdd_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t))); size_t img_dtype = CL_FLOAT;
BiasAdd_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true)); if (enable_fp16_) {
memset(BiasAdd_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t)); img_dtype = CL_HALF_FLOAT;
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_[1]->Data());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
BiasAdd_[i] = origin_weight[i];
} }
std::vector<size_t> img_size{size_t(div_ci), 1, img_dtype};
BiasAdd_ = allocator->Malloc(div_ci * C4NUM * fp_size, img_size);
BiasAdd_ = allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true);
memset(BiasAdd_, 0x00, div_ci * C4NUM * fp_size);
memcpy(BiasAdd_, in_tensors_[1]->Data(), C * fp_size);
allocator->UnmapBuffer(BiasAdd_); allocator->UnmapBuffer(BiasAdd_);
} }
int BiasAddOpenCLKernel::Init() { int BiasAddOpenCLKernel::Init() {
in_size_ = in_tensors_[0]->shape().size(); in_size_ = in_tensors_[0]->shape().size();
out_size_ = out_tensors_[0]->shape().size(); out_size_ = out_tensors_[0]->shape().size();
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
enable_fp16_ = ocl_runtime->GetFp16Enable();
fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float);
if (in_size_ != 4 && in_size_ != 2) { if (in_size_ != 4 && in_size_ != 2) {
MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_; MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_;
return RET_ERROR; return RET_ERROR;
...@@ -67,7 +72,6 @@ int BiasAddOpenCLKernel::Init() { ...@@ -67,7 +72,6 @@ int BiasAddOpenCLKernel::Init() {
std::string source = biasadd_source; std::string source = biasadd_source;
std::string program_name = "BiasAdd"; std::string program_name = "BiasAdd";
std::string kernel_name = "BiasAdd"; std::string kernel_name = "BiasAdd";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source); ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
......
...@@ -42,9 +42,11 @@ class BiasAddOpenCLKernel : public OpenCLKernel { ...@@ -42,9 +42,11 @@ class BiasAddOpenCLKernel : public OpenCLKernel {
private: private:
cl::Kernel kernel_; cl::Kernel kernel_;
FLOAT_t *BiasAdd_; void *BiasAdd_;
int in_size_; int in_size_;
int out_size_; int out_size_;
size_t fp_size;
bool enable_fp16_{false};
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -187,7 +187,7 @@ int Conv2dTransposeOpenCLKernel::Run() { ...@@ -187,7 +187,7 @@ int Conv2dTransposeOpenCLKernel::Run() {
int arg_cnt = 0; int arg_cnt = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, in_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, padWeight_, lite::opencl::MemType::BUF); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, padWeight_, lite::opencl::MemType::BUF);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, bias_, lite::opencl::MemType::BUF); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, bias_);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, out_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, kernel_size); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, kernel_size);
ocl_runtime->SetKernelArg(kernel_, arg_cnt++, stride); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, stride);
......
...@@ -164,7 +164,7 @@ int MatMulOpenCLKernel::Run() { ...@@ -164,7 +164,7 @@ int MatMulOpenCLKernel::Run() {
int arg_count = 0; int arg_count = 0;
ocl_runtime->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF); ocl_runtime->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF);
ocl_runtime->SetKernelArg(kernel_, arg_count++, bias_, lite::opencl::MemType::BUF); ocl_runtime->SetKernelArg(kernel_, arg_count++, bias_);
ocl_runtime->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCI); ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCI);
ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCO); ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCO);
......
...@@ -36,15 +36,16 @@ namespace mindspore::kernel { ...@@ -36,15 +36,16 @@ namespace mindspore::kernel {
void PReluOpenCLKernel::InitBuffer() { void PReluOpenCLKernel::InitBuffer() {
int C = in_tensors_[1]->shape()[0]; int C = in_tensors_[1]->shape()[0];
int div_ci = UP_DIV(C, C4NUM); int div_ci = UP_DIV(C, C4NUM);
std::cout << div_ci << std::endl;
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
PReluWeight_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t))); size_t img_dtype = CL_FLOAT;
PReluWeight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true)); if (enable_fp16_) {
memset(PReluWeight_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t)); img_dtype = CL_HALF_FLOAT;
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_[1]->Data());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
PReluWeight_[i] = origin_weight[i];
} }
std::vector<size_t> img_size{size_t(div_ci), 1, img_dtype};
PReluWeight_ = allocator->Malloc(div_ci * C4NUM * fp_size, img_size);
PReluWeight_ = allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true);
memset(PReluWeight_, 0x00, div_ci * C4NUM * fp_size);
memcpy(PReluWeight_, in_tensors_[1]->Data(), C * fp_size);
allocator->UnmapBuffer(PReluWeight_); allocator->UnmapBuffer(PReluWeight_);
} }
...@@ -61,14 +62,14 @@ int PReluOpenCLKernel::Init() { ...@@ -61,14 +62,14 @@ int PReluOpenCLKernel::Init() {
<< C_Weight << " and your input channel size is " << C; << C_Weight << " and your input channel size is " << C;
return RET_ERROR; return RET_ERROR;
} }
if (C_Weight != 1) {
InitBuffer();
}
std::set<std::string> build_options; std::set<std::string> build_options;
std::string source = prelu_source; std::string source = prelu_source;
std::string program_name = "PRelu"; std::string program_name = "PRelu";
std::string kernel_name = "PRelu"; std::string kernel_name = "PRelu";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
enable_fp16_ = ocl_runtime->GetFp16Enable();
fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float);
InitBuffer();
ocl_runtime->LoadSource(program_name, source); ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
in_ori_format_ = in_tensors_[0]->GetFormat(); in_ori_format_ = in_tensors_[0]->GetFormat();
...@@ -92,11 +93,7 @@ int PReluOpenCLKernel::Run() { ...@@ -92,11 +93,7 @@ int PReluOpenCLKernel::Run() {
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
if (in_tensors_[1]->shape()[0] == 1) { ocl_runtime->SetKernelArg(kernel_, arg_idx++, PReluWeight_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<float *>(in_tensors_[1]->Data()));
} else {
ocl_runtime->SetKernelArg(kernel_, arg_idx++, PReluWeight_);
}
ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<int>(in_tensors_[1]->shape()[0])); ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<int>(in_tensors_[1]->shape()[0]));
std::vector<size_t> local = {1, 1}; std::vector<size_t> local = {1, 1};
......
...@@ -40,7 +40,9 @@ class PReluOpenCLKernel : public OpenCLKernel { ...@@ -40,7 +40,9 @@ class PReluOpenCLKernel : public OpenCLKernel {
private: private:
cl::Kernel kernel_; cl::Kernel kernel_;
FLOAT_t *PReluWeight_; void *PReluWeight_;
size_t fp_size;
bool enable_fp16_{false};
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -35,21 +35,22 @@ void LoadDataBiasAdd(void *dst, size_t dst_size, const std::string &file_path) { ...@@ -35,21 +35,22 @@ void LoadDataBiasAdd(void *dst, size_t dst_size, const std::string &file_path) {
if (file_path.empty()) { if (file_path.empty()) {
memset(dst, 0x00, dst_size); memset(dst, 0x00, dst_size);
} else { } else {
auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); auto src_data = mindspore::lite::ReadFile(file_path.c_str(), &dst_size);
memcpy(dst, src_data, dst_size); memcpy(dst, src_data, dst_size);
} }
} }
template <typename T>
void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
auto *output_data = reinterpret_cast<float *>(output_tensor->Data());
size_t output_size = output_tensor->ElementsNum(); size_t output_size = output_tensor->ElementsNum();
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); auto output_data = reinterpret_cast<T *>(output_tensor->Data());
auto expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
constexpr float atol = 0.0002; constexpr float atol = 0.0002;
for (int i = 0; i < output_tensor->ElementsNum(); ++i) { for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect_data[i]) > atol) { if (std::fabs(output_data[i] - expect_data[i]) > atol) {
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); printf("error at idx[%d] expect=%f output=%f\n\n\n", i, expect_data[i], output_data[i]);
return; return;
} }
} }
...@@ -58,8 +59,10 @@ void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &s ...@@ -58,8 +59,10 @@ void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &s
printf("compare success!\n\n\n"); printf("compare success!\n\n\n");
} }
void printf_tensor_BiasAdd(mindspore::lite::tensor::Tensor *in_data, int size) { template <typename T>
auto input_data = reinterpret_cast<float *>(in_data->Data()); void printf_tensor_BiasAdd(const std::string log, mindspore::lite::tensor::Tensor *in_data, int size) {
MS_LOG(INFO) << log;
auto input_data = reinterpret_cast<T *>(in_data->Data());
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
printf("%f ", input_data[i]); printf("%f ", input_data[i]);
} }
...@@ -67,15 +70,6 @@ void printf_tensor_BiasAdd(mindspore::lite::tensor::Tensor *in_data, int size) { ...@@ -67,15 +70,6 @@ void printf_tensor_BiasAdd(mindspore::lite::tensor::Tensor *in_data, int size) {
MS_LOG(INFO) << "Print tensor done"; MS_LOG(INFO) << "Print tensor done";
} }
void printf_float_BiasAdd(float *data, int num = 0) {
float *temp = data;
for (int i = 0; i < num; ++i) {
std::cout << *temp << " ";
temp++;
}
std::cout << std::endl;
}
TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) {
std::string in_file = "/data/local/tmp/in_data.bin"; std::string in_file = "/data/local/tmp/in_data.bin";
std::string weight_file = "/data/local/tmp/weight_data.bin"; std::string weight_file = "/data/local/tmp/weight_data.bin";
...@@ -83,29 +77,34 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { ...@@ -83,29 +77,34 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) {
MS_LOG(INFO) << "BiasAdd Begin test:"; MS_LOG(INFO) << "BiasAdd Begin test:";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init(); ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator(); auto data_type = kNumberTypeFloat16;
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
MS_LOG(INFO) << "BiasAdd init tensors.";
std::vector<int> input_shape = {1, 9}; std::vector<int> input_shape = {1, 9};
std::vector<int> output_shape = {1, 9}; std::vector<int> output_shape = {1, 9};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode; auto tensor_type = schema::NodeType_ValueNode;
auto *input_tensor = schema::Format type;
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); int weight_shape = 0;
if (input_shape.size() == 4) {
weight_shape = input_shape[3];
type = schema::Format_NHWC;
} else {
weight_shape = input_shape[1];
type = schema::Format_NC;
}
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, type, tensor_type);
if (input_tensor == nullptr) { if (input_tensor == nullptr) {
MS_LOG(ERROR) << "new input tensor error!"; MS_LOG(ERROR) << "new input tensor error!";
return; return;
} }
auto *output_tensor = auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, type, tensor_type);
new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NC, tensor_type);
if (output_tensor == nullptr) { if (output_tensor == nullptr) {
MS_LOG(ERROR) << "new output tensor error!"; MS_LOG(ERROR) << "new output tensor error!";
delete input_tensor; delete input_tensor;
return; return;
} }
auto *weight_tensor = new (std::nothrow) auto *weight_tensor = new (std::nothrow)
lite::tensor::Tensor(data_type, std::vector<int>{input_shape[1]}, schema::Format_NHWC, tensor_type); lite::tensor::Tensor(data_type, std::vector<int>{weight_shape}, schema::Format_NHWC, tensor_type);
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "new weight tensor error!"; MS_LOG(ERROR) << "new weight tensor error!";
delete output_tensor; delete output_tensor;
...@@ -114,14 +113,18 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { ...@@ -114,14 +113,18 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) {
} }
std::vector<lite::tensor::Tensor *> inputs{input_tensor, weight_tensor}; std::vector<lite::tensor::Tensor *> inputs{input_tensor, weight_tensor};
std::vector<lite::tensor::Tensor *> outputs{output_tensor}; std::vector<lite::tensor::Tensor *> outputs{output_tensor};
auto allocator = ocl_runtime->GetAllocator();
inputs[0]->MallocData(allocator); inputs[0]->MallocData(allocator);
inputs[1]->MallocData(allocator); inputs[1]->MallocData(allocator);
LoadDataBiasAdd(input_tensor->Data(), input_tensor->Size(), in_file); LoadDataBiasAdd(input_tensor->Data(), input_tensor->Size(), in_file);
MS_LOG(INFO) << "BiasAdd==================input data================";
printf_tensor_BiasAdd(inputs[0], input_tensor->ElementsNum());
LoadDataBiasAdd(weight_tensor->Data(), weight_tensor->Size(), weight_file); LoadDataBiasAdd(weight_tensor->Data(), weight_tensor->Size(), weight_file);
MS_LOG(INFO) << "BiasAdd==================weight data================"; if (ocl_runtime->GetFp16Enable()) {
printf_tensor_BiasAdd(inputs[1], weight_tensor->ElementsNum()); printf_tensor_BiasAdd<float16_t>("BiasAdd:FP16--input data", inputs[0], input_tensor->ElementsNum());
printf_tensor_BiasAdd<float16_t>("BiasAdd:FP16--weight data", inputs[1], weight_tensor->ElementsNum());
} else {
printf_tensor_BiasAdd<float>("BiasAdd:FP32--input data", inputs[0], input_tensor->ElementsNum());
printf_tensor_BiasAdd<float>("BiasAdd:FP32--weight data", inputs[1], weight_tensor->ElementsNum());
}
auto *param = new (std::nothrow) OpParameter(); auto *param = new (std::nothrow) OpParameter();
if (param == nullptr) { if (param == nullptr) {
...@@ -189,9 +192,13 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { ...@@ -189,9 +192,13 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) {
return; return;
} }
MS_LOG(INFO) << "BiasAdd==================output data================"; if (ocl_runtime->GetFp16Enable()) {
printf_tensor_BiasAdd(outputs[0], output_tensor->ElementsNum()); printf_tensor_BiasAdd<float16_t>("BiasAdd:FP16--output data", outputs[0], output_tensor->ElementsNum());
CompareOutBiasAdd(output_tensor, standard_answer_file); CompareOutBiasAdd<float16_t>(output_tensor, standard_answer_file);
} else {
printf_tensor_BiasAdd<float>("BiasAdd:FP32--output data", outputs[0], output_tensor->ElementsNum());
CompareOutBiasAdd<float>(output_tensor, standard_answer_file);
}
delete input_tensor; delete input_tensor;
delete weight_tensor; delete weight_tensor;
delete output_tensor; delete output_tensor;
......
...@@ -37,15 +37,16 @@ void LoadDataPRelu(void *dst, size_t dst_size, const std::string &file_path) { ...@@ -37,15 +37,16 @@ void LoadDataPRelu(void *dst, size_t dst_size, const std::string &file_path) {
if (file_path.empty()) { if (file_path.empty()) {
memset(dst, 0x00, dst_size); memset(dst, 0x00, dst_size);
} else { } else {
auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); auto src_data = mindspore::lite::ReadFile(file_path.c_str(), &dst_size);
memcpy(dst, src_data, dst_size); memcpy(dst, src_data, dst_size);
} }
} }
template <typename T>
void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
auto *output_data = reinterpret_cast<float *>(output_tensor->Data()); auto *output_data = reinterpret_cast<T *>(output_tensor->Data());
size_t output_size = output_tensor->Size(); size_t output_size = output_tensor->Size();
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); auto expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
constexpr float atol = 0.0002; constexpr float atol = 0.0002;
for (int i = 0; i < output_tensor->ElementsNum(); ++i) { for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect_data[i]) > atol) { if (std::fabs(output_data[i] - expect_data[i]) > atol) {
...@@ -60,6 +61,17 @@ void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &sta ...@@ -60,6 +61,17 @@ void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &sta
printf("compare success!\n\n\n"); printf("compare success!\n\n\n");
} }
template <typename T>
void printf_tensor_Prelu(const std::string &log, mindspore::lite::tensor::Tensor *in_data, int size) {
MS_LOG(INFO) << log;
auto input_data = reinterpret_cast<T *>(in_data->Data());
for (int i = 0; i < size; ++i) {
printf("%f ", input_data[i]);
}
printf("\n");
MS_LOG(INFO) << "Print tensor done";
}
TEST_F(TestPReluOpenCL, PReluFp32_dim4) { TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
std::string in_file = "/data/local/tmp/in_data.bin"; std::string in_file = "/data/local/tmp/in_data.bin";
std::string weight_file = "/data/local/tmp/weight_data.bin"; std::string weight_file = "/data/local/tmp/weight_data.bin";
...@@ -71,16 +83,14 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { ...@@ -71,16 +83,14 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
MS_LOG(INFO) << "Init tensors."; MS_LOG(INFO) << "Init tensors.";
std::vector<int> input_shape = {1, 4, 3, 9}; std::vector<int> input_shape = {1, 4, 3, 9};
auto data_type = kNumberTypeFloat16;
auto data_type = kNumberTypeFloat32; ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
auto tensor_type = schema::NodeType_ValueNode; auto tensor_type = schema::NodeType_ValueNode;
auto input_tensor = auto input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
if (input_tensor == nullptr) { if (input_tensor == nullptr) {
MS_LOG(ERROR) << "new input_tensor error!"; MS_LOG(ERROR) << "new input_tensor error!";
return; return;
} }
auto output_tensor = auto output_tensor =
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
if (output_tensor == nullptr) { if (output_tensor == nullptr) {
...@@ -88,9 +98,8 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { ...@@ -88,9 +98,8 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
delete input_tensor; delete input_tensor;
return; return;
} }
auto weight_tensor = new (std::nothrow)
auto weight_tensor = lite::tensor::Tensor(data_type, std::vector<int>{input_shape[3]}, schema::Format_NHWC, tensor_type);
new (std::nothrow) lite::tensor::Tensor(data_type, std::vector<int>{9}, schema::Format_NHWC, tensor_type);
if (weight_tensor == nullptr) { if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "new weight_tensor error"; MS_LOG(ERROR) << "new weight_tensor error";
delete input_tensor; delete input_tensor;
...@@ -99,20 +108,20 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { ...@@ -99,20 +108,20 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
} }
std::vector<lite::tensor::Tensor *> inputs{input_tensor, weight_tensor}; std::vector<lite::tensor::Tensor *> inputs{input_tensor, weight_tensor};
std::vector<lite::tensor::Tensor *> outputs{output_tensor}; std::vector<lite::tensor::Tensor *> outputs{output_tensor};
// freamework to do!!! allocate memory by hand
inputs[0]->MallocData(allocator); inputs[0]->MallocData(allocator);
inputs[1]->MallocData(allocator); inputs[1]->MallocData(allocator);
MS_LOG(INFO) << "initialize input data"; MS_LOG(INFO) << "initialize input data";
LoadDataPRelu(input_tensor->Data(), input_tensor->Size(), in_file); LoadDataPRelu(input_tensor->Data(), input_tensor->Size(), in_file);
LoadDataPRelu(weight_tensor->Data(), weight_tensor->Size(), weight_file); LoadDataPRelu(weight_tensor->Data(), weight_tensor->Size(), weight_file);
auto weight_data = reinterpret_cast<float *>(weight_tensor->Data()); if (ocl_runtime->GetFp16Enable()) {
PrintData("Weight data", weight_data, inputs[1]->ElementsNum()); printf_tensor_Prelu<float16_t>("PRELU:FP16--input data", input_tensor, inputs[0]->ElementsNum());
auto *input_data = reinterpret_cast<float *>(inputs[0]->Data()); printf_tensor_Prelu<float16_t>("PRELU:FP16--weight data", weight_tensor, weight_tensor->ElementsNum());
PrintData("PRelu input data", input_data, inputs[0]->ElementsNum()); } else {
std::cout << inputs[0]->ElementsNum() << std::endl; printf_tensor_Prelu<float>("PRELU:FP32--input data", input_tensor, inputs[0]->ElementsNum());
std::cout << "--------------------------------------------" << std::endl; printf_tensor_Prelu<float>("PRELU:FP32--weight data", weight_tensor, inputs[1]->ElementsNum());
}
auto param = new (std::nothrow) PReluParameter(); auto param = new (std::nothrow) PReluParameter();
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "new PreluParameter error"; MS_LOG(ERROR) << "new PreluParameter error";
...@@ -173,10 +182,13 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { ...@@ -173,10 +182,13 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
return; return;
} }
MS_LOG(INFO) << "PRelu==================output data================"; if (ocl_runtime->GetFp16Enable()) {
auto *output_data = reinterpret_cast<float *>(outputs[0]->Data()); printf_tensor_Prelu<float16_t>("PRelu:FP16--output_data", output_tensor, outputs[0]->ElementsNum());
PrintData("output_data", output_data, outputs[0]->ElementsC4Num()); CompareOutPRelu<float16_t>(output_tensor, standard_answer_file);
CompareOutPRelu(output_tensor, standard_answer_file); } else {
printf_tensor_Prelu<float>("PRelu:FP32--output_data", output_tensor, outputs[0]->ElementsNum());
CompareOutPRelu<float>(output_tensor, standard_answer_file);
}
delete input_tensor; delete input_tensor;
delete output_tensor; delete output_tensor;
delete weight_tensor; delete weight_tensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册