“c7d3273d635a867be34326b4f2f09b0d786ece54”上不存在“python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py”
提交 33754671 编写于 作者: Y yaoxuefeng6

change example to 2.0 apis

上级 7c953b34
...@@ -458,11 +458,16 @@ class InMemoryDataset(DatasetBase): ...@@ -458,11 +458,16 @@ class InMemoryDataset(DatasetBase):
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"]) ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.load_into_memory() dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( paddle.enable_static()
) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program()) place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda() else paddle.CPUPlace()
exe.train_from_dataset(fluid.default_main_program(), exe = paddle.static.Executor(place)
dataset) startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt") os.remove("./test_queue_dataset_run_b.txt")
""" """
...@@ -789,9 +794,11 @@ class InMemoryDataset(DatasetBase): ...@@ -789,9 +794,11 @@ class InMemoryDataset(DatasetBase):
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.load_into_memory() dataset.load_into_memory()
dataset.global_shuffle(fleet) dataset.global_shuffle(fleet)
exe = fluid.Executor(fluid.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(fluid.default_startup_program()) startup_program = paddle.static.Program()
exe.train_from_dataset(fluid.default_main_program(), dataset) main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
dataset.release_memory() dataset.release_memory()
""" """
...@@ -924,7 +931,7 @@ class InMemoryDataset(DatasetBase): ...@@ -924,7 +931,7 @@ class InMemoryDataset(DatasetBase):
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
""" """
:api_attr: Static Graph :api_attr: Static Graph
QueueDataset, it will process data streamly. QueueDataset, it will process data streamly.
Examples: Examples:
......
...@@ -185,20 +185,24 @@ class TestDataset(unittest.TestCase): ...@@ -185,20 +185,24 @@ class TestDataset(unittest.TestCase):
use_var=slots_vars) use_var=slots_vars)
dataset.set_filelist([filename1, filename2]) dataset.set_filelist([filename1, filename2])
dataset.load_into_memory() dataset.load_into_memory()
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(startup_program)
if self.use_data_loader: if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset, data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(), fluid.cpu_places(),
self.drop_last) self.drop_last)
for i in range(self.epoch_num): for i in range(self.epoch_num):
for data in data_loader(): for data in data_loader():
exe.run(fluid.default_main_program(), feed=data) exe.run(main_program, feed=data)
else: else:
for i in range(self.epoch_num): for i in range(self.epoch_num):
try: try:
exe.train_from_dataset(fluid.default_main_program(), exe.train_from_dataset(fluid.main_program, dataset)
dataset)
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册