module.py 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import json
K
kinghuin 已提交
17 18
import argparse
import os
19

K
kinghuin 已提交
20 21
import numpy as np
import paddle
22 23 24 25 26
import paddlehub as hub
from paddlehub.module.module import runnable
from paddlehub.module.nlp_module import DataFormatError
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, serving
K
kinghuin 已提交
27
from paddlenlp.transformers import ErnieTokenizer, ErnieForGeneration
28

K
kinghuin 已提交
29
from ernie_gen_acrostic_poetry.decode import beam_search_infilling
30 31 32 33


@moduleinfo(
    name="ernie_gen_acrostic_poetry",
K
kinghuin 已提交
34
    version="1.1.0",
35 36 37 38 39 40 41
    summary=
    "ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for poetry generation task.",
    author="adaxiadaxi",
    author_email="",
    type="nlp/text_generation",
)
class ErnieGen(hub.NLPPredictionModule):
W
wuzewu 已提交
42
    def __init__(self, line=4, word=7):
43 44 45 46 47 48 49 50
        """
        initialize with the necessary elements
        """
        if line not in [4, 8]:
            raise ValueError("The line could only be 4 or 8.")
        if word not in [5, 7]:
            raise ValueError("The word could only be 5 or 7.")

51
        self.line = line
52
        assets_path = os.path.join(self.directory, "assets")
K
kinghuin 已提交
53 54 55 56 57 58 59 60
        gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_acrostic_poetry_L%sW%s.pdparams" % (line, word))
        self.model = ErnieForGeneration.from_pretrained("ernie-1.0")
        model_state = paddle.load(gen_checkpoint_path)
        self.model.set_dict(model_state)
        self.tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
        self.rev_dict = self.tokenizer.vocab.idx_to_token
        self.rev_dict[self.tokenizer.vocab['[PAD]']] = ''  # replace [PAD]
        self.rev_dict[self.tokenizer.vocab['[UNK]']] = ''  # replace [PAD]
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        self.rev_lookup = np.vectorize(lambda i: self.rev_dict[i])

    @serving
    def generate(self, texts, use_gpu=False, beam_width=5):
        """
        Get the continuation of the input poetry.

        Args:
             texts(list): the front part of a poetry.
             use_gpu(bool): whether use gpu to predict or not
             beam_width(int): the beam search width.

        Returns:
             results(list): the poetry continuations.
        """
W
wuzewu 已提交
76 77
        paddle.disable_static()

W
wuzewu 已提交
78
        if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]):
79 80
            predicted_data = texts
        else:
W
wuzewu 已提交
81
            raise ValueError("The input texts should be a list with nonempty string elements.")
82 83
        for i, text in enumerate(texts):
            if len(text) > self.line:
W
wuzewu 已提交
84 85
                logger.warning(
                    'The input text: %s, contains more than %i characters, which will be cut off' % (text, self.line))
86 87 88 89 90
                texts[i] = text[:self.line]

            for char in text:
                if not '\u4e00' <= char <= '\u9fff':
                    logger.warning(
W
wuzewu 已提交
91
                        'The input text: %s, contains non-Chinese characters, which may result in magic output' % text)
92 93
                    break

94 95 96 97 98
        if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ:
            use_gpu = False
            logger.warning(
                "use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
            )
K
kinghuin 已提交
99 100 101 102 103 104 105 106 107 108

        paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')

        self.model.eval()
        results = []
        for text in predicted_data:
            sample_results = []
            encode_text = self.tokenizer.encode(text)
            src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0)
            src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0)
W
wuzewu 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122
            output_ids = beam_search_infilling(
                self.model,
                src_ids,
                src_sids,
                eos_id=self.tokenizer.vocab['[SEP]'],
                sos_id=self.tokenizer.vocab['[CLS]'],
                attn_id=self.tokenizer.vocab['[MASK]'],
                pad_id=self.tokenizer.vocab['[PAD]'],
                unk_id=self.tokenizer.vocab['[UNK]'],
                vocab_size=len(self.tokenizer.vocab),
                max_decode_len=80,
                max_encode_len=20,
                beam_width=beam_width,
                tgt_type_id=1)
K
kinghuin 已提交
123 124 125 126 127 128 129
            output_str = self.rev_lookup(output_ids[0])

            for ostr in output_str.tolist():
                if '[SEP]' in ostr:
                    ostr = ostr[:ostr.index('[SEP]')]
                sample_results.append("".join(ostr))
            results.append(sample_results)
130 131 132 133 134 135
        return results

    def add_module_config_arg(self):
        """
        Add the command config options
        """
W
wuzewu 已提交
136 137
        self.arg_config_group.add_argument(
            '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU for prediction")
138

W
wuzewu 已提交
139
        self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width")
140 141 142 143 144 145

    @runnable
    def run_cmd(self, argvs):
        """
        Run as a command
        """
W
wuzewu 已提交
146 147 148 149 150
        self.parser = argparse.ArgumentParser(
            description='Run the %s module.' % self.name,
            prog='hub run %s' % self.name,
            usage='%(prog)s',
            add_help=True)
151

W
wuzewu 已提交
152
        self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
153
        self.arg_config_group = self.parser.add_argument_group(
W
wuzewu 已提交
154
            title="Config options", description="Run configuration for controlling module behavior, optional.")
155 156 157 158 159 160 161 162 163 164 165 166

        self.add_module_config_arg()
        self.add_module_input_arg()

        args = self.parser.parse_args(argvs)

        try:
            input_data = self.check_input_data(args)
        except DataFormatError and RuntimeError:
            self.parser.print_help()
            return None

W
wuzewu 已提交
167
        results = self.generate(texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width)
168 169 170 171 172 173 174 175

        return results


if __name__ == "__main__":
    module = ErnieGen()
    for result in module.generate(['夏雨荷', '我喜欢你'], beam_width=5):
        print(result)