提交 c4ed9a58 编写于 作者: J jrzaurin

refactored code to a class with fit and transform methods

上级 98014867
......@@ -26,11 +26,12 @@ class PrepareWide(object):
return df, crossed_colnames
def fit(self, df:pd.DataFrame)->np.ndarray:
df_wide = df.copy()[self.wide_cols]
if self.crossed_cols is not None:
df_wide, crossed_colnames = self.cross_cols(df_wide)
self.wide_crossed_cols = self.wide_cols + crossed_colnames
else:
self.wide_crossed_cols = self.wide_cols
if self.already_dummies:
X_oh_1 = self.df_wide[self.already_dummies].values
......@@ -41,7 +42,6 @@ class PrepareWide(object):
return self.one_hot_enc.fit_transform(df_wide[self.wide_crossed_cols])
def transform(self, df:pd.DataFrame)->np.ndarray:
df_wide = df.copy()[self.wide_cols]
if self.crossed_cols is not None:
df_wide, _ = self.cross_cols(df_wide)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册