From cf437a89920e66d9193b5386c8223d8089d4f6c2 Mon Sep 17 00:00:00 2001 From: zhuoyuan Date: Tue, 7 Feb 2017 16:11:05 -0800 Subject: [PATCH] modified cnn of mnist light --- demo/mnist/light_mnist.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/demo/mnist/light_mnist.py b/demo/mnist/light_mnist.py index 466cc2e2f5..4e70159981 100644 --- a/demo/mnist/light_mnist.py +++ b/demo/mnist/light_mnist.py @@ -18,10 +18,7 @@ if not is_predict: # learning_rate=0.1 / 128.0, # learning_method=MomentumOptimizer(0.9), # regularization=L2Regularization(0.0005 * 128)) -settings( - batch_size=50, - learning_rate=0.001, - learning_method=AdamOptimizer()) +settings(batch_size=50, learning_rate=0.001, learning_method=AdamOptimizer()) #######################Network Configuration ############# @@ -32,9 +29,15 @@ img = data_layer(name='pixel', size=data_size) # small_vgg is predined in trainer_config_helpers.network # predict = small_vgg(input_image=img, num_channels=1, num_classes=label_size) + # light cnn def light_cnn(input_image, num_channels, num_classes): - def __light__(ipt, num_filter=128, times=1, conv_filter_size=3, dropouts=0, num_channels_=None): + def __light__(ipt, + num_filter=128, + times=1, + conv_filter_size=3, + dropouts=0, + num_channels_=None): return img_conv_group( input=ipt, num_channels=num_channels_, @@ -53,9 +56,10 @@ def light_cnn(input_image, num_channels, num_classes): tmp = __light__(tmp, num_filter=128) tmp = __light__(tmp, num_filter=128, conv_filter_size=1) - tmp = fc_layer(input=tmp, size = num_classes, act=SoftmaxActivation()) + tmp = fc_layer(input=tmp, size=num_classes, act=SoftmaxActivation()) return tmp + predict = light_cnn(input_image=img, num_channels=1, num_classes=label_size) if not is_predict: -- GitLab