未验证 提交 5a7a841b 编写于 作者: L littletomatodonkey 提交者: GitHub

Merge pull request #303 from littletomatodonkey/sta/add_lite_demo

Add lite demo
......@@ -51,8 +51,7 @@ public:
class ResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
float &ratio_h, float &ratio_w);
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len);
};
} // namespace PaddleClas
\ No newline at end of file
......@@ -47,15 +47,12 @@ void Classifier::LoadModel(const std::string &model_dir) {
}
void Classifier::Run(cv::Mat &img) {
float ratio_h{};
float ratio_w{};
cv::Mat srcimg;
cv::Mat resize_img;
img.copyTo(srcimg);
this->resize_op_.Run(img, resize_img, this->resize_short_size_, ratio_h,
ratio_w);
this->resize_op_.Run(img, resize_img, this->resize_short_size_);
this->crop_op_.Run(resize_img, this->crop_size_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
......
......@@ -25,6 +25,7 @@
#include <cstring>
#include <fstream>
#include <math.h>
#include <numeric>
#include <include/preprocess_op.h>
......@@ -68,25 +69,22 @@ void CenterCropImg::Run(cv::Mat &img, const int crop_size) {
img = img(rect);
}
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
float &ratio_h, float &ratio_w) {
void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_short_size) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
if (h < w) {
ratio = float(max_size_len) / float(h);
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(max_size_len) / float(w);
ratio = float(resize_short_size) / float(w);
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h);
ratio_w = float(resize_w) / float(w);
}
} // namespace PaddleClas
\ No newline at end of file
ARM_ABI = arm8
export ARM_ABI
include ../Makefile.def
LITE_ROOT=../../../
THIRD_PARTY_DIR=${LITE_ROOT}/third_party
OPENCV_VERSION=opencv4.1.0
OPENCV_LIBS = ${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \
${THIRD_PARTY_DIR}/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a
OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
clas_system: fetch_opencv clas_system.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) clas_system.o -o clas_system $(CXX_LIBS) $(LDFLAGS)
clas_system.o: image_classfication.cpp
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o clas_system.o -c image_classfication.cpp
fetch_opencv:
@ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR}
@ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \
(echo "fetch opencv libs" && \
wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz)
@ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \
tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR}
.PHONY: clean
clean:
rm -f clas_system.o
rm -f clas_system
clas_model_file ./MobileNetV3_large_x1_0.nb
label_path ./imagenet1k_label_list.txt
resize_short_size 256
crop_size 224
visualize 0
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle_api.h" // NOLINT
#include <arm_neon.h>
#include <chrono>
#include <fstream>
#include <iostream>
#include <math.h>
#include <opencv2/opencv.hpp>
#include <sys/time.h>
#include <vector>
using namespace paddle::lite_api; // NOLINT
using namespace std;
struct RESULT {
std::string class_name;
int class_id;
float score;
};
std::vector<RESULT> PostProcess(const float *output_data, int output_size,
const std::vector<std::string> &word_labels,
cv::Mat &output_image) {
const int TOPK = 5;
int max_indices[TOPK];
double max_scores[TOPK];
for (int i = 0; i < TOPK; i++) {
max_indices[i] = 0;
max_scores[i] = 0;
}
for (int i = 0; i < output_size; i++) {
float score = output_data[i];
int index = i;
for (int j = 0; j < TOPK; j++) {
if (score > max_scores[j]) {
index += max_indices[j];
max_indices[j] = index - max_indices[j];
index -= max_indices[j];
score += max_scores[j];
max_scores[j] = score - max_scores[j];
score -= max_scores[j];
}
}
}
std::vector<RESULT> results(TOPK);
for (int i = 0; i < results.size(); i++) {
results[i].class_name = "Unknown";
if (max_indices[i] >= 0 && max_indices[i] < word_labels.size()) {
results[i].class_name = word_labels[max_indices[i]];
}
results[i].score = max_scores[i];
results[i].class_id = max_indices[i];
cv::putText(output_image,
"Top" + std::to_string(i + 1) + "." + results[i].class_name +
":" + std::to_string(results[i].score),
cv::Point2d(5, i * 18 + 20), cv::FONT_HERSHEY_PLAIN, 1,
cv::Scalar(51, 255, 255));
}
return results;
}
// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
void NeonMeanScale(const float *din, float *dout, int size,
const std::vector<float> mean,
const std::vector<float> scale) {
if (mean.size() != 3 || scale.size() != 3) {
std::cerr << "[ERROR] mean or scale size must equal to 3\n";
exit(1);
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
float32x4_t vmean1 = vdupq_n_f32(mean[1]);
float32x4_t vmean2 = vdupq_n_f32(mean[2]);
float32x4_t vscale0 = vdupq_n_f32(scale[0]);
float32x4_t vscale1 = vdupq_n_f32(scale[1]);
float32x4_t vscale2 = vdupq_n_f32(scale[2]);
float *dout_c0 = dout;
float *dout_c1 = dout + size;
float *dout_c2 = dout + size * 2;
int i = 0;
for (; i < size - 3; i += 4) {
float32x4x3_t vin3 = vld3q_f32(din);
float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
vst1q_f32(dout_c0, vs0);
vst1q_f32(dout_c1, vs1);
vst1q_f32(dout_c2, vs2);
din += 12;
dout_c0 += 4;
dout_c1 += 4;
dout_c2 += 4;
}
for (; i < size; i++) {
*(dout_c0++) = (*(din++) - mean[0]) * scale[0];
*(dout_c1++) = (*(din++) - mean[1]) * scale[1];
*(dout_c2++) = (*(din++) - mean[2]) * scale[2];
}
}
cv::Mat ResizeImage(const cv::Mat &img, const int &resize_short_size) {
int w = img.cols;
int h = img.rows;
cv::Mat resize_img;
float ratio = 1.f;
if (h < w) {
ratio = float(resize_short_size) / float(h);
} else {
ratio = float(resize_short_size) / float(w);
}
int resize_h = round(float(h) * ratio);
int resize_w = round(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
return resize_img;
}
cv::Mat CenterCropImg(const cv::Mat &img, const int &crop_size) {
int resize_w = img.cols;
int resize_h = img.rows;
int w_start = int((resize_w - crop_size) / 2);
int h_start = int((resize_h - crop_size) / 2);
cv::Rect rect(w_start, h_start, crop_size, crop_size);
cv::Mat crop_img = img(rect);
return crop_img;
}
std::vector<RESULT>
RunClasModel(std::shared_ptr<PaddlePredictor> predictor, const cv::Mat &img,
const std::map<std::string, std::string> &config,
const std::vector<std::string> &word_labels) {
// Read img
int resize_short_size = stoi(config.at("resize_short_size"));
int crop_size = stoi(config.at("crop_size"));
int visualize = stoi(config.at("visualize"));
cv::Mat resize_image = ResizeImage(img, resize_short_size);
cv::Mat crop_image = CenterCropImg(resize_image, crop_size);
cv::Mat img_fp;
double e = 1.0 / 255.0;
crop_image.convertTo(img_fp, CV_32FC3, e);
// Prepare input data from image
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
input_tensor->Resize({1, 3, img_fp.rows, img_fp.cols});
auto *data0 = input_tensor->mutable_data<float>();
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
const float *dimg = reinterpret_cast<const float *>(img_fp.data);
NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols, mean, scale);
// Run predictor
predictor->Run();
// Get output and post process
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
auto *output_data = output_tensor->data<float>();
int output_size = 1;
for (auto dim : output_tensor->shape()) {
output_size *= dim;
}
cv::Mat output_image;
auto results =
PostProcess(output_data, output_size, word_labels, output_image);
if (visualize) {
std::string output_image_path = "./clas_result.png";
cv::imwrite(output_image_path, output_image);
std::cout << "save output image into " << output_image_path << std::endl;
}
return results;
}
std::shared_ptr<PaddlePredictor> LoadModel(std::string model_file) {
MobileConfig config;
config.set_model_from_file(model_file);
std::shared_ptr<PaddlePredictor> predictor =
CreatePaddlePredictor<MobileConfig>(config);
return predictor;
}
std::vector<std::string> split(const std::string &str,
const std::string &delim) {
std::vector<std::string> res;
if ("" == str)
return res;
char *strs = new char[str.length() + 1];
std::strcpy(strs, str.c_str());
char *d = new char[delim.length() + 1];
std::strcpy(d, delim.c_str());
char *p = std::strtok(strs, d);
while (p) {
string s = p;
res.push_back(s);
p = std::strtok(NULL, d);
}
return res;
}
std::vector<std::string> ReadDict(std::string path) {
std::ifstream in(path);
std::string filename;
std::string line;
std::vector<std::string> m_vec;
if (in) {
while (getline(in, line)) {
m_vec.push_back(line);
}
} else {
std::cout << "no such file" << std::endl;
}
return m_vec;
}
std::map<std::string, std::string> LoadConfigTxt(std::string config_path) {
auto config = ReadDict(config_path);
std::map<std::string, std::string> dict;
for (int i = 0; i < config.size(); i++) {
std::vector<std::string> res = split(config[i], " ");
dict[res[0]] = res[1];
}
return dict;
}
void PrintConfig(const std::map<std::string, std::string> &config) {
std::cout << "=======PaddleClas lite demo config======" << std::endl;
for (auto iter = config.begin(); iter != config.end(); iter++) {
std::cout << iter->first << " : " << iter->second << std::endl;
}
std::cout << "=======End of PaddleClas lite demo config======" << std::endl;
}
std::vector<std::string> LoadLabels(const std::string &path) {
std::ifstream file;
std::vector<std::string> labels;
file.open(path);
while (file) {
std::string line;
std::getline(file, line);
std::string::size_type pos = line.find(" ");
if (pos != std::string::npos) {
line = line.substr(pos);
}
labels.push_back(line);
}
file.clear();
file.close();
return labels;
}
int main(int argc, char **argv) {
if (argc < 3) {
std::cerr << "[ERROR] usage: " << argv[0] << " config_path img_path\n";
exit(1);
}
std::string config_path = argv[1];
std::string img_path = argv[2];
// load config
auto config = LoadConfigTxt(config_path);
PrintConfig(config);
std::string clas_model_file = config.at("clas_model_file");
std::string label_path = config.at("label_path");
// Load Labels
std::vector<std::string> word_labels = LoadLabels(label_path);
auto clas_predictor = LoadModel(clas_model_file);
auto start = std::chrono::system_clock::now();
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
std::vector<RESULT> results =
RunClasModel(clas_predictor, srcimg, config, word_labels);
std::cout << "===clas result for image: " << img_path << "===" << std::endl;
for (int i = 0; i < results.size(); i++) {
std::cout << "\t"
<< "Top-" << i + 1 << ", class_id: " << results[i].class_id
<< ", class_name: " << results[i].class_name
<< ", score: " << results[i].score << std::endl;
}
auto end = std::chrono::system_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
std::cout << "Cost "
<< double(duration.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den
<< " s" << std::endl;
return 0;
}
#!/bin/bash
if [ $# != 1 ] ; then
echo "USAGE: $0 your_inference_lite_lib_path"
exit 1;
fi
mkdir -p $1/demo/cxx/clas/debug/
cp ../../ppcls/utils/imagenet1k_label_list.txt $1/demo/cxx/clas/debug/
cp -r ./* $1/demo/cxx/clas/
cp ./config.txt $1/demo/cxx/clas/debug/
cp ./imgs/tabby_cat.jpg $1/demo/cxx/clas/debug/
echo "Prepare Done"
# 端侧部署
本教程将介绍基于[Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) 在移动端部署PaddleClas分类模型的详细步骤。
Paddle Lite是飞桨轻量化推理引擎,为手机、IOT端提供高效推理能力,并广泛整合跨平台硬件,为端侧部署及应用落地问题提供轻量化的部署方案。如果希望直接测试速度,可以参考[Paddle-Lite移动端benchmark测试教程](../../docs/zh_CN/extension/paddle_mobile_inference.md)
## 1. 准备环境
### 运行准备
- 电脑(编译Paddle Lite)
- 安卓手机(armv7或armv8)
### 1.1 准备交叉编译环境
交叉编译环境用于编译 Paddle Lite 和 PaddleClas 的C++ demo。
支持多种开发环境,不同开发环境的编译流程请参考对应文档。
1. [Docker](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#docker)
2. [Linux](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#linux)
3. [MAC OS](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#mac-os)
### 1.2 准备预测库
预测库有两种获取方式:
1. [建议]直接下载,预测库下载链接如下:
|平台|预测库下载链接|
|-|-|
|Android|[arm7](https://paddlelite-data.bj.bcebos.com/Release/2.6.1/Android/inference_lite_lib.android.armv7.gcc.c++_static.with_extra.CV_ON.tar.gz) / [arm8](https://paddlelite-data.bj.bcebos.com/Release/2.6.1/Android/inference_lite_lib.android.armv8.gcc.c++_static.with_extra.CV_ON.tar.gz)|
|IOS|[arm7](https://paddlelite-data.bj.bcebos.com/Release/2.6.1/iOS/inference_lite_lib.ios.armv7.with_extra.CV_ON.tar.gz) / [arm8](https://paddlelite-data.bj.bcebos.com/Release/2.6.1/iOS/inference_lite_lib.ios64.armv8.with_extra.CV_ON.tar.gz)|
注:1. 如果是从下Paddle-Lite[官网文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/release_lib.html#android-toolchain-gcc)下载的预测库,
注意选择`with_extra=ON,with_cv=ON`的下载链接。2. 如果使用量化的模型部署在端侧,建议使用Paddle-Lite develop分支编译预测库。
2. 编译Paddle-Lite得到预测库,Paddle-Lite的编译方式如下:
```
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
# 如果使用编译方式,建议使用develop分支编译预测库
git checkout develop
./lite/tools/build_android.sh --arch=armv8 --with_cv=ON --with_extra=ON
```
注意:编译Paddle-Lite获得预测库时,需要打开`--with_cv=ON --with_extra=ON`两个选项,`--arch`表示`arm`版本,这里指定为armv8,更多编译命令介绍请参考[链接](https://paddle-lite.readthedocs.io/zh/latest/user_guides/Compile/Android.html#id2)
直接下载预测库并解压后,可以得到`inference_lite_lib.android.armv8/`文件夹,通过编译Paddle-Lite得到的预测库位于
`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/`文件夹下。
预测库的文件目录如下:
```
inference_lite_lib.android.armv8/
|-- cxx C++ 预测库和头文件
| |-- include C++ 头文件
| | |-- paddle_api.h
| | |-- paddle_image_preprocess.h
| | |-- paddle_lite_factory_helper.h
| | |-- paddle_place.h
| | |-- paddle_use_kernels.h
| | |-- paddle_use_ops.h
| | `-- paddle_use_passes.h
| `-- lib C++预测库
| |-- libpaddle_api_light_bundled.a C++静态库
| `-- libpaddle_light_api_shared.so C++动态库
|-- java Java预测库
| |-- jar
| | `-- PaddlePredictor.jar
| |-- so
| | `-- libpaddle_lite_jni.so
| `-- src
|-- demo C++和Java示例代码
| |-- cxx C++ 预测库demo
| `-- java Java 预测库demo
```
## 2 开始运行
### 2.1 模型优化
Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括量化、子图融合、混合调度、Kernel优选等方法,使用Paddle-lite的opt工具可以自动对inference模型进行优化,优化后的模型更轻量,模型运行速度更快。有2种方式对inference模型进行优化。
**注意**:如果已经准备好了 `.nb` 结尾的模型文件,可以跳过此步骤。
#### 2.1.1 [建议]pip安装paddlelite并进行转换
* python下安装paddlelite
```shell
pip install paddlelite
```
之后使用`paddle_lite_opt`工具可以进行inference模型的转换。`paddle_lite_opt`的部分参数如下
|选项|说明|
|-|-|
|--model_dir|待优化的PaddlePaddle模型(非combined形式)的路径|
|--model_file|待优化的PaddlePaddle模型(combined形式)的网络结构文件路径|
|--param_file|待优化的PaddlePaddle模型(combined形式)的权重文件路径|
|--optimize_out_type|输出模型类型,目前支持两种类型:protobuf和naive_buffer,其中naive_buffer是一种更轻量级的序列化/反序列化实现。若您需要在mobile端执行模型预测,请将此选项设置为naive_buffer。默认为protobuf|
|--optimize_out|优化模型的输出路径|
|--valid_targets|指定模型可执行的backend,默认为arm。目前可支持x86、arm、opencl、npu、xpu,可以同时指定多个backend(以空格分隔),Model Optimize Tool将会自动选择最佳方式。如果需要支持华为NPU(Kirin 810/990 Soc搭载的达芬奇架构NPU),应当设置为npu, arm|
|--record_tailoring_info|当使用 根据模型裁剪库文件 功能时,则设置该选项为true,以记录优化后模型含有的kernel和OP信息,默认为false|
`--model_file`表示inference模型的model文件地址,`--param_file`表示inference模型的param文件地址;`optimize_out`用于指定输出文件的名称(不需要添加`.nb`的后缀)。直接在命令行中运行`paddle_lite_opt`,也可以查看所有参数及其说明。
#### 2.1.2 源码编译Paddle-Lite生成opt工具
模型优化需要Paddle-Lite的opt可执行文件,可以通过编译Paddle-Lite源码获得,编译步骤如下:
```
# 如果准备环境时已经clone了Paddle-Lite,则不用重新clone Paddle-Lite
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
git checkout develop
# 启动编译
./lite/tools/build.sh build_optimize_tool
```
编译完成后,opt文件位于`build.opt/lite/api/`下,可通过如下方式查看opt的运行选项和使用方式;
```
cd build.opt/lite/api/
./opt
```
`opt`的使用方式与参数与上面的`paddle_lite_opt`完全一致。
#### 2.1.3 转换示例
下面以PaddleClas的`MobileNetV3_large_x1_0`模型为例,介绍使用`paddle_lite_opt`完成预训练模型到inference模型,再到Paddle-Lite优化模型的转换。
```shell
# 进入PaddleClas根目录
cd PaddleClas_root_path
# 下载并解压预训练模型
wget https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar
tar -xf MobileNetV3_large_x1_0_pretrained
# 将预训练模型导出为inference模型
python tools/export_model.py -m MobileNetV3_large_x1_0 -p ./MobileNetV3_large_x1_0_pretrained/ -o ./MobileNetV3_large_x1_0_inference/
# 将inference模型转化为Paddle-Lite优化模型
paddle_lite_opt --model_file=./MobileNetV3_large_x1_0_inference/model --param_file=./MobileNetV3_large_x1_0_inference/params --optimize_out=./MobileNetV3_large_x1_0
```
最终在当前文件夹下生成`MobileNetV3_large_x1_0.nb`的文件。
<a name="2.2与手机联调"></a>
### 2.2 与手机联调
首先需要进行一些准备工作。
1. 准备一台arm8的安卓手机,如果编译的预测库和opt文件是armv7,则需要arm7的手机,并修改Makefile中`ARM_ABI = arm7`
2. 打开手机的USB调试选项,选择文件传输模式,连接电脑。
3. 电脑上安装adb工具,用于调试。 adb安装方式如下:
3.1. MAC电脑安装ADB:
```
brew cask install android-platform-tools
```
3.2. Linux安装ADB
```
sudo apt update
sudo apt install -y wget adb
```
3.3. Window安装ADB
win上安装需要去谷歌的安卓平台下载adb软件包进行安装:[链接](https://developer.android.com/studio)
打开终端,手机连接电脑,在终端中输入
```
adb devices
```
如果有device输出,则表示安装成功。
```
List of devices attached
744be294 device
```
4. 准备优化后的模型、预测库文件、测试图像和类别映射文件。
```shell
cd PaddleClas_root_path
cd deploy/lite/
# 运行prepare.sh,准备预测库文件、测试图像和使用的字典文件,并放置在预测库中的demo/cxx/clas文件夹下
sh prepare.sh /{lite prediction library path}/inference_lite_lib.android.armv8
# 进入lite demo的工作目录
cd /{lite prediction library path}/inference_lite_lib.android.armv8/
cd demo/cxx/clas/
# 将C++预测动态库so文件复制到debug文件夹中
cp ../../../cxx/lib/libpaddle_light_api_shared.so ./debug/
```
准备测试图像,以`PaddleClas/deploy/lite/imgs/tabby_cat.jpg`为例,将测试的图像复制到`demo/cxx/clas/debug/`文件夹下。
准备`paddle_lite_opt`工具优化后的模型文件,比如使用`MobileNetV3_large_x1_0.nb`,模型文件放置在`demo/cxx/clas/debug/`文件夹下。
执行完成后,clas文件夹下将有如下文件格式:
```shell
demo/cxx/clas/
|-- debug/
| |--MobileNetV3_large_x1_0.nb 优化后的文字方向分类器模型文件
| |--tabby_cat.jpg 待测试图像
| |--imagenet1k_label_list.txt 类别映射文件
| |--libpaddle_light_api_shared.so C++预测库文件
| |--config.txt 分类预测超参数配置
|-- config.txt 分类预测超参数配置
|-- image_classfication.cpp 图像分类代码文件
|-- Makefile 编译文件
```
#### 注意:
1. `imagenet1k_label_list.txt`是ImageNet1k数据集的类别映射文件,如果是自定义的类别,需要更换该类别映射文件。
2. `config.txt` 包含了检测器、分类器的超参数,如下:
```
clas_model_file ./MobileNetV3_large_x1_0.nb # 模型文件地址
label_path ./imagenet1k_label_list.txt # 类别映射文本文件
resize_short_size 256 # resize之后的短边边长
crop_size 224 # 裁剪后用于预测的边长
visualize 0 # 是否进行可视化,如果选择的话,会在当前文件夹下生成名为clas_result.png的图像文件。
```
3. 启动调试:上述步骤完成后就可以使用adb将文件push到手机上运行,步骤如下:
```
# 执行编译,得到可执行文件clas_system
# clas_system可执行文件的使用方式为:
# ./clas_system 配置文件路径 测试图像路径
make -j
# 将编译的可执行文件移动到debug文件夹中
mv clas_system ./debug/
# 将debug文件夹push到手机上
adb push debug /data/local/tmp/
adb shell
cd /data/local/tmp/debug
export LD_LIBRARY_PATH=/data/local/tmp/debug:$LD_LIBRARY_PATH
./clas_system ./config.txt ./tabby_cat.jpg
```
如果对代码做了修改,则需要重新编译并push到手机上。
运行效果如下:
<div align="center">
<img src="./imgs/lite_demo_result.png" width="600">
</div>
## FAQ
Q1:如果想更换模型怎么办,需要重新按照流程走一遍吗?
A1:如果已经走通了上述步骤,更换模型只需要替换`.nb`模型文件即可,同时要注意修改下配置文件中的nb文件路径以及类别映射文件(如有必要)。
Q2:换一个图测试怎么做?
A2:替换debug下的测试图像为你想要测试的图像,adb push 到手机上即可。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册