wide_deep.py 16.5 KB
Newer Older
1
import numpy as np
2
import torch
3 4 5
import torch.nn as nn
import torch.nn.functional as F

6
from ..wdtypes import *
7
from ..initializers import Initializer, MultipleInitializers
8 9
from ..optimizers import MultipleOptimizers
from ..lr_schedulers import MultipleLRScheduler
10 11 12 13
from ..callbacks import Callback, History, CallbackContainer
from ..metrics import Metric, MultipleMetrics, MetricCallback
from ..transforms import MultipleTransforms
from ..losses import FocalLoss
14

15 16 17 18 19
from .wide import Wide
from .deep_dense import DeepDense
from .deep_text import DeepText
from .deep_image import DeepImage

20 21 22 23
from tqdm import tqdm,trange
from sklearn.utils import Bunch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
24

25 26 27
use_cuda = torch.cuda.is_available()


28 29 30
import pdb


31
class WideDeepLoader(Dataset):
32
    def __init__(self, X_wide:np.ndarray, X_deep:np.ndarray, target:np.ndarray,
33
        X_text:Optional[np.ndarray]=None, X_img:Optional[np.ndarray]=None,
34
        transforms:Optional=None):
35 36

        self.X_wide = X_wide
37 38
        self.X_deep = X_deep
        self.X_text = X_text
39
        self.X_img  = X_img
40
        self.transforms = transforms
41
        self.Y = target
42 43 44

    def __getitem__(self, idx:int):

45 46
        X = Bunch(wide=self.X_wide[idx])
        X.deepdense= self.X_deep[idx]
47
        if self.X_text is not None:
48
            X.deeptext = self.X_text[idx]
49 50 51 52
        if self.X_img is not None:
            xdi = (self.X_img[idx]/255).astype('float32')
            if self.transforms is not None:
                xdi = self.transforms(xdi)
53
            X.deepimg = xdi
54
        if self.Y is not None:
55 56
            y  = self.Y[idx]
            return X, y
57
        else:
58 59 60 61 62 63 64
            return X

    def __len__(self):
        return len(self.X_wide)


class WideDeep(nn.Module):
65

66 67 68 69 70
    def __init__(self,
        wide:TorchModel,
        deepdense:TorchModel,
        deeptext:Optional[TorchModel]=None,
        deepimage:Optional[TorchModel]=None):
71

72 73
        super(WideDeep, self).__init__()

74 75 76 77
        self.wide = wide
        self.deepdense = deepdense
        self.deeptext  = deeptext
        self.deepimage = deepimage
78

79
    def forward(self, X:List[Dict[str,Tensor]])->Tensor:
80
        wide_deep = self.wide(X['wide'])
81 82 83 84 85 86 87 88 89 90 91 92 93 94
        wide_deep.add_(self.deepdense(X['deepdense']))
        if self.deeptext is not None:
            wide_deep.add_(self.deeptext(X['deeptext']))
        if self.deepimage is not None:
            wide_deep.add_(self.deepimg(X['deepimg']))
        return wide_deep

    def _activation_fn(self, inp:Tensor) -> Tensor:
        if self.method == 'regression':
            return inp
        if self.method == 'logistic':
            return torch.sigmoid(inp)
        if self.method == 'multiclass':
            return F.softmax(inp, dim=1)
95

96
    def _loss_fn(self, y_pred:Tensor, y_true:Tensor) -> Tensor:
97 98 99 100 101 102
        if self.focal_loss:
            return self.FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
        if self.method == 'regression':
            return F.mse_loss(y_pred, y_true.view(-1, 1))
        if self.method == 'logistic':
            return F.binary_cross_entropy(y_pred, y_true.view(-1, 1), weight=self.class_weight)
103
        if self.method == 'multiclass':
104 105
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

106 107 108
    def compile(self,method:str,
        initializers:Optional[Dict[str,Initializer]]=None,
        optimizers:Optional[Dict[str,Optimizer]]=None,
109
        global_optimizer:Optional[Optimizer]=None,
110
        lr_schedulers:Optional[Dict[str,LRScheduler]]=None,
111
        global_lr_scheduler:Optional[LRScheduler]=None,
112 113 114 115 116
        transforms:Optional[List[Transforms]]=None,
        callbacks:Optional[List[Callback]]=None,
        metrics:Optional[List[Metric]]=None,
        class_weight:Optional[Union[float,List[float],Tuple[float]]]=None,
        focal_loss:bool=False, alpha:float=0.25, gamma:float=1):
117

118
        self.early_stop = False
119 120 121 122
        self.method = method
        self.focal_loss = focal_loss
        if self.focal_loss:
            self.alpha, self.gamma = alpha, gamma
123

124 125 126 127 128 129
        if isinstance(class_weight, float):
            self.class_weight = torch.tensor([class_weight, 1.-class_weight])
        elif isinstance(class_weight, (List, Tuple)):
            self.class_weight =  torch.tensor(class_weight)
        else:
            self.class_weight = None
130 131 132 133 134 135 136 137 138

        if initializers is not None:
            self.initializer = MultipleInitializers(initializers)
            self.initializer.apply(self)

        if optimizers is not None:
            self.optimizer = MultipleOptimizers(optimizers)
            self.optimizer.apply(self)
        elif global_optimizer is not None:
139 140
            if isinstance(global_optimizer, type): self.optimizer = global_optimizer()
            self.optimizer = global_optimizer(self.parameters())
141 142 143 144 145 146
        else:
            self.optimizer = torch.optim.Adam(self.parameters())

        if lr_schedulers is not None:
            self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
            self.lr_scheduler.apply(self.optimizer._optimizers)
147
            self.lr_scheduler_name = [sc.__class__.__name__.lower() for _,sc in self.lr_scheduler._schedulers.items()]
148
        elif global_lr_scheduler is not None:
149
            if isinstance(global_optimizer, type): self.lr_scheduler = global_lr_scheduler()
150
            self.lr_scheduler = global_lr_scheduler(self.optimizer)
151
            self.lr_scheduler_name = self.lr_scheduler.__class__.__name__.lower()
152
        else:
153
            self.lr_scheduler, self.lr_scheduler_name = None, None
154

155 156 157 158 159
        if transforms is not None:
            self.transforms = MultipleTransforms(transforms)()
        else:
            self.transforms = None

160
        self.history = History()
161 162
        self.callbacks = [self.history]
        if callbacks is not None:
163
            self.callbacks += callbacks
164 165 166

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
167
            self.callbacks += [MetricCallback(self.metric)]
168 169
        else:
            self.metric = None
170

171 172
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
173

174
    def _training_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
175

176
        X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
177 178 179 180
        y = target.float() if self.method != 'multiclass' else target
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
181 182
        y_pred =  self._activation_fn(self.forward(X))
        loss = self._loss_fn(y_pred, y)
183 184 185
        loss.backward()
        self.optimizer.step()

186 187
        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss/(batch_idx+1)
188

189 190
        if self.metric is not None:
            acc = self.metric(y_pred, y)
191
            return acc, avg_loss
192
        else:
193
            return None, avg_loss
194

195
    def _validation_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
196 197

        with torch.no_grad():
198
            X = {k:v.cuda() for k,v in data.item()} if use_cuda else data
199 200 201
            y = target.float() if self.method != 'multiclass' else target
            y = y.cuda() if use_cuda else y

202 203
            y_pred = self._activation_fn(self.forward(X))
            loss = self._loss_fn(y_pred, y)
204 205 206 207 208
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss/(batch_idx+1)

        if self.metric is not None:
            acc = self.metric(y_pred, y)
209
            return acc, avg_loss
210
        else:
211 212
            return None, avg_loss

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    def _train_val_split(self, X_wide:np.ndarray, X_deep:np.ndarray,
        target:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_train:Optional[Dict[str,
        np.ndarray]]=None, X_val:Optional[Dict[str, np.ndarray]]=None,
        val_split:float=0., seed:int=1):

        if X_val is None and val_split is None:
            if X_train is not None:
                X_wide, X_deep, target = X_train['wide'], X_train['deepdense'], X_train['target']
                if 'deeptext' in X_train.keys(): X_text = X['deeptext']
                if 'deepimage' in X_train.keys(): X_img = X['deepimage']
            X_train={'wide': X_wide, 'deepdense': X_deep, 'target': target}
            if 'X_text' in locals():
                X_train.update({'X_text': X_text})
            if 'X_img' in locals():
                X_train.update({'X_img': X_img})
            train_set = WideDeepLoader(**X_train, transforms=self.transforms)
            eval_set = None
        else:
            if X_val is not None:
                if X_train is None:
                    X_train = {'X_wide':X_wide, 'X_deep': X_deep, 'target': target}
                    if X_text is not None: X_train.update({'X_text': X_text})
                    if X_img is not None:  X_train.update({'X_img': X_img})
            else:
                if X_train is not None:
                    X_wide, X_deep, target = X_train['wide'], X_train['deepdense'], X_train['target']
                    if 'deeptext' in X_train.keys():
                        X_text = X['deeptext']
                    if 'deepimage' in X_train.keys():
                        X_img = X['deepimage']
                X_tr_wide, X_val_wide, X_tr_deep, X_val_deep, y_tr, y_val = \
                train_test_split(X_wide, X_deep, target, test_size=val_split, random_state=seed)
                X_train = {'X_wide':X_tr_wide, 'X_deep': X_tr_deep, 'target': y_tr}
                X_val = {'X_wide':X_val_wide, 'X_deep': X_val_deep, 'target': y_val}
                if 'X_text' is locals():
                    X_tr_text, X_val_text = train_test_split(X_text, test_size=val_split, random_state=seed)
                    X_train.update({'X_text': X_tr_text}), X_val.update({'X_text': X_val_text})
                if 'X_image' is locals():
                    X_tr_img, X_val_img = train_test_split(X_img, test_size=val_split, random_state=seed)
                    X_train.update({'X_img': X_tr_imgt}), X_val.update({'X_img': X_val_img})
            train_set = WideDeepLoader(**X_train, transforms=self.transforms)
            eval_set = WideDeepLoader(**X_val, transforms=self.transforms)
        return train_set, eval_set

258
    def fit(self, X_wide:np.ndarray, X_deep:np.ndarray, target:np.ndarray,
259 260 261
        X_text:Optional[np.ndarray]=None, X_img:Optional[np.ndarray]=None,
        n_epochs:int=1, batch_size:int=32, X_train:Optional[Dict[str,
        np.ndarray]]=None, X_val:Optional[Dict[str, np.ndarray]]=None,
262
        val_split:float=0., seed:int=1, patience:int=10, verbose:int=1):
263 264

        self.batch_size = batch_size
265 266 267
        train_set, eval_set = self._train_val_split(X_wide, X_deep, target, X_text,
            X_img, X_train, X_val, val_split, seed)
        train_loader = DataLoader(dataset=train_set, batch_size=batch_size, num_workers=8)
268 269 270
        train_steps =  (len(train_loader.dataset) // batch_size) + 1
        self.callback_container.on_train_begin({'batch_size': batch_size,
            'train_steps': train_steps, 'n_epochs': n_epochs})
271

272
        for epoch in range(n_epochs):
273
            # train step...
274 275 276 277 278
            epoch_logs = {}
            self.callback_container.on_epoch_begin(epoch, epoch_logs)
            self.train_running_loss = 0.
            with trange(train_steps, disable=verbose != 1) as t:
                for batch_idx, (data,target) in zip(t, train_loader):
279
                    t.set_description('epoch %i' % (epoch+1))
280
                    acc, train_loss = self._training_step(data, target, batch_idx)
281 282
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
283
                    else:
284
                        t.set_postfix(loss=np.sqrt(train_loss))
285 286
                    if self.lr_scheduler and 'cycl' in  self.lr_scheduler_name:
                        self.lr_scheduler.step()
287 288
            epoch_logs['train_loss'] = train_loss
            if acc is not None: epoch_logs['train_acc'] = acc
289

290 291 292 293 294
            # eval step...
            if eval_set is not None:
                eval_loader = DataLoader(dataset=eval_set, batch_size=batch_size, num_workers=8,
                    shuffle=False)
                eval_steps =  (len(eval_loader.dataset) // batch_size) + 1
295 296
                self.valid_running_loss = 0.
                with trange(eval_steps, disable=verbose != 1) as v:
297 298
                    for i, (data,target) in zip(v, eval_loader):
                        v.set_description('valid')
299
                        acc, val_loss = self._validation_step(data, target, i)
300 301
                        if acc is not None:
                            v.set_postfix(metrics=acc, loss=val_loss)
302
                        else:
303 304 305
                            v.set_postfix(loss=np.sqrt(val_loss))
                epoch_logs['val_loss'] = val_loss
                if acc is not None: epoch_logs['val_acc'] = acc
306

307 308 309
            self.callback_container.on_epoch_end(epoch, epoch_logs)
            if self.early_stop:
                break
310 311
            if self.lr_scheduler and 'cycl' not in  self.lr_scheduler_name:
                self.lr_scheduler.step()
312

313 314 315 316 317 318
    def predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:

        if X_test is not None:
            test_set = WideDeepLoader(**X_test)
        else:
319
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
320 321
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})
322
            test_set = WideDeepLoader(**load_dict)
323 324 325 326

        test_loader = torch.utils.data.DataLoader(dataset=test_set,
            batch_size=self.batch_size,shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1
327

328 329 330
        preds_l = []
        with torch.no_grad():
            with trange(test_steps) as t:
331
                for i, data in zip(t, test_loader):
332
                    t.set_description('predict')
333
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
334 335
                    preds = self._activation_fn(self.forward(X).cpu().data.numpy())
                    preds_l.append(preds)
336 337 338 339 340 341 342 343 344
            if self.method == "regression":
                return np.vstack(preds_l).squeeze(1)
            if self.method == "logistic":
                preds = np.vstack(preds_l).squeeze(1)
                return (preds > 0.5).astype('int')
            if self.method == "multiclass":
                preds = np.vstack(preds_l)
                return np.argmax(preds, 1)

345 346 347 348 349 350
    def predict_proba(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:

        if X_test is not None:
            test_set = WideDeepLoader(**X_test)
        else:
351
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
352 353 354 355 356 357 358
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})

        test_set = WideDeepLoader(**load_dict)
        test_loader = torch.utils.data.DataLoader(dataset=test_set,
            batch_size=self.batch_size,shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1
359

360 361 362
        preds_l = []
        with torch.no_grad():
            with trange(test_steps) as t:
363
                for i, data in zip(t, test_loader):
364
                    t.set_description('predict')
365
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
366 367
                    preds = self._activation_fn(self.forward(X).cpu().data.numpy())
                    preds_l.append(preds)
368 369 370 371 372 373 374 375 376
            if self.method == "logistic":
                preds = np.vstack(preds_l).squeeze(1)
                probs = np.zeros([preds.shape[0],2])
                probs[:,0] = 1-preds
                probs[:,1] = preds
                return probs
            if self.method == "multiclass":
                return np.vstack(preds_l)

377 378 379
    def get_embeddings(self, col_name:str,
        cat_embed_encoding_dict:Dict[str,Dict[str,int]]) -> Dict[str,np.ndarray]:

380 381 382 383
        params = list(self.named_parameters())
        emb_layers = [p for p in params if 'emb_layer' in p[0]]
        emb_layer  = [layer for layer in emb_layers if col_name in layer[0]][0]
        embeddings = emb_layer[1].cpu().data.numpy()
384
        col_label_encoding = cat_embed_encoding_dict[col_name]
385 386 387 388 389
        inv_dict = {v:k for k,v in col_label_encoding.items()}
        embeddings_dict = {}
        for idx,value in inv_dict.items():
            embeddings_dict[value] = embeddings[idx]
        return embeddings_dict