提交 65aff91e 编写于 作者: Y Yu Yang

Stash

上级 00de8db1
...@@ -70,27 +70,35 @@ def main(): ...@@ -70,27 +70,35 @@ def main():
parameters=parameters, parameters=parameters,
update_equation=paddle.optimizer.Adam( update_equation=paddle.optimizer.Adam(
learning_rate=1e-4)) learning_rate=1e-4))
reader_dict = {
'user_id': 0,
'gender_id': 1,
'age_id': 2,
'job_id': 3,
'movie_id': 4,
'category_id': 5,
'movie_title': 6,
'score': 7
}
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print "Pass %d Batch %d Cost %.2f" % ( print "Pass %d Batch %d Cost %.2f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, event.cost)
elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.movielens.test(), batch_size=256))
print result.cost
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.dataset.movielens.train(), batch_size=256), paddle.reader.shuffle(
paddle.dataset.movielens.train(), buf_size=8192),
batch_size=256),
event_handler=event_handler, event_handler=event_handler,
reader_dict={ reader_dict=reader_dict,
'user_id': 0, num_passes=10)
'gender_id': 1,
'age_id': 2,
'job_id': 3,
'movie_id': 4,
'category_id': 5,
'movie_title': 6,
'score': 7
})
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -123,9 +123,8 @@ class SGD(ITrainer): ...@@ -123,9 +123,8 @@ class SGD(ITrainer):
for each_param in self.__gradient_machine__.getParameters(): for each_param in self.__gradient_machine__.getParameters():
updater.update(each_param) updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch. # Get cost. We use numpy to calculate total cost for this batch.
cost_vec = out_args.getSlotValue(0) cost_sum = out_args.sumCosts()
cost_vec = cost_vec.copyToNumpyMat() cost = cost_sum / len(data_batch)
cost = cost_vec.sum() / len(data_batch)
updater.finishBatch(cost) updater.finishBatch(cost)
batch_evaluator.finish() batch_evaluator.finish()
event_handler( event_handler(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册