未验证 提交 e6a33c88 编写于 作者: timixy's avatar timixy 提交者: GitHub

Update bert.md (#1054)

上级 2588aa49
...@@ -203,7 +203,7 @@ class MaskLM(nn.Block): ...@@ -203,7 +203,7 @@ class MaskLM(nn.Block):
batch_size = X.shape[0] batch_size = X.shape[0]
batch_idx = np.arange(0, batch_size) batch_idx = np.arange(0, batch_size)
# 假设batch_size=2,num_pred_positions=3 # 假设batch_size=2,num_pred_positions=3
# 那么batch_idx是np.array([0,0,0,1,1]) # 那么batch_idx是np.array([0,0,0,1,1,1])
batch_idx = np.repeat(batch_idx, num_pred_positions) batch_idx = np.repeat(batch_idx, num_pred_positions)
masked_X = X[batch_idx, pred_positions] masked_X = X[batch_idx, pred_positions]
masked_X = masked_X.reshape((batch_size, num_pred_positions, -1)) masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册