From 95f142b18b794625928795ad844ac3fed533f4ad Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 14 Jan 2019 17:50:39 +0800 Subject: [PATCH] resolve conflict test=develop --- .../test_parallel_executor_fetch_feed.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py index 507d652e7..ee0941f19 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py @@ -59,12 +59,13 @@ class TestFetchAndFeed(unittest.TestCase): exe = fluid.Executor(place) exe.run(startup) - pe = fluid.ParallelExecutor( - use_cuda=use_cuda, loss_name=loss.name, main_program=main_program) - run_parallel_exe(main_program, pe, use_cuda, data, label, loss) + train_cp = compiler.CompiledProgram(main_program).with_data_parallel( + loss_name=loss.name) - def run_parallel_exe_with_fetch(self, main, pe, use_cuda, data, label, - loss): + run_parallel_exe(train_cp, exe, use_cuda, data, label, loss) + + def run_parallel_exe_with_fetch(self, compiled_program, exe, use_cuda, data, + label, loss): def get_data(batch_size=8): np.random.seed(5) while True: @@ -79,7 +80,7 @@ class TestFetchAndFeed(unittest.TestCase): # conv2d_1.b_0@GRAD. Those variables should not be pruned. # fluid.memory_optimize(main) fetch_list = [] - all_vars = main.global_block().vars + all_vars = compiled_program._program.global_block().vars for k, v in all_vars.items(): if ('tmp' not in k) and ( @@ -90,14 +91,18 @@ class TestFetchAndFeed(unittest.TestCase): for batch_id, img_label in enumerate(get_data()): img, l = img_label train_inputs = {data.name: img, label.name: l} - ret = pe.run(fetch_list, feed=train_inputs, return_numpy=True) + ret = exe.run(compiled_program, + fetch_list=fetch_list, + feed=train_inputs, + return_numpy=True) for i in range(len(fetch_list)): assert not math.isnan(np.sum(ret[i])) and \ not math.isinf(np.sum(ret[i])) if batch_id == 2: break - def run_parallel_exe_with_feed(self, main, pe, use_cuda, data, label, loss): + def run_parallel_exe_with_feed(self, compiled_program, exe, use_cuda, data, + label, loss): def get_data(batch_size=8): np.random.seed(5) while True: @@ -115,7 +120,9 @@ class TestFetchAndFeed(unittest.TestCase): reader = feeder.decorate_reader(get_data, multi_devices=True) for batch_id, data in enumerate(reader()): - loss_np = exe.run(train_cp, feed=data, fetch_list=[loss.name])[0] + loss_np = exe.run(compiled_program, + feed=data, + fetch_list=[loss.name])[0] print(batch_id, loss_np) if batch_id == 2: break -- GitLab