提交 71c4933a 编写于 作者: Y yuyang18

Use independent recordio file name

上级 7ebb2469
...@@ -20,6 +20,9 @@ import paddle ...@@ -20,6 +20,9 @@ import paddle
import paddle.dataset.mnist as mnist import paddle.dataset.mnist as mnist
import paddle.dataset.wmt16 as wmt16 import paddle.dataset.wmt16 as wmt16
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
WMT16_RECORDIO_FILE = "./wmt16_test_pe.recordio"
def simple_fc_net(use_feed): def simple_fc_net(use_feed):
if use_feed: if use_feed:
...@@ -27,7 +30,7 @@ def simple_fc_net(use_feed): ...@@ -27,7 +30,7 @@ def simple_fc_net(use_feed):
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else: else:
reader = fluid.layers.open_files( reader = fluid.layers.open_files(
filenames=['./mnist.recordio'], filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'],
...@@ -55,7 +58,7 @@ def fc_with_batchnorm(use_feed): ...@@ -55,7 +58,7 @@ def fc_with_batchnorm(use_feed):
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else: else:
reader = fluid.layers.open_files( reader = fluid.layers.open_files(
filenames=['mnist.recordio'], filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'],
...@@ -287,7 +290,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -287,7 +290,7 @@ class TestMNIST(TestParallelExecutorBase):
], ],
place=fluid.CPUPlace()) place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file( fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder) MNIST_RECORDIO_FILE, reader, feeder)
def check_simple_fc_convergence(self, balance_parameter_opt_between_cards): def check_simple_fc_convergence(self, balance_parameter_opt_between_cards):
self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net)
...@@ -536,7 +539,7 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -536,7 +539,7 @@ class TestTransformer(TestParallelExecutorBase):
batch_size=transformer_model.batch_size) batch_size=transformer_model.batch_size)
with fluid.recordio_writer.create_recordio_writer( with fluid.recordio_writer.create_recordio_writer(
"./wmt16.recordio") as writer: WMT16_RECORDIO_FILE) as writer:
for batch in reader(): for batch in reader():
for tensor in prepare_batch_input( for tensor in prepare_batch_input(
batch, ModelHyperParams.src_pad_idx, batch, ModelHyperParams.src_pad_idx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册