From b0c55c28ea7c5bac9e286f096a7f01966d798685 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Wed, 10 May 2023 10:18:24 +0800 Subject: [PATCH] fix error sample code in static.nn.loss.nce (#53588) (#53630) --- python/paddle/static/nn/loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/static/nn/loss.py b/python/paddle/static/nn/loss.py index e0a41c6ec65..7d7912a06fb 100644 --- a/python/paddle/static/nn/loss.py +++ b/python/paddle/static/nn/loss.py @@ -110,7 +110,8 @@ def nce( param_attr='embed', is_sparse=True) embs.append(emb) - embs = paddle.concat(x=embs, axis=1) + embs = paddle.concat(x=embs, axis=1) # concat from 4 * [(-1, 1, 32)] to (-1, 4, 32) + embs = paddle.reshape(x=embs, shape=(-1, 4 * 32)) # reshape to (batch_size = -1, dim = 4*32) loss = paddle.static.nn.nce(input=embs, label=words[label_word], num_total_classes=dict_size, param_attr='nce.w_0', bias_attr='nce.b_0') -- GitLab