未验证 提交 76410577 编写于 作者: H hong19860320 提交者: GitHub

[CI] Enable CI for Huawei kirin NPU, Rockchip NPU and MediaTek APU (#4408)

上级 d5e7e73e
......@@ -38,34 +38,31 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND NOT LITE_ON_TINY_PUBLISH)
endif()
if (WITH_TESTING)
set(LITE_URL_FOR_UNITTESTS "http://paddle-inference-dist.bj.bcebos.com/PaddleLite/models_and_data_for_unittests")
# models
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz")
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1_int16.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz")
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1_int16.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "MobileNetV1_quant.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "transformer_with_mask_fp32.tar.gz")
endif()
if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "mobilenet_v1_int8_for_mediatek_apu.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "mobilenet_v1_int8_for_rockchip_npu.tar.gz")
else()
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz")
set(LITE_URL_FOR_UNITTESTS "http://paddle-inference-dist.bj.bcebos.com/PaddleLite/models_and_data_for_unittests")
# models
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "bert.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ernie.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "GoogLeNet.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "VGG19.tar.gz")
endif()
# data
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ILSVRC2012_small.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "bert_data.tar.gz")
endif()
endif()
# ----------------------------- PUBLISH -----------------------------
......
......@@ -6,5 +6,5 @@ endif()
lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest)
if((NOT LITE_WITH_OPENCL) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${rknpu_kernels} ${mlu_kernels} ${bm_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${rknpu_kernels} ${mlu_kernels} ${bm_kernels} ${npu_kernels} ${apu_kernels} ${huawei_ascend_npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
if(LITE_WITH_ARM)
lite_cc_test(test_transformer_with_mask_fp32_arm SRCS test_transformer_with_mask_fp32_arm.cc
DEPS ${lite_model_test_DEPS} paddle_api_full
ARM_DEPS ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/transformer_with_mask_fp32 SERIAL)
if(WITH_TESTING)
add_dependencies(test_transformer_with_mask_fp32_arm extern_lite_download_transformer_with_mask_fp32_tar_gz)
function(lite_cc_test_with_model_and_data TARGET)
if(NOT WITH_TESTING)
return()
endif()
endif()
function(xpu_x86_without_xtcl_test TARGET MODEL DATA)
if(${DATA} STREQUAL "")
lite_cc_test(${TARGET} SRCS ${TARGET}.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/${MODEL})
else()
set(options "")
set(oneValueArgs MODEL DATA CONFIG ARGS)
set(multiValueArgs "")
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(ARGS "")
if(DEFINED args_MODEL)
set(ARGS "${ARGS} --model_dir=${LITE_MODEL_DIR}/${args_MODEL}")
endif()
if(DEFINED args_DATA)
set(ARGS "${ARGS} --data_dir=${LITE_MODEL_DIR}/${args_DATA}")
endif()
if(DEFINED args_CONFIG)
set(ARGS "${ARGS} --config_dir=${LITE_MODEL_DIR}/${args_CONFIG}")
endif()
if(DEFINED args_ARGS)
set(ARGS "${ARGS} ${args_ARGS}")
endif()
lite_cc_test(${TARGET} SRCS ${TARGET}.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/${MODEL} --data_dir=${LITE_MODEL_DIR}/${DATA})
DEPS ${lite_model_test_DEPS} paddle_api_full
ARM_DEPS ${arm_kernels}
X86_DEPS ${x86_kernels}
NPU_DEPS ${npu_kernels} ${npu_bridges}
HUAWEI_ASCEND_NPU_DEPS ${huawei_ascend_npu_kernels} ${huawei_ascend_npu_bridges}
XPU_DEPS ${xpu_kernels} ${xpu_bridges}
APU_DEPS ${apu_kernels} ${apu_bridges}
RKNPU_DEPS ${rknpu_kernels} ${rknpu_bridges}
BM_DEPS ${bm_kernels} ${bm_bridges}
MLU_DEPS ${mlu_kernels} ${mlu_bridges}
ARGS ${ARGS} SERIAL)
if(DEFINED args_MODEL)
add_dependencies(${TARGET} extern_lite_download_${args_MODEL}_tar_gz)
endif()
if(WITH_TESTING)
add_dependencies(${TARGET} extern_lite_download_${MODEL}_tar_gz)
if(NOT ${DATA} STREQUAL "")
add_dependencies(${TARGET} extern_lite_download_${DATA}_tar_gz)
if(DEFINED args_DATA)
add_dependencies(${TARGET} extern_lite_download_${args_DATA}_tar_gz)
endif()
if(DEFINED args_CONFIG)
add_dependencies(${TARGET} extern_lite_download_${args_CONFIG}_tar_gz)
endif()
endfunction()
if(LITE_WITH_ARM)
lite_cc_test_with_model_and_data(test_transformer_with_mask_fp32_arm MODEL transformer_with_mask_fp32 ARGS)
endif()
if(LITE_WITH_NPU)
lite_cc_test_with_model_and_data(test_mobilenetv1_fp32_huawei_kirin_npu MODEL mobilenet_v1 DATA ILSVRC2012_small)
lite_cc_test_with_model_and_data(test_mobilenetv2_fp32_huawei_kirin_npu MODEL mobilenet_v2_relu DATA ILSVRC2012_small)
lite_cc_test_with_model_and_data(test_resnet50_fp32_huawei_kirin_npu MODEL resnet50 DATA ILSVRC2012_small)
endif()
if(LITE_WITH_XPU AND NOT LITE_WITH_XTCL)
xpu_x86_without_xtcl_test(test_resnet50_fp32_xpu resnet50 ILSVRC2012_small)
xpu_x86_without_xtcl_test(test_googlenet_fp32_xpu GoogLeNet ILSVRC2012_small)
xpu_x86_without_xtcl_test(test_vgg19_fp32_xpu VGG19 ILSVRC2012_small)
xpu_x86_without_xtcl_test(test_ernie_fp32_xpu ernie bert_data)
xpu_x86_without_xtcl_test(test_bert_fp32_xpu bert bert_data)
lite_cc_test_with_model_and_data(test_resnet50_fp32_xpu MODEL resnet50 DATA ILSVRC2012_small)
lite_cc_test_with_model_and_data(test_googlenet_fp32_xpu MODEL GoogLeNet DATA ILSVRC2012_small)
lite_cc_test_with_model_and_data(test_vgg19_fp32_xpu MODEL VGG19 DATA ILSVRC2012_small)
lite_cc_test_with_model_and_data(test_ernie_fp32_xpu MODEL ernie DATA bert_data)
lite_cc_test_with_model_and_data(test_bert_fp32_xpu MODEL bert DATA bert_data)
endif()
if(LITE_WITH_RKNPU)
lite_cc_test(test_mobilenetv1_int8_rknpu SRCS test_mobilenetv1_int8_rknpu.cc
DEPS ${lite_model_test_DEPS} paddle_api_full
RKNPU_DEPS ${rknpu_kernels} ${rknpu_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL)
lite_cc_test_with_model_and_data(test_mobilenetv1_int8_rockchip_npu MODEL mobilenet_v1_int8_for_rockchip_npu DATA ILSVRC2012_small)
endif()
if(LITE_WITH_APU)
lite_cc_test(test_mobilenetv1_int8_apu SRCS test_mobilenetv1_int8_apu.cc
DEPS ${lite_model_test_DEPS} paddle_api_full
APU_DEPS ${apu_kernels} ${apu_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/MobilenetV1_full_quant SERIAL)
lite_cc_test_with_model_and_data(test_mobilenetv1_int8_mediatek_apu MODEL mobilenet_v1_int8_for_mediatek_apu DATA ILSVRC2012_small)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle {
namespace lite {
TEST(MobileNetV1, test_mobilenetv1_fp32_huawei_kirin_npu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}});
auto predictor = lite_api::CreatePaddlePredictor(config);
std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int> input_shape{
FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_size = 1;
for (auto i : input_shape) {
input_size *= i;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run();
}
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
ASSERT_GE(out_accuracy, 0.57f);
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
using namespace paddle::lite_api; // NOLINT
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
inline int64_t ShapeProduction(std::vector<int64_t> shape) {
int64_t s = 1;
for (int64_t dim : shape) {
s *= dim;
}
return s;
}
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "[ERROR] usage: ./" << argv[0]
<< " model_dir [thread_num] [warmup_times] [repeat_times] "
"[input_data_path] [output_data_path]"
<< std::endl;
return -1;
}
std::string model_dir = argv[1];
int thread_num = 1;
if (argc > 2) {
thread_num = atoi(argv[2]);
}
int warmup_times = 5;
if (argc > 3) {
warmup_times = atoi(argv[3]);
}
int repeat_times = 10;
if (argc > 4) {
repeat_times = atoi(argv[4]);
}
std::string input_data_path;
if (argc > 5) {
input_data_path = argv[5];
}
std::string output_data_path;
if (argc > 6) {
output_data_path = argv[6];
}
paddle::lite_api::CxxConfig config;
config.set_model_dir(model_dir);
config.set_threads(thread_num);
config.set_power_mode(paddle::lite_api::LITE_POWER_HIGH);
config.set_valid_places(
{paddle::lite_api::Place{
TARGET(kARM), PRECISION(kFloat), DATALAYOUT(kNCHW)},
paddle::lite_api::Place{
TARGET(kARM), PRECISION(kInt8), DATALAYOUT(kNCHW)},
paddle::lite_api::Place{
TARGET(kAPU), PRECISION(kInt8), DATALAYOUT(kNCHW)}});
auto predictor = paddle::lite_api::CreatePaddlePredictor(config);
std::unique_ptr<paddle::lite_api::Tensor> input_tensor(
std::move(predictor->GetInput(0)));
input_tensor->Resize({1, 3, 224, 224});
auto input_data = input_tensor->mutable_data<float>();
auto input_size = ShapeProduction(input_tensor->shape());
// test loop
int total_imgs = 500;
float test_num = 0;
float top1_num = 0;
float top5_num = 0;
int output_len = 1000;
std::vector<int> index(1000);
bool debug = true; // false;
int show_step = 500;
for (int i = 0; i < total_imgs; i++) {
// set input
std::string filename = input_data_path + "/" + std::to_string(i);
std::ifstream fs(filename, std::ifstream::binary);
if (!fs.is_open()) {
std::cout << "open input file fail.";
}
auto input_data_tmp = input_data;
for (int i = 0; i < input_size; ++i) {
fs.read(reinterpret_cast<char*>(input_data_tmp), sizeof(*input_data_tmp));
input_data_tmp++;
}
int label = 0;
fs.read(reinterpret_cast<char*>(&label), sizeof(label));
fs.close();
if (debug && i % show_step == 0) {
std::cout << "input data:" << std::endl;
std::cout << input_data[0] << " " << input_data[10] << " "
<< input_data[input_size - 1] << std::endl;
std::cout << "label:" << label << std::endl;
}
// run
predictor->Run();
auto output0 = predictor->GetOutput(0);
auto output0_data = output0->data<float>();
// get output
std::iota(index.begin(), index.end(), 0);
std::stable_sort(
index.begin(), index.end(), [output0_data](size_t i1, size_t i2) {
return output0_data[i1] > output0_data[i2];
});
test_num++;
if (label == index[0]) {
top1_num++;
}
for (int i = 0; i < 5; i++) {
if (label == index[i]) {
top5_num++;
}
}
if (debug && i % show_step == 0) {
std::cout << index[0] << " " << index[1] << " " << index[2] << " "
<< index[3] << " " << index[4] << std::endl;
std::cout << output0_data[index[0]] << " " << output0_data[index[1]]
<< " " << output0_data[index[2]] << " "
<< output0_data[index[3]] << " " << output0_data[index[4]]
<< std::endl;
std::cout << output0_data[630] << std::endl;
}
if (i % show_step == 0) {
std::cout << "step " << i << "; top1 acc:" << top1_num / test_num
<< "; top5 acc:" << top5_num / test_num << std::endl;
}
}
std::cout << "final result:" << std::endl;
std::cout << "top1 acc:" << top1_num / test_num << std::endl;
std::cout << "top5 acc:" << top5_num / test_num << std::endl;
return 0;
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle {
namespace lite {
TEST(MobileNetV1, test_mobilenetv1_int8_mediatek_apu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kInt8)},
lite_api::Place{TARGET(kAPU), PRECISION(kInt8)}});
auto predictor = lite_api::CreatePaddlePredictor(config);
std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int> input_shape{
FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_size = 1;
for (auto i : input_shape) {
input_size *= i;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run();
}
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
ASSERT_GE(out_accuracy, 0.55f);
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/time.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
inline int64_t ShapeProduction(std::vector<int64_t> shape) {
int64_t s = 1;
for (int64_t dim : shape) {
s *= dim;
}
return s;
}
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "[ERROR] usage: ./" << argv[0]
<< " model_dir [thread_num] [warmup_times] [repeat_times] "
"[input_data_path] [output_data_path]"
<< std::endl;
return -1;
}
std::string model_dir = argv[1];
int thread_num = 1;
if (argc > 2) {
thread_num = atoi(argv[2]);
}
int warmup_times = 5;
if (argc > 3) {
warmup_times = atoi(argv[3]);
}
int repeat_times = 10;
if (argc > 4) {
repeat_times = atoi(argv[4]);
}
std::string input_data_path;
if (argc > 5) {
input_data_path = argv[5];
}
std::string output_data_path;
if (argc > 6) {
output_data_path = argv[6];
}
paddle::lite_api::CxxConfig config;
config.set_model_dir(model_dir);
config.set_threads(thread_num);
config.set_power_mode(paddle::lite_api::LITE_POWER_HIGH);
config.set_valid_places(
{paddle::lite_api::Place{
TARGET(kARM), PRECISION(kFloat), DATALAYOUT(kNCHW)},
paddle::lite_api::Place{
TARGET(kARM), PRECISION(kInt8), DATALAYOUT(kNCHW)},
paddle::lite_api::Place{
TARGET(kARM), PRECISION(kInt8), DATALAYOUT(kNCHW)},
paddle::lite_api::Place{
TARGET(kRKNPU), PRECISION(kInt8), DATALAYOUT(kNCHW)}});
auto predictor = paddle::lite_api::CreatePaddlePredictor(config);
std::unique_ptr<paddle::lite_api::Tensor> input_tensor(
std::move(predictor->GetInput(0)));
input_tensor->Resize({1, 3, 224, 224});
auto input_data = input_tensor->mutable_data<float>();
auto input_size = ShapeProduction(input_tensor->shape());
if (input_data_path.empty()) {
for (int i = 0; i < input_size; i++) {
input_data[i] = 1;
}
} else {
std::fstream fs(input_data_path, std::ios::in);
if (!fs.is_open()) {
std::cerr << "open input data file failed." << std::endl;
return -1;
}
for (int i = 0; i < input_size; i++) {
fs >> input_data[i];
}
}
for (int i = 0; i < warmup_times; ++i) {
predictor->Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < repeat_times; ++i) {
predictor->Run();
}
std::cout << "Model: " << model_dir << ", threads num " << thread_num
<< ", warmup times: " << warmup_times
<< ", repeat times: " << repeat_times << ", spend "
<< (GetCurrentUS() - start) / repeat_times / 1000.0
<< " ms in average." << std::endl;
std::unique_ptr<const paddle::lite_api::Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
auto output_data = output_tensor->data<float>();
auto output_size = ShapeProduction(output_tensor->shape());
std::cout << "output data:";
for (int i = 0; i < output_size; i += 100) {
std::cout << "[" << i << "] " << output_data[i] << std::endl;
}
return 0;
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle {
namespace lite {
TEST(MobileNetV1, test_mobilenetv1_int8_rockchip_apu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kInt8)},
lite_api::Place{TARGET(kRKNPU), PRECISION(kInt8)}});
auto predictor = lite_api::CreatePaddlePredictor(config);
std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int> input_shape{
FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_size = 1;
for (auto i : input_shape) {
input_size *= i;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run();
}
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
ASSERT_GE(out_accuracy, 0.52f);
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle {
namespace lite {
TEST(MobileNetV2, test_mobilenetv2_fp32_huawei_kirin_npu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}});
auto predictor = lite_api::CreatePaddlePredictor(config);
std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int> input_shape{
FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_size = 1;
for (auto i : input_shape) {
input_size *= i;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run();
}
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
ASSERT_GE(out_accuracy, 0.57f);
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/ILSVRC2012_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 100, "iteration times to run");
DEFINE_int32(batch, 1, "batch of image");
DEFINE_int32(channel, 3, "image channel");
namespace paddle {
namespace lite {
TEST(ResNet50, test_resnet50_fp32_huawei_kirin_npu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}});
auto predictor = lite_api::CreatePaddlePredictor(config);
std::string raw_data_dir = FLAGS_data_dir + std::string("/raw_data");
std::vector<int> input_shape{
FLAGS_batch, FLAGS_channel, FLAGS_im_width, FLAGS_im_height};
auto raw_data = ReadRawData(raw_data_dir, input_shape, FLAGS_iteration);
int input_size = 1;
for (auto i : input_shape) {
input_size *= i;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
for (int j = 0; j < input_size; j++) {
data[j] = 0.f;
}
predictor->Run();
}
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (size_t i = 0; i < raw_data.size(); ++i) {
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(
std::vector<int64_t>(input_shape.begin(), input_shape.end()));
auto* data = input_tensor->mutable_data<float>();
memcpy(data, raw_data[i].data(), sizeof(float) * input_size);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1000);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", batch: " << FLAGS_batch
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::string labels_dir = FLAGS_data_dir + std::string("/labels.txt");
float out_accuracy = CalOutAccuracy(out_rets, labels_dir);
ASSERT_GE(out_accuracy, 0.64f);
}
} // namespace lite
} // namespace paddle
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册