generate_sample.py 3.1 KB
Newer Older
Z
Zhong Hui 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Many thanks for following projects.
# https://github.com/TsinghuaAI/CPM-Generate
# https://github.com/jm12138/CPM-Generate-Paddle

import argparse
import numpy as np

import paddle
from paddlenlp.utils.tools import loadz
from paddlenlp.transformers import GPT2Model, GPT2ForPretraining
from paddlenlp.transformers import GPT2ChineseTokenizer, GPT2Tokenizer
from paddlenlp.utils.log import logger

MODEL_CLASSES = {
    "gpt2-base-cn": (GPT2ForPretraining, GPT2ChineseTokenizer),
    "gpt2-medium-en": (GPT2ForPretraining, GPT2Tokenizer),
}


class Demo:
    def __init__(self, model_name_or_path="gpt2-base-cn"):
        model_class, tokenizer_class = MODEL_CLASSES[model_name_or_path]
        self.tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
        logger.info('Loading the model parameters, please wait...')
        self.model = model_class.from_pretrained(model_name_or_path)
        self.model.eval()
        logger.info('Model loaded.')

    # prediction function
    def predict(self, text, max_len=10):
        ids = self.tokenizer.encode(text)
        input_id = paddle.to_tensor(
            np.array(ids).reshape(1, -1).astype('int64'))
        output, cached_kvs = self.model(input_id, use_cache=True, cache=None)
        nid = int(np.argmax(output[0, -1].numpy()))
        ids.append(nid)
        out = [nid]
        for i in range(max_len):
            input_id = paddle.to_tensor(
                np.array([nid]).reshape(1, -1).astype('int64'))
            output, cached_kvs = self.model(
                input_id, use_cache=True, cache=cached_kvs)
            nid = int(np.argmax(output[0, -1].numpy()))
            ids.append(nid)
            # if nid is '\n', the predicion is over.
            if nid == 3:
                break
            out.append(nid)
        logger.info(text)
        logger.info(self.tokenizer.decode(out))

    # One shot example
    def ask_question(self, question, max_len=10):
        self.predict("问题:中国的首都是哪里?答案:北京。\n问题:%s 答案:" % question, max_len)

    # dictation poetry
    def dictation_poetry(self, front, max_len=10):
        self.predict('''默写古诗: 大漠孤烟直,长河落日圆。\n%s''' % front, max_len)


if __name__ == "__main__":
    demo = Demo("gpt2-base-cn")
    demo.ask_question("百度的厂长是谁?")
    demo.dictation_poetry("举杯邀明月,")
    del demo
    # demo = Demo("gpt2-medium-en")
    # demo.predict("Question: Where is the capital of China? Answer: Beijing. \nQuestion: Who is the CEO of Apple? Answer:", 20)