提交 9f50195b 编写于 作者: T typhoonzero

update using cifar10

上级 311d159e
FROM paddlepaddle/paddlecloud-job FROM paddlepaddle/paddlecloud-job
RUN mkdir -p /workspace && mkdir -p /root/.cache/paddle/dataset/flowers/ RUN mkdir -p /workspace
ADD vgg16.py reader.py /workspace/ ADD reader.py /workspace/
COPY 102flowers.tgz imagelabels.mat setid.mat /root/.cache/paddle/dataset/flowers/ RUN python /workspace/reader.py
ADD vgg16.py /workspace/
...@@ -67,4 +67,4 @@ if __name__ == '__main__': ...@@ -67,4 +67,4 @@ if __name__ == '__main__':
# print len(im[0]) # print len(im[0])
#for im in train_reader('test.list'): #for im in train_reader('test.list'):
# print len(im[0]) # print len(im[0])
paddle.dataset.flowers.train() paddle.dataset.cifar.train10()
...@@ -14,13 +14,15 @@ ...@@ -14,13 +14,15 @@
import gzip import gzip
import paddle.v2.dataset.flowers as flowers import paddle.v2.dataset.cifar as cifar
import paddle.v2 as paddle import paddle.v2 as paddle
import reader import reader
import time
DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2. DATA_DIM = 3 * 32 * 32
CLASS_DIM = 102 CLASS_DIM = 10
BATCH_SIZE = 128 BATCH_SIZE = 128
ts = 0
def vgg(input, nums, class_dim): def vgg(input, nums, class_dim):
...@@ -74,6 +76,7 @@ def vgg19(input, class_dim): ...@@ -74,6 +76,7 @@ def vgg19(input, class_dim):
def main(): def main():
global ts
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
image = paddle.layer.data( image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(DATA_DIM)) name="image", type=paddle.data_type.dense_vector(DATA_DIM))
...@@ -100,13 +103,13 @@ def main(): ...@@ -100,13 +103,13 @@ def main():
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
flowers.train(), cifar.train10(),
# To use other data, replace the above line with: # To use other data, replace the above line with:
# reader.train_reader('train.list'), # reader.train_reader('train.list'),
buf_size=1000), buf_size=1000),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(
flowers.valid(), cifar.test10(),
# To use other data, replace the above line with: # To use other data, replace the above line with:
# reader.test_reader('val.list'), # reader.test_reader('val.list'),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
...@@ -120,10 +123,14 @@ def main(): ...@@ -120,10 +123,14 @@ def main():
# End batch and end pass event handler # End batch and end pass event handler
def event_handler(event): def event_handler(event):
global ts
if isinstance(event, paddle.event.BeginIteration):
ts = time.time()
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1 == 0: if event.batch_id % 1 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % ( print "\nPass %d, Batch %d, Cost %f, %s, spent: %f" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics,
time.time() - ts)
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f: with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f) trainer.save_parameter_to_tar(f)
...@@ -137,3 +144,4 @@ def main(): ...@@ -137,3 +144,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册