wide_deep.py 18.5 KB
Newer Older
1
import numpy as np
2
import warnings
3
import torch
4 5 6
import torch.nn as nn
import torch.nn.functional as F

7
from ..wdtypes import *
8
from ..initializers import Initializer, MultipleInitializers
9 10
from ..optimizers import MultipleOptimizers
from ..lr_schedulers import MultipleLRScheduler
11 12 13 14
from ..callbacks import Callback, History, CallbackContainer
from ..metrics import Metric, MultipleMetrics, MetricCallback
from ..transforms import MultipleTransforms
from ..losses import FocalLoss
15

16 17 18 19 20
from .wide import Wide
from .deep_dense import DeepDense
from .deep_text import DeepText
from .deep_image import DeepImage

21 22 23 24
from tqdm import tqdm,trange
from sklearn.utils import Bunch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
25

26 27 28
use_cuda = torch.cuda.is_available()


29 30 31
import pdb


32
class WideDeepLoader(Dataset):
33
    def __init__(self, X_wide:np.ndarray, X_deep:np.ndarray, target:np.ndarray,
34
        X_text:Optional[np.ndarray]=None, X_img:Optional[np.ndarray]=None,
35
        transforms:Optional=None):
36 37

        self.X_wide = X_wide
38 39
        self.X_deep = X_deep
        self.X_text = X_text
40
        self.X_img  = X_img
41
        self.transforms = transforms
42 43 44
        if self.transforms:
            self.transforms_names = [tr.__class__.__name__ for tr in self.transforms.transforms]
        else: self.transforms_names = []
45
        self.Y = target
46 47 48

    def __getitem__(self, idx:int):

49 50
        X = Bunch(wide=self.X_wide[idx])
        X.deepdense= self.X_deep[idx]
51
        if self.X_text is not None:
52
            X.deeptext = self.X_text[idx]
53
        if self.X_img is not None:
54 55 56 57 58 59 60 61 62
            xdi = self.X_img[idx]
            if 'int' in str(xdi.dtype) and 'uint8' != str(xdi.dtype): xdi = xdi.astype('uint8')
            if 'float' in str(xdi.dtype) and 'float32' != str(xdi.dtype): xdi = xdi.astype('float32')
            if not self.transforms or 'ToTensor' not in self.transforms_names:
                xdi = xdi.transpose(2,0,1)
                if 'int' in str(xdi.dtype): xdi = (xdi/xdi.max()).astype('float32')
            if 'ToTensor' in self.transforms_names: xdi = self.transforms(xdi)
            elif self.transforms: xdi = self.transforms(torch.Tensor(xdi))
            X.deepimage = xdi
63
        if self.Y is not None:
64 65
            y  = self.Y[idx]
            return X, y
66
        else:
67 68 69 70 71 72 73
            return X

    def __len__(self):
        return len(self.X_wide)


class WideDeep(nn.Module):
74

75 76 77 78 79
    def __init__(self,
        wide:TorchModel,
        deepdense:TorchModel,
        deeptext:Optional[TorchModel]=None,
        deepimage:Optional[TorchModel]=None):
80

81
        super(WideDeep, self).__init__()
82 83 84 85
        self.wide = wide
        self.deepdense = deepdense
        self.deeptext  = deeptext
        self.deepimage = deepimage
86

87
    def forward(self, X:List[Dict[str,Tensor]])->Tensor:
88
        wide_deep = self.wide(X['wide'])
89 90 91 92
        wide_deep.add_(self.deepdense(X['deepdense']))
        if self.deeptext is not None:
            wide_deep.add_(self.deeptext(X['deeptext']))
        if self.deepimage is not None:
93
            wide_deep.add_(self.deepimage(X['deepimage']))
94 95 96 97 98
        return wide_deep

    def compile(self,method:str,
        initializers:Optional[Dict[str,Initializer]]=None,
        optimizers:Optional[Dict[str,Optimizer]]=None,
99
        global_optimizer:Optional[Optimizer]=None,
100
        lr_schedulers:Optional[Dict[str,LRScheduler]]=None,
101
        global_lr_scheduler:Optional[LRScheduler]=None,
102 103 104 105 106
        transforms:Optional[List[Transforms]]=None,
        callbacks:Optional[List[Callback]]=None,
        metrics:Optional[List[Metric]]=None,
        class_weight:Optional[Union[float,List[float],Tuple[float]]]=None,
        focal_loss:bool=False, alpha:float=0.25, gamma:float=1):
107

108
        self.early_stop = False
109 110 111 112
        self.method = method
        self.focal_loss = focal_loss
        if self.focal_loss:
            self.alpha, self.gamma = alpha, gamma
113

114 115 116 117 118 119
        if isinstance(class_weight, float):
            self.class_weight = torch.tensor([class_weight, 1.-class_weight])
        elif isinstance(class_weight, (List, Tuple)):
            self.class_weight =  torch.tensor(class_weight)
        else:
            self.class_weight = None
120 121 122 123 124 125 126 127 128

        if initializers is not None:
            self.initializer = MultipleInitializers(initializers)
            self.initializer.apply(self)

        if optimizers is not None:
            self.optimizer = MultipleOptimizers(optimizers)
            self.optimizer.apply(self)
        elif global_optimizer is not None:
129 130
            if isinstance(global_optimizer, type): self.optimizer = global_optimizer()
            self.optimizer = global_optimizer(self.parameters())
131 132 133 134 135 136
        else:
            self.optimizer = torch.optim.Adam(self.parameters())

        if lr_schedulers is not None:
            self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
            self.lr_scheduler.apply(self.optimizer._optimizers)
137 138 139
            if 'cycl' in [sc.__class__.__name__.lower() for _,sc in self.lr_scheduler._schedulers.items()]:
                self.cyclic = True
            else: self.cyclic = False
140
        elif global_lr_scheduler is not None:
141
            if isinstance(global_optimizer, type): self.lr_scheduler = global_lr_scheduler()
142
            self.lr_scheduler = global_lr_scheduler(self.optimizer)
143 144
            if 'cycl' in self.lr_scheduler.__class__.__name__.lower(): self.cyclic = True
            else: self.cyclic = False
145
        else:
146
            self.lr_scheduler = None
147

148 149 150 151 152
        if transforms is not None:
            self.transforms = MultipleTransforms(transforms)()
        else:
            self.transforms = None

153
        self.history = History()
154 155
        self.callbacks = [self.history]
        if callbacks is not None:
156
            self.callbacks += callbacks
157 158 159

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
160
            self.callbacks += [MetricCallback(self.metric)]
161 162
        else:
            self.metric = None
163

164 165
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
166

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    def _activation_fn(self, inp:Tensor) -> Tensor:
        if self.method == 'regression':
            return inp
        if self.method == 'logistic':
            return torch.sigmoid(inp)
        if self.method == 'multiclass':
            return F.softmax(inp, dim=1)

    def _loss_fn(self, y_pred:Tensor, y_true:Tensor) -> Tensor:
        if self.focal_loss:
            return self.FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
        if self.method == 'regression':
            return F.mse_loss(y_pred, y_true.view(-1, 1))
        if self.method == 'logistic':
            return F.binary_cross_entropy(y_pred, y_true.view(-1, 1), weight=self.class_weight)
        if self.method == 'multiclass':
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

185
    def _training_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
186

187
        X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
188 189 190 191
        y = target.float() if self.method != 'multiclass' else target
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
192 193
        y_pred =  self._activation_fn(self.forward(X))
        loss = self._loss_fn(y_pred, y)
194 195 196
        loss.backward()
        self.optimizer.step()

197 198
        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss/(batch_idx+1)
199

200 201
        if self.metric is not None:
            acc = self.metric(y_pred, y)
202
            return acc, avg_loss
203
        else:
204
            return None, avg_loss
205

206
    def _validation_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
207 208

        with torch.no_grad():
209
            X = {k:v.cuda() for k,v in data.item()} if use_cuda else data
210 211 212
            y = target.float() if self.method != 'multiclass' else target
            y = y.cuda() if use_cuda else y

213 214
            y_pred = self._activation_fn(self.forward(X))
            loss = self._loss_fn(y_pred, y)
215 216 217 218 219
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss/(batch_idx+1)

        if self.metric is not None:
            acc = self.metric(y_pred, y)
220
            return acc, avg_loss
221
        else:
222 223
            return None, avg_loss

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
    def _lr_scheduler_step(self, step_location:str):

        if self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler' and self.cyclic:
            if step_location == 'on_batch_end':
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():
                    if 'cycl' in scheduler_name.lower(): scheduler.step()
            elif step_location == 'on_epoch_end':
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():
                    if 'cycl' not in scheduler_name.lower(): scheduler.step()
        elif self.cyclic:
            if step_location == 'on_batch_end': self.lr_scheduler.step()
            else: pass
        elif self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler':
            if step_location == 'on_epoch_end': self.lr_scheduler.step()
            else: pass
        elif step_location == 'on_epoch_end': self.lr_scheduler.step()
        else: pass

    def _train_val_split(self, X_wide:Optional[np.ndarray]=None,
        X_deep:Optional[np.ndarray]=None, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None,
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
        val_split:Optional[float]=None, target:Optional[np.ndarray]=None,
        seed:int=1):
249 250 251

        if X_val is None and val_split is None:
            if X_train is not None:
252 253 254 255 256 257 258 259
                X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
                if 'X_text' in X_train.keys(): X_text = X_train['X_text']
                if 'X_img' in X_train.keys(): X_img = X_train['X_img']
            X_train={'X_wide': X_wide, 'X_deep': X_deep, 'target': target}
            try: X_train.update({'X_text': X_text})
            except: pass
            try: X_train.update({'X_img': X_img})
            except: pass
260 261 262 263 264 265 266 267 268 269
            train_set = WideDeepLoader(**X_train, transforms=self.transforms)
            eval_set = None
        else:
            if X_val is not None:
                if X_train is None:
                    X_train = {'X_wide':X_wide, 'X_deep': X_deep, 'target': target}
                    if X_text is not None: X_train.update({'X_text': X_text})
                    if X_img is not None:  X_train.update({'X_img': X_img})
            else:
                if X_train is not None:
270 271 272 273 274 275
                    X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
                    if 'X_text' in X_train.keys():
                        X_text = X_train['X_text']
                    if 'X_img' in X_train.keys():
                        X_img = X_train['X_img']
                X_tr_wide, X_val_wide, X_tr_deep, X_val_deep, y_tr, y_val = train_test_split(X_wide, X_deep, target, test_size=val_split, random_state=seed)
276 277
                X_train = {'X_wide':X_tr_wide, 'X_deep': X_tr_deep, 'target': y_tr}
                X_val = {'X_wide':X_val_wide, 'X_deep': X_val_deep, 'target': y_val}
278
                try:
279 280
                    X_tr_text, X_val_text = train_test_split(X_text, test_size=val_split, random_state=seed)
                    X_train.update({'X_text': X_tr_text}), X_val.update({'X_text': X_val_text})
281 282
                except: pass
                try:
283
                    X_tr_img, X_val_img = train_test_split(X_img, test_size=val_split, random_state=seed)
284 285
                    X_train.update({'X_img': X_tr_img}), X_val.update({'X_img': X_val_img})
                except: pass
286 287 288 289
            train_set = WideDeepLoader(**X_train, transforms=self.transforms)
            eval_set = WideDeepLoader(**X_val, transforms=self.transforms)
        return train_set, eval_set

290
    def fit(self, X_wide:Optional[np.ndarray]=None, X_deep:Optional[np.ndarray]=None,
291
        X_text:Optional[np.ndarray]=None, X_img:Optional[np.ndarray]=None,
292 293 294 295 296 297 298 299 300 301 302
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
        val_split:Optional[float]=None, target:Optional[np.ndarray]=None,
        n_epochs:int=1, batch_size:int=32, patience:int=10, seed:int=1,
        verbose:int=1):

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
                "training data is missing. Either a dictionary (X_train) with "
                "the training data or at least 3 arrays (X_wide, X_deep, "
                "target) must be passed to the fit method")
303 304

        self.batch_size = batch_size
305 306
        train_set, eval_set = self._train_val_split(X_wide, X_deep, X_text, X_img,
            X_train, X_val, val_split, target, seed)
307
        train_loader = DataLoader(dataset=train_set, batch_size=batch_size, num_workers=8)
308 309 310
        train_steps =  (len(train_loader.dataset) // batch_size) + 1
        self.callback_container.on_train_begin({'batch_size': batch_size,
            'train_steps': train_steps, 'n_epochs': n_epochs})
311

312
        for epoch in range(n_epochs):
313
            # train step...
314
            epoch_logs = {}
315
            self.callback_container.on_epoch_begin(epoch+1, epoch_logs)
316 317 318
            self.train_running_loss = 0.
            with trange(train_steps, disable=verbose != 1) as t:
                for batch_idx, (data,target) in zip(t, train_loader):
319
                    t.set_description('epoch %i' % (epoch+1))
320
                    acc, train_loss = self._training_step(data, target, batch_idx)
321 322
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
323
                    else:
324
                        t.set_postfix(loss=np.sqrt(train_loss))
325
                    if self.lr_scheduler: self._lr_scheduler_step(step_location="on_batch_end")
326 327
            epoch_logs['train_loss'] = train_loss
            if acc is not None: epoch_logs['train_acc'] = acc
328

329 330 331 332 333
            # eval step...
            if eval_set is not None:
                eval_loader = DataLoader(dataset=eval_set, batch_size=batch_size, num_workers=8,
                    shuffle=False)
                eval_steps =  (len(eval_loader.dataset) // batch_size) + 1
334 335
                self.valid_running_loss = 0.
                with trange(eval_steps, disable=verbose != 1) as v:
336 337
                    for i, (data,target) in zip(v, eval_loader):
                        v.set_description('valid')
338
                        acc, val_loss = self._validation_step(data, target, i)
339 340
                        if acc is not None:
                            v.set_postfix(metrics=acc, loss=val_loss)
341
                        else:
342 343 344
                            v.set_postfix(loss=np.sqrt(val_loss))
                epoch_logs['val_loss'] = val_loss
                if acc is not None: epoch_logs['val_acc'] = acc
345

346
            self.callback_container.on_epoch_end(epoch+1, epoch_logs)
347 348
            if self.early_stop:
                break
349
            if self.lr_scheduler: self._lr_scheduler_step(step_location="on_epoch_end")
350

351 352 353 354 355 356
    def predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:

        if X_test is not None:
            test_set = WideDeepLoader(**X_test)
        else:
357
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
358 359
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})
360
            test_set = WideDeepLoader(**load_dict)
361 362 363 364

        test_loader = torch.utils.data.DataLoader(dataset=test_set,
            batch_size=self.batch_size,shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1
365

366 367 368
        preds_l = []
        with torch.no_grad():
            with trange(test_steps) as t:
369
                for i, data in zip(t, test_loader):
370
                    t.set_description('predict')
371
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
372 373
                    preds = self._activation_fn(self.forward(X).cpu().data.numpy())
                    preds_l.append(preds)
374 375 376 377 378 379 380 381 382
            if self.method == "regression":
                return np.vstack(preds_l).squeeze(1)
            if self.method == "logistic":
                preds = np.vstack(preds_l).squeeze(1)
                return (preds > 0.5).astype('int')
            if self.method == "multiclass":
                preds = np.vstack(preds_l)
                return np.argmax(preds, 1)

383 384 385 386 387 388
    def predict_proba(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:

        if X_test is not None:
            test_set = WideDeepLoader(**X_test)
        else:
389
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
390 391 392 393 394 395 396
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})

        test_set = WideDeepLoader(**load_dict)
        test_loader = torch.utils.data.DataLoader(dataset=test_set,
            batch_size=self.batch_size,shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1
397

398 399 400
        preds_l = []
        with torch.no_grad():
            with trange(test_steps) as t:
401
                for i, data in zip(t, test_loader):
402
                    t.set_description('predict')
403
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
404 405
                    preds = self._activation_fn(self.forward(X).cpu().data.numpy())
                    preds_l.append(preds)
406 407 408 409 410 411 412 413 414
            if self.method == "logistic":
                preds = np.vstack(preds_l).squeeze(1)
                probs = np.zeros([preds.shape[0],2])
                probs[:,0] = 1-preds
                probs[:,1] = preds
                return probs
            if self.method == "multiclass":
                return np.vstack(preds_l)

415 416 417
    def get_embeddings(self, col_name:str,
        cat_embed_encoding_dict:Dict[str,Dict[str,int]]) -> Dict[str,np.ndarray]:

418 419 420 421
        params = list(self.named_parameters())
        emb_layers = [p for p in params if 'emb_layer' in p[0]]
        emb_layer  = [layer for layer in emb_layers if col_name in layer[0]][0]
        embeddings = emb_layer[1].cpu().data.numpy()
422
        col_label_encoding = cat_embed_encoding_dict[col_name]
423 424 425 426 427
        inv_dict = {v:k for k,v in col_label_encoding.items()}
        embeddings_dict = {}
        for idx,value in inv_dict.items():
            embeddings_dict[value] = embeddings[idx]
        return embeddings_dict