未验证 提交 b0d82b3d 编写于 作者: H houj04 提交者: GitHub

lac support npu and xpu (#1613)

上级 e420428b
...@@ -13,7 +13,10 @@ import six ...@@ -13,7 +13,10 @@ import six
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddle.inference import Config
from paddle.inference import create_predictor
import paddlehub as hub import paddlehub as hub
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common.paddle_helper import add_vars_prefix from paddlehub.common.paddle_helper import add_vars_prefix
...@@ -62,26 +65,86 @@ class LAC(hub.Module): ...@@ -62,26 +65,86 @@ class LAC(hub.Module):
self._set_config() self._set_config()
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
cpu_config = AnalysisConfig(self.pretrained_model_path)
# create default cpu predictor
cpu_config = Config(self.pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: # create predictors using various types of devices
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) # npu
use_gpu = True npu_id = self._get_device_id("FLAGS_selected_npus")
except: if npu_id != -1:
use_gpu = False # use npu
if use_gpu: npu_config = Config(self.pretrained_model_path)
gpu_config = AnalysisConfig(self.pretrained_model_path) npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
# gpu
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=gpu_id)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.pretrained_model_path)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
def _internal_predict(self, predictor, texts):
"""
Tranform the texts(list) to Tensor and then do "real predict"
Args:
texts(list): texts
Returns:
result(PaddleInferTensor): predict output
"""
# texts to data and lod
lod = [0]
data = []
for i, text in enumerate(texts):
text_inds = word_to_ids(text, self.word2id_dict, self.word_replace_dict, oov_id=self.oov_id)
data += text_inds
lod.append(len(text_inds) + lod[i])
# get predictor tensor
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0])
# set data, shape and lod
input_tensor.copy_from_cpu(np.array(data).astype('int64'))
input_tensor.reshape([lod[-1], 1])
input_tensor.set_lod([lod])
# real predict
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
return output_handle
def context(self, trainable=False): def context(self, trainable=False):
""" """
...@@ -167,26 +230,6 @@ class LAC(hub.Module): ...@@ -167,26 +230,6 @@ class LAC(hub.Module):
texts = unicode_texts texts = unicode_texts
return texts return texts
def texts2tensor(self, texts):
"""
Tranform the texts(list) to PaddleTensor
Args:
texts(list): texts
Returns:
tensor(PaddleTensor): tensor with texts data
"""
lod = [0]
data = []
for i, text in enumerate(texts):
text_inds = word_to_ids(text, self.word2id_dict, self.word_replace_dict, oov_id=self.oov_id)
data += text_inds
lod.append(len(text_inds) + lod[i])
tensor = PaddleTensor(np.array(data).astype('int64'))
tensor.name = "words"
tensor.lod = [lod]
tensor.shape = [lod[-1], 1]
return tensor
def _get_index(self, data_list, item=""): def _get_index(self, data_list, item=""):
""" """
find all indexes of item in data_list find all indexes of item in data_list
...@@ -198,7 +241,7 @@ class LAC(hub.Module): ...@@ -198,7 +241,7 @@ class LAC(hub.Module):
return res return res
@serving @serving
def cut(self, text, use_gpu=False, batch_size=1, return_tag=True): def cut(self, text, use_gpu=False, batch_size=1, return_tag=True, use_device=None):
""" """
The main function that segments an entire text that contains The main function that segments an entire text that contains
Chinese characters into separated words. Chinese characters into separated words.
...@@ -207,20 +250,32 @@ class LAC(hub.Module): ...@@ -207,20 +250,32 @@ class LAC(hub.Module):
use_gpu(bool): whether use gpu to predict or not use_gpu(bool): whether use gpu to predict or not
batch_size(int): the program deals once with one batch batch_size(int): the program deals once with one batch
return_tag: Whether to get tag or not. return_tag: Whether to get tag or not.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
results(dict or list): The word segmentation result of the input text, whose key is 'word', if text is a list. results(dict or list): The word segmentation result of the input text, whose key is 'word', if text is a list.
If text is a str, the word segmentation result (list) is obtained. If text is a str, the word segmentation result (list) is obtained.
""" """
if use_gpu:
try: # real predictor to use
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device is not None:
int(_places[0]) if use_device == "cpu":
except: predictor = self.cpu_predictor
raise RuntimeError( elif use_device == "xpu":
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." predictor = self.xpu_predictor
) elif use_device == "npu":
predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
if isinstance(text, list) and len(text) != 0: if isinstance(text, list) and len(text) != 0:
...@@ -240,13 +295,8 @@ class LAC(hub.Module): ...@@ -240,13 +295,8 @@ class LAC(hub.Module):
batch_data = predicted_data[start_idx:] batch_data = predicted_data[start_idx:]
start_idx = start_idx + batch_size start_idx = start_idx + batch_size
tensor_words = self.texts2tensor(batch_data) batch_out = self._internal_predict(predictor, batch_data)
batch_result = parse_result(batch_data, batch_out, self.id2label_dict, interventer=self.custom)
if use_gpu:
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = parse_result(batch_data, batch_out[0], self.id2label_dict, interventer=self.custom)
results += batch_result results += batch_result
for index in empty_str_indexes: for index in empty_str_indexes:
...@@ -259,13 +309,8 @@ class LAC(hub.Module): ...@@ -259,13 +309,8 @@ class LAC(hub.Module):
return results return results
elif isinstance(text, str) and text != "": elif isinstance(text, str) and text != "":
tensor_words = self.texts2tensor([text]) batch_out = self._internal_predict(predictor, [text])
batch_result = parse_result([text], batch_out, self.id2label_dict, interventer=self.custom)
if use_gpu:
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = parse_result([text], batch_out[0], self.id2label_dict, interventer=self.custom)
return batch_result[0]['word'] return batch_result[0]['word']
elif text == "": elif text == "":
...@@ -273,7 +318,7 @@ class LAC(hub.Module): ...@@ -273,7 +318,7 @@ class LAC(hub.Module):
else: else:
raise TypeError("The input data is inconsistent with expectations.") raise TypeError("The input data is inconsistent with expectations.")
def lexical_analysis(self, texts=[], data={}, use_gpu=False, batch_size=1, return_tag=True): def lexical_analysis(self, texts=[], data={}, use_gpu=False, batch_size=1, return_tag=True, use_device=None):
""" """
Get the word segmentation results with the texts as input Get the word segmentation results with the texts as input
...@@ -283,19 +328,30 @@ class LAC(hub.Module): ...@@ -283,19 +328,30 @@ class LAC(hub.Module):
use_gpu(bool): whether use gpu to predict or not use_gpu(bool): whether use gpu to predict or not
batch_size(int): the program deals once with one batch batch_size(int): the program deals once with one batch
return_tag: Whether to get tag or not. return_tag: Whether to get tag or not.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
results(list): the word segmentation results results(list): the word segmentation results
""" """
if use_gpu: # real predictor to use
try: if use_device is not None:
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device == "cpu":
int(_places[0]) predictor = self.cpu_predictor
except: elif use_device == "xpu":
raise RuntimeError( predictor = self.xpu_predictor
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." elif use_device == "npu":
) predictor = self.npu_predictor
elif use_device == "gpu":
predictor = self.gpu_predictor
else:
raise Exception("Unsupported device: " + use_device)
else:
# use_device is not set, therefore follow use_gpu
if use_gpu:
predictor = self.gpu_predictor
else:
predictor = self.cpu_predictor
if texts != [] and isinstance(texts, list) and data == {}: if texts != [] and isinstance(texts, list) and data == {}:
predicted_data = texts predicted_data = texts
...@@ -320,13 +376,8 @@ class LAC(hub.Module): ...@@ -320,13 +376,8 @@ class LAC(hub.Module):
batch_data = predicted_data[start_idx:] batch_data = predicted_data[start_idx:]
start_idx = start_idx + batch_size start_idx = start_idx + batch_size
tensor_words = self.texts2tensor(batch_data) batch_out = self._internal_predict(predictor, batch_data)
batch_result = parse_result(batch_data, batch_out, self.id2label_dict, interventer=self.custom)
if use_gpu:
batch_out = self.gpu_predictor.run([tensor_words])
else:
batch_out = self.cpu_predictor.run([tensor_words])
batch_result = parse_result(batch_data, batch_out[0], self.id2label_dict, interventer=self.custom)
results += batch_result results += batch_result
for index in empty_str_indexes: for index in empty_str_indexes:
...@@ -344,8 +395,10 @@ class LAC(hub.Module): ...@@ -344,8 +395,10 @@ class LAC(hub.Module):
""" """
Run as a command Run as a command
""" """
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description="Run the lac module.",
description="Run the lac module.", prog='hub run lac', usage='%(prog)s', add_help=True) prog='hub run lac',
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
...@@ -365,8 +418,11 @@ class LAC(hub.Module): ...@@ -365,8 +418,11 @@ class LAC(hub.Module):
if args.user_dict: if args.user_dict:
self.set_user_dict(args.user_dict) self.set_user_dict(args.user_dict)
results = self.lexical_analysis( results = self.lexical_analysis(texts=input_data,
texts=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size, return_tag=args.return_tag) use_gpu=args.use_gpu,
batch_size=args.batch_size,
return_tag=args.return_tag,
use_device=args.use_device)
return results return results
...@@ -388,17 +444,23 @@ class LAC(hub.Module): ...@@ -388,17 +444,23 @@ class LAC(hub.Module):
""" """
Add the command config options Add the command config options
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--use_gpu',
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not") type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument('--batch_size', type=int, default=1, help="batch size for prediction") self.arg_config_group.add_argument('--batch_size', type=int, default=1, help="batch size for prediction")
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--user_dict',
'--user_dict', type=str,
type=str, default=None,
default=None, help="customized dictionary for intervening the word segmentation result")
help="customized dictionary for intervening the word segmentation result") self.arg_config_group.add_argument('--return_tag',
self.arg_config_group.add_argument( type=ast.literal_eval,
'--return_tag', type=ast.literal_eval, default=True, help="whether return tags of results or not") default=True,
help="whether return tags of results or not")
self.arg_config_group.add_argument('--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -251,8 +251,8 @@ def word_to_ids(words, word2id_dict, word_replace_dict, oov_id=None): ...@@ -251,8 +251,8 @@ def word_to_ids(words, word2id_dict, word_replace_dict, oov_id=None):
def parse_result(lines, crf_decode, id2label_dict, interventer=None): def parse_result(lines, crf_decode, id2label_dict, interventer=None):
"""Convert model's output tensor into string and tags """ """Convert model's output tensor into string and tags """
offset_list = crf_decode.lod[0] offset_list = crf_decode.lod()[0]
crf_decode = crf_decode.as_ndarray() crf_decode = crf_decode.copy_to_cpu()
batch_size = len(offset_list) - 1 batch_size = len(offset_list) - 1
batch_out = [] batch_out = []
for sent_index in range(batch_size): for sent_index in range(batch_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册