未验证 提交 98145fb6 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update efficientnetb5_imagenet (#2053)

上级 ee6f74b1
......@@ -132,6 +132,11 @@
* 1.1.0
提升预测性能以及易用性
* 1.2.0
移除 Fluid API
- ```shell
$ hub install efficientnetb5_imagenet==1.1.0
$ hub install efficientnetb5_imagenet==1.2.0
```
......@@ -131,6 +131,11 @@
* 1.1.0
Improve the prediction performance and users' experience
* 1.2.0
Remove Fluid API
- ```shell
$ hub install efficientnetb5_imagenet==1.1.0
$ hub install efficientnetb5_imagenet==1.2.0
```
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from collections import OrderedDict
import numpy as np
from PIL import Image
__all__ = ['reader']
DATA_DIM = 224
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def process_image(img):
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
return img
def reader(images=None, paths=None):
"""
Preprocess to yield image.
Args:
images (list[numpy.ndarray]): images data, shape of each is [H, W, C].
paths (list[str]): paths to images.
Yield:
each (collections.OrderedDict): info of original image, preprocessed image.
"""
component = list()
if paths:
for im_path in paths:
each = OrderedDict()
assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path)
each['org_im_path'] = im_path
each['org_im'] = Image.open(im_path)
each['org_im_width'], each['org_im_height'] = each['org_im'].size
component.append(each)
if images is not None:
assert type(images), "images is a list."
for im in images:
each = OrderedDict()
each['org_im'] = Image.fromarray(im[:, :, ::-1])
each['org_im_path'] = 'ndarray_time={}'.format(round(time.time(), 6) * 1e6)
each['org_im_width'], each['org_im_height'] = each['org_im'].size
component.append(each)
for element in component:
element['image'] = process_image(element['org_im'])
yield element
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import cv2
import numpy as np
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def softmax(x):
if len(x.shape) > 1:
tmp = np.max(x, axis=1)
x -= tmp.reshape((x.shape[0], 1))
x = np.exp(x)
tmp = np.sum(x, axis=1)
x /= tmp.reshape((x.shape[0], 1))
else:
tmp = np.max(x)
x -= tmp
x = np.exp(x)
tmp = np.sum(x)
x /= tmp
return x
def postprocess(data_out, label_list, top_k):
"""
Postprocess output of network, one image at a time.
Args:
data_out (numpy.ndarray): output data of network.
label_list (list): list of label.
top_k (int): Return top k results.
"""
output = []
for result in data_out:
result_i = softmax(result)
output_i = {}
indexs = np.argsort(result_i)[::-1][0:top_k]
for index in indexs:
label = label_list[index].split(',')[0]
output_i[label] = float(result_i[index])
output.append(output_i)
return output
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/brFsZ7qszSY/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8OHx8ZG9nfGVufDB8fHx8MTY2MzA1ODQ1MQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="efficientnetb5_imagenet")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
def test_classification1(self):
results = self.module.classification(paths=['tests/test.jpg'])
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification2(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')])
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification3(self):
results = self.module.classification(images=[cv2.imread('tests/test.jpg')], use_gpu=True)
data = results[0]
self.assertTrue('Pembroke' in data)
self.assertTrue(data['Pembroke'] > 0.5)
def test_classification4(self):
self.assertRaises(AssertionError, self.module.classification, paths=['no.jpg'])
def test_classification5(self):
self.assertRaises(TypeError, self.module.classification, images=['tests/test.jpg'])
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册