diff --git a/demo/image_classification/api_v2_train.py b/demo/image_classification/api_v2_train.py index 94bf0b5db48c5046ab02fea4994ce2aae09c641e..585f61c6fa4c89c8621815a168742429ac236898 100644 --- a/demo/image_classification/api_v2_train.py +++ b/demo/image_classification/api_v2_train.py @@ -18,24 +18,6 @@ from api_v2_vgg import vgg_bn_drop from api_v2_resnet import resnet_cifar10 -# End batch and end pass event handler -def event_handler(event): - if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 100 == 0: - print "\nPass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) - else: - sys.stdout.write('.') - sys.stdout.flush() - if isinstance(event, paddle.event.EndPass): - result = trainer.test( - reader=paddle.reader.batched( - paddle.dataset.cifar.test10(), batch_size=128), - reader_dict={'image': 0, - 'label': 1}) - print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) - - def main(): datadim = 3 * 32 * 32 classdim = 10 @@ -73,6 +55,23 @@ def main(): learning_rate_schedule='discexp', batch_size=128) + # End batch and end pass event handler + def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 100 == 0: + print "\nPass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) + else: + sys.stdout.write('.') + sys.stdout.flush() + if isinstance(event, paddle.event.EndPass): + result = trainer.test( + reader=paddle.reader.batched( + paddle.dataset.cifar.test10(), batch_size=128), + reader_dict={'image': 0, + 'label': 1}) + print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) + # Create trainer trainer = paddle.trainer.SGD(cost=cost, parameters=parameters,