未验证 提交 241712a4 编写于 作者: L liu zhengxi 提交者: GitHub

Add static predict (#5147)

* add static predict

* rm position_encoding_init

* fix static predict
上级 9ddc050e
import os
import time
import sys
import argparse
import logging
import numpy as np
import yaml
from attrdict import AttrDict
from pprint import pprint
import paddle
from paddlenlp.transformers import InferTransformerModel
sys.path.append("../")
import reader
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="../configs/transformer.big.yaml",
type=str,
help="Path of the config file. ")
args = parser.parse_args()
return args
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def do_train(args):
paddle.enable_static()
if args.use_gpu:
place = paddle.set_device("gpu:0")
else:
place = paddle.set_device("cpu")
# Define data loader
test_loader, to_tokens = reader.create_infer_loader(args)
test_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(test_program, startup_program):
src_word = paddle.static.data(
name="src_word", shape=[None, None], dtype="int64")
# Define model
transformer = InferTransformerModel(
src_vocab_size=args.src_vocab_size,
trg_vocab_size=args.trg_vocab_size,
max_length=args.max_length + 1,
n_layer=args.n_layer,
n_head=args.n_head,
d_model=args.d_model,
d_inner_hid=args.d_inner_hid,
dropout=args.dropout,
weight_sharing=args.weight_sharing,
bos_id=args.bos_idx,
eos_id=args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
finished_seq = transformer(src_word=src_word)
test_program = test_program.clone(for_test=True)
exe = paddle.static.Executor(place)
exe.run(startup_program)
assert (
args.init_from_params), "must set init_from_params to load parameters"
paddle.static.load(test_program,
os.path.join(args.init_from_params, "transformer"), exe)
print("finish initing model from params from %s" % (args.init_from_params))
f = open(args.output_file, "w")
for data in test_loader:
finished_sequence, = exe.run(test_program,
feed={'src_word': data[0]},
fetch_list=finished_seq.name)
finished_sequence = finished_sequence.transpose([0, 2, 1])
for ins in finished_sequence:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best:
break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = to_tokens(id_list)
sequence = " ".join(word_list) + "\n"
f.write(sequence)
paddle.disable_static()
if __name__ == "__main__":
ARGS = parse_args()
yaml_file = ARGS.config
with open(yaml_file, 'rt') as f:
args = AttrDict(yaml.safe_load(f))
pprint(args)
do_train(args)
...@@ -12,7 +12,7 @@ from pprint import pprint ...@@ -12,7 +12,7 @@ from pprint import pprint
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion, position_encoding_init from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion
sys.path.append("../") sys.path.append("../")
import reader import reader
......
...@@ -11,7 +11,7 @@ import paddle ...@@ -11,7 +11,7 @@ import paddle
import paddle.distributed as dist import paddle.distributed as dist
import reader import reader
from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion, position_encoding_init from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion
from paddlenlp.utils.log import logger from paddlenlp.utils.log import logger
......
...@@ -167,7 +167,7 @@ class TransformerBeamSearchDecoder(nn.decode.BeamSearchDecoder): ...@@ -167,7 +167,7 @@ class TransformerBeamSearchDecoder(nn.decode.BeamSearchDecoder):
return c return c
def _split_batch_beams_with_var_dim(self, c): def _split_batch_beams_with_var_dim(self, c):
var_dim_size = c.shape[self.var_dim_in_state] var_dim_size = paddle.shape(c)[self.var_dim_in_state]
c = paddle.reshape( c = paddle.reshape(
c, [-1, self.beam_size] + c, [-1, self.beam_size] +
[int(size) [int(size)
...@@ -374,6 +374,7 @@ class InferTransformerModel(TransformerModel): ...@@ -374,6 +374,7 @@ class InferTransformerModel(TransformerModel):
max_step_num=self.max_out_len, max_step_num=self.max_out_len,
memory=enc_output, memory=enc_output,
trg_src_attn_bias=trg_src_attn_bias, trg_src_attn_bias=trg_src_attn_bias,
static_cache=static_cache) static_cache=static_cache,
is_test=True)
return rs return rs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册