未验证 提交 ae6fcc6a 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

updete resnet50_vd_animals (#2070)

上级 c97550a8
...@@ -168,10 +168,10 @@ ...@@ -168,10 +168,10 @@
初始发布 初始发布
* 1.0.1 * 1.1.0
移除 fluid api 移除 Fluid API
- ```shell - ```shell
$ hub install resnet50_vd_animals==1.0.1 $ hub install resnet50_vd_animals==1.1.0
``` ```
...@@ -171,10 +171,10 @@ ...@@ -171,10 +171,10 @@
First release First release
* 1.0.1 * 1.1.0
Remove fluid api Remove Fluid API
- ```shell - ```shell
$ hub install resnet50_vd_animals==1.0.1 $ hub install resnet50_vd_animals==1.1.0
``` ```
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
import time import time
from collections import OrderedDict from collections import OrderedDict
import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
......
...@@ -7,15 +7,12 @@ import ast ...@@ -7,15 +7,12 @@ import ast
import os import os
import numpy as np import numpy as np
import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from resnet50_vd_animals.data_feed import reader
from resnet50_vd_animals.processor import base64_to_cv2
from resnet50_vd_animals.processor import postprocess
import paddlehub as hub from .data_feed import reader
from paddlehub.common.paddle_helper import add_vars_prefix from .processor import base64_to_cv2
from .processor import postprocess
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable from paddlehub.module.module import runnable
from paddlehub.module.module import serving from paddlehub.module.module import serving
...@@ -28,10 +25,10 @@ from paddlehub.module.module import serving ...@@ -28,10 +25,10 @@ from paddlehub.module.module import serving
author_email="", author_email="",
summary="ResNet50vd is a image classfication model, this module is trained with Baidu's self-built animals dataset.", summary="ResNet50vd is a image classfication model, this module is trained with Baidu's self-built animals dataset.",
version="1.0.1") version="1.0.1")
class ResNet50vdAnimals(hub.Module): class ResNet50vdAnimals:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "model") self.default_pretrained_model_path = os.path.join(self.directory, "model", "model")
label_file = os.path.join(self.directory, "label_list.txt") label_file = os.path.join(self.directory, "label_list.txt")
with open(label_file, 'r', encoding='utf-8') as file: with open(label_file, 'r', encoding='utf-8') as file:
self.label_list = file.read().split("\n")[:-1] self.label_list = file.read().split("\n")[:-1]
...@@ -65,7 +62,9 @@ class ResNet50vdAnimals(hub.Module): ...@@ -65,7 +62,9 @@ class ResNet50vdAnimals(hub.Module):
""" """
# create default cpu predictor # create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path) model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
...@@ -76,7 +75,7 @@ class ResNet50vdAnimals(hub.Module): ...@@ -76,7 +75,7 @@ class ResNet50vdAnimals(hub.Module):
npu_id = self._get_device_id("FLAGS_selected_npus") npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1: if npu_id != -1:
# use npu # use npu
npu_config = Config(self.default_pretrained_model_path) npu_config = Config(model, params)
npu_config.disable_glog_info() npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id) npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config) self.npu_predictor = create_predictor(npu_config)
...@@ -85,7 +84,7 @@ class ResNet50vdAnimals(hub.Module): ...@@ -85,7 +84,7 @@ class ResNet50vdAnimals(hub.Module):
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES") gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1: if gpu_id != -1:
# use gpu # use gpu
gpu_config = Config(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
...@@ -94,7 +93,7 @@ class ResNet50vdAnimals(hub.Module): ...@@ -94,7 +93,7 @@ class ResNet50vdAnimals(hub.Module):
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES") xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1: if xpu_id != -1:
# use xpu # use xpu
xpu_config = Config(self.default_pretrained_model_path) xpu_config = Config(model, params)
xpu_config.disable_glog_info() xpu_config.disable_glog_info()
xpu_config.enable_xpu(100) xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config) self.xpu_predictor = create_predictor(xpu_config)
...@@ -165,24 +164,6 @@ class ResNet50vdAnimals(hub.Module): ...@@ -165,24 +164,6 @@ class ResNet50vdAnimals(hub.Module):
res += out res += out
return res return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
...@@ -4,9 +4,8 @@ from __future__ import division ...@@ -4,9 +4,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import base64 import base64
import cv2
import os
import cv2
import numpy as np import numpy as np
...@@ -18,7 +17,6 @@ def base64_to_cv2(b64str): ...@@ -18,7 +17,6 @@ def base64_to_cv2(b64str):
def softmax(x): def softmax(x):
orig_shape = x.shape
if len(x.shape) > 1: if len(x.shape) > 1:
tmp = np.max(x, axis=1) tmp = np.max(x, axis=1)
x -= tmp.reshape((x.shape[0], 1)) x -= tmp.reshape((x.shape[0], 1))
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/brFsZ7qszSY/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8OHx8ZG9nfGVufDB8fHx8MTY2MzA1ODQ1MQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="resnet50_vd_animals")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
def test_classification1(self):
results = self.module.classification(paths=['tests/test.jpg'])
data = results[0]
self.assertTrue('威尔士柯基' in data)
self.assertTrue(data['威尔士柯基'] > 0.5)
def test_classification2(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')])
data = results[0]
self.assertTrue('威尔士柯基' in data)
self.assertTrue(data['威尔士柯基'] > 0.5)
def test_classification3(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')], use_gpu=True)
data = results[0]
self.assertTrue('威尔士柯基' in data)
self.assertTrue(data['威尔士柯基'] > 0.5)
def test_classification4(self):
self.assertRaises(AssertionError, self.module.classification, paths=['no.jpg'])
def test_classification5(self):
self.assertRaises(TypeError, self.module.classification, images=['test.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册