From 71c4933a021b09f39c6b322a30beb37d1ddf5fb3 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Thu, 17 May 2018 16:23:45 +0800 Subject: [PATCH] Use independent recordio file name --- .../fluid/tests/unittests/test_parallel_executor.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 6dc016487fd..f1525253c8b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -20,6 +20,9 @@ import paddle import paddle.dataset.mnist as mnist import paddle.dataset.wmt16 as wmt16 +MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio" +WMT16_RECORDIO_FILE = "./wmt16_test_pe.recordio" + def simple_fc_net(use_feed): if use_feed: @@ -27,7 +30,7 @@ def simple_fc_net(use_feed): label = fluid.layers.data(name='label', shape=[1], dtype='int64') else: reader = fluid.layers.open_files( - filenames=['./mnist.recordio'], + filenames=[MNIST_RECORDIO_FILE], shapes=[[-1, 784], [-1, 1]], lod_levels=[0, 0], dtypes=['float32', 'int64'], @@ -55,7 +58,7 @@ def fc_with_batchnorm(use_feed): label = fluid.layers.data(name='label', shape=[1], dtype='int64') else: reader = fluid.layers.open_files( - filenames=['mnist.recordio'], + filenames=[MNIST_RECORDIO_FILE], shapes=[[-1, 784], [-1, 1]], lod_levels=[0, 0], dtypes=['float32', 'int64'], @@ -287,7 +290,7 @@ class TestMNIST(TestParallelExecutorBase): ], place=fluid.CPUPlace()) fluid.recordio_writer.convert_reader_to_recordio_file( - './mnist.recordio', reader, feeder) + MNIST_RECORDIO_FILE, reader, feeder) def check_simple_fc_convergence(self, balance_parameter_opt_between_cards): self.check_network_convergence(simple_fc_net) @@ -536,7 +539,7 @@ class TestTransformer(TestParallelExecutorBase): batch_size=transformer_model.batch_size) with fluid.recordio_writer.create_recordio_writer( - "./wmt16.recordio") as writer: + WMT16_RECORDIO_FILE) as writer: for batch in reader(): for tensor in prepare_batch_input( batch, ModelHyperParams.src_pad_idx, -- GitLab