提交 3532a980 编写于 作者: J jrzaurin

Fixed #60 and #61. Placing dropout in the right position and applying...

Fixed #60 and #61. Placing dropout in the right position and applying FullEmbeddingDropout only in training
上级 4de81ea2
......@@ -11,7 +11,7 @@ __pycache__*
Untitled*.ipynb
# data related dirs
# data/
tmp_data/
model_weights/
tmp_dir/
weights/
......
......@@ -18,10 +18,13 @@ class FullEmbeddingDropout(nn.Module):
self.dropout = dropout
def forward(self, X: Tensor) -> Tensor:
mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as(
X
) / (1 - self.dropout)
return mask * X
if self.training:
mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as(
X
) / (1 - self.dropout)
return mask * X
else:
return X
DropoutLayers = Union[nn.Dropout, FullEmbeddingDropout]
......@@ -167,12 +170,11 @@ class CategoricalEmbeddings(nn.Module):
x = torch.cat(cat_embed, 1)
else:
x = self.embed(X[:, self.cat_idx].long())
x = self.dropout(x)
if self.bias is not None:
x = x + self.bias.unsqueeze(0)
return x
return self.dropout(x)
class CatAndContEmbeddings(nn.Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册