提交 a88010b8 编写于 作者: P principal87

增加覆盖用例,修复NNRT离线模型用例

Signed-off-by: Nprincipal87 <nierong3@huawei.com>
上级 b156370d
......@@ -41,7 +41,8 @@
"resource/ai/mindspore/ml_headpose_pb2tflite/ml_headpose_pb2tflite_0.input -> /data/test",
"resource/ai/mindspore/ml_headpose_pb2tflite/ml_headpose_pb2tflite_1.input -> /data/test",
"resource/ai/mindspore/ml_headpose_pb2tflite/ml_headpose_pb2tflite_2.input -> /data/test",
"resource/ai/mindspore/ml_headpose_pb2tflite/ml_headpose_pb2tflite0.output -> /data/test"
"resource/ai/mindspore/ml_headpose_pb2tflite/ml_headpose_pb2tflite0.output -> /data/test",
"resource/ai/mindspore/Kirin_model/tinynet.om.ms -> /data/test"
]
},
{
......
......@@ -14,6 +14,7 @@
*/
#include <thread>
#include <random>
#include <inttypes.h>
#include <securec.h>
#include "ohos_common.h"
......@@ -66,11 +67,28 @@ bool IsNNRTAvailable() {
if (desc == nullptr) {
return false;
}
auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc);
if (type != 1) {
return false;
}
OH_AI_DestroyAllNNRTDeviceDescs(&desc);
return true;
}
bool IsNPU() {
size_t num = 0;
auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
if (desc == nullptr) {
return false;
}
auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
const std::string npu_name_prefix = "NPU_";
if (strncmp(npu_name_prefix.c_str(), name, npu_name_prefix.size()) != 0) {
return false;
}
return true;
}
// add nnrt device info
void AddContextDeviceNNRT(OH_AI_ContextHandle context) {
size_t num = 0;
......@@ -178,6 +196,7 @@ void FillInputsData(OH_AI_TensorHandleArray inputs, string model_name, bool is_t
auto imageBuf_nhwc = new char[size1];
PackNCHWToNHWCFp32(imageBuf, imageBuf_nhwc, shape[0], shape[1] * shape[2], shape[3]);
memcpy_s(input_data, size1, imageBuf_nhwc, size1);
delete[] imageBuf_nhwc;
} else {
memcpy_s(input_data, size1, imageBuf, size1);
}
......@@ -223,6 +242,7 @@ void ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, string m
char *graphBuf = ReadFile(graphPath, ptr_size);
ASSERT_NE(graphBuf, nullptr);
ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context);
delete[] graphBuf;
} else {
printf("==========Build model==========\n");
ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context);
......@@ -1933,6 +1953,46 @@ HWTEST(MSLiteTest, OHOS_OfflineModel_0008, Function | MediumTest | Level1) {
OH_AI_ModelDestroy(&model);
}
// 正常场景:离线模型覆盖NPU
HWTEST(MSLiteTest, OHOS_OfflineModel_0009, Function | MediumTest | Level1) {
if (!IsNPU()) {
printf("NNRt is not NPU, skip this test");
return;
}
printf("==========Init Context==========\n");
OH_AI_ContextHandle context = OH_AI_ContextCreate();
ASSERT_NE(context, nullptr);
AddContextDeviceNNRT(context);
printf("==========Create model==========\n");
OH_AI_ModelHandle model = OH_AI_ModelCreate();
ASSERT_NE(model, nullptr);
printf("==========Build model==========\n");
OH_AI_Status ret = OH_AI_ModelBuildFromFile(model, "/data/test/tinynet.om.ms",
OH_AI_MODELTYPE_MINDIR, context);
printf("==========build model return code:%d\n", ret);
ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
printf("==========GetInputs==========\n");
OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
ASSERT_NE(inputs.handle_list, nullptr);
for (size_t i = 0; i < inputs.handle_num; ++i) {
OH_AI_TensorHandle tensor = inputs.handle_list[i];
float *input_data = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(tensor));
size_t element_num = OH_AI_TensorGetElementNum(tensor);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0.0f,1.0f);
for (int z=0;z<element_num;z++) {
input_data[z] = dis(gen);
}
}
printf("==========Model Predict==========\n");
OH_AI_TensorHandleArray outputs;
ret = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
OH_AI_ModelDestroy(&model);
}
// 正常场景:delegate异构,使用低级接口创建nnrt device info,选取第一个NNRT设备
HWTEST(MSLiteTest, OHOS_NNRT_0001, Function | MediumTest | Level1) {
printf("==========Init Context==========\n");
......
......@@ -329,10 +329,14 @@ bool compFp32WithTData(float *actualOutputData, const std::string& expectedDataF
}
printf("\n");
if (isquant) {
return allclose(actualOutputData, expectedOutputData, data_size, rtol, atol,
bool ret = allclose(actualOutputData, expectedOutputData, data_size, rtol, atol,
true);
free(expectedOutputData);
return ret;
}
return allclose(actualOutputData, expectedOutputData, data_size, rtol, atol);
bool ret = allclose(actualOutputData, expectedOutputData, data_size, rtol, atol);
free(expectedOutputData);
return ret;
}
bool compUint8WithTData(uint8_t *actualOutputData, const std::string& expectedDataFile,
......@@ -357,11 +361,15 @@ bool compUint8WithTData(uint8_t *actualOutputData, const std::string& expectedDa
}
printf("\n");
if (isquant) {
return allclose_int8(actualOutputData, expectedOutputData, data_size, rtol,
bool ret = allclose_int8(actualOutputData, expectedOutputData, data_size, rtol,
atol, true);
free(expectedOutputData);
return ret;
}
return allclose_int8(actualOutputData, expectedOutputData, data_size, rtol,
bool ret = allclose_int8(actualOutputData, expectedOutputData, data_size, rtol,
atol);
free(expectedOutputData);
return ret;
}
/*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册