From e42c3854fa6c6683d52c19d65ca2f5e3833996c4 Mon Sep 17 00:00:00 2001 From: zhuoyuan Date: Wed, 8 Feb 2017 15:49:19 -0800 Subject: [PATCH] follow helin's comments --- demo/mnist/light_mnist.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/demo/mnist/light_mnist.py b/demo/mnist/light_mnist.py index 4e70159981..d796a3cc06 100644 --- a/demo/mnist/light_mnist.py +++ b/demo/mnist/light_mnist.py @@ -1,3 +1,17 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle.trainer_config_helpers import * is_predict = get_config_arg("is_predict", bool, False) @@ -13,11 +27,6 @@ if not is_predict: obj='process') ######################Algorithm Configuration ############# -# settings( -# batch_size=128, -# learning_rate=0.1 / 128.0, -# learning_method=MomentumOptimizer(0.9), -# regularization=L2Regularization(0.0005 * 128)) settings(batch_size=50, learning_rate=0.001, learning_method=AdamOptimizer()) #######################Network Configuration ############# @@ -26,11 +35,10 @@ data_size = 1 * 28 * 28 label_size = 10 img = data_layer(name='pixel', size=data_size) -# small_vgg is predined in trainer_config_helpers.network -# predict = small_vgg(input_image=img, num_channels=1, num_classes=label_size) - - # light cnn +# A shallower cnn model: [CNN, BN, ReLU, Max-Pooling] x4 + FC x1 +# Easier to train for mnist dataset and quite efficient +# Final performance is close to deeper ones on tasks such as digital and character classification def light_cnn(input_image, num_channels, num_classes): def __light__(ipt, num_filter=128, -- GitLab