提交 cf437a89 编写于 作者: Z zhuoyuan

modified cnn of mnist light

上级 bc8ba1de
...@@ -18,10 +18,7 @@ if not is_predict: ...@@ -18,10 +18,7 @@ if not is_predict:
# learning_rate=0.1 / 128.0, # learning_rate=0.1 / 128.0,
# learning_method=MomentumOptimizer(0.9), # learning_method=MomentumOptimizer(0.9),
# regularization=L2Regularization(0.0005 * 128)) # regularization=L2Regularization(0.0005 * 128))
settings( settings(batch_size=50, learning_rate=0.001, learning_method=AdamOptimizer())
batch_size=50,
learning_rate=0.001,
learning_method=AdamOptimizer())
#######################Network Configuration ############# #######################Network Configuration #############
...@@ -32,9 +29,15 @@ img = data_layer(name='pixel', size=data_size) ...@@ -32,9 +29,15 @@ img = data_layer(name='pixel', size=data_size)
# small_vgg is predined in trainer_config_helpers.network # small_vgg is predined in trainer_config_helpers.network
# predict = small_vgg(input_image=img, num_channels=1, num_classes=label_size) # predict = small_vgg(input_image=img, num_channels=1, num_classes=label_size)
# light cnn # light cnn
def light_cnn(input_image, num_channels, num_classes): def light_cnn(input_image, num_channels, num_classes):
def __light__(ipt, num_filter=128, times=1, conv_filter_size=3, dropouts=0, num_channels_=None): def __light__(ipt,
num_filter=128,
times=1,
conv_filter_size=3,
dropouts=0,
num_channels_=None):
return img_conv_group( return img_conv_group(
input=ipt, input=ipt,
num_channels=num_channels_, num_channels=num_channels_,
...@@ -53,9 +56,10 @@ def light_cnn(input_image, num_channels, num_classes): ...@@ -53,9 +56,10 @@ def light_cnn(input_image, num_channels, num_classes):
tmp = __light__(tmp, num_filter=128) tmp = __light__(tmp, num_filter=128)
tmp = __light__(tmp, num_filter=128, conv_filter_size=1) tmp = __light__(tmp, num_filter=128, conv_filter_size=1)
tmp = fc_layer(input=tmp, size = num_classes, act=SoftmaxActivation()) tmp = fc_layer(input=tmp, size=num_classes, act=SoftmaxActivation())
return tmp return tmp
predict = light_cnn(input_image=img, num_channels=1, num_classes=label_size) predict = light_cnn(input_image=img, num_channels=1, num_classes=label_size)
if not is_predict: if not is_predict:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册