提交 75c37e12 编写于 作者: J jrzaurin

Fixing style and small adjustment of the CategoricalEmbeddings class

上级 3532a980
......@@ -498,7 +498,7 @@ class ModelCheckpoint(Callback):
)
)
if self.wb is not None:
self.wb.run.summary["best"] = current
self.wb.run.summary["best"] = current # type: ignore[attr-defined]
self.best = current
self.best_epoch = epoch
self.best_state_dict = self.model.state_dict()
......
from ._base import load_bio_kdd04, load_adult
from ._base import load_adult, load_bio_kdd04
__all__ = ["load_bio_kdd04", "load_adult"]
from importlib import resources
import pandas as pd
......
......@@ -4,6 +4,7 @@ https://github.com/awslabs/autogluon/tree/master/tabular/src/autogluon/tabular/m
"""
import math
import warnings
import torch
from torch import nn
......@@ -19,9 +20,9 @@ class FullEmbeddingDropout(nn.Module):
def forward(self, X: Tensor) -> Tensor:
if self.training:
mask = X.new().resize_((X.size(1), 1)).bernoulli_(1 - self.dropout).expand_as(
X
) / (1 - self.dropout)
mask = X.new().resize_((X.size(1), 1)).bernoulli_(
1 - self.dropout
).expand_as(X) / (1 - self.dropout)
return mask * X
else:
return X
......@@ -128,13 +129,16 @@ class CategoricalEmbeddings(nn.Module):
self.categorical_cols = [ei[0] for ei in embed_input]
self.cat_idx = [self.column_idx[col] for col in self.categorical_cols]
self.bias = (
nn.Parameter(torch.Tensor(len(self.categorical_cols), embed_dim))
if use_bias
else None
)
if self.bias is not None:
if use_bias is not None:
self.bias = nn.Parameter(
torch.Tensor(len(self.categorical_cols), embed_dim)
)
nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5))
if shared_embed:
warnings.warn(
"The current implementation of 'SharedEmbeddings' does not use bias",
UserWarning,
)
# Categorical: val + 1 because 0 is reserved for padding/unseen cateogories.
if self.shared_embed:
......@@ -170,11 +174,11 @@ class CategoricalEmbeddings(nn.Module):
x = torch.cat(cat_embed, 1)
else:
x = self.embed(X[:, self.cat_idx].long())
if self.bias is not None:
x = x + self.bias.unsqueeze(0)
x = self.dropout(x)
if self.bias is not None:
x = x + self.bias.unsqueeze(0)
return self.dropout(x)
return x
class CatAndContEmbeddings(nn.Module):
......
from pytorch_widedeep.datasets import load_bio_kdd04, load_adult
import pandas as pd
import numpy as np
import pandas as pd
import pytest
from pytorch_widedeep.datasets import load_adult, load_bio_kdd04
@pytest.mark.parametrize(
"as_frame",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册