提交 5f6c4af3 编写于 作者: Y Yu Yang

Try to read data in mnist

上级 ad93b8f9
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
import paddle.trainer.PyDataProvider2 as dp
import paddle.trainer.config_parser
import numpy as np
from mnist_util import read_from_mnist
def init_parameter(network):
......@@ -13,6 +16,22 @@ def init_parameter(network):
array[i] = np.random.uniform(-1.0, 1.0)
def generator_to_batch(generator, batch_size):
ret_val = list()
for each_item in generator:
ret_val.append(each_item)
if len(ret_val) == batch_size:
yield ret_val
ret_val = list()
if len(ret_val) != 0:
yield ret_val
def input_order_converter(generator):
for each_item in generator:
yield each_item['pixel'], each_item['label']
def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
config = paddle.trainer.config_parser.parse_config(
......@@ -30,10 +49,20 @@ def main():
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater)
updater.init(m)
converter = DataProviderConverter(
input_types=[dp.dense_vector(784), dp.integer_value(10)])
train_file = './data/raw_data/train'
m.start()
for _ in xrange(100):
updater.startPass()
train_data_generator = input_order_converter(
read_from_mnist(train_file))
for data_batch in generator_to_batch(train_data_generator, 128):
inArgs = converter(data_batch)
updater.finishPass()
......
from paddle.trainer.PyDataProvider2 import *
import numpy
from mnist_util import read_from_mnist
# Define a py data provider
......@@ -8,27 +8,5 @@ import numpy
'label': integer_value(10)},
cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, filename): # settings is not used currently.
imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte"
f = open(imgf, "rb")
l = open(labelf, "rb")
f.read(16)
l.read(8)
# Define number of samples for train/test
if "train" in filename:
n = 60000
else:
n = 10000
images = numpy.fromfile(
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
for i in xrange(n):
yield {"pixel": images[i, :], 'label': labels[i]}
f.close()
l.close()
for each in read_from_mnist(filename):
yield each
import numpy
__all__ = ['read_from_mnist']
def read_from_mnist(filename):
imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte"
f = open(imgf, "rb")
l = open(labelf, "rb")
f.read(16)
l.read(8)
# Define number of samples for train/test
if "train" in filename:
n = 60000
else:
n = 10000
images = numpy.fromfile(
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
for i in xrange(n):
yield {"pixel": images[i, :], 'label': labels[i]}
f.close()
l.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册