diff --git a/.gitignore b/.gitignore index 4bd79b689655353a81da39b6c63731260e2903b3..c4533b1b20b4a47c88152d6cac93cd4f06782579 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ __pycache__* Untitled*.ipynb # data related dirs -# data/ +tmp_data/ model_weights/ tmp_dir/ weights/ diff --git a/pytorch_widedeep/models/transformers/_embeddings_layers.py b/pytorch_widedeep/models/transformers/_embeddings_layers.py index 2ce7da83e0323bb095cd453be9b15950925097b9..51c54890feea4d263daf4eeb3861d07bfa2c5288 100644 --- a/pytorch_widedeep/models/transformers/_embeddings_layers.py +++ b/pytorch_widedeep/models/transformers/_embeddings_layers.py @@ -18,10 +18,13 @@ class FullEmbeddingDropout(nn.Module): self.dropout = dropout def forward(self, X: Tensor) -> Tensor: - mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as( - X - ) / (1 - self.dropout) - return mask * X + if self.training: + mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as( + X + ) / (1 - self.dropout) + return mask * X + else: + return X DropoutLayers = Union[nn.Dropout, FullEmbeddingDropout] @@ -167,12 +170,11 @@ class CategoricalEmbeddings(nn.Module): x = torch.cat(cat_embed, 1) else: x = self.embed(X[:, self.cat_idx].long()) - x = self.dropout(x) if self.bias is not None: x = x + self.bias.unsqueeze(0) - return x + return self.dropout(x) class CatAndContEmbeddings(nn.Module):