API ‘rnn’ raise NotImplementedError
Created by: Akeepers
我参考paddle官网seq2seq例子 实现了一个seq2seq model
其中的encoder被改为了ernie ,decoder的代码保持不动,仅修改:
- layers.rnn中的initial_states为None:原始代码中的decoder 中的initial_states 为encoder最后一个cell的states,由于采用ernie作为encoder,所以将decoder中的 initial_states设为None
从API文档来看,rnn op的initial_states可以为None,但会报错
代码如下:
def decoder(encoder_output,
encoder_output_proj,
encoder_padding_mask,
trg=None,
is_train=True):
"""Decoder: GRU with Attention"""
decoder_cell = DecoderCell(hidden_size=decoder_size)
trg_embeder = lambda x: fluid.embedding(input=x,
size=[target_dict_size, hidden_dim],
dtype="float32",
param_attr=fluid.ParamAttr(
name="trg_emb_table"))
output_layer = lambda x: layers.fc(x,
size=target_dict_size,
num_flatten_dims=len(x.shape) - 1,
param_attr=fluid.ParamAttr(name=
"output_w"))
if is_train:
decoder_output, _ = layers.rnn(
cell=decoder_cell,
inputs=trg_embeder(trg),
initial_states=None,
time_major=False,
encoder_output=encoder_output,
encoder_output_proj=encoder_output_proj,
encoder_padding_mask=encoder_padding_mask)
decoder_output = output_layer(decoder_output)
else:
encoder_output = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output, beam_size)
encoder_output_proj = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output_proj, beam_size)
encoder_padding_mask = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_padding_mask, beam_size)
beam_search_decoder = layers.BeamSearchDecoder(
cell=decoder_cell,
start_token=bos_id,
end_token=eos_id,
beam_size=beam_size,
embedding_fn=trg_embeder,
output_fn=output_layer)
decoder_output, _ = layers.dynamic_decode(
decoder=beam_search_decoder,
inits=None,
max_step_num=max_length,
output_time_major=False,
encoder_output=encoder_output,
encoder_output_proj=encoder_output_proj,
encoder_padding_mask=encoder_padding_mask)
return decoder_output
Error:
Traceback (most recent call last):
File "run_seq2seq.py", line 136, in <module>
main(args)
File "run_seq2seq.py", line 128, in main
train(args)
File "run_seq2seq.py", line 61, in train
logits = model_func(args, inputs, ernie_config, is_train=True)
File "/home/yangpan/projects/paper_recurrence/PLMEE/ERNIE/ernie/finetune/seq2seq.py", line 208, in model_func
is_train=is_train)
File "/home/yangpan/projects/paper_recurrence/PLMEE/ERNIE/ernie/finetune/seq2seq.py", line 156, in decoder
encoder_padding_mask=encoder_padding_mask)
File "/home/yangpan/anaconda3/envs/paddle-py2.7/lib/python2.7/site-packages/paddle/fluid/layers/rnn.py", line 428, in rnn
initial_states = cell.get_initial_states(batch_ref=inputs)
File "/home/yangpan/anaconda3/envs/paddle-py2.7/lib/python2.7/site-packages/paddle/fluid/layers/rnn.py", line 116, in get_initial_states
states_shapes = self.state_shape if shape is None else shape
File "/home/yangpan/anaconda3/envs/paddle-py2.7/lib/python2.7/site-packages/paddle/fluid/layers/rnn.py", line 150, in state_shape
raise NotImplementedError
NotImplementedError