提交 3165f135 编写于 作者: J jrzaurin

adapt for cases with no crossed_columns

上级 96af7926
......@@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
from collections import namedtuple
from itertools import chain
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
......@@ -87,6 +88,7 @@ def prepare_data(df, wide_cols, crossed_cols, embeddings_cols, continuous_cols,
# Extract the target and copy the dataframe so we don't mutate it
# internally.
Y = np.array(df[target])
all_columns = list(set(wide_cols + deep_cols + list(chain(*crossed_cols))))
df_tmp = df.copy()[list(set(wide_cols + deep_cols))]
# Build the crossed columns
......
......@@ -136,11 +136,13 @@ class WideDeep(nn.Module):
# Deep Side
emb = [getattr(self, 'emb_layer_'+col)(X_d[:,self.deep_column_idx[col]].long())
for col,_,_ in self.embeddings_input]
if self.continuous_cols:
cont_idx = [self.deep_column_idx[col] for col in self.continuous_cols]
cont = [X_d[:, cont_idx].float()]
deep_inp = torch.cat(emb+cont, 1)
else:
deep_inp = torch.cat(emb, 1)
cont_idx = [self.deep_column_idx[col] for col in self.continuous_cols]
cont = [X_d[:, cont_idx].float()]
deep_inp = torch.cat(emb+cont, 1)
x_deep = F.relu(self.linear_1(deep_inp))
if self.dropout:
x_deep = self.linear_1_drop(x_deep)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册