From e26f220df81546d360e7759b3d96a2aa27d06ffc Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Tue, 11 Oct 2016 21:29:52 -0700 Subject: [PATCH] Mnist demo (#162) * added mnist demo * modified .gitignore for .project files * normalize pixel in mnist_provider.py and set use_gpu=0 --- .gitignore | 4 ++- demo/mnist/.gitignore | 6 ++++ demo/mnist/data/generate_list.py | 21 +++++++++++++ demo/mnist/data/get_mnist_data.sh | 22 +++++++++++++ demo/mnist/mnist_provider.py | 33 ++++++++++++++++++++ demo/mnist/train.sh | 31 ++++++++++++++++++ demo/mnist/vgg_16_mnist.py | 52 +++++++++++++++++++++++++++++++ 7 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 demo/mnist/.gitignore create mode 100644 demo/mnist/data/generate_list.py create mode 100644 demo/mnist/data/get_mnist_data.sh create mode 100644 demo/mnist/mnist_provider.py create mode 100755 demo/mnist/train.sh create mode 100644 demo/mnist/vgg_16_mnist.py diff --git a/.gitignore b/.gitignore index 7e21ba0b750..65ba217de37 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ build/ *.user .vscode -.idea \ No newline at end of file +.idea +.project +.pydevproject diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore new file mode 100644 index 00000000000..810910fd5ca --- /dev/null +++ b/demo/mnist/.gitignore @@ -0,0 +1,6 @@ +data/raw_data +data/*.list +mnist_vgg_model +plot.png +train.log +*pyc diff --git a/demo/mnist/data/generate_list.py b/demo/mnist/data/generate_list.py new file mode 100644 index 00000000000..1b929048b4d --- /dev/null +++ b/demo/mnist/data/generate_list.py @@ -0,0 +1,21 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +o = open("./" + "train.list", "w") +o.write("./data/raw_data/train" +"\n") +o.close() + +o = open("./" + "test.list", "w") +o.write("./data/raw_data/t10k" +"\n") +o.close() \ No newline at end of file diff --git a/demo/mnist/data/get_mnist_data.sh b/demo/mnist/data/get_mnist_data.sh new file mode 100644 index 00000000000..c3ef9944504 --- /dev/null +++ b/demo/mnist/data/get_mnist_data.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env sh +# This scripts downloads the mnist data and unzips it. + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +rm -rf "$DIR/raw_data" +mkdir "$DIR/raw_data" +cd "$DIR/raw_data" + +echo "Downloading..." + +for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte +do + if [ ! -e $fname ]; then + wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz + gunzip ${fname}.gz + fi +done + +cd $DIR +rm -f *.list +python generate_list.py + diff --git a/demo/mnist/mnist_provider.py b/demo/mnist/mnist_provider.py new file mode 100644 index 00000000000..0f14ded2dce --- /dev/null +++ b/demo/mnist/mnist_provider.py @@ -0,0 +1,33 @@ +from paddle.trainer.PyDataProvider2 import * + + +# Define a py data provider +@provider(input_types=[ + dense_vector(28 * 28), + integer_value(10) +]) +def process(settings, filename): # settings is not used currently. + imgf = filename + "-images-idx3-ubyte" + labelf = filename + "-labels-idx1-ubyte" + f = open(imgf, "rb") + l = open(labelf, "rb") + + f.read(16) + l.read(8) + + # Define number of samples for train/test + if "train" in filename: + n = 60000 + else: + n = 10000 + + for i in range(n): + label = ord(l.read(1)) + pixels = [] + for j in range(28*28): + pixels.append(float(ord(f.read(1))) / 255.0) + yield { "pixel": pixels, 'label': label } + + f.close() + l.close() + \ No newline at end of file diff --git a/demo/mnist/train.sh b/demo/mnist/train.sh new file mode 100755 index 00000000000..084b32ac390 --- /dev/null +++ b/demo/mnist/train.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +config=vgg_16_mnist.py +output=./mnist_vgg_model +log=train.log + +paddle train \ +--config=$config \ +--dot_period=10 \ +--log_period=100 \ +--test_all_data_in_one_period=1 \ +--use_gpu=0 \ +--trainer_count=1 \ +--num_passes=100 \ +--save_dir=$output \ +2>&1 | tee $log + +python -m paddle.utils.plotcurve -i $log > plot.png diff --git a/demo/mnist/vgg_16_mnist.py b/demo/mnist/vgg_16_mnist.py new file mode 100644 index 00000000000..ad0a4de3215 --- /dev/null +++ b/demo/mnist/vgg_16_mnist.py @@ -0,0 +1,52 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +is_predict = get_config_arg("is_predict", bool, False) + +####################Data Configuration ################## + + +if not is_predict: + data_dir='./data/' + define_py_data_sources2(train_list= data_dir + 'train.list', + test_list= data_dir + 'test.list', + module='mnist_provider', + obj='process') + +######################Algorithm Configuration ############# +settings( + batch_size = 128, + learning_rate = 0.1 / 128.0, + learning_method = MomentumOptimizer(0.9), + regularization = L2Regularization(0.0005 * 128) +) + +#######################Network Configuration ############# + +data_size=1*28*28 +label_size=10 +img = data_layer(name='pixel', size=data_size) + +# small_vgg is predined in trainer_config_helpers.network +predict = small_vgg(input_image=img, + num_channels=1, + num_classes=label_size) + +if not is_predict: + lbl = data_layer(name="label", size=label_size) + outputs(classification_cost(input=predict, label=lbl)) +else: + outputs(predict) -- GitLab