未验证 提交 3e471d8e 编写于 作者: H houj04 提交者: GitHub

resnet50_vd_animals support npu and xpu (#1579)

上级 5d6cab99
......@@ -9,7 +9,9 @@ import os
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
......@@ -47,26 +49,53 @@ class ResNet50vdAnimals(hub.Module):
im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3)
return im_std
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create predictors using various types of devices
# npu
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.default_pretrained_model_path)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def context(self, trainable=True, pretrained=True):
"""context for transfer learning.
......@@ -116,7 +145,7 @@ class ResNet50vdAnimals(hub.Module):
param.trainable = trainable
return inputs, outputs, context_prog
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1):
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1, use_device=None):
"""
API for image classification.
......@@ -126,18 +155,30 @@ class ResNet50vdAnimals(hub.Module):
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
top_k (int): Return top k results.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The classfication results.
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
# real predictor to use
if use_device is not None:
if use_device == "cpu":
predictor = self.cpu_predictor
elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
all_data = list()
for yield_data in reader(images, paths):
......@@ -157,10 +198,16 @@ class ResNet50vdAnimals(hub.Module):
pass
# feed batch image
batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run(
[batch_image])
out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
out = postprocess(data_out=predictor_output, label_list=self.label_list, top_k=top_k)
res += out
return res
......@@ -174,14 +221,13 @@ class ResNet50vdAnimals(hub.Module):
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
fluid.io.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
......@@ -197,28 +243,36 @@ class ResNet50vdAnimals(hub.Module):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.classification(paths=[args.input_path], batch_size=args.batch_size, use_gpu=args.use_gpu)
results = self.classification(paths=[args.input_path],
batch_size=args.batch_size,
use_gpu=args.use_gpu,
top_k=args.top_k,
use_device=args.use_device)
return results
def add_module_config_arg(self):
"""
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.")
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.")
self.arg_config_group.add_argument('--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册