未验证 提交 08c56a7d 编写于 作者: H houj04 提交者: GitHub

add xpu and npu support for classification and sentiment series. (#1649)

上级 0d60bf5a
...@@ -9,7 +9,10 @@ import os ...@@ -9,7 +9,10 @@ import os
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix from paddlehub.common.paddle_helper import add_vars_prefix
...@@ -48,26 +51,53 @@ class MobileNetV2Animals(hub.Module): ...@@ -48,26 +51,53 @@ class MobileNetV2Animals(hub.Module):
im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3) im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3)
return im_std return im_std
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: # create predictors using various types of devices
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) # npu
use_gpu = True npu_id = self._get_device_id("FLAGS_selected_npus")
except: if npu_id != -1:
use_gpu = False # use npu
if use_gpu: npu_config = Config(self.default_pretrained_model_path)
gpu_config = AnalysisConfig(self.default_pretrained_model_path) npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def context(self, trainable=True, pretrained=True): def context(self, trainable=True, pretrained=True):
"""context for transfer learning. """context for transfer learning.
...@@ -117,7 +147,7 @@ class MobileNetV2Animals(hub.Module): ...@@ -117,7 +147,7 @@ class MobileNetV2Animals(hub.Module):
param.trainable = trainable param.trainable = trainable
return inputs, outputs, context_prog return inputs, outputs, context_prog
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1): def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1, use_device=None):
""" """
API for image classification. API for image classification.
...@@ -127,18 +157,29 @@ class MobileNetV2Animals(hub.Module): ...@@ -127,18 +157,29 @@ class MobileNetV2Animals(hub.Module):
batch_size (int): batch size. batch_size (int): batch size.
use_gpu (bool): Whether to use gpu. use_gpu (bool): Whether to use gpu.
top_k (int): Return top k results. top_k (int): Return top k results.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list[dict]): The classfication results. res (list[dict]): The classfication results.
""" """
if use_gpu: # real predictor to use
try: if use_device is not None:
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device == "cpu":
int(_places[0]) predictor = self.cpu_predictor
except: elif use_device == "xpu":
raise RuntimeError( predictor = self.xpu_predictor
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." elif use_device == "npu":
) predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
all_data = list() all_data = list()
for yield_data in reader(images, paths): for yield_data in reader(images, paths):
...@@ -158,10 +199,16 @@ class MobileNetV2Animals(hub.Module): ...@@ -158,10 +199,16 @@ class MobileNetV2Animals(hub.Module):
pass pass
# feed batch image # feed batch image
batch_image = np.array([data['image'] for data in batch_data]) batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run( input_names = predictor.get_input_names()
[batch_image]) input_tensor = predictor.get_input_handle(input_names[0])
out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k) input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
out = postprocess(data_out=predictor_output, label_list=self.label_list, top_k=top_k)
res += out res += out
return res return res
...@@ -209,7 +256,12 @@ class MobileNetV2Animals(hub.Module): ...@@ -209,7 +256,12 @@ class MobileNetV2Animals(hub.Module):
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.classification(paths=[args.input_path], batch_size=args.batch_size, use_gpu=args.use_gpu) results = self.classification(
paths=[args.input_path],
batch_size=args.batch_size,
use_gpu=args.use_gpu,
top_k=args.top_k,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
...@@ -220,6 +272,10 @@ class MobileNetV2Animals(hub.Module): ...@@ -220,6 +272,10 @@ class MobileNetV2Animals(hub.Module):
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.") '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.") self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -9,7 +9,10 @@ import os ...@@ -9,7 +9,10 @@ import os
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix from paddlehub.common.paddle_helper import add_vars_prefix
...@@ -47,26 +50,53 @@ class ResNet50vdDishes(hub.Module): ...@@ -47,26 +50,53 @@ class ResNet50vdDishes(hub.Module):
im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3) im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3)
return im_std return im_std
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: # create predictors using various types of devices
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) # npu
use_gpu = True npu_id = self._get_device_id("FLAGS_selected_npus")
except: if npu_id != -1:
use_gpu = False # use npu
if use_gpu: npu_config = Config(self.default_pretrained_model_path)
gpu_config = AnalysisConfig(self.default_pretrained_model_path) npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def context(self, trainable=True, pretrained=True): def context(self, trainable=True, pretrained=True):
"""context for transfer learning. """context for transfer learning.
...@@ -116,7 +146,7 @@ class ResNet50vdDishes(hub.Module): ...@@ -116,7 +146,7 @@ class ResNet50vdDishes(hub.Module):
param.trainable = trainable param.trainable = trainable
return inputs, outputs, context_prog return inputs, outputs, context_prog
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1): def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1, use_device=None):
""" """
API for image classification. API for image classification.
...@@ -126,18 +156,29 @@ class ResNet50vdDishes(hub.Module): ...@@ -126,18 +156,29 @@ class ResNet50vdDishes(hub.Module):
batch_size (int): batch size. batch_size (int): batch size.
use_gpu (bool): Whether to use gpu. use_gpu (bool): Whether to use gpu.
top_k (int): Return top k results. top_k (int): Return top k results.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list[dict]): The classfication results. res (list[dict]): The classfication results.
""" """
if use_gpu: # real predictor to use
try: if use_device is not None:
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device == "cpu":
int(_places[0]) predictor = self.cpu_predictor
except: elif use_device == "xpu":
raise RuntimeError( predictor = self.xpu_predictor
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." elif use_device == "npu":
) predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
all_data = list() all_data = list()
for yield_data in reader(images, paths): for yield_data in reader(images, paths):
...@@ -157,10 +198,16 @@ class ResNet50vdDishes(hub.Module): ...@@ -157,10 +198,16 @@ class ResNet50vdDishes(hub.Module):
pass pass
# feed batch image # feed batch image
batch_image = np.array([data['image'] for data in batch_data]) batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run( input_names = predictor.get_input_names()
[batch_image]) input_tensor = predictor.get_input_handle(input_names[0])
out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k) input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
out = postprocess(data_out=predictor_output, label_list=self.label_list, top_k=top_k)
res += out res += out
return res return res
...@@ -208,7 +255,12 @@ class ResNet50vdDishes(hub.Module): ...@@ -208,7 +255,12 @@ class ResNet50vdDishes(hub.Module):
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.classification(paths=[args.input_path], batch_size=args.batch_size, use_gpu=args.use_gpu) results = self.classification(
paths=[args.input_path],
batch_size=args.batch_size,
use_gpu=args.use_gpu,
top_k=args.top_k,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
...@@ -219,6 +271,10 @@ class ResNet50vdDishes(hub.Module): ...@@ -219,6 +271,10 @@ class ResNet50vdDishes(hub.Module):
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.") '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.") self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -9,7 +9,10 @@ import os ...@@ -9,7 +9,10 @@ import os
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix from paddlehub.common.paddle_helper import add_vars_prefix
...@@ -48,26 +51,53 @@ class ResNet50vdWildAnimals(hub.Module): ...@@ -48,26 +51,53 @@ class ResNet50vdWildAnimals(hub.Module):
im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3) im_std = np.array([0.229, 0.224, 0.225]).reshape(1, 3)
return im_std return im_std
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting. predictor config setting.
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: # create predictors using various types of devices
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) # npu
use_gpu = True npu_id = self._get_device_id("FLAGS_selected_npus")
except: if npu_id != -1:
use_gpu = False # use npu
if use_gpu: npu_config = Config(self.default_pretrained_model_path)
gpu_config = AnalysisConfig(self.default_pretrained_model_path) npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def context(self, trainable=True, pretrained=True): def context(self, trainable=True, pretrained=True):
"""context for transfer learning. """context for transfer learning.
...@@ -117,7 +147,7 @@ class ResNet50vdWildAnimals(hub.Module): ...@@ -117,7 +147,7 @@ class ResNet50vdWildAnimals(hub.Module):
param.trainable = trainable param.trainable = trainable
return inputs, outputs, context_prog return inputs, outputs, context_prog
def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1): def classification(self, images=None, paths=None, batch_size=1, use_gpu=False, top_k=1, use_device=None):
""" """
API for image classification. API for image classification.
...@@ -127,18 +157,29 @@ class ResNet50vdWildAnimals(hub.Module): ...@@ -127,18 +157,29 @@ class ResNet50vdWildAnimals(hub.Module):
batch_size (int): batch size. batch_size (int): batch size.
use_gpu (bool): Whether to use gpu. use_gpu (bool): Whether to use gpu.
top_k (int): Return top k results. top_k (int): Return top k results.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list[dict]): The classfication results. res (list[dict]): The classfication results.
""" """
if use_gpu: # real predictor to use
try: if use_device is not None:
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device == "cpu":
int(_places[0]) predictor = self.cpu_predictor
except: elif use_device == "xpu":
raise RuntimeError( predictor = self.xpu_predictor
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." elif use_device == "npu":
) predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
all_data = list() all_data = list()
for yield_data in reader(images, paths): for yield_data in reader(images, paths):
...@@ -158,10 +199,16 @@ class ResNet50vdWildAnimals(hub.Module): ...@@ -158,10 +199,16 @@ class ResNet50vdWildAnimals(hub.Module):
pass pass
# feed batch image # feed batch image
batch_image = np.array([data['image'] for data in batch_data]) batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
predictor_output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run( input_names = predictor.get_input_names()
[batch_image]) input_tensor = predictor.get_input_handle(input_names[0])
out = postprocess(data_out=predictor_output[0].as_ndarray(), label_list=self.label_list, top_k=top_k) input_tensor.reshape(batch_image.shape)
input_tensor.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
predictor_output = output_handle.copy_to_cpu()
out = postprocess(data_out=predictor_output, label_list=self.label_list, top_k=top_k)
res += out res += out
return res return res
...@@ -209,7 +256,12 @@ class ResNet50vdWildAnimals(hub.Module): ...@@ -209,7 +256,12 @@ class ResNet50vdWildAnimals(hub.Module):
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.classification(paths=[args.input_path], batch_size=args.batch_size, use_gpu=args.use_gpu) results = self.classification(
paths=[args.input_path],
batch_size=args.batch_size,
use_gpu=args.use_gpu,
top_k=args.top_k,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
...@@ -220,6 +272,10 @@ class ResNet50vdWildAnimals(hub.Module): ...@@ -220,6 +272,10 @@ class ResNet50vdWildAnimals(hub.Module):
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.") '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.") self.arg_config_group.add_argument('--top_k', type=ast.literal_eval, default=1, help="Return top k results.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -151,7 +151,7 @@ class EmotionDetectionTextCNN(hub.NLPPredictionModule): ...@@ -151,7 +151,7 @@ class EmotionDetectionTextCNN(hub.NLPPredictionModule):
return inputs, outputs, main_program return inputs, outputs, main_program
@serving @serving
def emotion_classify(self, texts=[], data={}, use_gpu=False, batch_size=1): def emotion_classify(self, texts=[], data={}, use_gpu=False, batch_size=1, use_device=None):
""" """
Get the emotion prediction results results with the texts as input Get the emotion prediction results results with the texts as input
Args: Args:
...@@ -161,15 +161,26 @@ class EmotionDetectionTextCNN(hub.NLPPredictionModule): ...@@ -161,15 +161,26 @@ class EmotionDetectionTextCNN(hub.NLPPredictionModule):
batch_size(int): the program deals once with one batch batch_size(int): the program deals once with one batch
Returns: Returns:
results(list): the emotion prediction results results(list): the emotion prediction results
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
""" """
if use_gpu: # real predictor to use
try: if use_device is not None:
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device == "cpu":
int(_places[0]) predictor = self.cpu_predictor
except: elif use_device == "xpu":
raise RuntimeError( predictor = self.xpu_predictor
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." elif use_device == "npu":
) predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
if texts != [] and isinstance(texts, list) and data == {}: if texts != [] and isinstance(texts, list) and data == {}:
predicted_data = texts predicted_data = texts
...@@ -189,14 +200,10 @@ class EmotionDetectionTextCNN(hub.NLPPredictionModule): ...@@ -189,14 +200,10 @@ class EmotionDetectionTextCNN(hub.NLPPredictionModule):
else: else:
batch_data = predicted_data[start_idx:] batch_data = predicted_data[start_idx:]
start_idx = start_idx + batch_size start_idx = start_idx + batch_size
processed_results = preprocess(self.word_seg_module, batch_data, self.vocab, use_gpu, batch_size) processed_results = preprocess(self.word_seg_module, batch_data, self.vocab, use_gpu, batch_size,
tensor_words = self.texts2tensor(processed_results) use_device)
predictor_output = self._internal_predict(predictor, processed_results)
if use_gpu: batch_result = postprocess(predictor_output, processed_results)
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = postprocess(batch_out[0], processed_results)
results += batch_result results += batch_result
return results return results
......
...@@ -34,10 +34,10 @@ def get_predict_label(probs): ...@@ -34,10 +34,10 @@ def get_predict_label(probs):
return label, key return label, key
def preprocess(lac, predicted_data, word_dict, use_gpu=False, batch_size=1): def preprocess(lac, predicted_data, word_dict, use_gpu=False, batch_size=1, use_device=None):
result = [] result = []
data_dict = {"text": predicted_data} data_dict = {"text": predicted_data}
processed = lac.lexical_analysis(data=data_dict, use_gpu=use_gpu, batch_size=batch_size) processed = lac.lexical_analysis(data=data_dict, use_gpu=use_gpu, batch_size=batch_size, use_device=use_device)
unk_id = word_dict["<unk>"] unk_id = word_dict["<unk>"]
for index, data in enumerate(processed): for index, data in enumerate(processed):
result_i = {'processed': []} result_i = {'processed': []}
...@@ -54,7 +54,7 @@ def preprocess(lac, predicted_data, word_dict, use_gpu=False, batch_size=1): ...@@ -54,7 +54,7 @@ def preprocess(lac, predicted_data, word_dict, use_gpu=False, batch_size=1):
def postprocess(prediction, texts): def postprocess(prediction, texts):
result = [] result = []
pred = prediction.as_ndarray() pred = prediction.copy_to_cpu()
for index in range(len(texts)): for index in range(len(texts)):
result_i = {} result_i = {}
result_i['text'] = texts[index]['origin'] result_i['text'] = texts[index]['origin']
......
...@@ -153,7 +153,7 @@ class SentaBiLSTM(hub.NLPPredictionModule): ...@@ -153,7 +153,7 @@ class SentaBiLSTM(hub.NLPPredictionModule):
return inputs, outputs, main_program return inputs, outputs, main_program
@serving @serving
def sentiment_classify(self, texts=[], data={}, use_gpu=False, batch_size=1): def sentiment_classify(self, texts=[], data={}, use_gpu=False, batch_size=1, use_device=None):
""" """
Get the sentiment prediction results results with the texts as input Get the sentiment prediction results results with the texts as input
...@@ -162,18 +162,29 @@ class SentaBiLSTM(hub.NLPPredictionModule): ...@@ -162,18 +162,29 @@ class SentaBiLSTM(hub.NLPPredictionModule):
data(dict): key must be 'text', value is the texts to be predicted, if data not texts data(dict): key must be 'text', value is the texts to be predicted, if data not texts
use_gpu(bool): whether use gpu to predict or not use_gpu(bool): whether use gpu to predict or not
batch_size(int): the program deals once with one batch batch_size(int): the program deals once with one batch
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
results(list): the word segmentation results results(list): the word segmentation results
""" """
if use_gpu: # real predictor to use
try: if use_device is not None:
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device == "cpu":
int(_places[0]) predictor = self.cpu_predictor
except: elif use_device == "xpu":
raise RuntimeError( predictor = self.xpu_predictor
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." elif use_device == "npu":
) predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
if texts != [] and isinstance(texts, list) and data == {}: if texts != [] and isinstance(texts, list) and data == {}:
predicted_data = texts predicted_data = texts
...@@ -193,14 +204,10 @@ class SentaBiLSTM(hub.NLPPredictionModule): ...@@ -193,14 +204,10 @@ class SentaBiLSTM(hub.NLPPredictionModule):
batch_data = predicted_data[start_idx:] batch_data = predicted_data[start_idx:]
start_idx = start_idx + batch_size start_idx = start_idx + batch_size
processed_results = preprocess(self.word_seg_module, batch_data, self.word_dict, use_gpu, batch_size) processed_results = preprocess(self.word_seg_module, batch_data, self.word_dict, use_gpu, batch_size,
tensor_words = self.texts2tensor(processed_results) use_device)
predictor_output = self._internal_predict(predictor, processed_results)
if use_gpu: batch_result = postprocess(predictor_output, processed_results)
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = postprocess(batch_out[0], processed_results)
results += batch_result results += batch_result
return results return results
......
...@@ -17,14 +17,14 @@ def load_vocab(file_path): ...@@ -17,14 +17,14 @@ def load_vocab(file_path):
return vocab return vocab
def preprocess(lac, texts, word_dict, use_gpu=False, batch_size=1): def preprocess(lac, texts, word_dict, use_gpu=False, batch_size=1, use_device=None):
""" """
firstly, the predicted texts are segmented by lac module firstly, the predicted texts are segmented by lac module
then, the word segmention results input into senta then, the word segmention results input into senta
""" """
result = [] result = []
input_dict = {'text': texts} input_dict = {'text': texts}
processed = lac.lexical_analysis(data=input_dict, use_gpu=use_gpu, batch_size=batch_size) processed = lac.lexical_analysis(data=input_dict, use_gpu=use_gpu, batch_size=batch_size, use_device=use_device)
unk_id = word_dict["<unk>"] unk_id = word_dict["<unk>"]
for index, data in enumerate(processed): for index, data in enumerate(processed):
result_i = {'processed': []} result_i = {'processed': []}
...@@ -43,7 +43,7 @@ def postprocess(predict_out, texts): ...@@ -43,7 +43,7 @@ def postprocess(predict_out, texts):
""" """
Convert model's output tensor to sentiment label Convert model's output tensor to sentiment label
""" """
predict_out = predict_out.as_ndarray() predict_out = predict_out.copy_to_cpu()
batch_size = len(texts) batch_size = len(texts)
result = [] result = []
for index in range(batch_size): for index in range(batch_size):
......
...@@ -78,7 +78,7 @@ class PornDetectionLSTM(hub.NLPPredictionModule): ...@@ -78,7 +78,7 @@ class PornDetectionLSTM(hub.NLPPredictionModule):
return inputs, outputs, program return inputs, outputs, program
@serving @serving
def detection(self, texts=[], data={}, use_gpu=False, batch_size=1): def detection(self, texts=[], data={}, use_gpu=False, batch_size=1, use_device=None):
""" """
Get the porn prediction results results with the texts as input Get the porn prediction results results with the texts as input
...@@ -87,15 +87,29 @@ class PornDetectionLSTM(hub.NLPPredictionModule): ...@@ -87,15 +87,29 @@ class PornDetectionLSTM(hub.NLPPredictionModule):
data(dict): key must be 'text', value is the texts to be predicted, if data not texts data(dict): key must be 'text', value is the texts to be predicted, if data not texts
use_gpu(bool): whether use gpu to predict or not use_gpu(bool): whether use gpu to predict or not
batch_size(int): the program deals once with one batch batch_size(int): the program deals once with one batch
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
results(list): the porn prediction results results(list): the porn prediction results
""" """
try: # real predictor to use
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device is not None:
int(_places[0]) if use_device == "cpu":
except: predictor = self.cpu_predictor
use_gpu = False elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
if texts != [] and isinstance(texts, list) and data == {}: if texts != [] and isinstance(texts, list) and data == {}:
predicted_data = texts predicted_data = texts
...@@ -116,13 +130,8 @@ class PornDetectionLSTM(hub.NLPPredictionModule): ...@@ -116,13 +130,8 @@ class PornDetectionLSTM(hub.NLPPredictionModule):
start_idx = start_idx + batch_size start_idx = start_idx + batch_size
processed_results = preprocess(batch_data, self.tokenizer, self.vocab, self.sequence_max_len) processed_results = preprocess(batch_data, self.tokenizer, self.vocab, self.sequence_max_len)
tensor_words = self.texts2tensor(processed_results) predictor_output = self._internal_predict(predictor, processed_results)
batch_result = postprocess(predictor_output, processed_results)
if use_gpu:
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = postprocess(batch_out[0], processed_results)
results += batch_result results += batch_result
return results return results
......
...@@ -52,7 +52,7 @@ def postprocess(predict_out, texts): ...@@ -52,7 +52,7 @@ def postprocess(predict_out, texts):
Convert model's output tensor to pornography label Convert model's output tensor to pornography label
""" """
result = [] result = []
predict_out = predict_out.as_ndarray() predict_out = predict_out.copy_to_cpu()
for index in range(len(texts)): for index in range(len(texts)):
result_i = {} result_i = {}
result_i['text'] = texts[index]['origin'] result_i['text'] = texts[index]['origin']
......
...@@ -68,6 +68,10 @@ class RunCommand: ...@@ -68,6 +68,10 @@ class RunCommand:
arg_config_group.add_argument( arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help='whether use GPU for prediction') '--use_gpu', type=ast.literal_eval, default=False, help='whether use GPU for prediction')
arg_config_group.add_argument('--batch_size', type=int, default=1, help='batch size for prediction') arg_config_group.add_argument('--batch_size', type=int, default=1, help='batch size for prediction')
arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
module_type = module.type.lower() module_type = module.type.lower()
if module_type.startswith('cv'): if module_type.startswith('cv'):
...@@ -83,4 +87,8 @@ class RunCommand: ...@@ -83,4 +87,8 @@ class RunCommand:
input_data = {key: [args.input_path] if module_type.startswith('cv') else [args.input_text]} input_data = {key: [args.input_path] if module_type.startswith('cv') else [args.input_text]}
return module( return module(
sign_name=module.default_signature, data=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size) sign_name=module.default_signature,
data=input_data,
use_gpu=args.use_gpu,
batch_size=args.batch_size,
use_device=args.use_device)
...@@ -167,7 +167,13 @@ class ModuleV1(object): ...@@ -167,7 +167,13 @@ class ModuleV1(object):
program.global_block().var(feed_dict[tensor_name].name).desc.set_shape(seq_tensor_shape) program.global_block().var(feed_dict[tensor_name].name).desc.set_shape(seq_tensor_shape)
@paddle_utils.run_in_static_mode @paddle_utils.run_in_static_mode
def __call__(self, sign_name: str, data: dict, use_gpu: bool = False, batch_size: int = 1, **kwargs): def __call__(self,
sign_name: str,
data: dict,
use_gpu: bool = False,
batch_size: int = 1,
use_device: str = None,
**kwargs):
'''Call the specified signature function for prediction.''' '''Call the specified signature function for prediction.'''
def _get_reader_and_feeder(data_format, data, place): def _get_reader_and_feeder(data_format, data, place):
...@@ -188,7 +194,18 @@ class ModuleV1(object): ...@@ -188,7 +194,18 @@ class ModuleV1(object):
with paddle.static.program_guard(program): with paddle.static.program_guard(program):
result = [] result = []
index = 0 index = 0
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
if use_device is not None:
if use_device == "xpu":
place = paddle.XPUPlace(0)
elif use_device == "npu":
place = paddle.NPUPlace(0)
elif use_device == "gpu":
place = paddle.CUDAPlace(0)
else:
place = paddle.CPUPlace()
else:
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place=place) exe = paddle.static.Executor(place=place)
data = self.processor.preprocess(sign_name=sign_name, data_dict=data) data = self.processor.preprocess(sign_name=sign_name, data_dict=data)
......
...@@ -31,6 +31,9 @@ from paddlehub.module.module import runnable, RunModule ...@@ -31,6 +31,9 @@ from paddlehub.module.module import runnable, RunModule
from paddlehub.utils.parser import txt_parser from paddlehub.utils.parser import txt_parser
from paddlehub.utils.utils import sys_stdin_encoding from paddlehub.utils.utils import sys_stdin_encoding
from paddle.inference import Config
from paddle.inference import create_predictor
class DataFormatError(Exception): class DataFormatError(Exception):
def __init__(self, *args): def __init__(self, *args):
...@@ -48,24 +51,53 @@ class NLPBaseModule(RunModule): ...@@ -48,24 +51,53 @@ class NLPBaseModule(RunModule):
class NLPPredictionModule(NLPBaseModule): class NLPPredictionModule(NLPBaseModule):
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
'''predictor config setting''' """
cpu_config = paddle.fluid.core.AnalysisConfig(self.pretrained_model_path) predictor config setting
"""
# create default cpu predictor
cpu_config = Config(self.pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = paddle.fluid.core.create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: # create predictors using various types of devices
_places = os.environ['CUDA_VISIBLE_DEVICES']
int(_places[0]) # npu
use_gpu = True npu_id = self._get_device_id("FLAGS_selected_npus")
except: if npu_id != -1:
use_gpu = False # use npu
if use_gpu: npu_config = Config(self.pretrained_model_path)
gpu_config = paddle.fluid.core.AnalysisConfig(self.pretrained_model_path) npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=gpu_id)
self.gpu_predictor = paddle.fluid.core.create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def texts2tensor(self, texts: List[dict]) -> paddle.Tensor: def texts2tensor(self, texts: List[dict]) -> paddle.Tensor:
''' '''
...@@ -87,6 +119,29 @@ class NLPPredictionModule(NLPBaseModule): ...@@ -87,6 +119,29 @@ class NLPPredictionModule(NLPBaseModule):
tensor.shape = [lod[-1], 1] tensor.shape = [lod[-1], 1]
return tensor return tensor
def _internal_predict(self, predictor, texts):
lod = [0]
data = []
for i, text in enumerate(texts):
data += text['processed']
lod.append(len(text['processed']) + lod[i])
# get predictor tensor
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
# set data, shape and lod
input_tensor.copy_from_cpu(np.array(data).astype('int64'))
input_tensor.reshape([lod[-1], 1])
input_tensor.set_lod([lod])
# real predict
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
return output_handle
def to_unicode(self, texts: str) -> Text: def to_unicode(self, texts: str) -> Text:
''' '''
Convert each element's type(str) of texts(list) to unicode in python2.7 Convert each element's type(str) of texts(list) to unicode in python2.7
...@@ -129,7 +184,8 @@ class NLPPredictionModule(NLPBaseModule): ...@@ -129,7 +184,8 @@ class NLPPredictionModule(NLPBaseModule):
self.parser.print_help() self.parser.print_help()
return None return None
results = self.predict(texts=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size) results = self.predict(
texts=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size, use_device=args.use_device)
return results return results
...@@ -139,6 +195,10 @@ class NLPPredictionModule(NLPBaseModule): ...@@ -139,6 +195,10 @@ class NLPPredictionModule(NLPBaseModule):
'--use_gpu', type=ast.literal_eval, default=False, help='whether use GPU for prediction') '--use_gpu', type=ast.literal_eval, default=False, help='whether use GPU for prediction')
self.arg_config_group.add_argument('--batch_size', type=int, default=1, help='batch size for prediction') self.arg_config_group.add_argument('--batch_size', type=int, default=1, help='batch size for prediction')
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
'''Add the command input options''' '''Add the command input options'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册