未验证 提交 f427a987 编写于 作者: 提交者: GitHub

add paddle lite doc (#5434)

* add paddle lite doc
上级 b4357dbb
# Mobilenet_v3 在 ARM CPU 上部署示例
# 目录
- [1 获取 inference model]()
- [2 准备模型转换工具并生成 Paddle Lite 的部署模型]()
- [3 以 arm v8 、Android 系统为例进行部署]()
- [4 推理结果正确性验证]()
### 1 获取 inference model
提供以下两种方式获取 inference model
- 直接下载(推荐):[inference model](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_infer.tar)
- 通过预训练模型获取
首先获取[预训练模型](https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_pretrained.pdparams),在 ```models/tutorials/mobilenetv3_prod/Step6/tools``` 文件夹下提供了工具 export_model.py ,可以将预训练模型输出 为inference model ,运行如下命令即可获取 inference model。
```
# 假设当前在 models/tutorials/mobilenetv3_prod/Step6 目录下
python ./tools/export_model.py --pretrained=./mobilenet_v3_small_pretrained.pdparams --save-inference-dir=./mobilenet_v3_small_infer
```
在 mobilenet_v3_small_infer 文件夹下有 inference.pdmodel、inference.pdiparams 和 inference.pdiparams.info 文件。
### 2 准备模型转换工具并生成 Paddle Lite 的部署模型
- python 脚本方式
适用于 ``` python == 3.5\3.6\3.7 ```
首先 pip 安装 Paddle Lite:
```
pip3 install paddlelite==2.10
```
```mobilenet_v3```文件夹下允许如下命令:
```
python export_lite_model.py --model-file=./mobilenet_v3_small_infer/inference.pdmodel --param-file=./mobilenet_v3_small_infer/inference.pdiparams --optimize-out=./mobilenet_v3_small
```
在当前文件夹下会生成mobilenet_v3_small.nb文件。
- 终端命令方式
模型转换工具[opt_linux](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.10/opt_linux)[opt_mac](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.10/opt_mac)。或者参考[文档](https://paddle-lite.readthedocs.io/zh/develop/user_guides/model_optimize_tool.html)编译您的模型转换工具,使用如下命令转换可以转换 inference model 到 Paddle Lite 的 nb 模型:
```
./opt --model_file=./mobilenet_v3_small_infer/inference.pdmodel --param_file=./mobilenet_v3_small_infer/inference.pdiparams --optimize_out=./mobilenet_v3_small
```
在当前文件夹下会生成mobilenet_v3_small.nb文件。
注:在 mac 上运行 opt_mac 可能会有如下错误:
<div align="center">
<img src="../../images/Paddle-Lite/pic1.png" width=400">
</div>
需要搜索安全性与隐私,点击通用,点击仍然允许,即可。
<div align="center">
<img src="../../images/Paddle-Lite/pic2.png" width=500">
</div>
### 3 以 arm v8 、Android 系统为例进行部署
- 准备编译环境
```
gcc、g++(推荐版本为 8.2.0)
git、make、wget、python、adb
Java Environment
CMake(请使用 3.10 版本,其他版本的 Cmake 可能有兼容性问题,导致编译不通过)
Android NDK(支持 ndk-r17c 及之后的所有 NDK 版本, 注意从 ndk-r18 开始,NDK 交叉编译工具仅支持 Clang, 不支持 GCC)
```
- 环境安装命令
以 Ubuntu 为例介绍安装命令。注意需要 root 用户权限执行如下命令。mac 环境下编译 Android 库参考[Android 源码编译](https://paddle-lite.readthedocs.io/zh/develop/source_compile/macos_compile_android.html),Windows 下暂不支持编译 Android 版本库。
```
# 1. 安装 gcc g++ git make wget python unzip adb curl 等基础软件
apt update
apt-get install -y --no-install-recommends \
gcc g++ git make wget python unzip adb curl
# 2. 安装 jdk
apt-get install -y default-jdk
# 3. 安装 CMake,以下命令以 3.10.3 版本为例(其他版本的 Cmake 可能有兼容性问题,导致编译不通过,建议用这个版本)
wget -c https://mms-res.cdn.bcebos.com/cmake-3.10.3-Linux-x86_64.tar.gz && \
tar xzf cmake-3.10.3-Linux-x86_64.tar.gz && \
mv cmake-3.10.3-Linux-x86_64 /opt/cmake-3.10 &&
ln -s /opt/cmake-3.10/bin/cmake /usr/bin/cmake && \
ln -s /opt/cmake-3.10/bin/ccmake /usr/bin/ccmake
# 4. 下载 linux-x86_64 版本的 Android NDK,以下命令以 r17c 版本为例,其他版本步骤类似。
cd /tmp && curl -O https://dl.google.com/android/repository/android-ndk-r17c-linux-x86_64.zip
cd /opt && unzip /tmp/android-ndk-r17c-linux-x86_64.zip
# 5. 添加环境变量 NDK_ROOT 指向 Android NDK 的安装路径
echo "export NDK_ROOT=/opt/android-ndk-r17c" >> ~/.bashrc
source ~/.bashrc
```
- 获取预测库
可以使用下面两种方式获得预测库。
(1) 使用预编译包
推荐使用 Paddle Lite 仓库提供的 [release库](https://github.com/PaddlePaddle/Paddle-Lite/releases/tag/v2.10),在网页最下边选取要使用的库(注意本教程需要用 static 的库),例如这个[预编译库](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.10/inference_lite_lib.android.armv8.clang.c++_static.tar.gz)
```
mv inference_lite_lib.android.armv8.clang.c++_static.tar.gz inference_lite_lib.android.armv8.tar.gz
tar -xvzf inference_lite_lib.android.armv8.tar.gz
```
即可获取编译好的库。注意,即使获取编译好的库依然要进行上述**环境安装**的步骤,因为下面编译 demo 时候会用到。
(2) 编译预测库
运行编译脚本之前,请先检查系统环境变量 ``NDK_ROOT`` 指向正确的 Android NDK 安装路径。
之后可以下载并构建 Paddle Lite 编译包。
```
# 1. 检查环境变量 `NDK_ROOT` 指向正确的 Android NDK 安装路径
echo $NDK_ROOT
# 2. 下载 Paddle Lite 源码并切换到发布分支,如 release/v2.10
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite && git checkout release/v2.10
# (可选) 删除 third-party 目录,编译脚本会自动从国内 CDN 下载第三方库文件
# rm -rf third-party
# 3. 编译 Paddle Lite Android 预测库
./lite/tools/build_android.sh
```
如果按 ``./lite/tools/build_android.sh`` 中的默认参数执行,成功后会在 ``Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8`` 生成 Paddle Lite 编译包,文件目录如下。
```
inference_lite_lib.android.armv8/
├── cxx C++ 预测库和头文件
│ ├── include C++ 头文件
│ │ ├── paddle_api.h
│ │ ├── paddle_image_preprocess.h
│ │ ├── paddle_lite_factory_helper.h
│ │ ├── paddle_place.h
│ │ ├── paddle_use_kernels.h
│ │ ├── paddle_use_ops.h
│ │ └── paddle_use_passes.h
│ └── lib C++ 预测库
│ ├── libpaddle_api_light_bundled.a C++ 静态库
│ └── libpaddle_light_api_shared.so C++ 动态库
├── java Java 预测库
│ ├── jar
│ │ └── PaddlePredictor.jar Java JAR 包
│ ├── so
│ │ └── libpaddle_lite_jni.so Java JNI 动态链接库
│ └── src
└── demo C++ 和 Java 示例代码
├── cxx C++ 预测库示例
└── java Java 预测库示例
```
- 编译运行示例
将编译好的预测库放在当前目录下 mobilenet_v3 文件夹下,并准备好用于测试的[图片](../../images/demo.jpg),和 [label](./mobilenet_v3/imagenet1k_label_list.txt)[config](./mobilenet_v3/config.txt) 。最后文件夹如下所示:
```
mobilenet_v3/ 示例文件夹
├── inference_lite_lib.android.armv8/ Paddle Lite C++ 预测库和头文件
├── Makefile 编译相关
├── mobilenet_v3_small.nb 优化后的模型
├── mobilenet_v3.cc C++ 示例代码
├── demo.jpg 示例图片
├── imagenet1k_label_list.txt 示例label(用于后处理)
└── config.txt 示例config(用于前处理)
```
在 mobilenet_v3 文件夹下运行
```bash
make
```
会进行编译过程,注意编译过程会下载 opencv 第三方库,需要连接网络。编译完成后会生成 mobilenet_v3可执行文件。
注意 Makefile 中第4行:
```
LITE_ROOT=./inference_lite_lib.android.armv8
```
中的 ```LITE_ROOT```需要改成您的预测库的文件夹名。
- 在 Android 手机上部署
连接一台开启了**USB调试功能**的手机,运行
```
adb devices
```
可以看到有输出
```
List of devices attached
1ddcf602 device
```
- 在手机上运行 mobilenet_v3 demo。
```bash
#################################
# 假设当前位于 mobilenet_v3 目录下 #
#################################
# prepare enviroment on phone
adb shell mkdir -p /data/local/tmp/arm_cpu/
# push executable binary, library to device
adb push mobilenet_v3 /data/local/tmp/arm_cpu/
adb shell chmod +x /data/local/tmp/arm_cpu/mobilenet_v3
adb push inference_lite_lib.android.armv8/cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/arm_cpu/
# push model with optimized(opt) to device
adb push ./mobilenet_v3_small.nb /data/local/tmp/arm_cpu/
# push config and label and pictures to device
adb push ./config.txt /data/local/tmp/arm_cpu/
adb push ./imagenet1k_label_list.txt /data/local/tmp/arm_cpu/
adb push ./demo.jpg /data/local/tmp/arm_cpu/
# run demo on device
adb shell "export LD_LIBRARY_PATH=/data/local/tmp/arm_cpu/; \
/data/local/tmp/arm_cpu/mobilenet_v3 \
/data/local/tmp/arm_cpu/config.txt \
/data/local/tmp/arm_cpu/demo.jpg"
```
得到以下输出:
```
===clas result for image: /data/local/tmp/arm_cpu/demo.jpg===
Top-1, class_id: 8, class_name: hen, score: 0.901639
Top-2, class_id: 7, class_name: cock, score: 0.0970001
Top-3, class_id: 86, class_name: partridge, score: 0.000225853
Top-4, class_id: 80, class_name: black grouse, score: 0.0001647
Top-5, class_id: 21, class_name: kite, score: 0.000128394
```
代表在 Android 手机上推理部署完成。
### 4 验证推理结果正确性
`models/tutorials/mobilenetv3_prod/Step6`目录下运行如下命令:
```
python tools/predict.py --pretrained=./mobilenet_v3_small_paddle_pretrained.pdparams --img-path=images/demo.jpg
```
最终输出结果为 ```class_id: 8, prob: 0.9091238975524902``` ,表示预测的类别ID是```8```,置信度为```0.909```
与Paddle Lite预测结果一致。输出结果微小差距的原因是 Paddle Lite 所用 ```opencv``` 和 训练所用 ```PIL```库前处理方式有微小差别。
ARM_ABI = arm8
export ARM_ABI
LITE_ROOT=./inference_lite_lib.android.armv8
include ${LITE_ROOT}/demo/cxx/Makefile.def
THIRD_PARTY_DIR=${LITE_ROOT}/third_party
OPENCV_VERSION=opencv4.1.0
OPENCV_LIBS = ${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a
OPENCV_INCLUDE = -I${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
mobilenet_v3: fetch_opencv mobilenet_v3.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenet_v3.o -o mobilenet_v3 $(CXX_LIBS) $(LDFLAGS)
mobilenet_v3.o: mobilenet_v3.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobilenet_v3.o -c mobilenet_v3.cc
fetch_opencv:
@ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR}
@ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \
(echo "fetch opencv libs" && \
wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz)
@ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \
tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR}
.PHONY: clean
clean:
rm -f mobilenet_v3.o
rm -f mobilenet_v3
clas_model_file /data/local/tmp/arm_cpu/mobilenet_v3_small.nb
label_path /data/local/tmp/arm_cpu/imagenet1k_label_list.txt
resize_short_size 256
crop_size 224
visualize 0
enable_benchmark 0
from paddlelite.lite import *
def get_args(add_help=True):
import argparse
parser = argparse.ArgumentParser(
description='Paddle Lite Optimize model', add_help=add_help)
parser.add_argument('--model-dir', default='mobilenet_v3_small', help='model dir')
parser.add_argument('--model-file', default='', help='model file')
parser.add_argument('--param-file', default='', help='param file')
parser.add_argument('--target', default='arm', help='arm or opencl or X86')
parser.add_argument('--model-type', default='naive_buffer', help='save model type')
parser.add_argument('--optimize-out', default='mobilenet_v3_small', help='save model type')
args = parser.parse_args()
return args
def export(args):
opt=Opt()
opt.set_model_file(args.model_file)
opt.set_param_file(args.param_file)
opt.set_valid_places(args.target)
opt.set_model_type(args.model_type)
opt.set_optimize_out(args.optimize_out)
opt.run()
if __name__ == "__main__":
args = get_args()
export(args)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle_api.h" // NOLINT
#include <arm_neon.h>
#include <chrono>
#include <fstream>
#include <iostream>
#include <math.h>
#include <opencv2/opencv.hpp>
#include <sys/time.h>
#include <vector>
using namespace paddle::lite_api; // NOLINT
using namespace std;
struct RESULT {
std::string class_name;
int class_id;
float score;
};
std::vector<RESULT> PostProcess(const float *output_data, int output_size,
const std::vector<std::string> &word_labels,
cv::Mat &output_image) {
const int TOPK = 5;
int max_indices[TOPK];
double max_scores[TOPK];
for (int i = 0; i < TOPK; i++) {
max_indices[i] = 0;
max_scores[i] = 0;
}
for (int i = 0; i < output_size; i++) {
float score = output_data[i];
int index = i;
for (int j = 0; j < TOPK; j++) {
if (score > max_scores[j]) {
index += max_indices[j];
max_indices[j] = index - max_indices[j];
index -= max_indices[j];
score += max_scores[j];
max_scores[j] = score - max_scores[j];
score -= max_scores[j];
}
}
}
std::vector<RESULT> results(TOPK);
for (int i = 0; i < results.size(); i++) {
results[i].class_name = "Unknown";
if (max_indices[i] >= 0 && max_indices[i] < word_labels.size()) {
results[i].class_name = word_labels[max_indices[i]];
}
results[i].score = max_scores[i];
results[i].class_id = max_indices[i];
cv::putText(output_image,
"Top" + std::to_string(i + 1) + "." + results[i].class_name +
":" + std::to_string(results[i].score),
cv::Point2d(5, i * 18 + 20), cv::FONT_HERSHEY_PLAIN, 1,
cv::Scalar(51, 255, 255));
}
return results;
}
// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
void NeonMeanScale(const float *din, float *dout, int size,
const std::vector<float> mean,
const std::vector<float> scale) {
if (mean.size() != 3 || scale.size() != 3) {
std::cerr << "[ERROR] mean or scale size must equal to 3\n";
exit(1);
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
float32x4_t vmean1 = vdupq_n_f32(mean[1]);
float32x4_t vmean2 = vdupq_n_f32(mean[2]);
float32x4_t vscale0 = vdupq_n_f32(scale[0]);
float32x4_t vscale1 = vdupq_n_f32(scale[1]);
float32x4_t vscale2 = vdupq_n_f32(scale[2]);
float *dout_c0 = dout;
float *dout_c1 = dout + size;
float *dout_c2 = dout + size * 2;
int i = 0;
for (; i < size - 3; i += 4) {
float32x4x3_t vin3 = vld3q_f32(din);
float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
vst1q_f32(dout_c0, vs0);
vst1q_f32(dout_c1, vs1);
vst1q_f32(dout_c2, vs2);
din += 12;
dout_c0 += 4;
dout_c1 += 4;
dout_c2 += 4;
}
for (; i < size; i++) {
*(dout_c0++) = (*(din++) - mean[0]) * scale[0];
*(dout_c1++) = (*(din++) - mean[1]) * scale[1];
*(dout_c2++) = (*(din++) - mean[2]) * scale[2];
}
}
cv::Mat ResizeImage(const cv::Mat &img, const int &resize_short_size) {
int w = img.cols;
int h = img.rows;
cv::Mat resize_img;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
return resize_img;
}
cv::Mat CenterCropImg(const cv::Mat &img, const int &crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
int w_end = w_start + crop_size;
int h_end = h_start + crop_size;
cv::Rect rect(w_start, h_start, w_end, h_end);
cv::Mat crop_img = img(rect);
return crop_img;
}
std::vector<RESULT>
RunClasModel(std::shared_ptr<PaddlePredictor> predictor, const cv::Mat &img,
const std::map<std::string, std::string> &config,
const std::vector<std::string> &word_labels, double &cost_time) {
// Read img
int resize_short_size = stoi(config.at("resize_short_size"));
int crop_size = stoi(config.at("crop_size"));
int visualize = stoi(config.at("visualize"));
cv::Mat resize_image = ResizeImage(img, resize_short_size);
cv::Mat crop_image = CenterCropImg(resize_image, crop_size);
cv::Mat img_fp;
double e = 1.0 / 255.0;
crop_image.convertTo(img_fp, CV_32FC3, e);
// Prepare input data from image
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
input_tensor->Resize({1, 3, img_fp.rows, img_fp.cols});
auto *data0 = input_tensor->mutable_data<float>();
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
const float *dimg = reinterpret_cast<const float *>(img_fp.data);
NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols, mean, scale);
auto start = std::chrono::system_clock::now();
// Run predictor
predictor->Run();
// Get output and post process
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
auto *output_data = output_tensor->data<float>();
auto end = std::chrono::system_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
cost_time = double(duration.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den;
int output_size = 1;
for (auto dim : output_tensor->shape()) {
output_size *= dim;
}
cv::Mat output_image;
auto results =
PostProcess(output_data, output_size, word_labels, output_image);
if (visualize) {
std::string output_image_path = "./clas_result.png";
cv::imwrite(output_image_path, output_image);
std::cout << "save output image into " << output_image_path << std::endl;
}
return results;
}
std::shared_ptr<PaddlePredictor> LoadModel(std::string model_file) {
MobileConfig config;
config.set_model_from_file(model_file);
std::shared_ptr<PaddlePredictor> predictor =
CreatePaddlePredictor<MobileConfig>(config);
return predictor;
}
std::vector<std::string> split(const std::string &str,
const std::string &delim) {
std::vector<std::string> res;
if ("" == str)
return res;
char *strs = new char[str.length() + 1];
std::strcpy(strs, str.c_str());
char *d = new char[delim.length() + 1];
std::strcpy(d, delim.c_str());
char *p = std::strtok(strs, d);
while (p) {
string s = p;
res.push_back(s);
p = std::strtok(NULL, d);
}
return res;
}
std::vector<std::string> ReadDict(std::string path) {
std::ifstream in(path);
std::string filename;
std::string line;
std::vector<std::string> m_vec;
if (in) {
while (getline(in, line)) {
m_vec.push_back(line);
}
} else {
std::cout << "no such file" << std::endl;
}
return m_vec;
}
std::map<std::string, std::string> LoadConfigTxt(std::string config_path) {
auto config = ReadDict(config_path);
std::map<std::string, std::string> dict;
for (int i = 0; i < config.size(); i++) {
std::vector<std::string> res = split(config[i], " ");
dict[res[0]] = res[1];
}
return dict;
}
void PrintConfig(const std::map<std::string, std::string> &config) {
std::cout << "=======PaddleClas lite demo config======" << std::endl;
for (auto iter = config.begin(); iter != config.end(); iter++) {
std::cout << iter->first << " : " << iter->second << std::endl;
}
std::cout << "=======End of PaddleClas lite demo config======" << std::endl;
}
std::vector<std::string> LoadLabels(const std::string &path) {
std::ifstream file;
std::vector<std::string> labels;
file.open(path);
while (file) {
std::string line;
std::getline(file, line);
std::string::size_type pos = line.find(" ");
if (pos != std::string::npos) {
line = line.substr(pos);
}
labels.push_back(line);
}
file.clear();
file.close();
return labels;
}
int main(int argc, char **argv) {
if (argc < 3) {
std::cerr << "[ERROR] usage: " << argv[0] << " config_path img_path\n";
exit(1);
}
std::string config_path = argv[1];
std::string img_path = argv[2];
// load config
auto config = LoadConfigTxt(config_path);
PrintConfig(config);
double elapsed_time = 0.0;
int warmup_iter = 10;
bool enable_benchmark = bool(stoi(config.at("enable_benchmark")));
int total_cnt = enable_benchmark ? 1000 : 1;
std::string clas_model_file = config.at("clas_model_file");
std::string label_path = config.at("label_path");
// Load Labels
std::vector<std::string> word_labels = LoadLabels(label_path);
auto clas_predictor = LoadModel(clas_model_file);
for (int j = 0; j < total_cnt; ++j) {
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
double run_time = 0;
std::vector<RESULT> results =
RunClasModel(clas_predictor, srcimg, config, word_labels, run_time);
std::cout << "===clas result for image: " << img_path << "===" << std::endl;
for (int i = 0; i < results.size(); i++) {
std::cout << "\t"
<< "Top-" << i + 1 << ", class_id: " << results[i].class_id
<< ", class_name: " << results[i].class_name
<< ", score: " << results[i].score << std::endl;
}
if (j >= warmup_iter) {
elapsed_time += run_time;
std::cout << "Current image path: " << img_path << std::endl;
std::cout << "Current time cost: " << run_time << " s, "
<< "average time cost in all: "
<< elapsed_time / (j + 1 - warmup_iter) << " s." << std::endl;
} else {
std::cout << "Current time cost: " << run_time << " s." << std::endl;
}
}
return 0;
}
# Lite ARM CPU 推理开发文档
# Paddle Lite arm cpu 推理开发文档
# 目录
- [1. 简介](#1---)
- [2. Lite ARM CPU 基础推理开发文档](#2---)
- [3. Lite ARM CPU 高级推理开发文档](#3---)
- [4. FAQ](#4---)
- [1. 简介](#1)
- [2. 使用 Paddle Lite 在 ARM CPU 上的部署流程](#2)
- [2.1 准备推理数据与环境 ](#2.1)
- [2.2 准备推理模型 ](#2.1)
- [2.3 准备推理所需代码](#2.2)
- [2.4 开发数据预处理程序](#2.3)
- [2.5 开发推理程序](#2.4)
- [2.6 开发推理结果后处理程序](#2.5)
- [2.7 验证推理结果正确性](#2.6)
- [3. FAQ](#3)
- [3.1 通用问题](#3.1)
## 1. 简介
在 ARM CPU 上部署需要使用 [Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) 进行部署,Paddle Lite 是一个轻量级、灵活性强、易于扩展的高性能的深度学习预测框架,它可以支持诸如 ARM、OpenCL 、NPU 等等多种终端,同时拥有强大的图优化及预测加速能力。如果您希望将 Paddle Lite 框架集成到自己的项目中,那么只需要如下几步简单操作即可。
<div align="center">
<img src="../images/paddleliteworkflow.png" width=600">
</div>
图中的2、7是核验点,需要核验结果正确性。
## 2.使用 Paddle Lite 在 ARM CPU 上的部署流程
### 2.1 准备推理数据与环境
- 推理环境
开发机器:一台开发机,可以是 x86 linux 或者 Mac 设备。开发机器上需要安装开发环境。
推理设备:一台 ARM CPU 设备,可以连接到开发机上。开发板的系统可以是 Android 或 Armlinux。
开发机上安装开发环境以及对推理设备的配置参考[mobilenet_v3开发实战](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu)中的**准备开发环境****在 Android 手机上部署**章节。
- 推理数据
一张可用于推理的[图片](../../mobilenetv3_prod/Step6/images/demo.jpg)和用于前处理的[配置文件](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu/mobilenet_v3/config.txt)(可选,和前处理有关)以及用于推理结果后处理相关的 [label](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu/mobilenet_v3/imagenet1k_label_list.txt) 文件(可选,跟后处理有关)。
### 2.2 准备推理模型
- 准备 inference model
Paddle Lite 框架直接支持[ PaddlePaddle ](https://www.paddlepaddle.org.cn/)深度学习框架产出的模型。在 PaddlePaddle 静态图模式下,使用`save_inference_model`这个 API 保存预测模型,Paddle Lite 对此类预测模型已经做了充分支持;在 PaddlePaddle 动态图模式下,使用`paddle.jit.save`这个 API 保存预测模型,Paddle Lite 可以支持绝大部分此类预测模型了。
- 使用 opt 工具优化模型
Paddle Lite 框架拥有优秀的加速、优化策略及实现,包含量化、子图融合、Kernel 优选等优化手段。优化后的模型更轻量级,耗费资源更少,并且执行速度也更快。
这些优化通过 Paddle Lite 提供的 opt 工具实现。opt 工具还可以统计并打印出模型中的算子信息,并判断不同硬件平台下 Paddle Lite 的支持情况。
导出 inference model 和使用 opt 工具优化参考[mobilenet_v3开发实战](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu)中的**获取 inference model****生成 Paddle Lite 部署模型**章节,注意本步骤需要核验是否有```xxx.nb```模型生成。
### 2.3 准备推理所需代码
- Paddle Lite 预测库
Paddle Lite 提供了 `Android/IOS/ArmLinux/Windows/MacOS/Ubuntu` 平台的官方 Release 预测库下载,我们优先推荐您直接下载 [Paddle Lite 预编译库](https://github.com/PaddlePaddle/Paddle-Lite/releases/tag/v2.10)。您也可以根据目标平台选择对应的 [源码编译方法](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html)。Paddle Lite 提供了源码编译脚本,位于 `lite/tools/` 文件夹下,只需要进行必要的环境准备之后即可运行。
- 用户的推理应用程序,例如mobilenet_v3.cc
- Makefile用于编译应用程序
至此已经准备好部署所需的全部文件。以[mobilenet_v3开发实战](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu)中的 mobilenet_v3 文件夹为例展示:
```
mobilenet_v3/ 示例文件夹
├── inference_lite_lib.android.armv8/ Paddle Lite C++ 预测库和头文件
├── Makefile 编译相关
├── Makefile.def 编译相关
├── mobilenet_v3_small.nb 优化后的模型
├── mobilenet_v3.cc C++ 示例代码
├── demo.jpg 示例图片
├── imagenet1k_label_list.txt 示例label(用于后处理)
└── config.txt 示例config(用于前处理)
```
### 2.4 开发数据预处理程序
Paddle Lite 推理框架的输入不能直接是图片,所以需要对图片进行预处理,预处理过程一般包括 `opencv 读取``resize``crop``归一化`等操作,之后才能变成最后输入给推理框架的数据。预处理参考 [mobilenet_v3开发实战](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu/mobilenet_v3) 中的mobilenet_v3.cc 文件。
### 2.5 开发推理程序
使用 Paddle Lite 的 `API` 只需简单五步即可完成预测:
1. 声明 `MobileConfig` ,设置第二步优化后的模型文件路径,或选择从内存中加载模型
2. 创建 `Predictor` ,调用 `CreatePaddlePredictor` 接口,一行代码即可完成引擎初始化
3. 准备输入,通过 `predictor->GetInput(i)` 获取输入变量,并为其指定输入大小和输入值
4. 执行预测,只需要运行 `predictor->Run()` 一行代码,即可使用 Paddle Lite 框架完成预测执行
5. 获得输出,使用 `predictor->GetOutput(i)` 获取输出变量,并通过 `data<T>` 取得输出值
在此提供简单示例:
```c++
#include <iostream>
// 引入 C++ API
#include "paddle_lite/paddle_api.h"
#include "paddle_lite/paddle_use_ops.h"
#include "paddle_lite/paddle_use_kernels.h"
// 1. 设置 MobileConfig
MobileConfig config;
config.set_model_from_file(<modelPath>); // 设置 NaiveBuffer 格式模型路径
config.set_power_mode(LITE_POWER_NO_BIND); // 设置 CPU 运行模式
config.set_threads(4); // 设置工作线程数
// 2. 创建 PaddlePredictor
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<MobileConfig>(config);
// 3. 设置输入数据,可以在这里进行您的前处理,比如用opencv读取图片等。这里为全一输入。
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
input_tensor->Resize({1, 3, 224, 224});
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) {
data[i] = 1;
}
//其他前处理
// 4. 执行预测
predictor->run();
// 5. 获取输出数据
std::unique_ptr<const Tensor> output_tensor(std::move(predictor->GetOutput(0)));
std::cout << "Output shape " << output_tensor->shape()[1] << std::endl;
for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) {
std::cout << "Output[" << i << "]: " << output_tensor->data<float>()[i]
<< std::endl;
}
//后处理
```
### 2.6 开发推理结果后处理程序
后处理主要处理的是Paddle Lite 推理框架的输出 `tensor`, 包括选取哪个 `tensor` 以及根据 `label` 文件进行获得预测的类别,后处理参考 [mobilenet_v3开发实战](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu/mobilenet_v3) 中的mobilenet_v3.cc 文件。
### 2.7 验证推理结果正确性
Paddle Lite 的推理结果,需要和训练框架的预测结果对比是否一致。注意此过程中需要首先保证前处理和后处理与训练代码是一致的。具体可以参考 [mobilenet_v3开发实战](../../mobilenetv3_prod/Step6/deploy/lite_infer_cpp_arm_cpu)
## 3. FAQ
### 3.1 通用问题
如果您在使用过程中遇到任何问题,可以参考 [Paddle Lite 文档](https://paddle-lite.readthedocs.io/zh/latest/index.html) ,还可以在[这里](https://github.com/PaddlePaddle/Paddle-Lite/issues)提 issue 给我们,我们会高优跟进。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册