语言模型预测报错
Created by: xiangyubo
1)PaddlePaddle版本:1.5版本 -预测信息 希望使用官方的语言模型预测一句话的ppl,模型地址: https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/language_model 训练时候选择的是 encoder_static 模型。能够正常训练和保存模型,使用 save_inference_model 接口保存。预测时能够通过 load_inference_model 接口加载模型,但是预测时候报错说维度不匹配
保存时候的数据如下表示
x = layers.data(
name="x",
shape=[1, config.num_steps, 1],
dtype='int64',
append_batch_size=False)
y = layers.data(
name="y",
shape=[1 * config.num_steps, 1],
dtype='int64',
append_batch_size=False)
我对 lm_model.py 做了一些小改动,将预测的词语返回,以便计算 ppl。保存过程如下:
projection, last_hidden, last_cell = lm_model(
config.hidden_size,
config.vocab_size,
1,
num_layers=config.num_layers,
num_steps=config.num_steps,
init_scale=config.init_scale,
dropout=config.dropout,
rnn_model=config.rnn_model,
x=x)
loss = layers.softmax_with_cross_entropy(logits=projection, label=y, soft_label=False)
loss = layers.reshape(loss, shape=[-1, config.num_steps], inplace=True)
loss = layers.reduce_mean(loss, dim=[0])
loss = layers.reduce_sum(loss)
ppl = layers.exp(loss)
freeze_program = fluid.default_main_program()
fluid.io.load_persistables(exe, args.save_model_dir, freeze_program)
freeze_program = freeze_program.clone(for_test=True)
print("freeze out: {}, prediction layout: {}".format(args.save_freeze_dir, ppl))
fluid.io.save_inference_model(args.save_freeze_dir, ['x', 'y'], ppl, exe, freeze_program)
print("freeze success")
我参照了 train.py 213行的方式读取数据并 reshape,然后预测的时候报错, 预测时候的代码如下:
def infer_reader(fields):
unk_id = len(vocab) - 1
wids = [vocab[x] if x in vocab else unk_id for x in fields.split(" ")]
# wids = np.array(wids)
return wids[:-1], wids[1:]
def infer(fields, feeder):
x, y = infer_reader(fields)
x = np.array(x)
x = x.reshape((-1, len(x), 1))
y = np.array(y)
y = y.reshape((-1, 1))
print(x.shape)
print(y.shape)
# 此处 data 虽然是个 list,但对应的是一个样本,所以需要用元组括起来
# 单个元素的元组需要在第一个元素后加个逗号,类似这样:(a,)
ppl = exe.run(inference_program, feed={feed_target_names[0]: x, feed_target_names[1]: y}, fetch_list=fetch_targets, return_numpy=False)
ppl = np.array(ppl[0])
return ppl
错误信息是 concate 维度不匹配。求解答