提交 d44bcd93 编写于 作者: J jrzaurin

fixed code to adequately perform train/val split internally

上级 6ca5eeab
import numpy as np
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -38,6 +39,9 @@ class WideDeepLoader(Dataset):
self.X_text = X_text
self.X_img = X_img
self.transforms = transforms
if self.transforms:
self.transforms_names = [tr.__class__.__name__ for tr in self.transforms.transforms]
else: self.transforms_names = []
self.Y = target
def __getitem__(self, idx:int):
......@@ -47,10 +51,15 @@ class WideDeepLoader(Dataset):
if self.X_text is not None:
X.deeptext = self.X_text[idx]
if self.X_img is not None:
xdi = (self.X_img[idx]/255).astype('float32')
if self.transforms is not None:
xdi = self.transforms(xdi)
X.deepimg = xdi
xdi = self.X_img[idx]
if 'int' in str(xdi.dtype) and 'uint8' != str(xdi.dtype): xdi = xdi.astype('uint8')
if 'float' in str(xdi.dtype) and 'float32' != str(xdi.dtype): xdi = xdi.astype('float32')
if not self.transforms or 'ToTensor' not in self.transforms_names:
xdi = xdi.transpose(2,0,1)
if 'int' in str(xdi.dtype): xdi = (xdi/xdi.max()).astype('float32')
if 'ToTensor' in self.transforms_names: xdi = self.transforms(xdi)
elif self.transforms: xdi = self.transforms(torch.Tensor(xdi))
X.deepimage = xdi
if self.Y is not None:
y = self.Y[idx]
return X, y
......@@ -70,7 +79,6 @@ class WideDeep(nn.Module):
deepimage:Optional[TorchModel]=None):
super(WideDeep, self).__init__()
self.wide = wide
self.deepdense = deepdense
self.deeptext = deeptext
......@@ -82,27 +90,9 @@ class WideDeep(nn.Module):
if self.deeptext is not None:
wide_deep.add_(self.deeptext(X['deeptext']))
if self.deepimage is not None:
wide_deep.add_(self.deepimg(X['deepimg']))
wide_deep.add_(self.deepimage(X['deepimage']))
return wide_deep
def _activation_fn(self, inp:Tensor) -> Tensor:
if self.method == 'regression':
return inp
if self.method == 'logistic':
return torch.sigmoid(inp)
if self.method == 'multiclass':
return F.softmax(inp, dim=1)
def _loss_fn(self, y_pred:Tensor, y_true:Tensor) -> Tensor:
if self.focal_loss:
return self.FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
if self.method == 'regression':
return F.mse_loss(y_pred, y_true.view(-1, 1))
if self.method == 'logistic':
return F.binary_cross_entropy(y_pred, y_true.view(-1, 1), weight=self.class_weight)
if self.method == 'multiclass':
return F.cross_entropy(y_pred, y_true, weight=self.class_weight)
def compile(self,method:str,
initializers:Optional[Dict[str,Initializer]]=None,
optimizers:Optional[Dict[str,Optimizer]]=None,
......@@ -144,13 +134,16 @@ class WideDeep(nn.Module):
if lr_schedulers is not None:
self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
self.lr_scheduler.apply(self.optimizer._optimizers)
self.lr_scheduler_name = [sc.__class__.__name__.lower() for _,sc in self.lr_scheduler._schedulers.items()]
if 'cycl' in [sc.__class__.__name__.lower() for _,sc in self.lr_scheduler._schedulers.items()]:
self.cyclic = True
else: self.cyclic = False
elif global_lr_scheduler is not None:
if isinstance(global_optimizer, type): self.lr_scheduler = global_lr_scheduler()
self.lr_scheduler = global_lr_scheduler(self.optimizer)
self.lr_scheduler_name = self.lr_scheduler.__class__.__name__.lower()
if 'cycl' in self.lr_scheduler.__class__.__name__.lower(): self.cyclic = True
else: self.cyclic = False
else:
self.lr_scheduler, self.lr_scheduler_name = None, None
self.lr_scheduler = None
if transforms is not None:
self.transforms = MultipleTransforms(transforms)()
......@@ -171,6 +164,24 @@ class WideDeep(nn.Module):
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self)
def _activation_fn(self, inp:Tensor) -> Tensor:
if self.method == 'regression':
return inp
if self.method == 'logistic':
return torch.sigmoid(inp)
if self.method == 'multiclass':
return F.softmax(inp, dim=1)
def _loss_fn(self, y_pred:Tensor, y_true:Tensor) -> Tensor:
if self.focal_loss:
return self.FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
if self.method == 'regression':
return F.mse_loss(y_pred, y_true.view(-1, 1))
if self.method == 'logistic':
return F.binary_cross_entropy(y_pred, y_true.view(-1, 1), weight=self.class_weight)
if self.method == 'multiclass':
return F.cross_entropy(y_pred, y_true, weight=self.class_weight)
def _training_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
......@@ -210,22 +221,42 @@ class WideDeep(nn.Module):
else:
return None, avg_loss
def _train_val_split(self, X_wide:np.ndarray, X_deep:np.ndarray,
target:np.ndarray, X_text:Optional[np.ndarray]=None,
X_img:Optional[np.ndarray]=None, X_train:Optional[Dict[str,
np.ndarray]]=None, X_val:Optional[Dict[str, np.ndarray]]=None,
val_split:float=0., seed:int=1):
def _lr_scheduler_step(self, step_location:str):
if self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler' and self.cyclic:
if step_location == 'on_batch_end':
for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():
if 'cycl' in scheduler_name.lower(): scheduler.step()
elif step_location == 'on_epoch_end':
for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():
if 'cycl' not in scheduler_name.lower(): scheduler.step()
elif self.cyclic:
if step_location == 'on_batch_end': self.lr_scheduler.step()
else: pass
elif self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler':
if step_location == 'on_epoch_end': self.lr_scheduler.step()
else: pass
elif step_location == 'on_epoch_end': self.lr_scheduler.step()
else: pass
def _train_val_split(self, X_wide:Optional[np.ndarray]=None,
X_deep:Optional[np.ndarray]=None, X_text:Optional[np.ndarray]=None,
X_img:Optional[np.ndarray]=None,
X_train:Optional[Dict[str,np.ndarray]]=None,
X_val:Optional[Dict[str,np.ndarray]]=None,
val_split:Optional[float]=None, target:Optional[np.ndarray]=None,
seed:int=1):
if X_val is None and val_split is None:
if X_train is not None:
X_wide, X_deep, target = X_train['wide'], X_train['deepdense'], X_train['target']
if 'deeptext' in X_train.keys(): X_text = X['deeptext']
if 'deepimage' in X_train.keys(): X_img = X['deepimage']
X_train={'wide': X_wide, 'deepdense': X_deep, 'target': target}
if 'X_text' in locals():
X_train.update({'X_text': X_text})
if 'X_img' in locals():
X_train.update({'X_img': X_img})
X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
if 'X_text' in X_train.keys(): X_text = X_train['X_text']
if 'X_img' in X_train.keys(): X_img = X_train['X_img']
X_train={'X_wide': X_wide, 'X_deep': X_deep, 'target': target}
try: X_train.update({'X_text': X_text})
except: pass
try: X_train.update({'X_img': X_img})
except: pass
train_set = WideDeepLoader(**X_train, transforms=self.transforms)
eval_set = None
else:
......@@ -236,34 +267,43 @@ class WideDeep(nn.Module):
if X_img is not None: X_train.update({'X_img': X_img})
else:
if X_train is not None:
X_wide, X_deep, target = X_train['wide'], X_train['deepdense'], X_train['target']
if 'deeptext' in X_train.keys():
X_text = X['deeptext']
if 'deepimage' in X_train.keys():
X_img = X['deepimage']
X_tr_wide, X_val_wide, X_tr_deep, X_val_deep, y_tr, y_val = \
train_test_split(X_wide, X_deep, target, test_size=val_split, random_state=seed)
X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
if 'X_text' in X_train.keys():
X_text = X_train['X_text']
if 'X_img' in X_train.keys():
X_img = X_train['X_img']
X_tr_wide, X_val_wide, X_tr_deep, X_val_deep, y_tr, y_val = train_test_split(X_wide, X_deep, target, test_size=val_split, random_state=seed)
X_train = {'X_wide':X_tr_wide, 'X_deep': X_tr_deep, 'target': y_tr}
X_val = {'X_wide':X_val_wide, 'X_deep': X_val_deep, 'target': y_val}
if 'X_text' is locals():
try:
X_tr_text, X_val_text = train_test_split(X_text, test_size=val_split, random_state=seed)
X_train.update({'X_text': X_tr_text}), X_val.update({'X_text': X_val_text})
if 'X_image' is locals():
except: pass
try:
X_tr_img, X_val_img = train_test_split(X_img, test_size=val_split, random_state=seed)
X_train.update({'X_img': X_tr_imgt}), X_val.update({'X_img': X_val_img})
X_train.update({'X_img': X_tr_img}), X_val.update({'X_img': X_val_img})
except: pass
train_set = WideDeepLoader(**X_train, transforms=self.transforms)
eval_set = WideDeepLoader(**X_val, transforms=self.transforms)
return train_set, eval_set
def fit(self, X_wide:np.ndarray, X_deep:np.ndarray, target:np.ndarray,
def fit(self, X_wide:Optional[np.ndarray]=None, X_deep:Optional[np.ndarray]=None,
X_text:Optional[np.ndarray]=None, X_img:Optional[np.ndarray]=None,
n_epochs:int=1, batch_size:int=32, X_train:Optional[Dict[str,
np.ndarray]]=None, X_val:Optional[Dict[str, np.ndarray]]=None,
val_split:float=0., seed:int=1, patience:int=10, verbose:int=1):
X_train:Optional[Dict[str,np.ndarray]]=None,
X_val:Optional[Dict[str,np.ndarray]]=None,
val_split:Optional[float]=None, target:Optional[np.ndarray]=None,
n_epochs:int=1, batch_size:int=32, patience:int=10, seed:int=1,
verbose:int=1):
if X_train is None and (X_wide is None or X_deep is None or target is None):
raise ValueError(
"training data is missing. Either a dictionary (X_train) with "
"the training data or at least 3 arrays (X_wide, X_deep, "
"target) must be passed to the fit method")
self.batch_size = batch_size
train_set, eval_set = self._train_val_split(X_wide, X_deep, target, X_text,
X_img, X_train, X_val, val_split, seed)
train_set, eval_set = self._train_val_split(X_wide, X_deep, X_text, X_img,
X_train, X_val, val_split, target, seed)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, num_workers=8)
train_steps = (len(train_loader.dataset) // batch_size) + 1
self.callback_container.on_train_begin({'batch_size': batch_size,
......@@ -272,7 +312,7 @@ class WideDeep(nn.Module):
for epoch in range(n_epochs):
# train step...
epoch_logs = {}
self.callback_container.on_epoch_begin(epoch, epoch_logs)
self.callback_container.on_epoch_begin(epoch+1, epoch_logs)
self.train_running_loss = 0.
with trange(train_steps, disable=verbose != 1) as t:
for batch_idx, (data,target) in zip(t, train_loader):
......@@ -282,8 +322,7 @@ class WideDeep(nn.Module):
t.set_postfix(metrics=acc, loss=train_loss)
else:
t.set_postfix(loss=np.sqrt(train_loss))
if self.lr_scheduler and 'cycl' in self.lr_scheduler_name:
self.lr_scheduler.step()
if self.lr_scheduler: self._lr_scheduler_step(step_location="on_batch_end")
epoch_logs['train_loss'] = train_loss
if acc is not None: epoch_logs['train_acc'] = acc
......@@ -304,11 +343,10 @@ class WideDeep(nn.Module):
epoch_logs['val_loss'] = val_loss
if acc is not None: epoch_logs['val_acc'] = acc
self.callback_container.on_epoch_end(epoch, epoch_logs)
self.callback_container.on_epoch_end(epoch+1, epoch_logs)
if self.early_stop:
break
if self.lr_scheduler and 'cycl' not in self.lr_scheduler_name:
self.lr_scheduler.step()
if self.lr_scheduler: self._lr_scheduler_step(step_location="on_epoch_end")
def predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册