提交 9f50195b 编写于 作者: T typhoonzero

update using cifar10

上级 311d159e
FROM paddlepaddle/paddlecloud-job
RUN mkdir -p /workspace && mkdir -p /root/.cache/paddle/dataset/flowers/
ADD vgg16.py reader.py /workspace/
COPY 102flowers.tgz imagelabels.mat setid.mat /root/.cache/paddle/dataset/flowers/
RUN mkdir -p /workspace
ADD reader.py /workspace/
RUN python /workspace/reader.py
ADD vgg16.py /workspace/
......@@ -67,4 +67,4 @@ if __name__ == '__main__':
# print len(im[0])
#for im in train_reader('test.list'):
# print len(im[0])
paddle.dataset.flowers.train()
paddle.dataset.cifar.train10()
......@@ -14,13 +14,15 @@
import gzip
import paddle.v2.dataset.flowers as flowers
import paddle.v2.dataset.cifar as cifar
import paddle.v2 as paddle
import reader
import time
DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2.
CLASS_DIM = 102
DATA_DIM = 3 * 32 * 32
CLASS_DIM = 10
BATCH_SIZE = 128
ts = 0
def vgg(input, nums, class_dim):
......@@ -74,6 +76,7 @@ def vgg19(input, class_dim):
def main():
global ts
paddle.init(use_gpu=False, trainer_count=1)
image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(DATA_DIM))
......@@ -100,13 +103,13 @@ def main():
train_reader = paddle.batch(
paddle.reader.shuffle(
flowers.train(),
cifar.train10(),
# To use other data, replace the above line with:
# reader.train_reader('train.list'),
buf_size=1000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
flowers.valid(),
cifar.test10(),
# To use other data, replace the above line with:
# reader.test_reader('val.list'),
batch_size=BATCH_SIZE)
......@@ -120,10 +123,14 @@ def main():
# End batch and end pass event handler
def event_handler(event):
global ts
if isinstance(event, paddle.event.BeginIteration):
ts = time.time()
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
print "\nPass %d, Batch %d, Cost %f, %s, spent: %f" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
time.time() - ts)
if isinstance(event, paddle.event.EndPass):
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f)
......@@ -137,3 +144,4 @@ def main():
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册