From 33754671be9b6fbdfa062ae6dc669a7426e7b1f6 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Wed, 16 Sep 2020 16:52:34 +0800 Subject: [PATCH] change example to 2.0 apis --- .../distributed/fleet/dataset/dataset.py | 25 ++++++++++++------- .../fluid/tests/unittests/test_dataset.py | 12 ++++++--- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/python/paddle/distributed/fleet/dataset/dataset.py b/python/paddle/distributed/fleet/dataset/dataset.py index 80c5107e5b0..ce14909f2ec 100644 --- a/python/paddle/distributed/fleet/dataset/dataset.py +++ b/python/paddle/distributed/fleet/dataset/dataset.py @@ -458,11 +458,16 @@ class InMemoryDataset(DatasetBase): ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"]) dataset.load_into_memory() - exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0)) - exe.run(fluid.default_startup_program()) - exe.train_from_dataset(fluid.default_main_program(), - dataset) + paddle.enable_static() + + place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda() else paddle.CPUPlace() + exe = paddle.static.Executor(place) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + exe.run(startup_program) + + exe.train_from_dataset(main_program, dataset) + os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_b.txt") """ @@ -789,9 +794,11 @@ class InMemoryDataset(DatasetBase): dataset.set_filelist(filelist) dataset.load_into_memory() dataset.global_shuffle(fleet) - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_startup_program()) - exe.train_from_dataset(fluid.default_main_program(), dataset) + exe = paddle.static.Executor(paddle.CPUPlace()) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + exe.run(startup_program) + exe.train_from_dataset(main_program, dataset) dataset.release_memory() """ @@ -924,7 +931,7 @@ class InMemoryDataset(DatasetBase): class QueueDataset(DatasetBase): """ :api_attr: Static Graph - + QueueDataset, it will process data streamly. Examples: diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index f923e2fa933..c17454c69b5 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -185,20 +185,24 @@ class TestDataset(unittest.TestCase): use_var=slots_vars) dataset.set_filelist([filename1, filename2]) dataset.load_into_memory() + paddle.enable_static() + + exe = paddle.static.Executor(paddle.CPUPlace()) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_startup_program()) + exe.run(startup_program) if self.use_data_loader: data_loader = fluid.io.DataLoader.from_dataset(dataset, fluid.cpu_places(), self.drop_last) for i in range(self.epoch_num): for data in data_loader(): - exe.run(fluid.default_main_program(), feed=data) + exe.run(main_program, feed=data) else: for i in range(self.epoch_num): try: - exe.train_from_dataset(fluid.default_main_program(), - dataset) + exe.train_from_dataset(fluid.main_program, dataset) except Exception as e: self.assertTrue(False) -- GitLab