提交 e0a34dae 编写于 作者: H huangxinjing

Adjust the dense structure in the wide&deep multi-table

上级 3259dafa
......@@ -32,6 +32,7 @@ def argparse_init():
parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr.
parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient.
parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient.
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file.
......@@ -83,7 +84,6 @@ class WideDeepConfig():
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.l2_coef = args.l2_coef
self.ftrl_lr = args.ftrl_lr
self.adam_lr = args.adam_lr
......@@ -93,3 +93,4 @@ class WideDeepConfig():
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path
self.dropout_flag = bool(args.dropout_flag)
......@@ -89,9 +89,11 @@ class DenseLayer(nn.Cell):
output_dim,
weight_bias_init,
act_str,
keep_prob=0.7,
keep_prob=0.8,
scale_coef=1.0,
convert_dtype=True):
use_activation=True,
convert_dtype=True,
drop_out=False):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(weight_init, [input_dim, output_dim],
......@@ -101,11 +103,13 @@ class DenseLayer(nn.Cell):
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.dropout = Dropout(keep_prob=0.8)
self.dropout = Dropout(keep_prob=keep_prob)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
self.use_activation = use_activation
self.convert_dtype = convert_dtype
self.drop_out = drop_out
def _init_activation(self, act_str):
act_str = act_str.lower()
......@@ -118,23 +122,26 @@ class DenseLayer(nn.Cell):
return act_func
def construct(self, x):
"""
DenseLayer construct
"""
x = self.act_func(x)
if self.training:
'''
Construct Dense layer
'''
if self.training and self.drop_out:
x = self.dropout(x)
x = self.mul(x, self.scale_coef)
if self.convert_dtype:
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
bias = self.cast(self.bias, mstype.float16)
wx = self.matmul(x, weight)
wx = self.bias_add(wx, bias)
if self.use_activation:
wx = self.act_func(wx)
wx = self.cast(wx, mstype.float32)
else:
wx = self.matmul(x, self.weight)
wx = self.realDiv(wx, self.scale_coef)
output = self.bias_add(wx, self.bias)
return output
wx = self.bias_add(wx, self.bias)
if self.use_activation:
wx = self.act_func(wx)
return wx
class WideDeepModel(nn.Cell):
......@@ -211,33 +218,40 @@ class WideDeepModel(nn.Cell):
self.all_dim_list[1],
self.weight_bias_init,
self.deep_layer_act,
drop_out=config.dropout_flag,
convert_dtype=True)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
self.all_dim_list[2],
self.weight_bias_init,
self.deep_layer_act,
drop_out=config.dropout_flag,
convert_dtype=True)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
self.all_dim_list[3],
self.weight_bias_init,
self.deep_layer_act,
drop_out=config.dropout_flag,
convert_dtype=True)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
self.all_dim_list[4],
self.weight_bias_init,
self.deep_layer_act,
drop_out=config.dropout_flag,
convert_dtype=True)
self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
self.all_dim_list[5],
self.weight_bias_init,
self.deep_layer_act,
drop_out=config.dropout_flag,
convert_dtype=True)
self.deep_predict = DenseLayer(self.all_dim_list[5],
self.all_dim_list[6],
self.weight_bias_init,
self.deep_layer_act,
convert_dtype=True)
drop_out=config.dropout_flag,
convert_dtype=True,
use_activation=False)
self.gather_v2 = P.GatherV2()
self.mul = P.Mul()
......
......@@ -96,9 +96,10 @@ def train_and_eval(config):
keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
callback, ckpoint_cb])
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if int(get_rank()) == 0:
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train, callbacks=callback_list)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册