未验证 提交 ae6fcc6a 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

updete resnet50_vd_animals (#2070)

上级 c97550a8
......@@ -168,10 +168,10 @@
初始发布
* 1.0.1
* 1.1.0
移除 fluid api
移除 Fluid API
- ```shell
$ hub install resnet50_vd_animals==1.0.1
$ hub install resnet50_vd_animals==1.1.0
```
......@@ -171,10 +171,10 @@
First release
* 1.0.1
* 1.1.0
Remove fluid api
Remove Fluid API
- ```shell
$ hub install resnet50_vd_animals==1.0.1
$ hub install resnet50_vd_animals==1.1.0
```
......@@ -3,7 +3,6 @@ import os
import time
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
......
......@@ -7,15 +7,12 @@ import ast
import os
import numpy as np
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from resnet50_vd_animals.data_feed import reader
from resnet50_vd_animals.processor import base64_to_cv2
from resnet50_vd_animals.processor import postprocess
import paddlehub as hub
from paddlehub.common.paddle_helper import add_vars_prefix
from .data_feed import reader
from .processor import base64_to_cv2
from .processor import postprocess
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
......@@ -28,10 +25,10 @@ from paddlehub.module.module import serving
author_email="",
summary="ResNet50vd is a image classfication model, this module is trained with Baidu's self-built animals dataset.",
version="1.0.1")
class ResNet50vdAnimals(hub.Module):
class ResNet50vdAnimals:
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "model")
def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "model", "model")
label_file = os.path.join(self.directory, "label_list.txt")
with open(label_file, 'r', encoding='utf-8') as file:
self.label_list = file.read().split("\n")[:-1]
......@@ -65,7 +62,9 @@ class ResNet50vdAnimals(hub.Module):
"""
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_predictor(cpu_config)
......@@ -76,7 +75,7 @@ class ResNet50vdAnimals(hub.Module):
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.default_pretrained_model_path)
npu_config = Config(model, params)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
......@@ -85,7 +84,7 @@ class ResNet50vdAnimals(hub.Module):
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
......@@ -94,7 +93,7 @@ class ResNet50vdAnimals(hub.Module):
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config = Config(model, params)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
......@@ -165,24 +164,6 @@ class ResNet50vdAnimals(hub.Module):
res += out
return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......
......@@ -4,9 +4,8 @@ from __future__ import division
from __future__ import print_function
import base64
import cv2
import os
import cv2
import numpy as np
......@@ -18,7 +17,6 @@ def base64_to_cv2(b64str):
def softmax(x):
orig_shape = x.shape
if len(x.shape) > 1:
tmp = np.max(x, axis=1)
x -= tmp.reshape((x.shape[0], 1))
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/brFsZ7qszSY/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8OHx8ZG9nfGVufDB8fHx8MTY2MzA1ODQ1MQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="resnet50_vd_animals")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
def test_classification1(self):
results = self.module.classification(paths=['tests/test.jpg'])
data = results[0]
self.assertTrue('威尔士柯基' in data)
self.assertTrue(data['威尔士柯基'] > 0.5)
def test_classification2(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')])
data = results[0]
self.assertTrue('威尔士柯基' in data)
self.assertTrue(data['威尔士柯基'] > 0.5)
def test_classification3(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')], use_gpu=True)
data = results[0]
self.assertTrue('威尔士柯基' in data)
self.assertTrue(data['威尔士柯基'] > 0.5)
def test_classification4(self):
self.assertRaises(AssertionError, self.module.classification, paths=['no.jpg'])
def test_classification5(self):
self.assertRaises(TypeError, self.module.classification, images=['test.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册