diff --git a/demo/image_classification/api_v2_train.py b/demo/image_classification/api_v2_train.py index 0b4dc4d92982680a06cd95a58d3af0975e25b690..94bf0b5db48c5046ab02fea4994ce2aae09c641e 100644 --- a/demo/image_classification/api_v2_train.py +++ b/demo/image_classification/api_v2_train.py @@ -12,27 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License +import sys +import paddle.v2 as paddle from api_v2_vgg import vgg_bn_drop from api_v2_resnet import resnet_cifar10 -import paddle.v2 as paddle +# End batch and end pass event handler def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f" % (event.pass_id, event.batch_id, - event.cost) + print "\nPass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) + else: + sys.stdout.write('.') + sys.stdout.flush() + if isinstance(event, paddle.event.EndPass): + result = trainer.test( + reader=paddle.reader.batched( + paddle.dataset.cifar.test10(), batch_size=128), + reader_dict={'image': 0, + 'label': 1}) + print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) def main(): datadim = 3 * 32 * 32 classdim = 10 + # PaddlePaddle init paddle.init(use_gpu=True, trainer_count=1) image = paddle.layer.data( name="image", type=paddle.data_type.dense_vector(datadim)) + # Add neural network config # option 1. resnet net = resnet_cifar10(image, depth=32) # option 2. vgg @@ -46,8 +60,10 @@ def main(): name="label", type=paddle.data_type.integer_value(classdim)) cost = paddle.layer.classification_cost(input=out, label=lbl) + # Create parameters parameters = paddle.parameters.create(cost) + # Create optimizer momentum_optimizer = paddle.optimizer.Momentum( momentum=0.9, regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128), @@ -57,6 +73,7 @@ def main(): learning_rate_schedule='discexp', batch_size=128) + # Create trainer trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=momentum_optimizer)