interaction.py 2.9 KB
Newer Older
X
xiemoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
import json
import argparse
from collections import namedtuple
from termcolor import colored, cprint

import paddle

from utils.args import parse_args, str2bool
from utils import gen_inputs
from readers.nsp_reader import NSPReader
from readers.plato_reader import PlatoReader
from model import Plato2InferModel


def setup_args():
    """Setup arguments."""
    parser = argparse.ArgumentParser()
    group = parser.add_argument_group("Model")
    group.add_argument("--init_from_ckpt", type=str, default="")
    group.add_argument("--vocab_size", type=int, default=8001)
    group.add_argument("--latent_type_size", type=int, default=20)
    group.add_argument("--num_layers", type=int, default=24)

    group = parser.add_argument_group("Task")
    group.add_argument("--is_cn", type=str2bool, default=False)

    args, _ = parser.parse_known_args()
    NSPReader.add_cmdline_args(parser)

    args = parse_args(parser)
    args.batch_size *= args.latent_type_size

    #print(json.dumps(args, indent=2))
    return args


def load_params(model, init_from_ckpt):
    state_dict = paddle.load(init_from_ckpt)
    model.set_state_dict(state_dict)


def interact(args):
    """Inference main function."""
    plato_reader = PlatoReader(args)
    nsp_reader = NSPReader(args)

    if args.num_layers == 24:
        n_head = 16
        hidden_size = 1024
    elif args.num_layers == 32:
        n_head = 32
        hidden_size = 2048
    else:
        raise ValueError('The pre-trained model only support 24 or 32 layers, '
                         'but received num_layers=%d.' % args.num_layers)

    model = Plato2InferModel(nsp_reader, args.num_layers, n_head, hidden_size)
    load_params(model, args.init_from_ckpt)
    model.eval()

    Example = namedtuple("Example", ["src", "data_id"])
    context = []
    start_info = "Enter [EXIT] to quit the interaction, [NEXT] to start a new conversation."
    cprint(start_info, "yellow", attrs=["bold"])
    while True:
        user_utt = input(colored("[Human]: ", "red", attrs=["bold"])).strip()
        if user_utt == "[EXIT]":
            break
        elif user_utt == "[NEXT]":
            context = []
            cprint(start_info, "yellow", attrs=["bold"])
        else:
            context.append(user_utt)
            example = Example(src=" [SEP] ".join(context), data_id=0)
            record = plato_reader._convert_example_to_record(
                example, is_infer=True)
            data = plato_reader._pad_batch_records([record], is_infer=True)
            inputs = gen_inputs(data, args.latent_type_size)
            pred = model(inputs)[0]
            bot_response = pred["response"]
            print(
                colored(
                    "[Bot]:", "blue", attrs=["bold"]),
                colored(
                    bot_response, attrs=["bold"]))
            context.append(bot_response)
    return


if __name__ == "__main__":
    args = setup_args()
    interact(args)