提交 3532a980 编写于 作者: J jrzaurin

Fixed #60 and #61. Placing dropout in the right position and applying...

Fixed #60 and #61. Placing dropout in the right position and applying FullEmbeddingDropout only in training
上级 4de81ea2
...@@ -11,7 +11,7 @@ __pycache__* ...@@ -11,7 +11,7 @@ __pycache__*
Untitled*.ipynb Untitled*.ipynb
# data related dirs # data related dirs
# data/ tmp_data/
model_weights/ model_weights/
tmp_dir/ tmp_dir/
weights/ weights/
......
...@@ -18,10 +18,13 @@ class FullEmbeddingDropout(nn.Module): ...@@ -18,10 +18,13 @@ class FullEmbeddingDropout(nn.Module):
self.dropout = dropout self.dropout = dropout
def forward(self, X: Tensor) -> Tensor: def forward(self, X: Tensor) -> Tensor:
mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as( if self.training:
X mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as(
) / (1 - self.dropout) X
return mask * X ) / (1 - self.dropout)
return mask * X
else:
return X
DropoutLayers = Union[nn.Dropout, FullEmbeddingDropout] DropoutLayers = Union[nn.Dropout, FullEmbeddingDropout]
...@@ -167,12 +170,11 @@ class CategoricalEmbeddings(nn.Module): ...@@ -167,12 +170,11 @@ class CategoricalEmbeddings(nn.Module):
x = torch.cat(cat_embed, 1) x = torch.cat(cat_embed, 1)
else: else:
x = self.embed(X[:, self.cat_idx].long()) x = self.embed(X[:, self.cat_idx].long())
x = self.dropout(x)
if self.bias is not None: if self.bias is not None:
x = x + self.bias.unsqueeze(0) x = x + self.bias.unsqueeze(0)
return x return self.dropout(x)
class CatAndContEmbeddings(nn.Module): class CatAndContEmbeddings(nn.Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册