提交 873e0342 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5676 activation support NC4HW4 in opencl

Merge pull request !5676 from liuzhongkai/NC4HW4
......@@ -4,20 +4,26 @@
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
__read_only image2d_t alpha, const int dim) {
int C = input_shape.w; // channel size
__read_only image2d_t alpha, const int data_type) {
int H = input_shape.y;
int C = input_shape.w; // channel size
C = UP_DIV(C, C4NUM);
if ((C == 0 || H == 0) && data_type != 1) {
return;
}
int Y = get_global_id(0); // height id
int X = get_global_id(1); // weight id
for (int num = 0; num < UP_DIV(C, C4NUM); ++num) {
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, C4NUM) + num, Y)); // NHWC4: H WC
FLT4 tmp = in_c4;
int index = 0;
if (dim == 2) {
index = X;
} else {
index = num;
}
tmp += READ_IMAGE(alpha, smp_zero, (int2)(index, 0));
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, C4NUM) + num, Y), tmp); // NHWC4: H WC
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));
FLT4 tmp = in_c4;
int index = 0;
if (data_type == 1) { // NC
index = X;
} else if (data_type == 2) { // NHWC4
index = X % C;
} else { // NC4HW4
index = Y / H;
}
tmp += READ_IMAGE(alpha, smp_zero, (int2)(index, 0));
WRITE_IMAGE(output, (int2)(X, Y), tmp);
}
......@@ -4,27 +4,38 @@
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
__read_only image2d_t alpha, const int dim) {
__read_only image2d_t alpha, const int data_type, const int bias_dim) {
int H = input_shape.y;
int C = input_shape.w; // channel size
C = UP_DIV(C, SLICES);
if (C == 0 || H == 0) {
return;
}
int Y = get_global_id(0); // height id
int X = get_global_id(1); // weight id
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
FLT4 tmp;
if (dim == 1) {
FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(0, 0));
tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * weight.x;
tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * weight.x;
tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * weight.x;
tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * weight.x;
} else {
FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(num, 0));
tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * weight.x;
tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * weight.y;
tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * weight.z;
tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * weight.w;
}
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));
FLT4 tmp;
int index = 0;
if (data_type == 1) { // NHWC4
index = X % C;
} else if (data_type == 2) { // NC4HW4
index = Y / H;
} else {
return;
}
if (bias_dim == 1) {
index = 0;
}
FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(index, 0));
FLT4 bias = weight;
if (bias_dim == 1) {
bias.y = weight.x;
bias.z = weight.x;
bias.w = weight.x;
}
tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * bias.x;
tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * bias.y;
tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * bias.z;
tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * bias.w;
WRITE_IMAGE(output, (int2)(X, Y), tmp);
}
......@@ -77,20 +77,10 @@ int ActivationOpenClKernel::Init() {
std::set<std::string> build_options;
ocl_runtime->LoadSource(Program_Kernel[type_][0], source);
ocl_runtime->BuildKernel(kernel_, Program_Kernel[type_][0], Program_Kernel[type_][1], build_options);
std::map<int, schema::Format> format{{4, schema::Format_NHWC4}, {2, schema::Format_NC4}};
if (format.count(out_size_) == 0) {
MS_LOG(ERROR) << "Not found output tensor format";
return RET_ERROR;
}
in_ori_format_ = in_tensors_[0]->GetFormat();
out_ori_format_ = out_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(format[in_size_]);
out_tensors_[0]->SetFormat(format[out_size_]);
if (in_size_ == 2) {
in_ori_format_ = schema::Format_NC4;
out_ori_format_ = schema::Format_NC4;
}
in_tensors_[0]->SetFormat(op_format_);
out_tensors_[0]->SetFormat(op_format_);
MS_LOG(DEBUG) << op_parameter_->name_ << " init Done!";
return RET_OK;
}
......@@ -121,11 +111,15 @@ cl_int4 ActivationOpenClKernel::GetImg2dShape() {
for (int i = 0; i < in_size_; ++i) {
img2d_shape.s[i + 4 - in_size_] = in_tensors_[0]->shape()[i];
}
if (in_size_ == 2) {
if (op_format_ == schema::Format_NC4) {
img2d_shape.s[1] = img2d_shape.s[2];
img2d_shape.s[2] = UP_DIV(img2d_shape.s[3], C4NUM);
img2d_shape.s[3] = C4NUM;
}
if (op_format_ == schema::Format_NC4HW4) {
img2d_shape.s[1] = UP_DIV(img2d_shape.s[3], C4NUM) * img2d_shape.s[1]; // UP(c / 4) * H
img2d_shape.s[3] = C4NUM;
}
return img2d_shape;
}
......
......@@ -54,6 +54,9 @@ void BiasAddOpenCLKernel::InitBuffer() {
int BiasAddOpenCLKernel::Init() {
in_size_ = in_tensors_[0]->shape().size();
out_size_ = out_tensors_[0]->shape().size();
for (int i = 0; i < in_size_; ++i) {
input_shape_.s[i + 4 - in_size_] = in_tensors_[0]->shape()[i];
}
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
enable_fp16_ = ocl_runtime->GetFp16Enable();
fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
......@@ -77,33 +80,26 @@ int BiasAddOpenCLKernel::Init() {
in_ori_format_ = in_tensors_[0]->GetFormat();
out_ori_format_ = out_tensors_[0]->GetFormat();
std::map<int, schema::Format> format{{4, schema::Format_NHWC4}, {2, schema::Format_NC4}};
if (format.count(out_size_) == 0) {
MS_LOG(ERROR) << "Not found output tensor format";
return RET_ERROR;
}
in_tensors_[0]->SetFormat(format[in_size_]);
out_tensors_[0]->SetFormat(format[out_size_]);
if (in_size_ == 2) {
in_ori_format_ = format[in_size_];
out_ori_format_ = format[out_size_];
}
in_tensors_[0]->SetFormat(op_format_);
out_tensors_[0]->SetFormat(op_format_);
MS_LOG(DEBUG) << program_name << " Init Done!";
return RET_OK;
}
int BiasAddOpenCLKernel::Run() {
cl_int4 input_shape = GetImg2dShape();
cl_int4 global_size = GetGlobalshape();
MS_LOG(DEBUG) << op_parameter_->name_ << " Running!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
int arg_idx = 0;
std::map<schema::Format, int> data_type{
{schema::Format_NC4, 1}, {schema::Format_NHWC4, 2}, {schema::Format_NC4HW4, 3}};
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, BiasAdd_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_size_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, data_type[op_format_]);
std::vector<size_t> local = {1, 1};
std::vector<size_t> global = {static_cast<size_t>(input_shape.s[1]), static_cast<size_t>(input_shape.s[2])};
std::vector<size_t> global = {static_cast<size_t>(global_size.s[1]), static_cast<size_t>(global_size.s[2])};
auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error.";
......@@ -112,29 +108,29 @@ int BiasAddOpenCLKernel::Run() {
return RET_OK;
}
cl_int4 BiasAddOpenCLKernel::GetImg2dShape() {
cl_int4 img2d_shape = {0, 0, 0, 0};
for (int i = 0; i < in_size_; ++i) {
img2d_shape.s[i + 4 - in_size_] = in_tensors_[0]->shape()[i];
cl_int4 BiasAddOpenCLKernel::GetGlobalshape() {
cl_int4 global_shape = input_shape_;
if (op_format_ == schema::Format_NC4) {
global_shape.s[1] = global_shape.s[2];
global_shape.s[2] = UP_DIV(global_shape.s[3], C4NUM);
}
if (in_size_ == 2) {
img2d_shape.s[1] = img2d_shape.s[2];
img2d_shape.s[2] = UP_DIV(img2d_shape.s[3], C4NUM);
img2d_shape.s[3] = C4NUM;
if (op_format_ == schema::Format_NC4HW4) {
global_shape.s[1] = UP_DIV(global_shape.s[3], C4NUM) * global_shape.s[1]; // c / 4 * H
}
return img2d_shape;
if (op_format_ == schema::Format_NHWC4) {
global_shape.s[2] = UP_DIV(global_shape.s[3], C4NUM) * global_shape.s[2];
}
return global_shape;
}
int BiasAddOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
cl_int4 img_shape = GetImg2dShape();
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
cl_int4 img_shape = GetGlobalshape();
size_t img_dtype = CL_FLOAT;
#endif
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
img_size->clear();
img_size->push_back(img_shape.s[2] * UP_DIV(img_shape.s[3], C4NUM));
img_size->push_back(img_shape.s[2]);
img_size->push_back(img_shape.s[1]);
img_size->push_back(img_dtype);
return RET_OK;
......
......@@ -38,7 +38,7 @@ class BiasAddOpenCLKernel : public OpenCLKernel {
int Run() override;
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
void InitBuffer();
cl_int4 GetImg2dShape();
cl_int4 GetGlobalshape();
private:
cl::Kernel kernel_;
......@@ -46,6 +46,7 @@ class BiasAddOpenCLKernel : public OpenCLKernel {
int in_size_;
int out_size_;
size_t fp_size;
cl_int4 input_shape_;
bool enable_fp16_{false};
};
......
......@@ -18,6 +18,7 @@
#include <set>
#include <vector>
#include <map>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
......@@ -62,6 +63,9 @@ int PReluOpenCLKernel::Init() {
<< C_Weight << " and your input channel size is " << C;
return RET_ERROR;
}
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
input_shape_.s[i] = in_tensors_[0]->shape()[i];
}
std::set<std::string> build_options;
std::string source = prelu_source;
std::string program_name = "PRelu";
......@@ -73,31 +77,26 @@ int PReluOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
in_tensors_[0]->SetFormat(op_format_);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_tensors_[0]->SetFormat(op_format_);
MS_LOG(DEBUG) << program_name << " init Done!";
return RET_OK;
}
int PReluOpenCLKernel::Run() {
MS_LOG(DEBUG) << op_parameter_->name_ << " Running!";
int N = in_tensors_[0]->shape()[0];
int H = in_tensors_[0]->shape()[1];
int W = in_tensors_[0]->shape()[2];
int C = in_tensors_[0]->shape()[3];
cl_int4 input_shape = {N, H, W, C};
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
std::map<schema::Format, int> data_type{{schema::Format_NHWC4, 1}, {schema::Format_NC4HW4, 2}};
int arg_idx = 0;
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, PReluWeight_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, data_type[op_format_]);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<int>(in_tensors_[1]->shape()[0]));
std::vector<size_t> local = {1, 1};
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
std::vector<size_t> global = {static_cast<size_t>(global_shape_.s[1]), static_cast<size_t>(global_shape_.s[2])};
auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error.";
......@@ -107,19 +106,22 @@ int PReluOpenCLKernel::Run() {
}
int PReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
int H = in_tensors_[0]->shape()[1];
int W = in_tensors_[0]->shape()[2];
int C = in_tensors_[0]->shape()[3];
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
global_shape_ = input_shape_;
if (op_format_ == schema::Format_NC4HW4) {
global_shape_.s[1] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[1];
} else if (op_format_ == schema::Format_NHWC4) {
global_shape_.s[2] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[2];
} else {
MS_LOG(ERROR) << "op_format_:" << op_format_ << " is do not support!";
return RET_ERROR;
}
img_size->clear();
img_size->push_back(W * UP_DIV(C, C4NUM));
img_size->push_back(H);
img_size->push_back(global_shape_.s[2]);
img_size->push_back(global_shape_.s[1]);
img_size->push_back(img_dtype);
return RET_OK;
}
......@@ -128,7 +130,7 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::tensor::Ten
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) {
if (inputs.size() == 0) {
if (inputs.empty()) {
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
return nullptr;
}
......
......@@ -41,6 +41,8 @@ class PReluOpenCLKernel : public OpenCLKernel {
private:
cl::Kernel kernel_;
void *PReluWeight_;
cl_int4 input_shape_;
cl_int4 global_shape_;
size_t fp_size;
bool enable_fp16_{false};
};
......
......@@ -51,7 +51,7 @@ void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard
auto *output_data = reinterpret_cast<T *>(output_tensor->Data());
size_t output_size = output_tensor->Size();
auto expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
constexpr float atol = 0.0002;
constexpr float atol = 0.001;
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]);
......@@ -88,10 +88,8 @@ TEST_F(TestActivationOpenCL, ReluFp_dim4) {
bool enable_fp16 = ocl_runtime->GetFp16Enable();
MS_LOG(INFO) << "Init tensors.";
std::vector<int> input_shape = {1, 9};
schema::Format format = schema::Format_NHWC;
if (input_shape.size() == 2) {
format = schema::Format_NC;
}
schema::Format format = schema::Format_NC;
schema::Format op_format = schema::Format_NC4;
auto tensor_type = schema::NodeType_ValueNode;
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
if (input_tensor == nullptr) {
......@@ -124,6 +122,7 @@ TEST_F(TestActivationOpenCL, ReluFp_dim4) {
param->type_ = ActivationType_RELU;
auto *kernel =
new (std::nothrow) kernel::ActivationOpenClKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
kernel->SetFormatType(op_format);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Kernel:Relu create fail.";
delete param;
......@@ -194,17 +193,15 @@ TEST_F(TestActivationOpenCL, Relu6Fp_dim4) {
std::string out_file = "/data/local/tmp/relu6.bin";
MS_LOG(INFO) << "Relu6 Begin test!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto data_type = kNumberTypeFloat32;
auto data_type = kNumberTypeFloat16;
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
bool enable_fp16 = ocl_runtime->GetFp16Enable();
ocl_runtime->Init();
MS_LOG(INFO) << "Init tensors.";
std::vector<int> input_shape = {1, 9};
schema::Format format = schema::Format_NHWC;
if (input_shape.size() == 2) {
format = schema::Format_NC;
}
schema::Format format = schema::Format_NC;
schema::Format op_format = schema::Format_NC4;
auto tensor_type = schema::NodeType_ValueNode;
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
if (input_tensor == nullptr) {
......@@ -246,6 +243,7 @@ TEST_F(TestActivationOpenCL, Relu6Fp_dim4) {
delete output_tensor;
return;
}
kernel->SetFormatType(op_format);
auto ret = kernel->Init();
if (ret != RET_OK) {
delete param;
......@@ -311,16 +309,14 @@ TEST_F(TestActivationOpenCL, SigmoidFp_dim4) {
MS_LOG(INFO) << "Sigmoid Begin test!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto data_type = kNumberTypeFloat16;
auto data_type = kNumberTypeFloat32;
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
bool enable_fp16 = ocl_runtime->GetFp16Enable();
MS_LOG(INFO) << "Init tensors.";
std::vector<int> input_shape = {1, 9};
schema::Format format = schema::Format_NHWC;
if (input_shape.size() == 2) {
format = schema::Format_NC;
}
schema::Format format = schema::Format_NC;
schema::Format op_format = schema::Format_NC4;
auto tensor_type = schema::NodeType_ValueNode;
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
if (input_tensor == nullptr) {
......@@ -362,6 +358,7 @@ TEST_F(TestActivationOpenCL, SigmoidFp_dim4) {
delete output_tensor;
return;
}
kernel->SetFormatType(op_format);
auto ret = kernel->Init();
if (ret != RET_OK) {
delete param;
......@@ -427,17 +424,15 @@ TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) {
MS_LOG(INFO) << "Leaky relu Begin test!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto data_type = kNumberTypeFloat32;
auto data_type = kNumberTypeFloat16; // need modify
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
bool enable_fp16 = ocl_runtime->GetFp16Enable();
MS_LOG(INFO) << "Init tensors.";
std::vector<int> input_shape = {1, 9};
std::vector<int> input_shape = {1, 9}; // need modify
auto tensor_type = schema::NodeType_ValueNode;
schema::Format format = schema::Format_NHWC;
if (input_shape.size() == 2) {
format = schema::Format_NC;
}
schema::Format format = schema::Format_NC; // need modify
schema::Format op_format = schema::Format_NC4; // need modify
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
if (input_tensor == nullptr) {
MS_LOG(ERROR) << "new input tensor error!";
......@@ -479,6 +474,7 @@ TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) {
delete output_tensor;
return;
}
kernel->SetFormatType(op_format);
auto ret = kernel->Init();
if (ret != RET_OK) {
delete param;
......
......@@ -77,20 +77,18 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) {
MS_LOG(INFO) << "BiasAdd Begin test:";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto data_type = kNumberTypeFloat16;
auto data_type = kNumberTypeFloat16; // need modify
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
std::vector<int> input_shape = {1, 9};
std::vector<int> output_shape = {1, 9};
std::vector<int> input_shape = {1, 9}; // need modify
std::vector<int> output_shape = {1, 9}; // need modify
auto tensor_type = schema::NodeType_ValueNode;
schema::Format type;
schema::Format type = schema::Format_NC; // need modify
schema::Format op_format = schema::Format_NC4; // need modify
int weight_shape = 0;
if (input_shape.size() == 4) {
weight_shape = input_shape[3];
type = schema::Format_NHWC;
} else {
weight_shape = input_shape[1];
type = schema::Format_NC;
}
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, type, tensor_type);
if (input_tensor == nullptr) {
......@@ -144,7 +142,7 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) {
delete param;
return;
}
biasadd_kernel->SetFormatType(op_format);
auto ret = biasadd_kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "biasadd kernel init error.";
......
......@@ -85,14 +85,15 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
std::vector<int> input_shape = {1, 4, 3, 9};
auto data_type = kNumberTypeFloat16;
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
schema::Format format = schema::Format_NHWC;
schema::Format op_format = schema::Format_NC4HW4;
auto tensor_type = schema::NodeType_ValueNode;
auto input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
auto input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
if (input_tensor == nullptr) {
MS_LOG(ERROR) << "new input_tensor error!";
return;
}
auto output_tensor =
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
auto output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(ERROR) << "new output_tensor error";
delete input_tensor;
......@@ -140,6 +141,7 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
delete param;
return;
}
prelu_kernel->SetFormatType(op_format);
auto ret = prelu_kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init prelu kernel error";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册