未验证 提交 6d00e673 编写于 作者: C chenziheng 提交者: GitHub

PaddleClas-whl (#536)

* paddleclas whl

* fix whl

* Add whl.md

* fix whl

* fix paddleclas and whl.md

* fix paddleclas_whl

* paddleclas_whl

* paddleclas_whl

* fix paddleclas_whl

* fix paddleclas_whl

* fix paddleclas_whl

* fix paddleclas_whl
上级 21dc73f7
include LICENSE.txt
include README.md
recursive-include tools/infer utils.py predict.py
recursive-include ppcls/utils imagenet1k_label_list.txt
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['PaddleClas']
from .paddleclas import PaddleClas
# paddleclas package
## Get started quickly
### install package
install by pypi
```bash
pip install paddleclas==2.0.0rc1
```
build own whl package and install
```bash
python3 setup.py bdist_wheel
pip3 install dist/paddleclas-x.x.x-py3-none-any.whl
```
### 1. Quick Start
* Assign `image_file='docs/images/whl/demo.jpg'`, Use inference model that Paddle provides `model_name='ResNet50'`
**Here is demo.jpg**
![](../images/whl/demo.jpg)
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False,use_tensorrt=False)
image_file='docs/images/whl/demo.jpg'
result=clas.predict(image_file)
print(result)
```
```
>>> result
[{'filename': '/Users/mac/Downloads/PaddleClas/docs/images/whl/demo.jpg', 'class_ids': [8], 'scores': [0.9796774], 'label_names': ['hen']}]
```
* Using command line interactive programming
```bash
paddleclas --model_name='ResNet50' --image_file='docs/images/whl/demo.jpg'
```
```
>>> result
**********/Users/mac/Downloads/PaddleClas/docs/images/whl/demo.jpg**********
[{'filename': '/Users/mac/Downloads/PaddleClas/docs/images/whl/demo.jpg', 'class_ids': [8], 'scores': [0.9796774], 'label_names': ['hen']}]
```
### 2. Definition of Parameters
* model_name(str): model's name. If not assigning `model_file`and`params_file`, you can assign this param. If using inference model based on ImageNet1k provided by Paddle, set as default='ResNet50'.
* image_file(str): image's path. Support assigning single local image, internet image and folder containing series of images. Also Support numpy.ndarray.
* use_gpu(bool): Whether to use GPU or not, defalut=False。
* use_tensorrt(bool): whether to open tensorrt or not. Using it can greatly promote predict preformance, default=False.
* resize_short(int): resize the minima between height and width into resize_short(int), default=256
* resize(int): resize image into resize(int), default=224.
* normalize(bool): whether normalize image or not, default=True.
* batch_size(int): batch number, default=1.
* model_file(str): path of inference.pdmodel. If not assign this param,you need assign `model_name` for downloading.
* params_file(str): path of inference.pdiparams. If not assign this param,you need assign `model_name` for downloading.
* ir_optim(bool): whether enable IR optimization or not, default=True.
* gpu_mem(int): GPU memory usages,default=8000。
* enable_profile(bool): whether enable profile or not,default=False.
* top_k(int): Assign top_k, default=1.
* enable_mkldnn(bool): whether enable MKLDNN or not, default=False.
* cpu_num_threads(int): Assign number of cpu threads, default=10.
* label_name_path(str): Assign path of label_name_dict you use. If using your own training model, you can assign this param. If using inference model based on ImageNet1k provided by Paddle, you may not assign this param.Defaults take ImageNet1k's label name.
* pre_label_image(bool): whether prelabel or not, default=False.
* pre_label_out_idr(str): If prelabeling, the path of output.
### 3. Different Usages of Codes
**We provide two ways to use: 1. Python interative programming 2. Bash command line programming**
* check `help` information
```bash
paddleclas -h
```
* Use user-specified model, you need to assign model's path `model_file` and parameters's path`params_file`
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_file='user-specified model path',
params_file='parmas path', use_gpu=False, use_tensorrt=False)
image_file = ''
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_file='user-specified model path' --params_file='parmas path' --image_file='image path'
```
* Use inference model which PaddlePaddle provides to predict, you need to choose one of model when initializing PaddleClas to assign `model_name`. You may not assign `model_file` , and the model you chosen will be download in `BASE_INFERENCE_MODEL_DIR` ,which will be saved in folder named by `model_name`,avoiding overlay different inference model.
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False)
image_file = ''
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path'
```
* You can assign input as format`np.ndarray` which has been preprocessed `--image_file=np.ndarray`.
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False)
image_file =np.ndarray # image_file 可指定为前缀是https的网络图片,也可指定为本地图片
result=clas.predict(image_file)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file=np.ndarray
```
* You can assign `image_file` as a folder path containing series of images, also can assign `top_k`.
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False,top_k=5)
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path' --top_k=5
```
* You can assign `--pre_label_image=True`, `--pre_label_out_idr= './output_pre_label/'`.Then images will be copied into folder named by top-1 class_id.
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False,top_k=5, pre_label_image=True,pre_label_out_idr='./output_pre_label/')
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path' --top_k=5 --pre_label_image=True --pre_label_out_idr='./output_pre_label/'
```
* You can assign `--label_name_path` as your own label_dict_file, format should be as(class_id<space>class_name<\n>).
```
0 tench, Tinca tinca
1 goldfish, Carassius auratus
2 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
......
```
* If you use inference model that Paddle provides, you do not need assign `label_name_path`. Program will take `ppcls/utils/imagenet1k_label_list.txt` as defaults. If you hope using your own training model, you can provide `label_name_path` outputing 'label_name' and scores, otherwise no 'label_name' in output information.
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_file= './inference.pdmodel',params_file = './inference.pdiparams',label_name_path='./ppcls/utils/imagenet1k_label_list.txt',use_gpu=False)
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_file= './inference.pdmodel' --params_file = './inference.pdiparams' --image_file='image path' --label_name_path='./ppcls/utils/imagenet1k_label_list.txt'
```
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False)
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path'
```
# paddleclas package使用说明
## 快速上手
### 安装whl包
pip安装
```bash
pip install paddleclas=2.0.0rc1
```
本地构建并安装
```bash
python3 setup.py bdist_wheel
pip3 install dist/paddleclas-x.x.x-py3-none-any.whl # x.x.x是paddleclas的版本号
```
### 1. 快速开始
* 指定`image_file='docs/images/whl/demo.jpg'`,使用Paddle提供的inference model,`model_name='ResNet50'`, 使用图片`docs/images/whl/demo.jpg`
**下图是使用的demo图片**
![](../images/whl/demo.jpg)
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False,use_tensorrt=False)
image_file='docs/images/whl/demo.jpg'
result=clas.predict(image_file)
print(result)
```
```
>>> result
[{'filename': '/Users/mac/Downloads/PaddleClas/docs/images/whl/demo.jpg', 'class_ids': [8], 'scores': [0.9796774], 'label_names': ['hen']}]
```
* 使用命令行式交互方法。直接获得结果。
```bash
paddleclas --model_name='ResNet50' --image_file='docs/images/whl/demo.jpg'
```
```
>>> result
**********/Users/mac/Downloads/PaddleClas/docs/images/whl/demo.jpg**********
[{'filename': '/Users/mac/Downloads/PaddleClas/docs/images/whl/demo.jpg', 'class_ids': [8], 'scores': [0.9796774], 'label_names': ['hen']}]
```
### 2. 参数解释
* model_name(str): 模型名称,没有指定自定义的model_file和params_file时,可以指定该参数,使用PaddleClas提供的基于ImageNet1k的inference model,默认值为ResNet50。
* image_file(str): 图像地址,支持指定单一图像的路径或图像的网址进行预测,支持指定包含图像的文件夹路径,支持经过预处理的np.ndarray形式输入。
* use_gpu(bool): 是否使用GPU,如果使用,指定为True。默认为False。
* use_tensorrt(bool): 是否开启TensorRT预测,可提升GPU预测性能,需要使用带TensorRT的预测库。当使用TensorRT推理加速,指定为True。默认为False。
* resize_short(int): 将图像的高宽二者中小的值,调整到指定的resize_short值,大的值按比例放大。默认为256。
* resize(int): 将图像裁剪到指定的resize值大小,默认224。
* normalize(bool): 是否对图像数据归一化,默认True。
* batch_size(int): 预测时每个batch的样本数,默认为1。
* model_file(str): 模型.pdmodel的路径,若不指定该参数,需要指定model_name,获得下载的模型。
* params_file(str): 模型参数.pdiparams的路径,若不与model_file指定,则需要指定model_name,以获得下载的模型。
* ir_optim(bool): 是否开启IR优化,默认为True。
* gpu_mem(int): 使用的GPU显存大小,默认为8000。
* enable_profile(bool): 是否开启profile功能,默认False。
* top_k(int): 指定的topk,预测的前k个类别和对应的分类概率,默认为1。
* enable_mkldnn(bool): 是否开启MKLDNN,默认False。
* cpu_num_threads(int): 指定cpu线程数,默认设置为10。
* label_name_path(str): 指定一个表示所有的label name的文件路径。当用户使用自己训练的模型,可指定这一参数,打印结果时可以显示图像对应的类名称。若用户使用Paddle提供的inference model,则可不指定该参数,使用imagenet1k的label_name,默认为空字符串。
* pre_label_image(bool): 是否需要进行预标注。
* pre_label_out_idr(str): 进行预标注后,输出结果的文件路径,默认为None。
### 3. 代码使用方法
**提供两种使用方式:1、python交互式编程。2、bash命令行式编程**
* 查看帮助信息
###### bash
```bash
paddleclas -h
```
* 用户使用自己指定的模型,需要指定模型路径参数`model_file`和参数`params_file`
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_file='user-specified model path',
params_file='parmas path', use_gpu=False, use_tensorrt=False)
image_file = '' # image_file 可指定为前缀是https的网络图片,也可指定为本地图片
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_file='user-specified model path' --params_file='parmas path' --image_file='image path'
```
* 用户使用PaddlePaddle训练好的inference model来预测,用户需要使用,初始化打印的模型的其中一个,并指定给`model_name`
用户可以不指定`model_file`,模型会自动下载到当前目录,并保存在以`model_name`命名的文件夹中,避免下载不同模型的覆盖问题。
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False)
image_file = '' # image_file 可指定为前缀是https的网络图片,也可指定为本地图片
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path'
```
* 用户可以使用经过预处理的np.ndarray格式`--image_file=np.ndarray`
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False)
image_file =np.ndarray # image_file 可指定为前缀是https的网络图片,也可指定为本地图片
result=clas.predict(image_file)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file=np.ndarray
```
* 用户可以将`image_file`指定为包含图片的文件夹路径,可以指定`top_k`参数
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False,top_k=5)
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path' --top_k=5
```
* 用户可以指定`--pre_label_image=True`, `--pre_label_out_idr= './output_pre_label/'`,将图片复制到,以其top1对应的类别命名的文件夹中。
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False, use_tensorrt=False,top_k=5, pre_label_image=True,pre_label_out_idr='./output_pre_label/')
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path' --top_k=5 --pre_label_image=True --pre_label_out_idr='./output_pre_label/'
```
* 用户可以指定`--label_name_path`,作为用户自己训练模型的`label_dict_file`,格式应为(class_id<space>class_name<\n>)
```
0 tench, Tinca tinca
1 goldfish, Carassius auratus
2 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
......
```
* 用户如果使用Paddle提供的inference model,则不需要提供`label_name_path`,会默认使用`ppcls/utils/imagenet1k_label_list.txt`
如果用户希望使用自己的模型,则可以提供`label_name_path`,将label_name与结果一并输出。如果不提供将不会输出label_name信息。
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_file= './inference.pdmodel',params_file = './inference.pdiparams',label_name_path='./ppcls/utils/imagenet1k_label_list.txt',use_gpu=False)
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_file= './inference.pdmodel' --params_file = './inference.pdiparams' --image_file='image path' --label_name_path='./ppcls/utils/imagenet1k_label_list.txt'
```
###### python
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50',use_gpu=False)
image_file = '' # it can be image_file folder path which contains all of images you want to predict.
result=clas.predict(image_file)
print(result)
```
###### bash
```bash
paddleclas --model_name='ResNet50' --image_file='image path'
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import numpy as np
import tarfile
import requests
from tqdm import tqdm
import tools.infer.utils as utils
import shutil
__all__ = ['PaddleClas']
BASE_DIR = os.path.expanduser("~/.paddleclas/")
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, 'inference_model')
BASE_IMAGES_DIR = os.path.join(BASE_DIR, 'images')
model_names = {
'Xception71', 'SE_ResNeXt101_32x4d', 'ShuffleNetV2_x0_5', 'ResNet34',
'ShuffleNetV2_x2_0', 'ResNeXt101_32x4d', 'HRNet_W48_C_ssld',
'ResNeSt50_fast_1s1x64d', 'MobileNetV2_x2_0', 'MobileNetV3_large_x1_0',
'Fix_ResNeXt101_32x48d_wsl', 'MobileNetV2_ssld', 'ResNeXt101_vd_64x4d',
'ResNet34_vd_ssld', 'MobileNetV3_small_x1_0', 'VGG11',
'ResNeXt50_vd_32x4d', 'MobileNetV3_large_x1_25',
'MobileNetV3_large_x1_0_ssld', 'MobileNetV2_x0_75',
'MobileNetV3_small_x0_35', 'MobileNetV1_x0_75', 'MobileNetV1_ssld',
'ResNeXt50_32x4d', 'GhostNet_x1_3_ssld', 'Res2Net101_vd_26w_4s',
'ResNet152', 'Xception65', 'EfficientNetB0', 'ResNet152_vd', 'HRNet_W18_C',
'Res2Net50_14w_8s', 'ShuffleNetV2_x0_25', 'HRNet_W64_C',
'Res2Net50_vd_26w_4s_ssld', 'HRNet_W18_C_ssld', 'ResNet18_vd',
'ResNeXt101_32x16d_wsl', 'SE_ResNeXt50_32x4d', 'SqueezeNet1_1',
'SENet154_vd', 'SqueezeNet1_0', 'GhostNet_x1_0', 'ResNet50_vc', 'DPN98',
'HRNet_W48_C', 'DenseNet264', 'SE_ResNet34_vd', 'HRNet_W44_C',
'MobileNetV3_small_x1_25', 'MobileNetV1_x0_5', 'ResNet200_vd', 'VGG13',
'EfficientNetB3', 'EfficientNetB2', 'ShuffleNetV2_x0_33',
'MobileNetV3_small_x0_75', 'ResNeXt152_vd_32x4d', 'ResNeXt101_32x32d_wsl',
'ResNet18', 'MobileNetV3_large_x0_35', 'Res2Net50_26w_4s',
'MobileNetV2_x0_5', 'EfficientNetB0_small', 'ResNet101_vd_ssld',
'EfficientNetB6', 'EfficientNetB1', 'EfficientNetB7', 'ResNeSt50',
'ShuffleNetV2_x1_0', 'MobileNetV3_small_x1_0_ssld', 'InceptionV4',
'GhostNet_x0_5', 'SE_HRNet_W64_C_ssld', 'ResNet50_ACNet_deploy',
'Xception41', 'ResNet50', 'Res2Net200_vd_26w_4s_ssld',
'Xception41_deeplab', 'SE_ResNet18_vd', 'SE_ResNeXt50_vd_32x4d',
'HRNet_W30_C', 'HRNet_W40_C', 'VGG19', 'Res2Net200_vd_26w_4s',
'ResNeXt101_32x8d_wsl', 'ResNet50_vd', 'ResNeXt152_64x4d', 'DarkNet53',
'ResNet50_vd_ssld', 'ResNeXt101_64x4d', 'MobileNetV1_x0_25',
'Xception65_deeplab', 'AlexNet', 'ResNet101', 'DenseNet121',
'ResNet50_vd_v2', 'Res2Net50_vd_26w_4s', 'ResNeXt101_32x48d_wsl',
'MobileNetV3_large_x0_5', 'MobileNetV2_x0_25', 'DPN92', 'ResNet101_vd',
'MobileNetV2_x1_5', 'DPN131', 'ResNeXt50_vd_64x4d', 'ShuffleNetV2_x1_5',
'ResNet34_vd', 'MobileNetV1', 'ResNeXt152_vd_64x4d', 'DPN107', 'VGG16',
'ResNeXt50_64x4d', 'RegNetX_4GF', 'DenseNet161', 'GhostNet_x1_3',
'HRNet_W32_C', 'Fix_ResNet50_vd_ssld_v2', 'Res2Net101_vd_26w_4s_ssld',
'DenseNet201', 'DPN68', 'EfficientNetB4', 'ResNeXt152_32x4d',
'InceptionV3', 'ShuffleNetV2_swish', 'GoogLeNet', 'ResNet50_vd_ssld_v2',
'SE_ResNet50_vd', 'MobileNetV2', 'ResNeXt101_vd_32x4d',
'MobileNetV3_large_x0_75', 'MobileNetV3_small_x0_5', 'DenseNet169',
'EfficientNetB5'
}
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
raise Exception("Something went wrong while downloading models")
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def save_prelabel_results(class_id, input_filepath, output_idr):
output_dir = os.path.join(output_idr, str(class_id))
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
shutil.copy(input_filepath, output_dir)
def load_label_name_dict(path):
result = {}
if not os.path.exists(path):
print(
'Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!'
)
else:
for line in open(path, 'r'):
partition = line.split('\n')[0].partition(' ')
try:
result[int(partition[0])] = str(partition[-1])
except:
result = {}
break
return result
def parse_args(mMain=True, add_help=True):
import argparse
def str2bool(v):
return v.lower() in ("true", "t", "1")
if mMain == True:
# general params
parser = argparse.ArgumentParser(add_help=add_help)
parser.add_argument("--model_name", type=str)
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=False)
# params for preprocess
parser.add_argument("--resize_short", type=int, default=256)
parser.add_argument("--resize", type=int, default=224)
parser.add_argument("--normalize", type=str2bool, default=True)
parser.add_argument("-b", "--batch_size", type=int, default=1)
# params for predict
parser.add_argument(
"--model_file", type=str, default='') ## inference.pdmodel
parser.add_argument(
"--params_file", type=str, default='') ## inference.pdiparams
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--enable_profile", type=str2bool, default=False)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
parser.add_argument("--cpu_num_threads", type=int, default=10)
parser.add_argument("--hubserving", type=str2bool, default=False)
# parameters for pre-label the images
parser.add_argument("--label_name_path", type=str, default='')
parser.add_argument(
"--pre_label_image",
type=str2bool,
default=False,
help="Whether to pre-label the images using the loaded weights")
parser.add_argument("--pre_label_out_idr", type=str, default=None)
return parser.parse_args()
else:
return argparse.Namespace(
model_name='',
image_file='',
use_gpu=False,
use_fp16=False,
use_tensorrt=False,
resize_short=256,
resize=224,
normalize=True,
batch_size=1,
model_file='',
params_file='',
ir_optim=True,
gpu_mem=8000,
enable_profile=False,
top_k=1,
enable_mkldnn=False,
enable_benchmark=False,
cpu_num_threads=10,
hubserving=False,
label_name_path='',
pre_label_image=False,
pre_label_out_idr=None)
class PaddleClas(object):
print('Inference models that Paddle provides are listed as follows:\n\n{}'.
format(model_names), '\n')
def __init__(self, **kwargs):
process_params = parse_args(mMain=False, add_help=False)
process_params.__dict__.update(**kwargs)
if not os.path.exists(process_params.model_file):
if process_params.model_name is None:
raise Exception(
'Please input model name that you want to use!')
if process_params.model_name in model_names:
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar'.format(
process_params.model_name)
if not os.path.exists(
os.path.join(BASE_INFERENCE_MODEL_DIR,
process_params.model_name)):
os.makedirs(
os.path.join(BASE_INFERENCE_MODEL_DIR,
process_params.model_name))
download_path = os.path.join(BASE_INFERENCE_MODEL_DIR,
process_params.model_name)
maybe_download(model_storage_directory=download_path, url=url)
process_params.model_file = os.path.join(download_path,
'inference.pdmodel')
process_params.params_file = os.path.join(
download_path, 'inference.pdiparams')
process_params.label_name_path = os.path.join(
__dir__, 'ppcls/utils/imagenet1k_label_list.txt')
else:
raise Exception(
'If you want to use your own model, Please input model_file as model path!'
)
else:
print('Using user-specified model and params!')
print("process params are as follows: \n{}".format(process_params))
self.label_name_dict = load_label_name_dict(
process_params.label_name_path)
self.args = process_params
self.predictor = utils.create_paddle_predictor(process_params)
def predict(self, img):
"""
predict label of img with paddleclas
Args:
img: input image for clas, support single image , internet url, folder path containing series of images
Returns:
dict:{image_name: "", class_id: [], scores: [], label_names: []},if label name path == None,label_names will be empty.
"""
assert isinstance(img, (str, np.ndarray))
input_names = self.predictor.get_input_names()
input_tensor = self.predictor.get_input_handle(input_names[0])
output_names = self.predictor.get_output_names()
output_tensor = self.predictor.get_output_handle(output_names[0])
if isinstance(img, str):
# download internet image
if img.startswith('http'):
if not os.path.exists(BASE_IMAGES_DIR):
os.makedirs(BASE_IMAGES_DIR)
image_path = os.path.join(BASE_IMAGES_DIR, 'tmp.jpg')
download_with_progressbar(img, image_path)
print("Current using image from Internet:{}, renamed as: {}".
format(img, image_path))
img = image_path
image_list = utils.get_image_list(img)
else:
if isinstance(img, np.ndarray):
image_list = [img]
else:
print('Please input legal image!')
total_result = []
for filename in image_list:
if isinstance(filename, str):
image = cv2.imread(filename)[:, :, ::-1]
assert image is not None, "Error in loading image: {}".format(
filename)
inputs = utils.preprocess(image, self.args)
inputs = np.expand_dims(
inputs, axis=0).repeat(
1, axis=0).copy()
else:
inputs = filename
input_tensor.copy_from_cpu(inputs)
self.predictor.run()
outputs = output_tensor.copy_to_cpu()
classes, scores = utils.postprocess(outputs, self.args)
label_names = []
if len(self.label_name_dict) != 0:
label_names = [self.label_name_dict[c] for c in classes]
result = {
"filename": filename if isinstance(filename, str) else 'image',
"class_ids": classes.tolist(),
"scores": scores.tolist(),
"label_names": label_names,
}
total_result.append(result)
if self.args.pre_label_image:
save_prelabel_results(classes[0], filename,
self.args.pre_label_out_idr)
print("\tSaving prelabel results in {}".format(
os.path.join(self.args.pre_label_out_idr, str(classes[
0]))))
return total_result
def main():
# for cmd
args = parse_args(mMain=True)
clas_engine = PaddleClas(**(args.__dict__))
print('{}{}{}'.format('*' * 10, args.image_file, '*' * 10))
result = clas_engine.predict(args.image_file)
if result is not None:
print(result)
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup
from io import open
with open('requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
def readme():
with open('docs/en/whl_en.md', encoding="utf-8-sig") as f:
README = f.read()
return README
setup(
name='paddleclas',
packages=['paddleclas'],
package_dir={'paddleclas': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddleclas= paddleclas.paddleclas:main"]},
version='0.0.0',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome Image Classification toolkits based on PaddlePaddle ',
long_description=readme(),
long_description_content_type='text/markdown',
url='https://github.com/PaddlePaddle/PaddleClas',
download_url='https://github.com/PaddlePaddle/PaddleClas.git',
keywords=[
'A treasure chest for image classification powered by PaddlePaddle.'
],
classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册