提交 95f142b1 编写于 作者: X Xin Pan

resolve conflict

test=develop
上级 abdd9411
...@@ -59,12 +59,13 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -59,12 +59,13 @@ class TestFetchAndFeed(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
pe = fluid.ParallelExecutor( train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
use_cuda=use_cuda, loss_name=loss.name, main_program=main_program) loss_name=loss.name)
run_parallel_exe(main_program, pe, use_cuda, data, label, loss)
def run_parallel_exe_with_fetch(self, main, pe, use_cuda, data, label, run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
loss):
def run_parallel_exe_with_fetch(self, compiled_program, exe, use_cuda, data,
label, loss):
def get_data(batch_size=8): def get_data(batch_size=8):
np.random.seed(5) np.random.seed(5)
while True: while True:
...@@ -79,7 +80,7 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -79,7 +80,7 @@ class TestFetchAndFeed(unittest.TestCase):
# conv2d_1.b_0@GRAD. Those variables should not be pruned. # conv2d_1.b_0@GRAD. Those variables should not be pruned.
# fluid.memory_optimize(main) # fluid.memory_optimize(main)
fetch_list = [] fetch_list = []
all_vars = main.global_block().vars all_vars = compiled_program._program.global_block().vars
for k, v in all_vars.items(): for k, v in all_vars.items():
if ('tmp' not in k) and ( if ('tmp' not in k) and (
...@@ -90,14 +91,18 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -90,14 +91,18 @@ class TestFetchAndFeed(unittest.TestCase):
for batch_id, img_label in enumerate(get_data()): for batch_id, img_label in enumerate(get_data()):
img, l = img_label img, l = img_label
train_inputs = {data.name: img, label.name: l} train_inputs = {data.name: img, label.name: l}
ret = pe.run(fetch_list, feed=train_inputs, return_numpy=True) ret = exe.run(compiled_program,
fetch_list=fetch_list,
feed=train_inputs,
return_numpy=True)
for i in range(len(fetch_list)): for i in range(len(fetch_list)):
assert not math.isnan(np.sum(ret[i])) and \ assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i])) not math.isinf(np.sum(ret[i]))
if batch_id == 2: if batch_id == 2:
break break
def run_parallel_exe_with_feed(self, main, pe, use_cuda, data, label, loss): def run_parallel_exe_with_feed(self, compiled_program, exe, use_cuda, data,
label, loss):
def get_data(batch_size=8): def get_data(batch_size=8):
np.random.seed(5) np.random.seed(5)
while True: while True:
...@@ -115,7 +120,9 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -115,7 +120,9 @@ class TestFetchAndFeed(unittest.TestCase):
reader = feeder.decorate_reader(get_data, multi_devices=True) reader = feeder.decorate_reader(get_data, multi_devices=True)
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(reader()):
loss_np = exe.run(train_cp, feed=data, fetch_list=[loss.name])[0] loss_np = exe.run(compiled_program,
feed=data,
fetch_list=[loss.name])[0]
print(batch_id, loss_np) print(batch_id, loss_np)
if batch_id == 2: if batch_id == 2:
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册