🐛 positional encoding buffer

上级 b2b305ff
......@@ -28,7 +28,7 @@ class PositionalEncoding(Module):
super().__init__()
self.dropout = nn.Dropout(dropout_prob)
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
def __call__(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册