From b0d82b3de8ffaa03d4c177ddf07eef5dbc5cb6c0 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Mon, 13 Sep 2021 19:13:37 +0800 Subject: [PATCH] lac support npu and xpu (#1613) --- modules/text/lexical_analysis/lac/module.py | 234 +++++++++++------- .../text/lexical_analysis/lac/processor.py | 4 +- 2 files changed, 150 insertions(+), 88 deletions(-) diff --git a/modules/text/lexical_analysis/lac/module.py b/modules/text/lexical_analysis/lac/module.py index fb460ba5..40136fe6 100644 --- a/modules/text/lexical_analysis/lac/module.py +++ b/modules/text/lexical_analysis/lac/module.py @@ -13,7 +13,10 @@ import six import numpy as np import paddle.fluid as fluid -from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor + +from paddle.inference import Config +from paddle.inference import create_predictor + import paddlehub as hub from paddlehub.common.logger import logger from paddlehub.common.paddle_helper import add_vars_prefix @@ -62,26 +65,86 @@ class LAC(hub.Module): self._set_config() + def _get_device_id(self, places): + try: + places = os.environ[places] + id = int(places) + except: + id = -1 + return id + def _set_config(self): """ predictor config setting """ - cpu_config = AnalysisConfig(self.pretrained_model_path) + + # create default cpu predictor + cpu_config = Config(self.pretrained_model_path) cpu_config.disable_glog_info() cpu_config.disable_gpu() - self.cpu_predictor = create_paddle_predictor(cpu_config) - - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - use_gpu = True - except: - use_gpu = False - if use_gpu: - gpu_config = AnalysisConfig(self.pretrained_model_path) + self.cpu_predictor = create_predictor(cpu_config) + + # create predictors using various types of devices + + # npu + npu_id = self._get_device_id("FLAGS_selected_npus") + if npu_id != -1: + # use npu + npu_config = Config(self.pretrained_model_path) + npu_config.disable_glog_info() + npu_config.enable_npu(device_id=npu_id) + self.npu_predictor = create_predictor(npu_config) + + # gpu + gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES") + if gpu_id != -1: + # use gpu + gpu_config = Config(self.pretrained_model_path) gpu_config.disable_glog_info() - gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) - self.gpu_predictor = create_paddle_predictor(gpu_config) + gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=gpu_id) + self.gpu_predictor = create_predictor(gpu_config) + + # xpu + xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES") + if xpu_id != -1: + # use xpu + xpu_config = Config(self.pretrained_model_path) + xpu_config.disable_glog_info() + xpu_config.enable_xpu(100) + self.xpu_predictor = create_predictor(xpu_config) + + def _internal_predict(self, predictor, texts): + """ + Tranform the texts(list) to Tensor and then do "real predict" + Args: + texts(list): texts + Returns: + result(PaddleInferTensor): predict output + """ + + # texts to data and lod + lod = [0] + data = [] + for i, text in enumerate(texts): + text_inds = word_to_ids(text, self.word2id_dict, self.word_replace_dict, oov_id=self.oov_id) + data += text_inds + lod.append(len(text_inds) + lod[i]) + + # get predictor tensor + input_names = predictor.get_input_names() + input_tensor = predictor.get_input_handle(input_names[0]) + + # set data, shape and lod + input_tensor.copy_from_cpu(np.array(data).astype('int64')) + input_tensor.reshape([lod[-1], 1]) + input_tensor.set_lod([lod]) + + # real predict + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + + return output_handle def context(self, trainable=False): """ @@ -167,26 +230,6 @@ class LAC(hub.Module): texts = unicode_texts return texts - def texts2tensor(self, texts): - """ - Tranform the texts(list) to PaddleTensor - Args: - texts(list): texts - Returns: - tensor(PaddleTensor): tensor with texts data - """ - lod = [0] - data = [] - for i, text in enumerate(texts): - text_inds = word_to_ids(text, self.word2id_dict, self.word_replace_dict, oov_id=self.oov_id) - data += text_inds - lod.append(len(text_inds) + lod[i]) - tensor = PaddleTensor(np.array(data).astype('int64')) - tensor.name = "words" - tensor.lod = [lod] - tensor.shape = [lod[-1], 1] - return tensor - def _get_index(self, data_list, item=""): """ find all indexes of item in data_list @@ -198,7 +241,7 @@ class LAC(hub.Module): return res @serving - def cut(self, text, use_gpu=False, batch_size=1, return_tag=True): + def cut(self, text, use_gpu=False, batch_size=1, return_tag=True, use_device=None): """ The main function that segments an entire text that contains Chinese characters into separated words. @@ -207,20 +250,32 @@ class LAC(hub.Module): use_gpu(bool): whether use gpu to predict or not batch_size(int): the program deals once with one batch return_tag: Whether to get tag or not. + use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag. Returns: results(dict or list): The word segmentation result of the input text, whose key is 'word', if text is a list. If text is a str, the word segmentation result (list) is obtained. """ - if use_gpu: - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - except: - raise RuntimeError( - "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." - ) + + # real predictor to use + if use_device is not None: + if use_device == "cpu": + predictor = self.cpu_predictor + elif use_device == "xpu": + predictor = self.xpu_predictor + elif use_device == "npu": + predictor = self.npu_predictor + elif use_device == "gpu": + predictor = self.gpu_predictor + else: + raise Exception("Unsupported device: " + use_device) + else: + # use_device is not set, therefore follow use_gpu + if use_gpu: + predictor = self.gpu_predictor + else: + predictor = self.cpu_predictor if isinstance(text, list) and len(text) != 0: @@ -240,13 +295,8 @@ class LAC(hub.Module): batch_data = predicted_data[start_idx:] start_idx = start_idx + batch_size - tensor_words = self.texts2tensor(batch_data) - - if use_gpu: - batch_out = self.gpu_predictor.run([tensor_words]) - else: - batch_out = self.cpu_predictor.run([tensor_words]) - batch_result = parse_result(batch_data, batch_out[0], self.id2label_dict, interventer=self.custom) + batch_out = self._internal_predict(predictor, batch_data) + batch_result = parse_result(batch_data, batch_out, self.id2label_dict, interventer=self.custom) results += batch_result for index in empty_str_indexes: @@ -259,13 +309,8 @@ class LAC(hub.Module): return results elif isinstance(text, str) and text != "": - tensor_words = self.texts2tensor([text]) - - if use_gpu: - batch_out = self.gpu_predictor.run([tensor_words]) - else: - batch_out = self.cpu_predictor.run([tensor_words]) - batch_result = parse_result([text], batch_out[0], self.id2label_dict, interventer=self.custom) + batch_out = self._internal_predict(predictor, [text]) + batch_result = parse_result([text], batch_out, self.id2label_dict, interventer=self.custom) return batch_result[0]['word'] elif text == "": @@ -273,7 +318,7 @@ class LAC(hub.Module): else: raise TypeError("The input data is inconsistent with expectations.") - def lexical_analysis(self, texts=[], data={}, use_gpu=False, batch_size=1, return_tag=True): + def lexical_analysis(self, texts=[], data={}, use_gpu=False, batch_size=1, return_tag=True, use_device=None): """ Get the word segmentation results with the texts as input @@ -283,19 +328,30 @@ class LAC(hub.Module): use_gpu(bool): whether use gpu to predict or not batch_size(int): the program deals once with one batch return_tag: Whether to get tag or not. + use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag. Returns: results(list): the word segmentation results """ - if use_gpu: - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - except: - raise RuntimeError( - "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id." - ) + # real predictor to use + if use_device is not None: + if use_device == "cpu": + predictor = self.cpu_predictor + elif use_device == "xpu": + predictor = self.xpu_predictor + elif use_device == "npu": + predictor = self.npu_predictor + elif use_device == "gpu": + predictor = self.gpu_predictor + else: + raise Exception("Unsupported device: " + use_device) + else: + # use_device is not set, therefore follow use_gpu + if use_gpu: + predictor = self.gpu_predictor + else: + predictor = self.cpu_predictor if texts != [] and isinstance(texts, list) and data == {}: predicted_data = texts @@ -320,13 +376,8 @@ class LAC(hub.Module): batch_data = predicted_data[start_idx:] start_idx = start_idx + batch_size - tensor_words = self.texts2tensor(batch_data) - - if use_gpu: - batch_out = self.gpu_predictor.run([tensor_words]) - else: - batch_out = self.cpu_predictor.run([tensor_words]) - batch_result = parse_result(batch_data, batch_out[0], self.id2label_dict, interventer=self.custom) + batch_out = self._internal_predict(predictor, batch_data) + batch_result = parse_result(batch_data, batch_out, self.id2label_dict, interventer=self.custom) results += batch_result for index in empty_str_indexes: @@ -344,8 +395,10 @@ class LAC(hub.Module): """ Run as a command """ - self.parser = argparse.ArgumentParser( - description="Run the lac module.", prog='hub run lac', usage='%(prog)s', add_help=True) + self.parser = argparse.ArgumentParser(description="Run the lac module.", + prog='hub run lac', + usage='%(prog)s', + add_help=True) self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_config_group = self.parser.add_argument_group( @@ -365,8 +418,11 @@ class LAC(hub.Module): if args.user_dict: self.set_user_dict(args.user_dict) - results = self.lexical_analysis( - texts=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size, return_tag=args.return_tag) + results = self.lexical_analysis(texts=input_data, + use_gpu=args.use_gpu, + batch_size=args.batch_size, + return_tag=args.return_tag, + use_device=args.use_device) return results @@ -388,17 +444,23 @@ class LAC(hub.Module): """ Add the command config options """ - self.arg_config_group.add_argument( - '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not") + self.arg_config_group.add_argument('--use_gpu', + type=ast.literal_eval, + default=False, + help="whether use GPU or not") self.arg_config_group.add_argument('--batch_size', type=int, default=1, help="batch size for prediction") - self.arg_config_group.add_argument( - '--user_dict', - type=str, - default=None, - help="customized dictionary for intervening the word segmentation result") - self.arg_config_group.add_argument( - '--return_tag', type=ast.literal_eval, default=True, help="whether return tags of results or not") + self.arg_config_group.add_argument('--user_dict', + type=str, + default=None, + help="customized dictionary for intervening the word segmentation result") + self.arg_config_group.add_argument('--return_tag', + type=ast.literal_eval, + default=True, + help="whether return tags of results or not") + self.arg_config_group.add_argument('--use_device', + choices=["cpu", "gpu", "xpu", "npu"], + help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.") def add_module_input_arg(self): """ diff --git a/modules/text/lexical_analysis/lac/processor.py b/modules/text/lexical_analysis/lac/processor.py index 6ad9d661..1521182e 100644 --- a/modules/text/lexical_analysis/lac/processor.py +++ b/modules/text/lexical_analysis/lac/processor.py @@ -251,8 +251,8 @@ def word_to_ids(words, word2id_dict, word_replace_dict, oov_id=None): def parse_result(lines, crf_decode, id2label_dict, interventer=None): """Convert model's output tensor into string and tags """ - offset_list = crf_decode.lod[0] - crf_decode = crf_decode.as_ndarray() + offset_list = crf_decode.lod()[0] + crf_decode = crf_decode.copy_to_cpu() batch_size = len(offset_list) - 1 batch_out = [] for sent_index in range(batch_size): -- GitLab