提交 502de688 编写于 作者: J Javier

adding function to use different optimizers for wide and deep parts

上级 912e565b
......@@ -55,6 +55,12 @@ if __name__ == '__main__':
dropout,
encoding_dict,
n_class)
# if multiple compilers for wide and deep side:
# optimizer={'wide': ['name', lr, momentum], 'deep': ['name', lr, momentum]}
# for example:
# optimizer={'wide': ['SGD', 0.001, 0.1], 'deep': ['Adam', 0.001]}
# and
# model.compile(method=method, optimizer=optimizer)
model.compile(method=method)
if use_cuda:
model = model.cuda()
......
......@@ -8,14 +8,35 @@ import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
use_cuda = torch.cuda.is_available()
class MultipleOptimizer(object):
"""Helper to use multiple optimizers as one.
Parameters:
----------
opts: List
List with the names of the optimizers to use
"""
def __init__(self, opts):
self.optimizers = opts
def zero_grad(self):
for op in self.optimizers:
op.zero_grad()
def step(self):
for op in self.optimizers:
op.step()
class WideDeepLoader(Dataset):
"""Helper to facilitate loading the data to the pytorch models.
Parameters:
--------
----------
data: namedtuple with 3 elements - (wide_input_data, deep_inp_data, target)
"""
def __init__(self, data):
......@@ -45,16 +66,23 @@ class WideDeep(nn.Module):
Parameters:
--------
wide_dim (int) : dim of the wide-side input tensor
embeddings_input (tuple): 3-elements tuple with the embeddings "set-up" -
(col_name, unique_values, embeddings dim)
continuous_cols (list) : list with the name of the continuum columns
deep_column_idx (dict) : dictionary where the keys are column names and the values
their corresponding index in the deep-side input tensor
hidden_layers (list) : list with the number of units per hidden layer
encoding_dict (dict) : dictionary with the label-encode mapping
n_class (int) : number of classes. Defaults to 1 if logistic or regression
dropout (float)
wide_dim: Int
dim of the wide-side input tensor
embeddings_input: Tuple.
3-elements tuple with the embeddings "set-up" - (col_name,
unique_values, embeddings dim)
continuous_cols: List.
list with the name of the continuum columns
deep_column_idx: Dict
dictionary where the keys are column names and the values their
corresponding index in the deep-side input tensor
hidden_layers: List
list with the number of units per hidden layer
encoding_dict: Dict
dictionary with the label-encode mapping
n_class: Int
number of classes. Defaults to 1 if logistic or regression
dropout: Float
"""
def __init__(self,
......@@ -95,29 +123,77 @@ class WideDeep(nn.Module):
self.output = nn.Linear(self.hidden_layers[-1]+self.wide_dim, self.n_class)
def compile(self, method="logistic", optimizer="Adam", learning_rate=0.001, momentum=0.0):
"""Wrapper to set the activation, loss and the optimizer.
Parameters:
----------
method (str) : regression, logistic or multiclass
optimizer (str): SGD, Adam, or RMSprop
@staticmethod
def set_optimizer(model_params, optimizer, learning_rate, momentum=0.0):
"""
if method == 'regression':
self.activation, self.criterion = None, F.mse_loss
if method == 'logistic':
self.activation, self.criterion = torch.sigmoid, F.binary_cross_entropy
if method == 'multiclass':
self.activation, self.criterion = F.softmax, F.cross_entropy
Simple helper so we can set the optimizers with a string, which will
be convenient later. Add more parameters if you need.
"""
if optimizer == "Adam":
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
return torch.optim.Adam(model_params, lr=learning_rate)
if optimizer == "Adagrad":
return torch.optim.Adam(model_params, lr=learning_rate)
if optimizer == "RMSprop":
self.optimizer = torch.optim.RMSprop(self.parameters(), lr=learning_rate)
return torch.optim.RMSprop(model_params, lr=learning_rate, momentum=momentum)
if optimizer == "SGD":
self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, momentum=momentum)
return torch.optim.SGD(model_params, lr=learning_rate, momentum=momentum)
@staticmethod
def set_method(method):
"""
Simple helper so we can set the method with a string, which will
be convenient later.
"""
if method =='regression':
return None, F.mse_loss
if method =='logistic':
return torch.sigmoid, F.binary_cross_entropy
if method=='multiclass':
return F.softmax, F.cross_entropy
def compile(self, method="logistic", optimizer="Adam", learning_rate=0.001, momentum=0.0):
"""Wrapper to set the activation, loss and the optimizer.
Parameters:
----------
method: String
'regression', 'logistic' or 'multiclass'
optimizer: String or Dict
if string one of the following: 'SGD', 'Adam', or 'RMSprop'
if Dictionary must contain two elements for the wide and deep
parts respectively with keys 'wide' and 'deep'. E.g.
optimizer = {'wide: ['SGD', 0.001, 0.3]', 'deep':['Adam', 0.001]}
"""
self.method = method
self.activation, self.criterion = self.set_method(method)
if type(optimizer) is dict:
params = list(self.parameters())
# last two sets of parameters are the weights and bias of the last
# linear layer
last_linear_weights = params[-2]
# by construction, if the weights from wide_dim "in advance"
# correspond to the weight side and will use one optimizer
wide_params = [nn.Parameter(last_linear_weights[:, -self.wide_dim:])]
# The weights from the deep side and the bias will use the other
# optimizer
deep_weights = last_linear_weights[:, :-self.wide_dim]
deep_params = params[:-2] + [nn.Parameter(deep_weights)] + [params[-1]]
# Very inelegant, but will do for now
if len(optimizer['wide'])>2:
wide_opt = self.set_optimizer(wide_params, optimizer['wide'][0], optimizer['wide'][1], optimizer['wide'][2])
else:
wide_opt = self.set_optimizer(wide_params, optimizer['wide'][0], optimizer['wide'][1])
if len(optimizer['deep'])>2:
deep_opt = self.set_optimizer(deep_params, optimizer['deep'][0], optimizer['deep'][1], optimizer['deep'][2])
else:
deep_opt = self.set_optimizer(deep_params, optimizer['deep'][0], optimizer['deep'][1])
self.optimizer = MultipleOptimizer([wide_opt, deep_opt])
elif type(optimizer) is str:
self.optimizer = self.set_optimizer(self.parameters(), optimizer, learning_rate, momentum)
def forward(self, X_w, X_d):
......@@ -192,7 +268,7 @@ class WideDeep(nn.Module):
X_w, X_d, y = X_w.cuda(), X_d.cuda(), y.cuda()
self.optimizer.zero_grad()
y_pred = net(X_w, X_d) # [batch_size, 1]
y_pred = net(X_w, X_d)
loss = None
if(self.criterion == F.cross_entropy):
loss = self.criterion(y_pred, y) #[batch_size, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册