diff --git a/modules/image/classification/resnet50_vd_wildanimals/README.md b/modules/image/classification/resnet50_vd_wildanimals/README.md index d857c89b70156dda3891da0994336ca7d5f801fc..755d179b277b79a301047dc2757ae2c6335d2e1e 100644 --- a/modules/image/classification/resnet50_vd_wildanimals/README.md +++ b/modules/image/classification/resnet50_vd_wildanimals/README.md @@ -129,6 +129,11 @@ * 1.0.0 初始发布 + +* 1.1.0 + + 移除 Fluid API + - ```shell - $ hub install resnet50_vd_wildanimals==1.0.0 + $ hub install resnet50_vd_wildanimals==1.1.0 ``` diff --git a/modules/image/classification/resnet50_vd_wildanimals/README_en.md b/modules/image/classification/resnet50_vd_wildanimals/README_en.md index 9a526d581511fd56250d9b4d4fe490349b367c2a..97294b0d493275079452d37a398ee8d86d111952 100644 --- a/modules/image/classification/resnet50_vd_wildanimals/README_en.md +++ b/modules/image/classification/resnet50_vd_wildanimals/README_en.md @@ -129,6 +129,11 @@ * 1.0.0 First release + +* 1.1.0 + + Remove Fluid API + - ```shell - $ hub install resnet50_vd_wildanimals==1.0.0 + $ hub install resnet50_vd_wildanimals==1.1.0 ``` diff --git a/modules/image/classification/resnet50_vd_wildanimals/data_feed.py b/modules/image/classification/resnet50_vd_wildanimals/data_feed.py index 99a0855fd6a93dbecd081cef312a04a350cfcc50..95ba8337365e0bdf18769864d625b86fec66c17a 100644 --- a/modules/image/classification/resnet50_vd_wildanimals/data_feed.py +++ b/modules/image/classification/resnet50_vd_wildanimals/data_feed.py @@ -1,9 +1,7 @@ -# coding=utf-8 import os import time from collections import OrderedDict -import cv2 import numpy as np from PIL import Image diff --git a/modules/image/classification/resnet50_vd_wildanimals/module.py b/modules/image/classification/resnet50_vd_wildanimals/module.py index e3ab6e73b35da2c8ca6d955fd8a864a284018434..d3003772bc9c95b984a9a1db34dceb4abb85a582 100644 --- a/modules/image/classification/resnet50_vd_wildanimals/module.py +++ b/modules/image/classification/resnet50_vd_wildanimals/module.py @@ -2,20 +2,20 @@ from __future__ import absolute_import from __future__ import division -import ast import argparse +import ast import os import numpy as np -import paddle.fluid as fluid -import paddlehub as hub -from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor -from paddlehub.module.module import moduleinfo, runnable, serving -from paddlehub.common.paddle_helper import add_vars_prefix +from paddle.inference import Config +from paddle.inference import create_predictor -from resnet50_vd_wildanimals.processor import postprocess, base64_to_cv2 -from resnet50_vd_wildanimals.data_feed import reader -from resnet50_vd_wildanimals.resnet_vd import ResNet50_vd +from .data_feed import reader +from .processor import base64_to_cv2 +from .processor import postprocess +from paddlehub.module.module import moduleinfo +from paddlehub.module.module import runnable +from paddlehub.module.module import serving @moduleinfo( @@ -25,10 +25,11 @@ from resnet50_vd_wildanimals.resnet_vd import ResNet50_vd author_email="", summary= "ResNet50vd is a image classfication model, this module is trained with IFAW's self-built wild animals dataset.", - version="1.0.0") -class ResNet50vdWildAnimals(hub.Module): - def _initialize(self): - self.default_pretrained_model_path = os.path.join(self.directory, "model") + version="1.1.0") +class ResNet50vdWildAnimals: + + def __init__(self): + self.default_pretrained_model_path = os.path.join(self.directory, "model", "model") label_file = os.path.join(self.directory, "label_list.txt") with open(label_file, 'r', encoding='utf-8') as file: self.label_list = file.read().split("\n")[:-1] @@ -52,10 +53,12 @@ class ResNet50vdWildAnimals(hub.Module): """ predictor config setting. """ - cpu_config = AnalysisConfig(self.default_pretrained_model_path) + model = self.default_pretrained_model_path + '.pdmodel' + params = self.default_pretrained_model_path + '.pdiparams' + cpu_config = Config(model, params) cpu_config.disable_glog_info() cpu_config.disable_gpu() - self.cpu_predictor = create_paddle_predictor(cpu_config) + self.cpu_predictor = create_predictor(cpu_config) try: _places = os.environ["CUDA_VISIBLE_DEVICES"] @@ -64,58 +67,10 @@ class ResNet50vdWildAnimals(hub.Module): except: use_gpu = False if use_gpu: - gpu_config = AnalysisConfig(self.default_pretrained_model_path) + gpu_config = Config(model, params) gpu_config.disable_glog_info() gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) - self.gpu_predictor = create_paddle_predictor(gpu_config) - - def context(self, trainable=True, pretrained=True): - """context for transfer learning. - - Args: - trainable (bool): Set parameters in program to be trainable. - pretrained (bool) : Whether to load pretrained model. - - Returns: - inputs (dict): key is 'image', corresponding vaule is image tensor. - outputs (dict): key is : - 'classification', corresponding value is the result of classification. - 'feature_map', corresponding value is the result of the layer before the fully connected layer. - context_prog (fluid.Program): program for transfer learning. - """ - context_prog = fluid.Program() - startup_prog = fluid.Program() - with fluid.program_guard(context_prog, startup_prog): - with fluid.unique_name.guard(): - image = fluid.layers.data(name="image", shape=[3, 224, 224], dtype="float32") - resnet_vd = ResNet50_vd() - output, feature_map = resnet_vd.net(input=image, class_dim=len(self.label_list)) - - name_prefix = '@HUB_{}@'.format(self.name) - inputs = {'image': name_prefix + image.name} - outputs = {'classification': name_prefix + output.name, 'feature_map': name_prefix + feature_map.name} - add_vars_prefix(context_prog, name_prefix) - add_vars_prefix(startup_prog, name_prefix) - global_vars = context_prog.global_block().vars - inputs = {key: global_vars[value] for key, value in inputs.items()} - outputs = {key: global_vars[value] for key, value in outputs.items()} - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - # pretrained - if pretrained: - - def _if_exist(var): - b = os.path.exists(os.path.join(self.default_pretrained_model_path, var.name)) - return b - - fluid.io.load_vars(exe, self.default_pretrained_model_path, context_prog, predicate=_if_exist) - else: - exe.run(startup_prog) - # trainable - for param in context_prog.global_block().iter_parameters(): - param.trainable = trainable - return inputs, outputs, context_prog + self.gpu_predictor = create_predictor(gpu_config) def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1): """ @@ -131,15 +86,6 @@ class ResNet50vdWildAnimals(hub.Module): Returns: res (list[dict]): The classfication results. """ - if use_gpu: - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - except: - raise RuntimeError( - "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." - ) - all_data = list() for yield_data in reader(images, paths): all_data.append(yield_data) @@ -158,32 +104,19 @@ class ResNet50vdWildAnimals(hub.Module): pass # feed batch image batch_image = np.array([data['image'] for data in batch_data]) - batch_image = PaddleTensor(batch_image.copy()) - predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run( - [batch_image]) - out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k) + + predictor = self.gpu_predictor if use_gpu else self.cpu_predictor + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + input_handle.copy_from_cpu(batch_image.copy()) + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + + out = postprocess(data_out=output_handle.copy_to_cpu(), label_list=self.label_list, top_k=top_k) res += out return res - def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True): - if combined: - model_filename = "__model__" if not model_filename else model_filename - params_filename = "__params__" if not params_filename else params_filename - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - program, feeded_var_names, target_vars = fluid.io.load_inference_model( - dirname=self.default_pretrained_model_path, executor=exe) - - fluid.io.save_inference_model( - dirname=dirname, - main_program=program, - executor=exe, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - model_filename=model_filename, - params_filename=params_filename) - @serving def serving_method(self, images, **kwargs): """ @@ -198,11 +131,10 @@ class ResNet50vdWildAnimals(hub.Module): """ Run as a command. """ - self.parser = argparse.ArgumentParser( - description="Run the {} module.".format(self.name), - prog='hub run {}'.format(self.name), - usage='%(prog)s', - add_help=True) + self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name), + prog='hub run {}'.format(self.name), + usage='%(prog)s', + add_help=True) self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_config_group = self.parser.add_argument_group( title="Config options", description="Run configuration for controlling module behavior, not required.") @@ -216,8 +148,10 @@ class ResNet50vdWildAnimals(hub.Module): """ Add the command config options. """ - self.arg_config_group.add_argument( - '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.") + self.arg_config_group.add_argument('--use_gpu', + type=ast.literal_eval, + default=False, + help="whether use GPU or not.") self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.") diff --git a/modules/image/classification/resnet50_vd_wildanimals/processor.py b/modules/image/classification/resnet50_vd_wildanimals/processor.py index 6dc49772fceaad183e668cf1b5d170ffbc2086d4..0a352fd81fc3e1c28eed18281906ac9a9c8241b3 100644 --- a/modules/image/classification/resnet50_vd_wildanimals/processor.py +++ b/modules/image/classification/resnet50_vd_wildanimals/processor.py @@ -4,9 +4,8 @@ from __future__ import division from __future__ import print_function import base64 -import cv2 -import os +import cv2 import numpy as np @@ -18,7 +17,6 @@ def base64_to_cv2(b64str): def softmax(x): - orig_shape = x.shape if len(x.shape) > 1: tmp = np.max(x, axis=1) x -= tmp.reshape((x.shape[0], 1)) diff --git a/modules/image/classification/resnet50_vd_wildanimals/resnet_vd.py b/modules/image/classification/resnet50_vd_wildanimals/resnet_vd.py deleted file mode 100755 index 3d9a91ca7e8e40cdb54f5bcf9f9f522251e6be86..0000000000000000000000000000000000000000 --- a/modules/image/classification/resnet50_vd_wildanimals/resnet_vd.py +++ /dev/null @@ -1,185 +0,0 @@ -#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -import paddle -import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr - -__all__ = ["ResNet", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd", "ResNet200_vd"] - -train_parameters = { - "input_size": [3, 224, 224], - "input_mean": [0.485, 0.456, 0.406], - "input_std": [0.229, 0.224, 0.225], - "learning_strategy": { - "name": "piecewise_decay", - "batch_size": 256, - "epochs": [30, 60, 90], - "steps": [0.1, 0.01, 0.001, 0.0001] - } -} - - -class ResNet(): - def __init__(self, layers=50, is_3x3=False): - self.params = train_parameters - self.layers = layers - self.is_3x3 = is_3x3 - - def net(self, input, class_dim=1000): - is_3x3 = self.is_3x3 - layers = self.layers - supported_layers = [50, 101, 152, 200] - assert layers in supported_layers, \ - "supported layers are {} but input layer is {}".format(supported_layers, layers) - - if layers == 50: - depth = [3, 4, 6, 3] - elif layers == 101: - depth = [3, 4, 23, 3] - elif layers == 152: - depth = [3, 8, 36, 3] - elif layers == 200: - depth = [3, 12, 48, 3] - num_filters = [64, 128, 256, 512] - if is_3x3 == False: - conv = self.conv_bn_layer(input=input, num_filters=64, filter_size=7, stride=2, act='relu') - else: - conv = self.conv_bn_layer(input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1') - conv = self.conv_bn_layer(input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2') - conv = self.conv_bn_layer(input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3') - - conv = fluid.layers.pool2d(input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') - - for block in range(len(depth)): - for i in range(depth[block]): - if layers in [101, 152, 200] and block == 2: - if i == 0: - conv_name = "res" + str(block + 2) + "a" - else: - conv_name = "res" + str(block + 2) + "b" + str(i) - else: - conv_name = "res" + str(block + 2) + chr(97 + i) - conv = self.bottleneck_block( - input=conv, - num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - if_first=block == 0, - name=conv_name) - - pool = fluid.layers.pool2d(input=conv, pool_size=7, pool_type='avg', global_pooling=True) - stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) - - out = fluid.layers.fc( - input=pool, - size=class_dim, - param_attr=fluid.param_attr.ParamAttr(initializer=fluid.initializer.Uniform(-stdv, stdv))) - - return out, pool - - def conv_bn_layer(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None): - conv = fluid.layers.conv2d( - input=input, - num_filters=num_filters, - filter_size=filter_size, - stride=stride, - padding=(filter_size - 1) // 2, - groups=groups, - act=None, - param_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) - if name == "conv1": - bn_name = "bn_" + name - else: - bn_name = "bn" + name[3:] - return fluid.layers.batch_norm( - input=conv, - act=act, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - - def conv_bn_layer_new(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None): - pool = fluid.layers.pool2d(input=input, pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg') - - conv = fluid.layers.conv2d( - input=pool, - num_filters=num_filters, - filter_size=filter_size, - stride=1, - padding=(filter_size - 1) // 2, - groups=groups, - act=None, - param_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) - if name == "conv1": - bn_name = "bn_" + name - else: - bn_name = "bn" + name[3:] - return fluid.layers.batch_norm( - input=conv, - act=act, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - - def shortcut(self, input, ch_out, stride, name, if_first=False): - ch_in = input.shape[1] - if ch_in != ch_out or stride != 1: - if if_first: - return self.conv_bn_layer(input, ch_out, 1, stride, name=name) - else: - return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name) - else: - return input - - def bottleneck_block(self, input, num_filters, stride, name, if_first): - conv0 = self.conv_bn_layer( - input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a") - conv1 = self.conv_bn_layer( - input=conv0, num_filters=num_filters, filter_size=3, stride=stride, act='relu', name=name + "_branch2b") - conv2 = self.conv_bn_layer( - input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c") - - short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1") - - return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') - - -def ResNet50_vd(): - model = ResNet(layers=50, is_3x3=True) - return model - - -def ResNet101_vd(): - model = ResNet(layers=101, is_3x3=True) - return model - - -def ResNet152_vd(): - model = ResNet(layers=152, is_3x3=True) - return model - - -def ResNet200_vd(): - model = ResNet(layers=200, is_3x3=True) - return model diff --git a/modules/image/classification/resnet50_vd_wildanimals/test.py b/modules/image/classification/resnet50_vd_wildanimals/test.py new file mode 100644 index 0000000000000000000000000000000000000000..77f4e27f318ce2fb40485ad06c2356f08dac289a --- /dev/null +++ b/modules/image/classification/resnet50_vd_wildanimals/test.py @@ -0,0 +1,63 @@ +import os +import shutil +import unittest + +import cv2 +import requests + +import paddlehub as hub + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +class TestHubModule(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + img_url = 'https://unsplash.com/photos/J33o16cP0SA/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8Mnx8aXZvcnl8ZW58MHx8fHwxNjY1NTUwNjk4&force=true&w=640' + if not os.path.exists('tests'): + os.makedirs('tests') + response = requests.get(img_url) + assert response.status_code == 200, 'Network Error.' + with open('tests/test.jpg', 'wb') as f: + f.write(response.content) + cls.module = hub.Module(name="resnet50_vd_wildanimals") + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree('tests') + shutil.rmtree('inference') + + def test_classification1(self): + results = self.module.classification(paths=['tests/test.jpg']) + data = results[0] + self.assertTrue('象牙' in data) + self.assertTrue(data['象牙'] > 0.2) + + def test_classification2(self): + results = self.module.classification(images=[cv2.imread('tests/test.jpg')]) + data = results[0] + self.assertTrue('象牙' in data) + self.assertTrue(data['象牙'] > 0.2) + + def test_classification3(self): + results = self.module.classification(images=[cv2.imread('tests/test.jpg')], use_gpu=True) + data = results[0] + self.assertTrue('象牙' in data) + self.assertTrue(data['象牙'] > 0.2) + + def test_classification4(self): + self.assertRaises(AssertionError, self.module.classification, paths=['no.jpg']) + + def test_classification5(self): + self.assertRaises(TypeError, self.module.classification, images=['tests/test.jpg']) + + def test_save_inference_model(self): + self.module.save_inference_model('./inference/model') + + self.assertTrue(os.path.exists('./inference/model.pdmodel')) + self.assertTrue(os.path.exists('./inference/model.pdiparams')) + + +if __name__ == "__main__": + unittest.main()