提交 89ffcec1 编写于 作者: A Aston Zhang

add padding explanation

上级 f0379049
......@@ -316,6 +316,7 @@ def train_embedding(num_epochs):
# pred 形状:(batch_size, 1, max_len)。
pred = nd.batch_dot(emb_in, emb_out.swapaxes(1, 2))
# mask 和 label 形状:(batch_size, max_len)。
# 避免填充对损失函数计算的影响。
l = (loss(pred.reshape(label.shape), label, mask) *
mask.shape[1] / mask.sum(axis=1))
l.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册