From 208ca38a204748108d088bc1b6336e2d965dc71d Mon Sep 17 00:00:00 2001 From: wwhu Date: Tue, 13 Jun 2017 19:34:00 +0800 Subject: [PATCH] fix bug for resnet_cifar10 and adjust learning rate --- image_classification/resnet.py | 6 +++--- image_classification/train.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/image_classification/resnet.py b/image_classification/resnet.py index 9c3c46d8..eeed7141 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -85,9 +85,9 @@ def resnet_cifar10(input, depth=32, class_dim=10): nStages = {16, 64, 128} conv1 = conv_bn_layer( input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) - res1 = layer_warp(basicblock, conv1, 16, n, 1) - res2 = layer_warp(basicblock, res1, 32, n, 2) - res3 = layer_warp(basicblock, res2, 64, n, 2) + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) pool = paddle.layer.img_pool( input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) out = paddle.layer.fc( diff --git a/image_classification/train.py b/image_classification/train.py index 0a3fdb49..b3de4134 100755 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -31,6 +31,7 @@ def main(): name="label", type=paddle.data_type.integer_value(CLASS_DIM)) extra_layers = None + learning_rate = 0.01 if args.model == 'alexnet': out = alexnet.alexnet(image, class_dim=CLASS_DIM) elif args.model == 'vgg13': @@ -41,6 +42,7 @@ def main(): out = vgg.vgg19(image, class_dim=CLASS_DIM) elif args.model == 'resnet': out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) + learning_rate = 0.1 elif args.model == 'googlenet': out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM) loss1 = paddle.layer.cross_entropy_cost( @@ -61,7 +63,7 @@ def main(): momentum=0.9, regularization=paddle.optimizer.L2Regularization(rate=0.0005 * BATCH_SIZE), - learning_rate=0.001 / BATCH_SIZE, + learning_rate=learning_rate / BATCH_SIZE, learning_rate_decay_a=0.1, learning_rate_decay_b=128000 * 35, learning_rate_schedule="discexp", ) -- GitLab