wide_deep.py 19.4 KB
Newer Older
1
import numpy as np
2
import warnings
3
import torch
4 5 6
import torch.nn as nn
import torch.nn.functional as F

7 8 9 10 11 12 13 14 15 16 17
from ..wdtypes import *

from ..initializers import Initializer, MultipleInitializer
from ..callbacks import Callback, History, CallbackContainer
from ..metrics import Metric, MultipleMetrics, MetricCallback
from ..losses import FocalLoss

from ._wd_dataset import WideDeepDataset
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_lr_scheduler import MultipleLRScheduler
from ._multiple_transforms import MultipleTransforms
18
from .deep_dense import dense_layer
19

20 21
from tqdm import tqdm,trange
from sklearn.model_selection import train_test_split
22
from torch.utils.data import DataLoader
23

24 25
import pdb

26

27 28 29 30
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
31

32
    def __init__(self,
33 34
        wide:nn.Module,
        deepdense:nn.Module,
35
        output_dim:int=1,
36
        deeptext:Optional[nn.Module]=None,
37 38 39 40 41
        deepimage:Optional[nn.Module]=None,
        deephead:Optional[nn.Module]=None,
        head_layers:Optional[List]=None,
        head_dropout:Optional[List]=None,
        head_batchnorm:Optional[bool]=None):
42

43
        super(WideDeep, self).__init__()
44 45

        # The main 5 components of the wide and deep assemble
46 47 48 49
        self.wide = wide
        self.deepdense = deepdense
        self.deeptext  = deeptext
        self.deepimage = deepimage
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        self.deephead = deephead

        if self.deephead is None:
            if head_layers is not None:
                input_dim = self.deepdense.output_dim + self.deeptext.output_dim + self.deepimage.output_dim
                head_layers = [input_dim] + head_layers
                if not head_dropout: head_dropout = [0.] * (len(head_layers)-1)
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
                        'head_layer_{}'.format(i-1),
                        dense_layer( head_layers[i-1], head_layers[i], head_dropout[i-1], head_batchnorm))
                self.deephead.add_module('head_out', nn.Linear(head_layers[-1], output_dim))
            else:
                self.deepdense = nn.Sequential(
                    self.deepdense,
                    nn.Linear(self.deepdense.output_dim, output_dim))
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
                        self.deeptext,
                        nn.Linear(self.deeptext.output_dim, output_dim))
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
                        self.deepimage,
                        nn.Linear(self.deepimage.output_dim, output_dim))
75

76
    def forward(self, X:List[Dict[str,Tensor]])->Tensor:
77
        # Wide output: direct connection to the output neuron(s)
78
        out = self.wide(X['wide'])
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
            deepside = self.deepdense(X['deepdense'])
            if self.deeptext is not None:
                deepside = torch.cat( [deepside, self.deeptext(X['deeptext'])], axis=1 )
            if self.deepimage is not None:
                deepside = torch.cat( [deepside, self.deepimage(X['deepimage'])], axis=1 )
            deepside_out = self.head(deepside)
            return out.add_(deepside_out)
        else:
            out.add_(self.deepdense(X['deepdense']))
            if self.deeptext is not None:
                out.add_(self.deeptext(X['deeptext']))
            if self.deepimage is not None:
                out.add_(self.deepimage(X['deepimage']))
            return out

    def compile(self,
        method:str,
100 101
        initializers:Optional[Dict[str,Initializer]]=None,
        optimizers:Optional[Dict[str,Optimizer]]=None,
102
        global_optimizer:Optional[Optimizer]=None,
103
        lr_schedulers:Optional[Dict[str,LRScheduler]]=None,
104
        global_lr_scheduler:Optional[LRScheduler]=None,
105 106 107 108
        transforms:Optional[List[Transforms]]=None,
        callbacks:Optional[List[Callback]]=None,
        metrics:Optional[List[Metric]]=None,
        class_weight:Optional[Union[float,List[float],Tuple[float]]]=None,
109 110 111
        with_focal_loss:bool=False,
        alpha:float=0.25,
        gamma:float=1,
112
        verbose=1):
113

114
        self.verbose = verbose
115
        self.early_stop = False
116
        self.method = method
117 118
        self.with_focal_loss = with_focal_loss
        if self.with_focal_loss:
119
            self.alpha, self.gamma = alpha, gamma
120

121 122 123 124 125 126
        if isinstance(class_weight, float):
            self.class_weight = torch.tensor([class_weight, 1.-class_weight])
        elif isinstance(class_weight, (List, Tuple)):
            self.class_weight =  torch.tensor(class_weight)
        else:
            self.class_weight = None
127 128

        if initializers is not None:
129
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
130 131 132
            self.initializer.apply(self)

        if optimizers is not None:
133
            self.optimizer = MultipleOptimizer(optimizers)
134
        elif global_optimizer is not None:
135
            self.optimizer = global_optimizer
136 137 138 139 140
        else:
            self.optimizer = torch.optim.Adam(self.parameters())

        if lr_schedulers is not None:
            self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
141 142
            scheduler_names = [sc.__class__.__name__.lower() for _,sc in self.lr_scheduler._schedulers.items()]
            self.cyclic = any(['cycl' in sn for sn in scheduler_names])
143
        elif global_lr_scheduler is not None:
144
            self.lr_scheduler = global_lr_scheduler
145
            self.cyclic = 'cycl' in self.lr_scheduler.__class__.__name__.lower()
146
        else:
147
            self.lr_scheduler, self.cyclic = None, False
148

149 150 151 152 153
        if transforms is not None:
            self.transforms = MultipleTransforms(transforms)()
        else:
            self.transforms = None

154
        self.history = History()
155 156
        self.callbacks = [self.history]
        if callbacks is not None:
157 158 159
            for callback in callbacks:
                if isinstance(callback, type): callback = callback()
                self.callbacks.append(callback)
160 161 162

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
163
            self.callbacks += [MetricCallback(self.metric)]
164 165
        else:
            self.metric = None
166

167 168
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
169

170 171 172 173 174 175 176 177 178
    def _activation_fn(self, inp:Tensor) -> Tensor:
        if self.method == 'regression':
            return inp
        if self.method == 'logistic':
            return torch.sigmoid(inp)
        if self.method == 'multiclass':
            return F.softmax(inp, dim=1)

    def _loss_fn(self, y_pred:Tensor, y_true:Tensor) -> Tensor:
179 180
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
181 182 183 184 185 186 187
        if self.method == 'regression':
            return F.mse_loss(y_pred, y_true.view(-1, 1))
        if self.method == 'logistic':
            return F.binary_cross_entropy(y_pred, y_true.view(-1, 1), weight=self.class_weight)
        if self.method == 'multiclass':
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

188
    def _training_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
189

190
        X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
191 192 193 194
        y = target.float() if self.method != 'multiclass' else target
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
195 196
        y_pred =  self._activation_fn(self.forward(X))
        loss = self._loss_fn(y_pred, y)
197 198 199
        loss.backward()
        self.optimizer.step()

200 201
        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss/(batch_idx+1)
202

203 204
        if self.metric is not None:
            acc = self.metric(y_pred, y)
205
            return acc, avg_loss
206
        else:
207
            return None, avg_loss
208

209
    def _validation_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
210 211

        with torch.no_grad():
212
            X = {k:v.cuda() for k,v in data.item()} if use_cuda else data
213 214 215
            y = target.float() if self.method != 'multiclass' else target
            y = y.cuda() if use_cuda else y

216 217
            y_pred = self._activation_fn(self.forward(X))
            loss = self._loss_fn(y_pred, y)
218 219 220 221 222
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss/(batch_idx+1)

        if self.metric is not None:
            acc = self.metric(y_pred, y)
223
            return acc, avg_loss
224
        else:
225 226
            return None, avg_loss

227 228 229 230
    def _lr_scheduler_step(self, step_location:str):

        if self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler' and self.cyclic:
            if step_location == 'on_batch_end':
231 232
                for model_name, scheduler in self.lr_scheduler._schedulers.items():
                    if 'cycl' in scheduler.__class__.__name__.lower(): scheduler.step()
233 234
            elif step_location == 'on_epoch_end':
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():
235
                    if 'cycl' not in scheduler.__class__.__name__.lower(): scheduler.step()
236 237 238 239 240 241 242 243 244
        elif self.cyclic:
            if step_location == 'on_batch_end': self.lr_scheduler.step()
            else: pass
        elif self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler':
            if step_location == 'on_epoch_end': self.lr_scheduler.step()
            else: pass
        elif step_location == 'on_epoch_end': self.lr_scheduler.step()
        else: pass

245 246 247 248
    def _train_val_split(self,
        X_wide:Optional[np.ndarray]=None,
        X_deep:Optional[np.ndarray]=None,
        X_text:Optional[np.ndarray]=None,
249 250 251
        X_img:Optional[np.ndarray]=None,
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
252 253
        val_split:Optional[float]=None,
        target:Optional[np.ndarray]=None,
254
        seed:int=1):
255

256
        # No evaluation set
257 258
        if X_val is None and val_split is None:
            if X_train is not None:
259 260 261 262 263 264 265 266
                X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
                if 'X_text' in X_train.keys(): X_text = X_train['X_text']
                if 'X_img' in X_train.keys(): X_img = X_train['X_img']
            X_train={'X_wide': X_wide, 'X_deep': X_deep, 'target': target}
            try: X_train.update({'X_text': X_text})
            except: pass
            try: X_train.update({'X_img': X_img})
            except: pass
267
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)
268 269
            eval_set = None
        else:
270
            # evaluation set will be used. Either X_val or val_split are not None
271 272 273 274 275 276 277
            if X_val is not None:
                if X_train is None:
                    X_train = {'X_wide':X_wide, 'X_deep': X_deep, 'target': target}
                    if X_text is not None: X_train.update({'X_text': X_text})
                    if X_img is not None:  X_train.update({'X_img': X_img})
            else:
                if X_train is not None:
278
                    X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
279 280 281 282
                    if 'X_text' in X_train.keys(): X_text = X_train['X_text']
                    if 'X_img' in X_train.keys(): X_img = X_train['X_img']
                X_tr_wide, X_val_wide, X_tr_deep, X_val_deep, y_tr, y_val = train_test_split(X_wide,
                    X_deep, target, test_size=val_split, random_state=seed)
283 284
                X_train = {'X_wide':X_tr_wide, 'X_deep': X_tr_deep, 'target': y_tr}
                X_val = {'X_wide':X_val_wide, 'X_deep': X_val_deep, 'target': y_val}
285
                try:
286 287
                    X_tr_text, X_val_text = train_test_split(X_text, test_size=val_split,
                        random_state=seed)
288
                    X_train.update({'X_text': X_tr_text}), X_val.update({'X_text': X_val_text})
289 290
                except: pass
                try:
291 292
                    X_tr_img, X_val_img = train_test_split(X_img, test_size=val_split,
                        random_state=seed)
293 294
                    X_train.update({'X_img': X_tr_img}), X_val.update({'X_img': X_val_img})
                except: pass
295
            # Train and validation dictionaries have been built
296 297
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)
298 299
        return train_set, eval_set

300 301 302 303 304
    def fit(self,
        X_wide:Optional[np.ndarray]=None,
        X_deep:Optional[np.ndarray]=None,
        X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None,
305 306
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
307 308 309 310 311 312
        val_split:Optional[float]=None,
        target:Optional[np.ndarray]=None,
        n_epochs:int=1,
        batch_size:int=32,
        patience:int=10,
        seed:int=1):
313 314 315 316 317 318

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
                "training data is missing. Either a dictionary (X_train) with "
                "the training data or at least 3 arrays (X_wide, X_deep, "
                "target) must be passed to the fit method")
319 320

        self.batch_size = batch_size
321 322
        train_set, eval_set = self._train_val_split(X_wide, X_deep, X_text, X_img,
            X_train, X_val, val_split, target, seed)
323
        train_loader = DataLoader(dataset=train_set, batch_size=batch_size, num_workers=8)
324 325 326
        train_steps =  (len(train_loader.dataset) // batch_size) + 1
        self.callback_container.on_train_begin({'batch_size': batch_size,
            'train_steps': train_steps, 'n_epochs': n_epochs})
327

328
        for epoch in range(n_epochs):
329
            # train step...
330 331
            epoch_logs={}
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
332
            self.train_running_loss = 0.
333
            with trange(train_steps, disable=self.verbose != 1) as t:
334
                for batch_idx, (data,target) in zip(t, train_loader):
335
                    t.set_description('epoch %i' % (epoch+1))
336
                    acc, train_loss = self._training_step(data, target, batch_idx)
337 338
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
339
                    else:
340
                        t.set_postfix(loss=np.sqrt(train_loss))
341 342
                    if self.lr_scheduler: self._lr_scheduler_step(step_location='on_batch_end')
                    self.callback_container.on_batch_end(batch=batch_idx)
343
            epoch_logs['train_loss'] = train_loss
344
            if acc is not None: epoch_logs['train_acc'] = acc['acc']
345 346 347 348 349
            # eval step...
            if eval_set is not None:
                eval_loader = DataLoader(dataset=eval_set, batch_size=batch_size, num_workers=8,
                    shuffle=False)
                eval_steps =  (len(eval_loader.dataset) // batch_size) + 1
350
                self.valid_running_loss = 0.
351
                with trange(eval_steps, disable=self.verbose != 1) as v:
352 353
                    for i, (data,target) in zip(v, eval_loader):
                        v.set_description('valid')
354
                        acc, val_loss = self._validation_step(data, target, i)
355 356
                        if acc is not None:
                            v.set_postfix(metrics=acc, loss=val_loss)
357
                        else:
358 359
                            v.set_postfix(loss=np.sqrt(val_loss))
                epoch_logs['val_loss'] = val_loss
360 361 362 363
                if acc is not None: epoch_logs['val_acc'] = acc['acc']
            if self.lr_scheduler: self._lr_scheduler_step(step_location='on_epoch_end')
            # log and check if early_stop...
            self.callback_container.on_epoch_end(epoch, epoch_logs)
364
            if self.early_stop:
365
                self.callback_container.on_train_end(epoch)
366
                break
367
            self.callback_container.on_train_end(epoch)
368

369 370 371 372
    def predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:

        if X_test is not None:
373
            test_set = WideDeepDataset(**X_test)
374
        else:
375
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
376 377
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})
378
            test_set = WideDeepDataset(**load_dict)
379 380 381 382

        test_loader = torch.utils.data.DataLoader(dataset=test_set,
            batch_size=self.batch_size,shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1
383

384 385
        preds_l = []
        with torch.no_grad():
386
            with trange(test_steps, disable=self.verbose != 1) as t:
387
                for i, data in zip(t, test_loader):
388
                    t.set_description('predict')
389
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
390
                    preds = self._activation_fn(self.forward(X)).cpu().data.numpy()
391
                    preds_l.append(preds)
392 393 394 395 396 397 398 399 400
            if self.method == "regression":
                return np.vstack(preds_l).squeeze(1)
            if self.method == "logistic":
                preds = np.vstack(preds_l).squeeze(1)
                return (preds > 0.5).astype('int')
            if self.method == "multiclass":
                preds = np.vstack(preds_l)
                return np.argmax(preds, 1)

401 402 403 404
    def predict_proba(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:

        if X_test is not None:
405
            test_set = WideDeepDataset(**X_test)
406
        else:
407
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
408 409
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})
410
            test_set = WideDeepDataset(**load_dict)
411 412 413 414

        test_loader = torch.utils.data.DataLoader(dataset=test_set,
            batch_size=self.batch_size,shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1
415

416 417
        preds_l = []
        with torch.no_grad():
418
            with trange(test_steps, disable=self.verbose != 1) as t:
419
                for i, data in zip(t, test_loader):
420
                    t.set_description('predict')
421
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
422
                    preds = self._activation_fn(self.forward(X)).cpu().data.numpy()
423
                    preds_l.append(preds)
424 425 426 427 428 429 430 431 432
            if self.method == "logistic":
                preds = np.vstack(preds_l).squeeze(1)
                probs = np.zeros([preds.shape[0],2])
                probs[:,0] = 1-preds
                probs[:,1] = preds
                return probs
            if self.method == "multiclass":
                return np.vstack(preds_l)

433
    def get_embeddings(self, col_name:str,
434 435 436 437 438 439 440 441 442 443
        cat_encoding_dict:Dict[str,Dict[str,int]]) -> Dict[str,np.ndarray]:
        for n,p in self.named_parameters():
            if 'embed_layers' in n and col_name in n:
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
        inv_encoding_dict = {v:k for k,v in encoding_dict.items()}
        cat_embed_dict = {}
        for idx,value in inv_encoding_dict.items():
            cat_embed_dict[value] = embed_mtx[idx]
        return cat_embed_dict