提交 1700cbdd 编写于 作者: Q qijun

pass pre-commit

上级 be45facf
...@@ -44,13 +44,15 @@ def convolutional_neural_network(img): ...@@ -44,13 +44,15 @@ def convolutional_neural_network(img):
input=conv_pool_2, size=10, act=paddle.activation.Softmax()) input=conv_pool_2, size=10, act=paddle.activation.Softmax())
return predict return predict
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
# define network topology # define network topology
images = paddle.layer.data( images = paddle.layer.data(
name='pixel', type=paddle.data_type.dense_vector(784)) name='pixel', type=paddle.data_type.dense_vector(784))
label = paddle.layer.data(name='label', type=paddle.data_type.integer_value(10)) label = paddle.layer.data(
name='label', type=paddle.data_type.integer_value(10))
# Here we can build the prediction network in different ways. Please # Here we can build the prediction network in different ways. Please
# choose one by uncomment corresponding line. # choose one by uncomment corresponding line.
...@@ -72,7 +74,6 @@ def main(): ...@@ -72,7 +74,6 @@ def main():
lists = [] lists = []
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
...@@ -81,13 +82,11 @@ def main(): ...@@ -81,13 +82,11 @@ def main():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.batch( result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
print "Test with Pass %d, Cost %f, %s\n" % (event.pass_id, print "Test with Pass %d, Cost %f, %s\n" % (
result.cost, event.pass_id, result.cost, result.metrics)
result.metrics)
lists.append((event.pass_id, result.cost, lists.append((event.pass_id, result.cost,
result.metrics['classification_error_evaluator'])) result.metrics['classification_error_evaluator']))
trainer.train( trainer.train(
reader=paddle.batch( reader=paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192), paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192),
...@@ -100,5 +99,6 @@ def main(): ...@@ -100,5 +99,6 @@ def main():
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1]) print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100) print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -36,6 +36,7 @@ def get_usr_combined_features(): ...@@ -36,6 +36,7 @@ def get_usr_combined_features():
act=paddle.activation.Tanh()) act=paddle.activation.Tanh())
return usr_combined_features return usr_combined_features
def get_mov_combined_features(): def get_mov_combined_features():
movie_title_dict = paddle.dataset.movielens.get_movie_title_dict() movie_title_dict = paddle.dataset.movielens.get_movie_title_dict()
mov_id = paddle.layer.data( mov_id = paddle.layer.data(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册