提交 bd9b609a 编写于 作者: S Superjom

code style with yapf

上级 d718d1e6
...@@ -23,7 +23,6 @@ paddle.init(use_gpu=False, trainer_count=11) ...@@ -23,7 +23,6 @@ paddle.init(use_gpu=False, trainer_count=11)
# ============================================================================== # ==============================================================================
# input layers # input layers
# ============================================================================== # ==============================================================================
dnn_merged_input = layer.data( dnn_merged_input = layer.data(
name='dnn_input', name='dnn_input',
type=paddle.data_type.sparse_binary_vector(data_meta_info['dnn_input'])) type=paddle.data_type.sparse_binary_vector(data_meta_info['dnn_input']))
...@@ -34,11 +33,10 @@ lr_merged_input = layer.data( ...@@ -34,11 +33,10 @@ lr_merged_input = layer.data(
click = paddle.layer.data(name='click', type=dtype.dense_vector(1)) click = paddle.layer.data(name='click', type=dtype.dense_vector(1))
# ============================================================================== # ==============================================================================
# network structure # network structure
# ============================================================================== # ==============================================================================
def build_dnn_submodel(dnn_layer_dims): def build_dnn_submodel(dnn_layer_dims):
dnn_embedding = layer.fc(input=dnn_merged_input, size=dnn_layer_dims[0]) dnn_embedding = layer.fc(input=dnn_merged_input, size=dnn_layer_dims[0])
_input_layer = dnn_embedding _input_layer = dnn_embedding
...@@ -93,10 +91,10 @@ dataset = AvazuDataset(train_data_path, n_records_as_test=test_set_size) ...@@ -93,10 +91,10 @@ dataset = AvazuDataset(train_data_path, n_records_as_test=test_set_size)
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
num_samples = event.batch_id * batch_size
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
logging.warning("Pass %d, Samples %d, Cost %f" % logging.warning("Pass %d, Samples %d, Cost %f" %
(event.pass_id, event.batch_id * batch_size, (event.pass_id, num_samples, event.cost))
event.cost))
if event.batch_id % 1000 == 0: if event.batch_id % 1000 == 0:
result = trainer.test( result = trainer.test(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册