diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index be6d8ff250c2694f57a980cfb734c892b6c96fc0..dbcd29f7e2cdd80fc1b2f20f0b0b047c27160a83 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -64,4 +64,4 @@ from .finetune.strategy import CombinedStrategy from .autofinetune.evaluator import report_final_result -from .module.nlp_module import BERTModule +from .module.nlp_module import NLPPredictionModule, TransformerModule diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index ec20133523f15267d28476889b7555fb9c803340..17bfa6bfd0adca13c240ee39285db5d8e0062c63 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -1,4 +1,4 @@ -#coding:utf-8 +# coding:utf-8 # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" @@ -135,19 +135,27 @@ class Module(object): if "_is_initialize" in self.__dict__ and self._is_initialize: return - mod = self.__class__.__module__ + "." + self.__class__.__name__ - if mod in _module_runnable_func: - _run_func_name = _module_runnable_func[mod] - self._run_func = getattr(self, _run_func_name) - else: - self._run_func = None - self._serving_func_name = _module_serving_func.get(mod, None) - self._code_version = "v2" + _run_func_name = self._get_func_name(self.__class__, + _module_runnable_func) + self._run_func = getattr(self, _run_func_name) + self._serving_func_name = self._get_func_name(self.__class__, + _module_serving_func) self._directory = directory self._initialize(**kwargs) self._is_initialize = True self._code_version = "v2" + def _get_func_name(self, current_cls, module_func_dict): + mod = current_cls.__module__ + "." + current_cls.__name__ + if mod in module_func_dict: + _func_name = module_func_dict[mod] + return _func_name + elif current_cls.__bases__: + for base_class in current_cls.__bases__: + return self._get_func_name(base_class, module_func_dict) + else: + return None + @classmethod def init_with_name(cls, name, version=None, **kwargs): fp_lock = open(os.path.join(CACHE_HOME, name), "a") diff --git a/paddlehub/module/nlp_module.py b/paddlehub/module/nlp_module.py index d162c938abfb0bdf5da1e1e81c9caef3bcc0978a..632dbfbc8dc4d8f6c52fdd869755035b0f240a9b 100644 --- a/paddlehub/module/nlp_module.py +++ b/paddlehub/module/nlp_module.py @@ -1,4 +1,4 @@ -#coding:utf-8 +# coding:utf-8 # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" @@ -17,15 +17,198 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import ast +import json import os import re +import six -import paddlehub as hub +import numpy as np import paddle.fluid as fluid -from paddlehub import logger +from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor +import paddlehub as hub +from paddlehub.common.logger import logger +from paddlehub.common.utils import sys_stdin_encoding +from paddlehub.io.parser import txt_parser +from paddlehub.module.module import runnable + + +class DataFormatError(Exception): + def __init__(self, *args): + self.args = args + + +class NLPBaseModule(hub.Module): + def _initialize(self): + """ + initialize with the necessary elements + This method must be overrided. + """ + raise NotImplementedError() + + def get_vocab_path(self): + """ + Get the path to the vocabulary whih was used to pretrain + + Returns: + self.vocab_path(str): the path to vocabulary + """ + return self.vocab_path + + +class NLPPredictionModule(NLPBaseModule): + def _set_config(self): + """ + predictor config setting + """ + cpu_config = AnalysisConfig(self.pretrained_model_path) + cpu_config.disable_glog_info() + cpu_config.disable_gpu() + self.cpu_predictor = create_paddle_predictor(cpu_config) + + try: + _places = os.environ["CUDA_VISIBLE_DEVICES"] + int(_places[0]) + use_gpu = True + except: + use_gpu = False + if use_gpu: + gpu_config = AnalysisConfig(self.pretrained_model_path) + gpu_config.disable_glog_info() + gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) + self.gpu_predictor = create_paddle_predictor(gpu_config) + + def texts2tensor(self, texts): + """ + Tranform the texts(dict) to PaddleTensor + Args: + texts(list): each element is a dict that must have a named 'processed' key whose value is word_ids, such as + texts = [{'processed': [23, 89, 43, 906]}] + Returns: + tensor(PaddleTensor): tensor with texts data + """ + lod = [0] + data = [] + for i, text in enumerate(texts): + data += text['processed'] + lod.append(len(text['processed']) + lod[i]) + tensor = PaddleTensor(np.array(data).astype('int64')) + tensor.name = "words" + tensor.lod = [lod] + tensor.shape = [lod[-1], 1] + return tensor + + def to_unicode(self, texts): + """ + Convert each element's type(str) of texts(list) to unicode in python2.7 + Args: + texts(list): each element's type is str in python2.7 + Returns: + texts(list): each element's type is unicode in python2.7 + """ + if six.PY2: + unicode_texts = [] + for text in texts: + if not isinstance(text, six.string_types): + unicode_texts.append( + text.decode(sys_stdin_encoding()).decode("utf8")) + else: + unicode_texts.append(text) + texts = unicode_texts + return texts + + @runnable + def run_cmd(self, argvs): + """ + Run as a command + """ + self.parser = argparse.ArgumentParser( + description='Run the %s module.' % self.module_name, + prog='hub run %s' % self.module_name, + usage='%(prog)s', + add_help=True) + + self.arg_input_group = self.parser.add_argument_group( + title="Input options", description="Input data. Required") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", + description= + "Run configuration for controlling module behavior, not required.") + + self.add_module_config_arg() + self.add_module_input_arg() + + args = self.parser.parse_args(argvs) + + try: + input_data = self.check_input_data(args) + except DataFormatError and RuntimeError: + self.parser.print_help() + return None + results = self.predict( + texts=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size) + + return results -class _BERTEmbeddingTask(hub.BaseTask): + def add_module_config_arg(self): + """ + Add the command config options + """ + self.arg_config_group.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=False, + help="whether use GPU for prediction") + + self.arg_config_group.add_argument( + '--batch_size', + type=int, + default=1, + help="batch size for prediction") + + def add_module_input_arg(self): + """ + Add the command input options + """ + self.arg_input_group.add_argument( + '--input_file', + type=str, + default=None, + help="file contain input data") + self.arg_input_group.add_argument( + '--input_text', type=str, default=None, help="text to predict") + + def check_input_data(self, args): + input_data = [] + if args.input_file: + if not os.path.exists(args.input_file): + print("File %s is not exist." % args.input_file) + raise RuntimeError + else: + input_data = txt_parser.parse(args.input_file, use_strip=True) + elif args.input_text: + if args.input_text.strip() != '': + if six.PY2: + input_data = [ + args.input_text.decode( + sys_stdin_encoding()).decode("utf8") + ] + else: + input_data = [args.input_text] + else: + print( + "ERROR: The input data is inconsistent with expectations.") + + if input_data == []: + print("ERROR: The input data is inconsistent with expectations.") + raise DataFormatError + + return input_data + + +class _TransformerEmbeddingTask(hub.BaseTask): def __init__(self, pooled_feature, seq_feature, @@ -33,7 +216,7 @@ class _BERTEmbeddingTask(hub.BaseTask): data_reader, config=None): main_program = pooled_feature.block.program - super(_BERTEmbeddingTask, self).__init__( + super(_TransformerEmbeddingTask, self).__init__( main_program=main_program, data_reader=data_reader, feed_list=feed_list, @@ -57,21 +240,10 @@ class _BERTEmbeddingTask(hub.BaseTask): return results -class BERTModule(hub.Module): - def _initialize(self): - """ - Must override this method. - - some member variables are required, others are optional. - """ - # required config - self.MAX_SEQ_LEN = None - self.params_path = None - self.vocab_path = None - # optional config - self.spm_path = None - self.word_dict_path = None - raise NotImplementedError +class TransformerModule(NLPBaseModule): + """ + Tranformer Module base class can be used by BERT, ERNIE, RoBERTa and so on. + """ def init_pretraining_params(self, exe, pretraining_params_path, main_program): @@ -157,7 +329,6 @@ class BERTModule(hub.Module): place = fluid.CPUPlace() exe = fluid.Executor(place) - exe.run(startup_program) self.init_pretraining_params( exe, self.params_path, main_program=startup_program) @@ -176,7 +347,7 @@ class BERTModule(hub.Module): def get_embedding(self, texts, use_gpu=False, batch_size=1): """ get pooled_output and sequence_output for input texts. - Warnings: this method depends on Paddle Inference Library, it may not work properly in PaddlePaddle < 1.6.2. + Warnings: this method depends on Paddle Inference Library, it may not work properly in PaddlePaddle <= 1.6.2. Args: texts (list): each element is a text sample, each sample include text_a and text_b where text_b can be omitted. @@ -220,7 +391,7 @@ class BERTModule(hub.Module): batch_size=batch_size) self.emb_job = {} - self.emb_job["task"] = _BERTEmbeddingTask( + self.emb_job["task"] = _TransformerEmbeddingTask( pooled_feature=pooled_feature, seq_feature=seq_feature, feed_list=feed_list, @@ -233,9 +404,6 @@ class BERTModule(hub.Module): return self.emb_job["task"].predict( data=texts, return_result=True, accelerate_mode=True) - def get_vocab_path(self): - return self.vocab_path - def get_spm_path(self): if hasattr(self, "spm_path"): return self.spm_path