“79ba1760e4d1d4e13dc620af54b63f9c83047d3c”上不存在“mobile/test/framework/test_load.cpp”
未验证 提交 b0d82b3d 编写于 作者: H houj04 提交者: GitHub

lac support npu and xpu (#1613)

上级 e420428b
......@@ -13,7 +13,10 @@ import six
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.common.paddle_helper import add_vars_prefix
......@@ -62,26 +65,86 @@ class LAC(hub.Module):
self._set_config()
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.pretrained_model_path)
self.cpu_predictor = create_predictor(cpu_config)
# create predictors using various types of devices
# npu
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.pretrained_model_path)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def _internal_predict(self, predictor, texts):
"""
Tranform the texts(list) to Tensor and then do "real predict"
Args:
texts(list): texts
Returns:
result(PaddleInferTensor): predict output
"""
# texts to data and lod
lod = [0]
data = []
for i, text in enumerate(texts):
text_inds = word_to_ids(text, self.word2id_dict, self.word_replace_dict, oov_id=self.oov_id)
data += text_inds
lod.append(len(text_inds) + lod[i])
# get predictor tensor
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
# set data, shape and lod
input_tensor.copy_from_cpu(np.array(data).astype('int64'))
input_tensor.reshape([lod[-1], 1])
input_tensor.set_lod([lod])
# real predict
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
return output_handle
def context(self, trainable=False):
"""
......@@ -167,26 +230,6 @@ class LAC(hub.Module):
texts = unicode_texts
return texts
def texts2tensor(self, texts):
"""
Tranform the texts(list) to PaddleTensor
Args:
texts(list): texts
Returns:
tensor(PaddleTensor): tensor with texts data
"""
lod = [0]
data = []
for i, text in enumerate(texts):
text_inds = word_to_ids(text, self.word2id_dict, self.word_replace_dict, oov_id=self.oov_id)
data += text_inds
lod.append(len(text_inds) + lod[i])
tensor = PaddleTensor(np.array(data).astype('int64'))
tensor.name = "words"
tensor.lod = [lod]
tensor.shape = [lod[-1], 1]
return tensor
def _get_index(self, data_list, item=""):
"""
find all indexes of item in data_list
......@@ -198,7 +241,7 @@ class LAC(hub.Module):
return res
@serving
def cut(self, text, use_gpu=False, batch_size=1, return_tag=True):
def cut(self, text, use_gpu=False, batch_size=1, return_tag=True, use_device=None):
"""
The main function that segments an entire text that contains
Chinese characters into separated words.
......@@ -207,20 +250,32 @@ class LAC(hub.Module):
use_gpu(bool): whether use gpu to predict or not
batch_size(int): the program deals once with one batch
return_tag: Whether to get tag or not.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
results(dict or list): The word segmentation result of the input text, whose key is 'word', if text is a list.
If text is a str, the word segmentation result (list) is obtained.
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
# real predictor to use
if use_device is not None:
if use_device == "cpu":
predictor = self.cpu_predictor
elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
if isinstance(text, list) and len(text) != 0:
......@@ -240,13 +295,8 @@ class LAC(hub.Module):
batch_data = predicted_data[start_idx:]
start_idx = start_idx + batch_size
tensor_words = self.texts2tensor(batch_data)
if use_gpu:
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = parse_result(batch_data, batch_out[0], self.id2label_dict, interventer=self.custom)
batch_out = self._internal_predict(predictor, batch_data)
batch_result = parse_result(batch_data, batch_out, self.id2label_dict, interventer=self.custom)
results += batch_result
for index in empty_str_indexes:
......@@ -259,13 +309,8 @@ class LAC(hub.Module):
return results
elif isinstance(text, str) and text != "":
tensor_words = self.texts2tensor([text])
if use_gpu:
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = parse_result([text], batch_out[0], self.id2label_dict, interventer=self.custom)
batch_out = self._internal_predict(predictor, [text])
batch_result = parse_result([text], batch_out, self.id2label_dict, interventer=self.custom)
return batch_result[0]['word']
elif text == "":
......@@ -273,7 +318,7 @@ class LAC(hub.Module):
else:
raise TypeError("The input data is inconsistent with expectations.")
def lexical_analysis(self, texts=[], data={}, use_gpu=False, batch_size=1, return_tag=True):
def lexical_analysis(self, texts=[], data={}, use_gpu=False, batch_size=1, return_tag=True, use_device=None):
"""
Get the word segmentation results with the texts as input
......@@ -283,19 +328,30 @@ class LAC(hub.Module):
use_gpu(bool): whether use gpu to predict or not
batch_size(int): the program deals once with one batch
return_tag: Whether to get tag or not.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
results(list): the word segmentation results
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
# real predictor to use
if use_device is not None:
if use_device == "cpu":
predictor = self.cpu_predictor
elif use_device == "xpu":
predictor = self.xpu_predictor
elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
if texts != [] and isinstance(texts, list) and data == {}:
predicted_data = texts
......@@ -320,13 +376,8 @@ class LAC(hub.Module):
batch_data = predicted_data[start_idx:]
start_idx = start_idx + batch_size
tensor_words = self.texts2tensor(batch_data)
if use_gpu:
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = parse_result(batch_data, batch_out[0], self.id2label_dict, interventer=self.custom)
batch_out = self._internal_predict(predictor, batch_data)
batch_result = parse_result(batch_data, batch_out, self.id2label_dict, interventer=self.custom)
results += batch_result
for index in empty_str_indexes:
......@@ -344,8 +395,10 @@ class LAC(hub.Module):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(
description="Run the lac module.", prog='hub run lac', usage='%(prog)s', add_help=True)
self.parser = argparse.ArgumentParser(description="Run the lac module.",
prog='hub run lac',
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
......@@ -365,8 +418,11 @@ class LAC(hub.Module):
if args.user_dict:
self.set_user_dict(args.user_dict)
results = self.lexical_analysis(
texts=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size, return_tag=args.return_tag)
results = self.lexical_analysis(texts=input_data,
use_gpu=args.use_gpu,
batch_size=args.batch_size,
return_tag=args.return_tag,
use_device=args.use_device)
return results
......@@ -388,17 +444,23 @@ class LAC(hub.Module):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not")
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument('--batch_size', type=int, default=1, help="batch size for prediction")
self.arg_config_group.add_argument(
'--user_dict',
type=str,
default=None,
help="customized dictionary for intervening the word segmentation result")
self.arg_config_group.add_argument(
'--return_tag', type=ast.literal_eval, default=True, help="whether return tags of results or not")
self.arg_config_group.add_argument('--user_dict',
type=str,
default=None,
help="customized dictionary for intervening the word segmentation result")
self.arg_config_group.add_argument('--return_tag',
type=ast.literal_eval,
default=True,
help="whether return tags of results or not")
self.arg_config_group.add_argument('--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self):
"""
......
......@@ -251,8 +251,8 @@ def word_to_ids(words, word2id_dict, word_replace_dict, oov_id=None):
def parse_result(lines, crf_decode, id2label_dict, interventer=None):
"""Convert model's output tensor into string and tags """
offset_list = crf_decode.lod[0]
crf_decode = crf_decode.as_ndarray()
offset_list = crf_decode.lod()[0]
crf_decode = crf_decode.copy_to_cpu()
batch_size = len(offset_list) - 1
batch_out = []
for sent_index in range(batch_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册