diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index 228b09bcff8a30869d7828a2a5a71fa0cb802292..b4635a48d9c259b8897785092c7502e7fa40f90c 100755 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -51,11 +51,18 @@ if (WITH_TESTING) lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz") - lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz") - lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "bert.tar.gz") - lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "ernie.tar.gz") + + set(LITE_URL_FOR_UNITTESTS "http://paddle-inference-dist.bj.bcebos.com/PaddleLite/models_and_data_for_unittests") + # models + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "resnet50.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "bert.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ernie.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "GoogLeNet.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "VGG19.tar.gz") + # data + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ILSVRC2012_small.tar.gz") endif() endif() diff --git a/lite/tests/api/CMakeLists.txt b/lite/tests/api/CMakeLists.txt index be9e7192b7d0d1009a4f48fb7033bdbdfd0c4f10..42fd8189dc2baae92edeaf7020b18ed6c07cc187 100644 --- a/lite/tests/api/CMakeLists.txt +++ b/lite/tests/api/CMakeLists.txt @@ -1,52 +1,45 @@ if(LITE_WITH_ARM) lite_cc_test(test_transformer_with_mask_fp32_arm SRCS test_transformer_with_mask_fp32_arm.cc - DEPS ${lite_model_test_DEPS} paddle_api_full - ARM_DEPS ${arm_kernels} - ARGS --model_dir=${LITE_MODEL_DIR}/transformer_with_mask_fp32 SERIAL) - if(WITH_TESTING) - add_dependencies(test_transformer_with_mask_fp32_arm extern_lite_download_transformer_with_mask_fp32_tar_gz) - endif() + DEPS ${lite_model_test_DEPS} paddle_api_full + ARM_DEPS ${arm_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/transformer_with_mask_fp32 SERIAL) + if(WITH_TESTING) + add_dependencies(test_transformer_with_mask_fp32_arm extern_lite_download_transformer_with_mask_fp32_tar_gz) + endif() endif() -if(LITE_WITH_XPU AND NOT LITE_WITH_XTCL) - lite_cc_test(test_resnet50_fp32_xpu SRCS test_resnet50_fp32_xpu.cc - DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils - ${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} - ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) - lite_cc_test(test_ernie_fp32_xpu SRCS test_ernie_fp32_xpu.cc - DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils - ${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} - ARGS --model_dir=${LITE_MODEL_DIR}/ernie) - lite_cc_test(test_bert_fp32_xpu SRCS test_bert_fp32_xpu.cc - DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils - ${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} - ARGS --model_dir=${LITE_MODEL_DIR}/bert) +function(xpu_x86_without_xtcl_test TARGET MODEL DATA) + lite_cc_test(${TARGET} SRCS ${TARGET}.cc + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/${MODEL} + --data_dir=${LITE_MODEL_DIR}/${DATA}) if(WITH_TESTING) - add_dependencies(test_resnet50_fp32_xpu extern_lite_download_resnet50_tar_gz) - add_dependencies(test_ernie_fp32_xpu extern_lite_download_ernie_tar_gz) - add_dependencies(test_bert_fp32_xpu extern_lite_download_bert_tar_gz) + add_dependencies(${TARGET} extern_lite_download_${MODEL}_tar_gz) + if(NOT ${DATA} STREQUAL "") + add_dependencies(${TARGET} extern_lite_download_${DATA}_tar_gz) + endif() endif() - # TODO(miaotianxiang): enable later - #lite_cc_test(test_fpr_fp32_xpu SRCS test_fpr_fp32_xpu.cc - #DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils - #${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} - #ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) - #lite_cc_test(test_mmdnn_fp32_xpu SRCS test_mmdnn_fp32_xpu.cc - #DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils - #${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} - #ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) +endfunction() + +if(LITE_WITH_XPU AND NOT LITE_WITH_XTCL) + xpu_x86_without_xtcl_test(test_resnet50_fp32_xpu resnet50 ILSVRC2012_small) + xpu_x86_without_xtcl_test(test_googlenet_fp32_xpu GoogLeNet ILSVRC2012_small) + xpu_x86_without_xtcl_test(test_vgg19_fp32_xpu VGG19 ILSVRC2012_small) + xpu_x86_without_xtcl_test(test_ernie_fp32_xpu ernie "") + xpu_x86_without_xtcl_test(test_bert_fp32_xpu bert "") endif() if(LITE_WITH_RKNPU) lite_cc_test(test_mobilenetv1_int8_rknpu SRCS test_mobilenetv1_int8_rknpu.cc - DEPS ${lite_model_test_DEPS} paddle_api_full - RKNPU_DEPS ${rknpu_kernels} ${rknpu_bridges} - ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL) + DEPS ${lite_model_test_DEPS} paddle_api_full + RKNPU_DEPS ${rknpu_kernels} ${rknpu_bridges} + ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL) endif() if(LITE_WITH_APU) lite_cc_test(test_mobilenetv1_int8_apu SRCS test_mobilenetv1_int8_apu.cc - DEPS ${lite_model_test_DEPS} paddle_api_full - APU_DEPS ${apu_kernels} ${apu_bridges} - ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL) + DEPS ${lite_model_test_DEPS} paddle_api_full + APU_DEPS ${apu_kernels} ${apu_bridges} + ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL) endif() diff --git a/lite/tests/api/ILSVRC2012_utility.h b/lite/tests/api/ILSVRC2012_utility.h new file mode 100644 index 0000000000000000000000000000000000000000..a8cf478cf35e72224018172d44ce9f42e8c06603 --- /dev/null +++ b/lite/tests/api/ILSVRC2012_utility.h @@ -0,0 +1,85 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/utils/cp_logging.h" +#include "lite/utils/io.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { + +template +std::vector> ReadRawData( + const std::string& raw_data_dir, + const std::vector& input_shape = {1, 3, 224, 224}, + int iteration = 100) { + std::vector> raw_data; + + int image_size = 1; + for (size_t i = 1; i < input_shape.size(); i++) { + image_size *= input_shape[i]; + } + int input_size = image_size * input_shape[0]; + + for (int i = 0; i < iteration; i++) { + std::vector one_iter_raw_data; + one_iter_raw_data.resize(input_size); + T* data = &(one_iter_raw_data.at(0)); + for (int j = 0; j < input_shape[0]; j++) { + std::string raw_data_file_dir = + raw_data_dir + std::string("/") + + std::to_string(i * input_shape[0] + j + 1); + std::ifstream fin(raw_data_file_dir, std::ios::in | std::ios::binary); + CHECK(fin.is_open()) << "failed to open file " << raw_data_file_dir; + fin.seekg(0, std::ios::end); + int file_size = fin.tellg(); + fin.seekg(0, std::ios::beg); + CHECK_EQ(file_size, image_size * sizeof(T) / sizeof(char)); + fin.read(reinterpret_cast(data), file_size); + fin.close(); + data += image_size; + } + raw_data.emplace_back(one_iter_raw_data); + } + + return raw_data; +} + +float CalOutAccuracy(const std::vector>& out_rets, + const std::string& labels_dir) { + std::vector labels; + std::vector out_top1; + int right_num = 0; + + auto label_lines = ReadLines(labels_dir); + for (size_t i = 0; i < out_rets.size(); i++) { + int label = std::stoi(Split(label_lines[i], " ")[1]); + + auto out = out_rets[i]; + auto largest = std::max_element(out.begin(), out.end()); + int out_top1 = std::distance(out.begin(), largest); + + right_num += (out_top1 == label); + } + + return static_cast(right_num) / static_cast(out_rets.size()); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/api/test_googlenet_fp32_xpu.cc b/lite/tests/api/test_googlenet_fp32_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..de5979d0b93956ef11ab7f0488527ec876ed580c --- /dev/null +++ b/lite/tests/api/test_googlenet_fp32_xpu.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/tests/api/ILSVRC2012_utility.h" +#include "lite/utils/cp_logging.h" + +DEFINE_string(data_dir, "", "data dir"); +DEFINE_int32(iteration, 100, "iteration times to run"); +DEFINE_int32(batch, 1, "batch of image"); +DEFINE_int32(channel, 3, "image channel"); + +namespace paddle { +namespace lite { + +TEST(GoogLeNet, test_googlenet_fp32_xpu) { + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + config.set_xpu_workspace_l3_size_per_thread(); + auto predictor = lite_api::CreatePaddlePredictor(config); + + std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data"); + std::vector input_shape{ + FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height}; + auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration); + + int input_size = 1; + for (auto i : input_shape) { + input_size *= i; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize( + std::vector(input_shape.begin(), input_shape.end())); + auto* data = input_tensor->mutable_data(); + for (int j = 0; j < input_size; j++) { + data[j] = 0.f; + } + predictor->Run(); + } + + std::vector> out_rets; + out_rets.resize(FLAGS_iteration); + double cost_time = 0; + for (size_t i = 0; i < raw_data.size(); ++i) { + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize( + std::vector(input_shape.begin(), input_shape.end())); + auto* data = input_tensor->mutable_data(); + memcpy(data, raw_data[i].data(), sizeof(float) * input_size); + + double start = GetCurrentUS(); + predictor->Run(); + cost_time += GetCurrentUS() - start; + + auto output_tensor = predictor->GetOutput(0); + auto output_shape = output_tensor->shape(); + auto output_data = output_tensor->data(); + ASSERT_EQ(output_shape.size(), 2UL); + ASSERT_EQ(output_shape[0], 1); + ASSERT_EQ(output_shape[1], 1000); + + int output_size = output_shape[0] * output_shape[1]; + out_rets[i].resize(output_size); + memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch + << ", iteration: " << FLAGS_iteration << ", spend " + << cost_time / FLAGS_iteration / 1000.0 << " ms in average."; + + std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt"); + float out_accuracy = CalOutAccuracy(out_rets, labels_dir); + ASSERT_GT(out_accuracy, 0.57f); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/api/test_resnet50_fp32_xpu.cc b/lite/tests/api/test_resnet50_fp32_xpu.cc index 40414e270a7679411e813fc468d0a86ff5680766..795a8fe5c8965d3f9f6116af47a27be763ecf549 100644 --- a/lite/tests/api/test_resnet50_fp32_xpu.cc +++ b/lite/tests/api/test_resnet50_fp32_xpu.cc @@ -21,8 +21,14 @@ #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/api/test_helper.h" +#include "lite/tests/api/ILSVRC2012_utility.h" #include "lite/utils/cp_logging.h" +DEFINE_string(data_dir, "", "data dir"); +DEFINE_int32(iteration, 100, "iteration times to run"); +DEFINE_int32(batch, 1, "batch of image"); +DEFINE_int32(channel, 3, "image channel"); + namespace paddle { namespace lite { @@ -35,52 +41,62 @@ TEST(Resnet50, test_resnet50_fp32_xpu) { config.set_xpu_workspace_l3_size_per_thread(); auto predictor = lite_api::CreatePaddlePredictor(config); - auto input_tensor = predictor->GetInput(0); - std::vector input_shape{1, 3, 224, 224}; - input_tensor->Resize(input_shape); - auto* data = input_tensor->mutable_data(); - int input_num = 1; - for (size_t i = 0; i < input_shape.size(); ++i) { - input_num *= input_shape[i]; - } - for (int i = 0; i < input_num; i++) { - data[i] = 1; + std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data"); + std::vector input_shape{ + FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height}; + auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration); + + int input_size = 1; + for (auto i : input_shape) { + input_size *= i; } for (int i = 0; i < FLAGS_warmup; ++i) { + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize( + std::vector(input_shape.begin(), input_shape.end())); + auto* data = input_tensor->mutable_data(); + for (int j = 0; j < input_size; j++) { + data[j] = 0.f; + } predictor->Run(); } - auto start = GetCurrentUS(); - for (int i = 0; i < FLAGS_repeats; ++i) { + std::vector> out_rets; + out_rets.resize(FLAGS_iteration); + double cost_time = 0; + for (size_t i = 0; i < raw_data.size(); ++i) { + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize( + std::vector(input_shape.begin(), input_shape.end())); + auto* data = input_tensor->mutable_data(); + memcpy(data, raw_data[i].data(), sizeof(float) * input_size); + + double start = GetCurrentUS(); predictor->Run(); + cost_time += GetCurrentUS() - start; + + auto output_tensor = predictor->GetOutput(0); + auto output_shape = output_tensor->shape(); + auto output_data = output_tensor->data(); + ASSERT_EQ(output_shape.size(), 2UL); + ASSERT_EQ(output_shape[0], 1); + ASSERT_EQ(output_shape[1], 1000); + + int output_size = output_shape[0] * output_shape[1]; + out_rets[i].resize(output_size); + memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size); } LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads - << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats - << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 - << " ms in average."; - - std::vector> results; - results.emplace_back(std::vector( - {0.000268651, 0.000174053, 0.000213181, 0.000396771, 0.000591516, - 0.00018169, 0.000289721, 0.000855934, 0.000732185, 9.2055e-05, - 0.000220664, 0.00235289, 0.00571265, 0.00357688, 0.00129667, - 0.000465392, 0.000143775, 0.000211628, 0.000617144, 0.000265033})); - auto out = predictor->GetOutput(0); - ASSERT_EQ(out->shape().size(), 2); - ASSERT_EQ(out->shape()[0], 1); - ASSERT_EQ(out->shape()[1], 1000); + << ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch + << ", iteration: " << FLAGS_iteration << ", spend " + << cost_time / FLAGS_iteration / 1000.0 << " ms in average."; - int step = 50; - for (size_t i = 0; i < results.size(); ++i) { - for (size_t j = 0; j < results[i].size(); ++j) { - EXPECT_NEAR(out->data()[j * step + (out->shape()[1] * i)], - results[i][j], - 1e-5); - } - } + std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt"); + float out_accuracy = CalOutAccuracy(out_rets, labels_dir); + ASSERT_GT(out_accuracy, 0.6f); } } // namespace lite diff --git a/lite/tests/api/test_vgg19_fp32_xpu.cc b/lite/tests/api/test_vgg19_fp32_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..71c086dda9f561f9932123de5c20f48979ec9dc0 --- /dev/null +++ b/lite/tests/api/test_vgg19_fp32_xpu.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/tests/api/ILSVRC2012_utility.h" +#include "lite/utils/cp_logging.h" + +DEFINE_string(data_dir, "", "data dir"); +DEFINE_int32(iteration, 100, "iteration times to run"); +DEFINE_int32(batch, 1, "batch of image"); +DEFINE_int32(channel, 3, "image channel"); + +namespace paddle { +namespace lite { + +TEST(VGG19, test_vgg19_fp32_xpu) { + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + config.set_xpu_workspace_l3_size_per_thread(); + auto predictor = lite_api::CreatePaddlePredictor(config); + + std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data"); + std::vector input_shape{ + FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height}; + auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration); + + int input_size = 1; + for (auto i : input_shape) { + input_size *= i; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize( + std::vector(input_shape.begin(), input_shape.end())); + auto* data = input_tensor->mutable_data(); + for (int j = 0; j < input_size; j++) { + data[j] = 0.f; + } + predictor->Run(); + } + + std::vector> out_rets; + out_rets.resize(FLAGS_iteration); + double cost_time = 0; + for (size_t i = 0; i < raw_data.size(); ++i) { + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize( + std::vector(input_shape.begin(), input_shape.end())); + auto* data = input_tensor->mutable_data(); + memcpy(data, raw_data[i].data(), sizeof(float) * input_size); + + double start = GetCurrentUS(); + predictor->Run(); + cost_time += GetCurrentUS() - start; + + auto output_tensor = predictor->GetOutput(0); + auto output_shape = output_tensor->shape(); + auto output_data = output_tensor->data(); + ASSERT_EQ(output_shape.size(), 2UL); + ASSERT_EQ(output_shape[0], 1); + ASSERT_EQ(output_shape[1], 1000); + + int output_size = output_shape[0] * output_shape[1]; + out_rets[i].resize(output_size); + memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch + << ", iteration: " << FLAGS_iteration << ", spend " + << cost_time / FLAGS_iteration / 1000.0 << " ms in average."; + + std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt"); + float out_accuracy = CalOutAccuracy(out_rets, labels_dir); + ASSERT_GT(out_accuracy, 0.56f); +} + +} // namespace lite +} // namespace paddle