未验证 提交 3e471d8e 编写于 作者: H houj04 提交者: GitHub

resnet50_vd_animals support npu and xpu (#1579)

上级 5d6cab99
...@@ -9,7 +9,9 @@ import os ...@@ -9,7 +9,9 @@ import os
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix from paddlehub.common.paddle_helper import add_vars_prefix
...@@ -47,26 +49,53 @@ class ResNet50vdAnimals(hub.Module): ...@@ -47,26 +49,53 @@ class ResNet50vdAnimals(hub.Module):
im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3) im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3)
return im_std return im_std
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: # create predictors using various types of devices
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) # npu
use_gpu = True npu_id = self._get_device_id("FLAGS_selected_npus")
except: if npu_id != -1:
use_gpu = False # use npu
if use_gpu: npu_config = Config(self.default_pretrained_model_path)
gpu_config = AnalysisConfig(self.default_pretrained_model_path) npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def context(self, trainable=True, pretrained=True): def context(self, trainable=True, pretrained=True):
"""context for transfer learning. """context for transfer learning.
...@@ -116,7 +145,7 @@ class ResNet50vdAnimals(hub.Module): ...@@ -116,7 +145,7 @@ class ResNet50vdAnimals(hub.Module):
param.trainable = trainable param.trainable = trainable
return inputs, outputs, context_prog return inputs, outputs, context_prog
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1): def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1, use_device=None):
""" """
API for image classification. API for image classification.
...@@ -126,18 +155,30 @@ class ResNet50vdAnimals(hub.Module): ...@@ -126,18 +155,30 @@ class ResNet50vdAnimals(hub.Module):
batch_size (int): batch size. batch_size (int): batch size.
use_gpu (bool): Whether to use gpu. use_gpu (bool): Whether to use gpu.
top_k (int): Return top k results. top_k (int): Return top k results.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list[dict]): The classfication results. res (list[dict]): The classfication results.
""" """
if use_gpu:
try: # real predictor to use
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device is not None:
int(_places[0]) if use_device == "cpu":
except: predictor = self.cpu_predictor
raise RuntimeError( elif use_device == "xpu":
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." predictor = self.xpu_predictor
) elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
all_data = list() all_data = list()
for yield_data in reader(images, paths): for yield_data in reader(images, paths):
...@@ -157,10 +198,16 @@ class ResNet50vdAnimals(hub.Module): ...@@ -157,10 +198,16 @@ class ResNet50vdAnimals(hub.Module):
pass pass
# feed batch image # feed batch image
batch_image = np.array([data['image'] for data in batch_data]) batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run( input_names = predictor.get_input_names()
[batch_image]) input_tensor = predictor.get_input_handle(input_names[0])
out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k) input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
out = postprocess(data_out=predictor_output, label_list=self.label_list, top_k=top_k)
res += out res += out
return res return res
...@@ -174,14 +221,13 @@ class ResNet50vdAnimals(hub.Module): ...@@ -174,14 +221,13 @@ class ResNet50vdAnimals(hub.Module):
program, feeded_var_names, target_vars = fluid.io.load_inference_model( program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe) dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model( fluid.io.save_inference_model(dirname=dirname,
dirname=dirname, main_program=program,
main_program=program, executor=exe,
executor=exe, feeded_var_names=feeded_var_names,
feeded_var_names=feeded_var_names, target_vars=target_vars,
target_vars=target_vars, model_filename=model_filename,
model_filename=model_filename, params_filename=params_filename)
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
...@@ -197,28 +243,36 @@ class ResNet50vdAnimals(hub.Module): ...@@ -197,28 +243,36 @@ class ResNet50vdAnimals(hub.Module):
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
description="Run the {} module.".format(self.name), prog='hub run {}'.format(self.name),
prog='hub run {}'.format(self.name), usage='%(prog)s',
usage='%(prog)s', add_help=True)
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.") title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.classification(paths=[args.input_path], batch_size=args.batch_size, use_gpu=args.use_gpu) results = self.classification(paths=[args.input_path],
batch_size=args.batch_size,
use_gpu=args.use_gpu,
top_k=args.top_k,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
""" """
Add the command config options. Add the command config options.
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--use_gpu',
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.") type=ast.literal_eval,
default=False,
help="whether use GPU or not.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.") self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.")
self.arg_config_group.add_argument('--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册