From 241712a41cd39ca1bf82bc0df034e081e53657fe Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Tue, 5 Jan 2021 20:15:18 +0800 Subject: [PATCH] Add static predict (#5147) * add static predict * rm position_encoding_init * fix static predict --- .../benchmark/transformer/static/predict.py | 121 ++++++++++++++++++ .../benchmark/transformer/static/train.py | 2 +- .../machine_translation/transformer/train.py | 2 +- .../transformers/transformer/modeling.py | 5 +- 4 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 PaddleNLP/benchmark/transformer/static/predict.py diff --git a/PaddleNLP/benchmark/transformer/static/predict.py b/PaddleNLP/benchmark/transformer/static/predict.py new file mode 100644 index 00000000..27ccf1d6 --- /dev/null +++ b/PaddleNLP/benchmark/transformer/static/predict.py @@ -0,0 +1,121 @@ +import os +import time +import sys + +import argparse +import logging +import numpy as np +import yaml +from attrdict import AttrDict +from pprint import pprint + +import paddle + +from paddlenlp.transformers import InferTransformerModel + +sys.path.append("../") +import reader + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="../configs/transformer.big.yaml", + type=str, + help="Path of the config file. ") + args = parser.parse_args() + return args + + +def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): + """ + Post-process the decoded sequence. + """ + eos_pos = len(seq) - 1 + for i, idx in enumerate(seq): + if idx == eos_idx: + eos_pos = i + break + seq = [ + idx for idx in seq[:eos_pos + 1] + if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) + ] + return seq + + +def do_train(args): + paddle.enable_static() + if args.use_gpu: + place = paddle.set_device("gpu:0") + else: + place = paddle.set_device("cpu") + + # Define data loader + test_loader, to_tokens = reader.create_infer_loader(args) + + test_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(test_program, startup_program): + src_word = paddle.static.data( + name="src_word", shape=[None, None], dtype="int64") + + # Define model + transformer = InferTransformerModel( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + n_layer=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + beam_size=args.beam_size, + max_out_len=args.max_out_len) + + finished_seq = transformer(src_word=src_word) + + test_program = test_program.clone(for_test=True) + + exe = paddle.static.Executor(place) + exe.run(startup_program) + + assert ( + args.init_from_params), "must set init_from_params to load parameters" + paddle.static.load(test_program, + os.path.join(args.init_from_params, "transformer"), exe) + print("finish initing model from params from %s" % (args.init_from_params)) + + f = open(args.output_file, "w") + for data in test_loader: + finished_sequence, = exe.run(test_program, + feed={'src_word': data[0]}, + fetch_list=finished_seq.name) + finished_sequence = finished_sequence.transpose([0, 2, 1]) + for ins in finished_sequence: + for beam_idx, beam in enumerate(ins): + if beam_idx >= args.n_best: + break + id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) + word_list = to_tokens(id_list) + sequence = " ".join(word_list) + "\n" + f.write(sequence) + + paddle.disable_static() + + +if __name__ == "__main__": + ARGS = parse_args() + yaml_file = ARGS.config + with open(yaml_file, 'rt') as f: + args = AttrDict(yaml.safe_load(f)) + pprint(args) + + do_train(args) diff --git a/PaddleNLP/benchmark/transformer/static/train.py b/PaddleNLP/benchmark/transformer/static/train.py index 1b40bc01..8736d906 100644 --- a/PaddleNLP/benchmark/transformer/static/train.py +++ b/PaddleNLP/benchmark/transformer/static/train.py @@ -12,7 +12,7 @@ from pprint import pprint import paddle import paddle.distributed as dist -from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion, position_encoding_init +from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion sys.path.append("../") import reader diff --git a/PaddleNLP/examples/machine_translation/transformer/train.py b/PaddleNLP/examples/machine_translation/transformer/train.py index a8c9849d..0137968f 100644 --- a/PaddleNLP/examples/machine_translation/transformer/train.py +++ b/PaddleNLP/examples/machine_translation/transformer/train.py @@ -11,7 +11,7 @@ import paddle import paddle.distributed as dist import reader -from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion, position_encoding_init +from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion from paddlenlp.utils.log import logger diff --git a/PaddleNLP/paddlenlp/transformers/transformer/modeling.py b/PaddleNLP/paddlenlp/transformers/transformer/modeling.py index c69c2405..890189a3 100644 --- a/PaddleNLP/paddlenlp/transformers/transformer/modeling.py +++ b/PaddleNLP/paddlenlp/transformers/transformer/modeling.py @@ -167,7 +167,7 @@ class TransformerBeamSearchDecoder(nn.decode.BeamSearchDecoder): return c def _split_batch_beams_with_var_dim(self, c): - var_dim_size = c.shape[self.var_dim_in_state] + var_dim_size = paddle.shape(c)[self.var_dim_in_state] c = paddle.reshape( c, [-1, self.beam_size] + [int(size) @@ -374,6 +374,7 @@ class InferTransformerModel(TransformerModel): max_step_num=self.max_out_len, memory=enc_output, trg_src_attn_bias=trg_src_attn_bias, - static_cache=static_cache) + static_cache=static_cache, + is_test=True) return rs -- GitLab