未验证 提交 aa228ed2 编写于 作者: Z zhupengyang 提交者: GitHub

[xpu] update resnet50 ut and add googlenet, vgg19 uts (#4277)

上级 50632848
...@@ -51,11 +51,18 @@ if (WITH_TESTING) ...@@ -51,11 +51,18 @@ if (WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "bert.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "ernie.tar.gz") set(LITE_URL_FOR_UNITTESTS "http://paddle-inference-dist.bj.bcebos.com/PaddleLite/models_and_data_for_unittests")
# models
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "bert.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ernie.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "GoogLeNet.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "VGG19.tar.gz")
# data
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ILSVRC2012_small.tar.gz")
endif() endif()
endif() endif()
......
if(LITE_WITH_ARM) if(LITE_WITH_ARM)
lite_cc_test(test_transformer_with_mask_fp32_arm SRCS test_transformer_with_mask_fp32_arm.cc lite_cc_test(test_transformer_with_mask_fp32_arm SRCS test_transformer_with_mask_fp32_arm.cc
DEPS ${lite_model_test_DEPS} paddle_api_full DEPS ${lite_model_test_DEPS} paddle_api_full
ARM_DEPS ${arm_kernels} ARM_DEPS ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/transformer_with_mask_fp32 SERIAL) ARGS --model_dir=${LITE_MODEL_DIR}/transformer_with_mask_fp32 SERIAL)
if(WITH_TESTING) if(WITH_TESTING)
add_dependencies(test_transformer_with_mask_fp32_arm extern_lite_download_transformer_with_mask_fp32_tar_gz) add_dependencies(test_transformer_with_mask_fp32_arm extern_lite_download_transformer_with_mask_fp32_tar_gz)
endif() endif()
endif() endif()
if(LITE_WITH_XPU AND NOT LITE_WITH_XTCL) function(xpu_x86_without_xtcl_test TARGET MODEL DATA)
lite_cc_test(test_resnet50_fp32_xpu SRCS test_resnet50_fp32_xpu.cc lite_cc_test(${TARGET} SRCS ${TARGET}.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} ${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) ARGS --model_dir=${LITE_MODEL_DIR}/${MODEL}
lite_cc_test(test_ernie_fp32_xpu SRCS test_ernie_fp32_xpu.cc --data_dir=${LITE_MODEL_DIR}/${DATA})
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/ernie)
lite_cc_test(test_bert_fp32_xpu SRCS test_bert_fp32_xpu.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/bert)
if(WITH_TESTING) if(WITH_TESTING)
add_dependencies(test_resnet50_fp32_xpu extern_lite_download_resnet50_tar_gz) add_dependencies(${TARGET} extern_lite_download_${MODEL}_tar_gz)
add_dependencies(test_ernie_fp32_xpu extern_lite_download_ernie_tar_gz) if(NOT ${DATA} STREQUAL "")
add_dependencies(test_bert_fp32_xpu extern_lite_download_bert_tar_gz) add_dependencies(${TARGET} extern_lite_download_${DATA}_tar_gz)
endif()
endif() endif()
# TODO(miaotianxiang): enable later endfunction()
#lite_cc_test(test_fpr_fp32_xpu SRCS test_fpr_fp32_xpu.cc
#DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils if(LITE_WITH_XPU AND NOT LITE_WITH_XTCL)
#${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} xpu_x86_without_xtcl_test(test_resnet50_fp32_xpu resnet50 ILSVRC2012_small)
#ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) xpu_x86_without_xtcl_test(test_googlenet_fp32_xpu GoogLeNet ILSVRC2012_small)
#lite_cc_test(test_mmdnn_fp32_xpu SRCS test_mmdnn_fp32_xpu.cc xpu_x86_without_xtcl_test(test_vgg19_fp32_xpu VGG19 ILSVRC2012_small)
#DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils xpu_x86_without_xtcl_test(test_ernie_fp32_xpu ernie "")
#${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} xpu_x86_without_xtcl_test(test_bert_fp32_xpu bert "")
#ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
endif() endif()
if(LITE_WITH_RKNPU) if(LITE_WITH_RKNPU)
lite_cc_test(test_mobilenetv1_int8_rknpu SRCS test_mobilenetv1_int8_rknpu.cc lite_cc_test(test_mobilenetv1_int8_rknpu SRCS test_mobilenetv1_int8_rknpu.cc
DEPS ${lite_model_test_DEPS} paddle_api_full DEPS ${lite_model_test_DEPS} paddle_api_full
RKNPU_DEPS ${rknpu_kernels} ${rknpu_bridges} RKNPU_DEPS ${rknpu_kernels} ${rknpu_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL) ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL)
endif() endif()
if(LITE_WITH_APU) if(LITE_WITH_APU)
lite_cc_test(test_mobilenetv1_int8_apu SRCS test_mobilenetv1_int8_apu.cc lite_cc_test(test_mobilenetv1_int8_apu SRCS test_mobilenetv1_int8_apu.cc
DEPS ${lite_model_test_DEPS} paddle_api_full DEPS ${lite_model_test_DEPS} paddle_api_full
APU_DEPS ${apu_kernels} ${apu_bridges} APU_DEPS ${apu_kernels} ${apu_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL) ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL)
endif() endif()
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
#include "lite/utils/cp_logging.h"
#include "lite/utils/io.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
template <class T = float>
std::vector<std::vector<T>> ReadRawData(
const std::string& raw_data_dir,
const std::vector<int>& input_shape = {1, 3, 224, 224},
int iteration = 100) {
std::vector<std::vector<T>> raw_data;
int image_size = 1;
for (size_t i = 1; i < input_shape.size(); i++) {
image_size *= input_shape[i];
}
int input_size = image_size * input_shape[0];
for (int i = 0; i < iteration; i++) {
std::vector<T> one_iter_raw_data;
one_iter_raw_data.resize(input_size);
T* data = &(one_iter_raw_data.at(0));
for (int j = 0; j < input_shape[0]; j++) {
std::string raw_data_file_dir =
raw_data_dir + std::string("/") +
std::to_string(i * input_shape[0] + j + 1);
std::ifstream fin(raw_data_file_dir, std::ios::in | std::ios::binary);
CHECK(fin.is_open()) << "failed to open file " << raw_data_file_dir;
fin.seekg(0, std::ios::end);
int file_size = fin.tellg();
fin.seekg(0, std::ios::beg);
CHECK_EQ(file_size, image_size * sizeof(T) / sizeof(char));
fin.read(reinterpret_cast<char*>(data), file_size);
fin.close();
data += image_size;
}
raw_data.emplace_back(one_iter_raw_data);
}
return raw_data;
}
float CalOutAccuracy(const std::vector<std::vector<float>>& out_rets,
const std::string& labels_dir) {
std::vector<int> labels;
std::vector<int> out_top1;
int right_num = 0;
auto label_lines = ReadLines(labels_dir);
for (size_t i = 0; i < out_rets.size(); i++) {
int label = std::stoi(Split(label_lines[i], " ")[1]);
auto out = out_rets[i];
auto largest = std::max_element(out.begin(), out.end());
int out_top1 = std::distance(out.begin(), largest);
right_num += (out_top1 == label);
}
return static_cast<float>(right_num) / static_cast<float>(out_rets.size());
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle {
namespace lite {
TEST(GoogLeNet, test_googlenet_fp32_xpu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config);
std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int> input_shape{
FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_size = 1;
for (auto i : input_shape) {
input_size *= i;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run();
}
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
ASSERT_GT(out_accuracy, 0.57f);
}
} // namespace lite
} // namespace paddle
...@@ -21,8 +21,14 @@ ...@@ -21,8 +21,14 @@
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h" #include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h" #include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -35,52 +41,62 @@ TEST(Resnet50, test_resnet50_fp32_xpu) { ...@@ -35,52 +41,62 @@ TEST(Resnet50, test_resnet50_fp32_xpu) {
config.set_xpu_workspace_l3_size_per_thread(); config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
auto input_tensor = predictor->GetInput(0); std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int64_t> input_shape{1, 3, 224, 224}; std::vector<int> input_shape{
input_tensor->Resize(input_shape); FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto* data = input_tensor->mutable_data<float>(); auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_num = 1;
for (size_t i = 0; i < input_shape.size(); ++i) { int input_size = 1;
input_num *= input_shape[i]; for (auto i : input_shape) {
} input_size *= i;
for (int i = 0; i < input_num; i++) {
data[i] = 1;
} }
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run(); predictor->Run();
} }
auto start = GetCurrentUS(); std::vector<std::vector<float>> out_rets;
for (int i = 0; i < FLAGS_repeats; ++i) { out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run(); predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
} }
LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats << ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << ", iteration: " << FLAGS_iteration << ", spend "
<< " ms in average."; << cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::vector<std::vector<float>> results;
results.emplace_back(std::vector<float>(
{0.000268651, 0.000174053, 0.000213181, 0.000396771, 0.000591516,
0.00018169, 0.000289721, 0.000855934, 0.000732185, 9.2055e-05,
0.000220664, 0.00235289, 0.00571265, 0.00357688, 0.00129667,
0.000465392, 0.000143775, 0.000211628, 0.000617144, 0.000265033}));
auto out = predictor->GetOutput(0);
ASSERT_EQ(out->shape().size(), 2);
ASSERT_EQ(out->shape()[0], 1);
ASSERT_EQ(out->shape()[1], 1000);
int step = 50; std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
for (size_t i = 0; i < results.size(); ++i) { float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
for (size_t j = 0; j < results[i].size(); ++j) { ASSERT_GT(out_accuracy, 0.6f);
EXPECT_NEAR(out->data<float>()[j * step + (out->shape()[1] * i)],
results[i][j],
1e-5);
}
}
} }
} // namespace lite } // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle {
namespace lite {
TEST(VGG19, test_vgg19_fp32_xpu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config);
std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int> input_shape{
FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_size = 1;
for (auto i : input_shape) {
input_size *= i;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run();
}
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
ASSERT_GT(out_accuracy, 0.56f);
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册