提交 cf5bb9f1 编写于 作者: W wuzewu

Fix ernie_gen* bug.

上级 a7c06ac9
...@@ -43,7 +43,7 @@ from .model import StackModel ...@@ -43,7 +43,7 @@ from .model import StackModel
type="nlp/text_generation", type="nlp/text_generation",
) )
class ErnieGen(hub.Module): class ErnieGen(hub.Module):
def _initialize(self): def __init__(self):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -59,25 +59,25 @@ class ErnieGen(hub.Module): ...@@ -59,25 +59,25 @@ class ErnieGen(hub.Module):
return self._model return self._model
def finetune( def finetune(
self, self,
train_path, train_path,
dev_path=None, dev_path=None,
save_dir="ernie_gen_result", save_dir="ernie_gen_result",
init_ckpt_path=None, init_ckpt_path=None,
use_gpu=True, use_gpu=True,
max_steps=500, max_steps=500,
batch_size=8, batch_size=8,
max_encode_len=50, max_encode_len=50,
max_decode_len=50, max_decode_len=50,
learning_rate=5e-5, learning_rate=5e-5,
warmup_proportion=0.1, warmup_proportion=0.1,
weight_decay=0.1, weight_decay=0.1,
noise_prob=0, noise_prob=0,
label_smooth=0, label_smooth=0,
beam_width=5, beam_width=5,
length_penalty=1.0, length_penalty=1.0,
log_interval=100, log_interval=100,
save_interval=200, save_interval=200,
): ):
""" """
finetune with the specified dataset. finetune with the specified dataset.
...@@ -109,6 +109,7 @@ class ErnieGen(hub.Module): ...@@ -109,6 +109,7 @@ class ErnieGen(hub.Module):
last_ppl(float): last model ppl. last_ppl(float): last model ppl.
} }
""" """
paddle.disable_static()
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
if init_ckpt_path is not None: if init_ckpt_path is not None:
...@@ -118,12 +119,13 @@ class ErnieGen(hub.Module): ...@@ -118,12 +119,13 @@ class ErnieGen(hub.Module):
train_dataset = self._load_dataset(train_path) train_dataset = self._load_dataset(train_path)
attn_id = self.tokenizer.vocab['[MASK]'] attn_id = self.tokenizer.vocab['[MASK]']
trans_func = convert_example(tokenizer=self.tokenizer, trans_func = convert_example(
attn_id=attn_id, tokenizer=self.tokenizer,
tgt_type_id=1, attn_id=attn_id,
max_encode_len=max_encode_len, tgt_type_id=1,
max_decode_len=max_decode_len, max_encode_len=max_encode_len,
noise_prob=noise_prob) max_decode_len=max_decode_len,
noise_prob=noise_prob)
train_dataset = train_dataset.map(trans_func) train_dataset = train_dataset.map(trans_func)
train_batch_sampler = paddle.io.BatchSampler(train_dataset, batch_size=batch_size, shuffle=True) train_batch_sampler = paddle.io.BatchSampler(train_dataset, batch_size=batch_size, shuffle=True)
...@@ -137,20 +139,18 @@ class ErnieGen(hub.Module): ...@@ -137,20 +139,18 @@ class ErnieGen(hub.Module):
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # attn_ids Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # attn_ids
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # tgt_labels Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # tgt_labels
): after_padding(fn(samples)) ): after_padding(fn(samples))
train_data_loader = DataLoader(dataset=train_dataset, train_data_loader = DataLoader(
batch_sampler=train_batch_sampler, dataset=train_dataset,
collate_fn=batchify_fn, batch_sampler=train_batch_sampler,
num_workers=0, collate_fn=batchify_fn,
return_list=True) num_workers=0,
return_list=True)
if dev_path: if dev_path:
dev_dataset = self._load_dataset(dev_path) dev_dataset = self._load_dataset(dev_path)
dev_dataset = dev_dataset.map(trans_func) dev_dataset = dev_dataset.map(trans_func)
dev_data_loader = DataLoader(dataset=dev_dataset, dev_data_loader = DataLoader(
batch_size=batch_size, dataset=dev_dataset, batch_size=batch_size, collate_fn=batchify_fn, num_workers=0, return_list=True)
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
label_num = self.model.word_emb.weight.shape[0] label_num = self.model.word_emb.weight.shape[0]
train_model = StackModel(self.model) train_model = StackModel(self.model)
...@@ -158,11 +158,12 @@ class ErnieGen(hub.Module): ...@@ -158,11 +158,12 @@ class ErnieGen(hub.Module):
# Generate parameter names needed to perform weight decay. # Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded. # All bias and LayerNorm parameters are excluded.
decay_params = [p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])] decay_params = [p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler, optimizer = paddle.optimizer.AdamW(
parameters=self.model.parameters(), learning_rate=lr_scheduler,
weight_decay=weight_decay, parameters=self.model.parameters(),
grad_clip=nn.ClipGradByGlobalNorm(1.0), weight_decay=weight_decay,
apply_decay_param_fun=lambda x: x in decay_params) grad_clip=nn.ClipGradByGlobalNorm(1.0),
apply_decay_param_fun=lambda x: x in decay_params)
rouge1 = Rouge1() rouge1 = Rouge1()
rouge2 = Rouge2() rouge2 = Rouge2()
...@@ -174,8 +175,8 @@ class ErnieGen(hub.Module): ...@@ -174,8 +175,8 @@ class ErnieGen(hub.Module):
(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt, (src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
mask_attn_2_srctgtattn, tgt_labels, _) = batch mask_attn_2_srctgtattn, tgt_labels, _) = batch
if label_smooth > 0.: if label_smooth > 0.:
tgt_labels = nn.functional.label_smooth(nn.functional.one_hot(tgt_labels, label_num), tgt_labels = nn.functional.label_smooth(
epsilon=label_smooth) nn.functional.one_hot(tgt_labels, label_num), epsilon=label_smooth)
tgt_pos = paddle.nonzero(attn_ids == attn_id) tgt_pos = paddle.nonzero(attn_ids == attn_id)
loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src,
...@@ -189,8 +190,8 @@ class ErnieGen(hub.Module): ...@@ -189,8 +190,8 @@ class ErnieGen(hub.Module):
if global_step % log_interval == 0 and paddle.distributed.get_rank() == 0: if global_step % log_interval == 0 and paddle.distributed.get_rank() == 0:
loss_np = loss.numpy() loss_np = loss.numpy()
ppl = np.exp(loss_np) ppl = np.exp(loss_np)
logger.info('[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' % logger.info('[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' % (global_step, max_steps, loss_np,
(global_step, max_steps, loss_np, ppl, lr_scheduler.get_lr())) ppl, lr_scheduler.get_lr()))
if save_dir and global_step % save_interval == 0 and global_step > 0: if save_dir and global_step % save_interval == 0 and global_step > 0:
loss_np = loss.numpy() loss_np = loss.numpy()
ppl = np.exp(loss_np) ppl = np.exp(loss_np)
...@@ -213,8 +214,8 @@ class ErnieGen(hub.Module): ...@@ -213,8 +214,8 @@ class ErnieGen(hub.Module):
if global_step % save_interval != 0: if global_step % save_interval != 0:
loss_np = loss.numpy() loss_np = loss.numpy()
ppl = np.exp(loss_np) ppl = np.exp(loss_np)
logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' % logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' % (global_step, loss_np, ppl,
(global_step, loss_np, ppl, lr_scheduler.get_lr())) lr_scheduler.get_lr()))
if save_dir: if save_dir:
save_name = "step_%s_ppl_%.5f.pdparams" % (global_step, ppl) save_name = "step_%s_ppl_%.5f.pdparams" % (global_step, ppl)
save_path = os.path.join(save_dir, save_name) save_path = os.path.join(save_dir, save_name)
...@@ -304,20 +305,21 @@ class ErnieGen(hub.Module): ...@@ -304,20 +305,21 @@ class ErnieGen(hub.Module):
for data in data_loader: for data in data_loader:
(src_ids, src_tids, src_pids, _, _, _, _, _, _, _, _, raw_tgt_labels) = data # never use target when infer (src_ids, src_tids, src_pids, _, _, _, _, _, _, _, _, raw_tgt_labels) = data # never use target when infer
# Use greedy_search_infilling or beam_search_infilling to get predictions # Use greedy_search_infilling or beam_search_infilling to get predictions
output_ids = beam_search_infilling(model, output_ids = beam_search_infilling(
src_ids, model,
src_tids, src_ids,
eos_id=eos_id, src_tids,
sos_id=sos_id, eos_id=eos_id,
attn_id=attn_id, sos_id=sos_id,
pad_id=pad_id, attn_id=attn_id,
unk_id=unk_id, pad_id=pad_id,
vocab_size=vocab_size, unk_id=unk_id,
max_decode_len=max_decode_len, vocab_size=vocab_size,
max_encode_len=max_encode_len, max_decode_len=max_decode_len,
beam_width=beam_width, max_encode_len=max_encode_len,
length_penalty=length_penalty, beam_width=beam_width,
tgt_type_id=1) length_penalty=length_penalty,
tgt_type_id=1)
for ids in output_ids.tolist(): for ids in output_ids.tolist():
if eos_id in ids: if eos_id in ids:
...@@ -359,10 +361,11 @@ class ErnieGen(hub.Module): ...@@ -359,10 +361,11 @@ class ErnieGen(hub.Module):
if __name__ == "__main__": if __name__ == "__main__":
module = ErnieGen() module = ErnieGen()
result = module.finetune(train_path='test_data/train.txt', result = module.finetune(
dev_path='test_data/dev.txt', train_path='test_data/train.txt',
max_steps=30, dev_path='test_data/dev.txt',
batch_size=2, max_steps=30,
log_interval=10, batch_size=2,
save_interval=20) log_interval=10,
save_interval=20)
module.export(params_path=result['last_save_path'], module_name="ernie_gen_test", author="test") module.export(params_path=result['last_save_path'], module_name="ernie_gen_test", author="test")
...@@ -39,7 +39,7 @@ from ernie_gen_acrostic_poetry.decode import beam_search_infilling ...@@ -39,7 +39,7 @@ from ernie_gen_acrostic_poetry.decode import beam_search_infilling
type="nlp/text_generation", type="nlp/text_generation",
) )
class ErnieGen(hub.NLPPredictionModule): class ErnieGen(hub.NLPPredictionModule):
def _initialize(self, line=4, word=7): def __init__(self, line=4, word=7):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -73,14 +73,16 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -73,14 +73,16 @@ class ErnieGen(hub.NLPPredictionModule):
Returns: Returns:
results(list): the poetry continuations. results(list): the poetry continuations.
""" """
paddle.disable_static()
if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]): if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]):
predicted_data = texts predicted_data = texts
else: else:
raise ValueError("The input texts should be a list with nonempty string elements.") raise ValueError("The input texts should be a list with nonempty string elements.")
for i, text in enumerate(texts): for i, text in enumerate(texts):
if len(text) > self.line: if len(text) > self.line:
logger.warning('The input text: %s, contains more than %i characters, which will be cut off' % logger.warning(
(text, self.line)) 'The input text: %s, contains more than %i characters, which will be cut off' % (text, self.line))
texts[i] = text[:self.line] texts[i] = text[:self.line]
for char in text: for char in text:
...@@ -104,19 +106,20 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -104,19 +106,20 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text = self.tokenizer.encode(text) encode_text = self.tokenizer.encode(text)
src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0) src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0)
src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0) src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0)
output_ids = beam_search_infilling(self.model, output_ids = beam_search_infilling(
src_ids, self.model,
src_sids, src_ids,
eos_id=self.tokenizer.vocab['[SEP]'], src_sids,
sos_id=self.tokenizer.vocab['[CLS]'], eos_id=self.tokenizer.vocab['[SEP]'],
attn_id=self.tokenizer.vocab['[MASK]'], sos_id=self.tokenizer.vocab['[CLS]'],
pad_id=self.tokenizer.vocab['[PAD]'], attn_id=self.tokenizer.vocab['[MASK]'],
unk_id=self.tokenizer.vocab['[UNK]'], pad_id=self.tokenizer.vocab['[PAD]'],
vocab_size=len(self.tokenizer.vocab), unk_id=self.tokenizer.vocab['[UNK]'],
max_decode_len=80, vocab_size=len(self.tokenizer.vocab),
max_encode_len=20, max_decode_len=80,
beam_width=beam_width, max_encode_len=20,
tgt_type_id=1) beam_width=beam_width,
tgt_type_id=1)
output_str = self.rev_lookup(output_ids[0]) output_str = self.rev_lookup(output_ids[0])
for ostr in output_str.tolist(): for ostr in output_str.tolist():
...@@ -130,10 +133,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -130,10 +133,8 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Add the command config options Add the command config options
""" """
self.arg_config_group.add_argument('--use_gpu', self.arg_config_group.add_argument(
type=ast.literal_eval, '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU for prediction")
default=False,
help="whether use GPU for prediction")
self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width") self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width")
...@@ -142,10 +143,11 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -142,10 +143,11 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Run as a command Run as a command
""" """
self.parser = argparse.ArgumentParser(description='Run the %s module.' % self.name, self.parser = argparse.ArgumentParser(
prog='hub run %s' % self.name, description='Run the %s module.' % self.name,
usage='%(prog)s', prog='hub run %s' % self.name,
add_help=True) usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
......
...@@ -39,7 +39,7 @@ from ernie_gen_couplet.decode import beam_search_infilling ...@@ -39,7 +39,7 @@ from ernie_gen_couplet.decode import beam_search_infilling
type="nlp/text_generation", type="nlp/text_generation",
) )
class ErnieGen(hub.NLPPredictionModule): class ErnieGen(hub.NLPPredictionModule):
def _initialize(self): def __init__(self):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns: Returns:
results(list): the right rolls. results(list): the right rolls.
""" """
paddle.disable_static()
if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]): if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]):
predicted_data = texts predicted_data = texts
else: else:
...@@ -93,19 +95,20 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -93,19 +95,20 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text = self.tokenizer.encode(text) encode_text = self.tokenizer.encode(text)
src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0) src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0)
src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0) src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0)
output_ids = beam_search_infilling(self.model, output_ids = beam_search_infilling(
src_ids, self.model,
src_sids, src_ids,
eos_id=self.tokenizer.vocab['[SEP]'], src_sids,
sos_id=self.tokenizer.vocab['[CLS]'], eos_id=self.tokenizer.vocab['[SEP]'],
attn_id=self.tokenizer.vocab['[MASK]'], sos_id=self.tokenizer.vocab['[CLS]'],
pad_id=self.tokenizer.vocab['[PAD]'], attn_id=self.tokenizer.vocab['[MASK]'],
unk_id=self.tokenizer.vocab['[UNK]'], pad_id=self.tokenizer.vocab['[PAD]'],
vocab_size=len(self.tokenizer.vocab), unk_id=self.tokenizer.vocab['[UNK]'],
max_decode_len=20, vocab_size=len(self.tokenizer.vocab),
max_encode_len=20, max_decode_len=20,
beam_width=beam_width, max_encode_len=20,
tgt_type_id=1) beam_width=beam_width,
tgt_type_id=1)
output_str = self.rev_lookup(output_ids[0]) output_str = self.rev_lookup(output_ids[0])
for ostr in output_str.tolist(): for ostr in output_str.tolist():
...@@ -119,10 +122,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -119,10 +122,8 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Add the command config options Add the command config options
""" """
self.arg_config_group.add_argument('--use_gpu', self.arg_config_group.add_argument(
type=ast.literal_eval, '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU for prediction")
default=False,
help="whether use GPU for prediction")
self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width") self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width")
...@@ -131,10 +132,11 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -131,10 +132,11 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Run as a command Run as a command
""" """
self.parser = argparse.ArgumentParser(description='Run the %s module.' % self.name, self.parser = argparse.ArgumentParser(
prog='hub run %s' % self.name, description='Run the %s module.' % self.name,
usage='%(prog)s', prog='hub run %s' % self.name,
add_help=True) usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
......
...@@ -39,7 +39,7 @@ from ernie_gen_lover_words.decode import beam_search_infilling ...@@ -39,7 +39,7 @@ from ernie_gen_lover_words.decode import beam_search_infilling
type="nlp/text_generation", type="nlp/text_generation",
) )
class ErnieGen(hub.NLPPredictionModule): class ErnieGen(hub.NLPPredictionModule):
def _initialize(self): def __init__(self):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns: Returns:
results(list): the poetry continuations. results(list): the poetry continuations.
""" """
paddle.disable_static()
if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]): if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]):
predicted_data = texts predicted_data = texts
else: else:
...@@ -85,19 +87,20 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -85,19 +87,20 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text = self.tokenizer.encode(text) encode_text = self.tokenizer.encode(text)
src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0) src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0)
src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0) src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0)
output_ids = beam_search_infilling(self.model, output_ids = beam_search_infilling(
src_ids, self.model,
src_sids, src_ids,
eos_id=self.tokenizer.vocab['[SEP]'], src_sids,
sos_id=self.tokenizer.vocab['[CLS]'], eos_id=self.tokenizer.vocab['[SEP]'],
attn_id=self.tokenizer.vocab['[MASK]'], sos_id=self.tokenizer.vocab['[CLS]'],
pad_id=self.tokenizer.vocab['[PAD]'], attn_id=self.tokenizer.vocab['[MASK]'],
unk_id=self.tokenizer.vocab['[UNK]'], pad_id=self.tokenizer.vocab['[PAD]'],
vocab_size=len(self.tokenizer.vocab), unk_id=self.tokenizer.vocab['[UNK]'],
max_decode_len=80, vocab_size=len(self.tokenizer.vocab),
max_encode_len=20, max_decode_len=80,
beam_width=beam_width, max_encode_len=20,
tgt_type_id=1) beam_width=beam_width,
tgt_type_id=1)
output_str = self.rev_lookup(output_ids[0]) output_str = self.rev_lookup(output_ids[0])
for ostr in output_str.tolist(): for ostr in output_str.tolist():
...@@ -111,10 +114,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -111,10 +114,8 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Add the command config options Add the command config options
""" """
self.arg_config_group.add_argument('--use_gpu', self.arg_config_group.add_argument(
type=ast.literal_eval, '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU for prediction")
default=False,
help="whether use GPU for prediction")
self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width") self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width")
...@@ -123,10 +124,11 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -123,10 +124,11 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Run as a command Run as a command
""" """
self.parser = argparse.ArgumentParser(description='Run the %s module.' % self.name, self.parser = argparse.ArgumentParser(
prog='hub run %s' % self.name, description='Run the %s module.' % self.name,
usage='%(prog)s', prog='hub run %s' % self.name,
add_help=True) usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
......
...@@ -39,7 +39,7 @@ from ernie_gen_poetry.decode import beam_search_infilling ...@@ -39,7 +39,7 @@ from ernie_gen_poetry.decode import beam_search_infilling
type="nlp/text_generation", type="nlp/text_generation",
) )
class ErnieGen(hub.NLPPredictionModule): class ErnieGen(hub.NLPPredictionModule):
def _initialize(self): def __init__(self):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -67,6 +67,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns: Returns:
results(list): the poetry continuations. results(list): the poetry continuations.
""" """
paddle.disable_static()
if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]): if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]):
predicted_data = texts predicted_data = texts
else: else:
...@@ -102,19 +104,20 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -102,19 +104,20 @@ class ErnieGen(hub.NLPPredictionModule):
encode_text = self.tokenizer.encode(text) encode_text = self.tokenizer.encode(text)
src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0) src_ids = paddle.to_tensor(encode_text['input_ids']).unsqueeze(0)
src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0) src_sids = paddle.to_tensor(encode_text['token_type_ids']).unsqueeze(0)
output_ids = beam_search_infilling(self.model, output_ids = beam_search_infilling(
src_ids, self.model,
src_sids, src_ids,
eos_id=self.tokenizer.vocab['[SEP]'], src_sids,
sos_id=self.tokenizer.vocab['[CLS]'], eos_id=self.tokenizer.vocab['[SEP]'],
attn_id=self.tokenizer.vocab['[MASK]'], sos_id=self.tokenizer.vocab['[CLS]'],
pad_id=self.tokenizer.vocab['[PAD]'], attn_id=self.tokenizer.vocab['[MASK]'],
unk_id=self.tokenizer.vocab['[UNK]'], pad_id=self.tokenizer.vocab['[PAD]'],
vocab_size=len(self.tokenizer.vocab), unk_id=self.tokenizer.vocab['[UNK]'],
max_decode_len=80, vocab_size=len(self.tokenizer.vocab),
max_encode_len=20, max_decode_len=80,
beam_width=beam_width, max_encode_len=20,
tgt_type_id=1) beam_width=beam_width,
tgt_type_id=1)
output_str = self.rev_lookup(output_ids[0]) output_str = self.rev_lookup(output_ids[0])
for ostr in output_str.tolist(): for ostr in output_str.tolist():
...@@ -128,10 +131,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -128,10 +131,8 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Add the command config options Add the command config options
""" """
self.arg_config_group.add_argument('--use_gpu', self.arg_config_group.add_argument(
type=ast.literal_eval, '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU for prediction")
default=False,
help="whether use GPU for prediction")
self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width") self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width")
...@@ -140,10 +141,11 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -140,10 +141,11 @@ class ErnieGen(hub.NLPPredictionModule):
""" """
Run as a command Run as a command
""" """
self.parser = argparse.ArgumentParser(description='Run the %s module.' % self.name, self.parser = argparse.ArgumentParser(
prog='hub run %s' % self.name, description='Run the %s module.' % self.name,
usage='%(prog)s', prog='hub run %s' % self.name,
add_help=True) usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册