未验证 提交 af2ad8d8 编写于 作者: R RedContritio 提交者: GitHub

fix error sample code in static.nn.loss.nce (#53588)

上级 8d340ee1
......@@ -110,7 +110,8 @@ def nce(
param_attr='embed', is_sparse=True)
embs.append(emb)
embs = paddle.concat(x=embs, axis=1)
embs = paddle.concat(x=embs, axis=1) # concat from 4 * [(-1, 1, 32)] to (-1, 4, 32)
embs = paddle.reshape(x=embs, shape=(-1, 4 * 32)) # reshape to (batch_size = -1, dim = 4*32)
loss = paddle.static.nn.nce(input=embs, label=words[label_word],
num_total_classes=dict_size, param_attr='nce.w_0',
bias_attr='nce.b_0')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册